Skip to main content

ares/rag/
embeddings.rs

1//! Embedding Service for RAG
2//!
3//! This module provides a comprehensive embedding service with support for:
4//! - 30+ text embedding models (BGE, Qwen3, Gemma, E5, Jina, etc.)
5//! - Sparse embeddings for hybrid search (SPLADE, BGE-M3)
6//! - Reranking models (BGE, Jina)
7//! - Async embedding via `spawn_blocking`
8//! - In-memory LRU caching to avoid recomputing embeddings
9//!
10//! # GPU Acceleration (TODO)
11//! GPU acceleration is planned for future iterations. See `docs/FUTURE_ENHANCEMENTS.md`.
12//! Potential approach:
13//! - Add feature flags: `cuda`, `metal`, `vulkan`
14//! - Use ORT execution providers for ONNX models
15//! - Use Candle GPU features for Qwen3 models
16//!
17//! # Embedding Cache
18//! Use `CachedEmbeddingService` to wrap the `EmbeddingService` with an LRU cache.
19//! See [`crate::rag::cache`] for cache configuration options.
20
21use crate::types::{AppError, Result};
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::fmt::Display;
25use std::str::FromStr;
26use std::sync::{Arc, Mutex, OnceLock};
27// Note: Arc is now used both for MODEL_INIT_LOCKS and for wrapping the embedding models
28use tokio::task::spawn_blocking;
29
30// Re-export fastembed types for convenience
31pub use fastembed::{EmbeddingModel as FastEmbedModel, InitOptions, SparseModel, TextEmbedding};
32
33/// Global lock for model initialization to prevent race conditions during parallel downloads.
34/// The key is the model name (from FastEmbedModel's Debug representation).
35static MODEL_INIT_LOCKS: OnceLock<Mutex<HashMap<String, Arc<Mutex<()>>>>> = OnceLock::new();
36
37/// Get or create a lock for a specific model to prevent concurrent initialization.
38fn get_model_lock(model_name: &str) -> Arc<Mutex<()>> {
39    let locks = MODEL_INIT_LOCKS.get_or_init(|| Mutex::new(HashMap::new()));
40    let mut map = locks.lock().unwrap();
41    map.entry(model_name.to_string())
42        .or_insert_with(|| Arc::new(Mutex::new(())))
43        .clone()
44}
45
46// ============================================================================
47// Embedding Model Configuration
48// ============================================================================
49
50/// Supported embedding models with their metadata.
51///
52/// This enum wraps fastembed's EmbeddingModel with additional metadata
53/// for easier configuration and selection.
54#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
55#[serde(rename_all = "kebab-case")]
56pub enum EmbeddingModelType {
57    // Fast English models (recommended defaults)
58    /// BAAI/bge-small-en-v1.5 - Fast, 384 dimensions (DEFAULT)
59    #[default]
60    BgeSmallEnV15,
61    /// Quantized BAAI/bge-small-en-v1.5
62    BgeSmallEnV15Q,
63    /// sentence-transformers/all-MiniLM-L6-v2 - Very fast, 384 dimensions
64    AllMiniLmL6V2,
65    /// Quantized all-MiniLM-L6-v2
66    AllMiniLmL6V2Q,
67    /// sentence-transformers/all-MiniLM-L12-v2 - Better quality, 384 dimensions
68    AllMiniLmL12V2,
69    /// Quantized all-MiniLM-L12-v2
70    AllMiniLmL12V2Q,
71    /// sentence-transformers/all-mpnet-base-v2 - 768 dimensions
72    AllMpnetBaseV2,
73
74    // High quality English models
75    /// BAAI/bge-base-en-v1.5 - 768 dimensions
76    BgeBaseEnV15,
77    /// Quantized BAAI/bge-base-en-v1.5
78    BgeBaseEnV15Q,
79    /// BAAI/bge-large-en-v1.5 - 1024 dimensions
80    BgeLargeEnV15,
81    /// Quantized BAAI/bge-large-en-v1.5
82    BgeLargeEnV15Q,
83
84    // Multilingual models
85    // NOTE: BGE-M3 is not available in fastembed 5.5.0, use MultilingualE5 instead
86    /// intfloat/multilingual-e5-small - 384 dimensions
87    MultilingualE5Small,
88    /// intfloat/multilingual-e5-base - 768 dimensions
89    MultilingualE5Base,
90    /// intfloat/multilingual-e5-large - 1024 dimensions
91    MultilingualE5Large,
92    /// sentence-transformers/paraphrase-MiniLM-L12-v2
93    ParaphraseMiniLmL12V2,
94    /// Quantized paraphrase-MiniLM-L12-v2
95    ParaphraseMiniLmL12V2Q,
96    /// sentence-transformers/paraphrase-multilingual-mpnet-base-v2 - 768 dimensions
97    ParaphraseMultilingualMpnetBaseV2,
98
99    // Chinese models
100    /// BAAI/bge-small-zh-v1.5 - 512 dimensions
101    BgeSmallZhV15,
102    /// BAAI/bge-large-zh-v1.5 - 1024 dimensions
103    BgeLargeZhV15,
104
105    // Long context models
106    /// nomic-ai/nomic-embed-text-v1 - 768 dimensions, 8192 context
107    NomicEmbedTextV1,
108    /// nomic-ai/nomic-embed-text-v1.5 - 768 dimensions, 8192 context
109    NomicEmbedTextV15,
110    /// Quantized nomic-embed-text-v1.5
111    NomicEmbedTextV15Q,
112
113    // Specialized models
114    /// mixedbread-ai/mxbai-embed-large-v1 - 1024 dimensions
115    MxbaiEmbedLargeV1,
116    /// Quantized mxbai-embed-large-v1
117    MxbaiEmbedLargeV1Q,
118    /// Alibaba-NLP/gte-base-en-v1.5 - 768 dimensions
119    GteBaseEnV15,
120    /// Quantized gte-base-en-v1.5
121    GteBaseEnV15Q,
122    /// Alibaba-NLP/gte-large-en-v1.5 - 1024 dimensions
123    GteLargeEnV15,
124    /// Quantized gte-large-en-v1.5
125    GteLargeEnV15Q,
126    /// Qdrant/clip-ViT-B-32-text - 512 dimensions, pairs with vision model
127    ClipVitB32,
128
129    // Code models
130    /// jinaai/jina-embeddings-v2-base-code - 768 dimensions
131    JinaEmbeddingsV2BaseCode,
132    // NOTE: JinaEmbeddingsV2BaseEN is not available in fastembed 5.5.0
133
134    // Modern models
135    /// google/embeddinggemma-300m - 768 dimensions
136    EmbeddingGemma300M,
137    /// lightonai/modernbert-embed-large - 1024 dimensions
138    ModernBertEmbedLarge,
139
140    // Snowflake Arctic models
141    /// snowflake/snowflake-arctic-embed-xs - 384 dimensions
142    SnowflakeArcticEmbedXs,
143    /// Quantized snowflake-arctic-embed-xs
144    SnowflakeArcticEmbedXsQ,
145    /// snowflake/snowflake-arctic-embed-s - 384 dimensions
146    SnowflakeArcticEmbedS,
147    /// Quantized snowflake-arctic-embed-s
148    SnowflakeArcticEmbedSQ,
149    /// snowflake/snowflake-arctic-embed-m - 768 dimensions
150    SnowflakeArcticEmbedM,
151    /// Quantized snowflake-arctic-embed-m
152    SnowflakeArcticEmbedMQ,
153    /// snowflake/snowflake-arctic-embed-m-long - 768 dimensions, 2048 context
154    SnowflakeArcticEmbedMLong,
155    /// Quantized snowflake-arctic-embed-m-long
156    SnowflakeArcticEmbedMLongQ,
157    /// snowflake/snowflake-arctic-embed-l - 1024 dimensions
158    SnowflakeArcticEmbedL,
159    /// Quantized snowflake-arctic-embed-l
160    SnowflakeArcticEmbedLQ,
161}
162
163impl EmbeddingModelType {
164    /// Convert to fastembed's EmbeddingModel enum
165    pub fn to_fastembed_model(&self) -> FastEmbedModel {
166        match self {
167            // Fast English
168            Self::BgeSmallEnV15 => FastEmbedModel::BGESmallENV15,
169            Self::BgeSmallEnV15Q => FastEmbedModel::BGESmallENV15Q,
170            Self::AllMiniLmL6V2 => FastEmbedModel::AllMiniLML6V2,
171            Self::AllMiniLmL6V2Q => FastEmbedModel::AllMiniLML6V2Q,
172            Self::AllMiniLmL12V2 => FastEmbedModel::AllMiniLML12V2,
173            Self::AllMiniLmL12V2Q => FastEmbedModel::AllMiniLML12V2Q,
174            Self::AllMpnetBaseV2 => FastEmbedModel::AllMpnetBaseV2,
175
176            // High quality English
177            Self::BgeBaseEnV15 => FastEmbedModel::BGEBaseENV15,
178            Self::BgeBaseEnV15Q => FastEmbedModel::BGEBaseENV15Q,
179            Self::BgeLargeEnV15 => FastEmbedModel::BGELargeENV15,
180            Self::BgeLargeEnV15Q => FastEmbedModel::BGELargeENV15Q,
181
182            // Multilingual
183            Self::MultilingualE5Small => FastEmbedModel::MultilingualE5Small,
184            Self::MultilingualE5Base => FastEmbedModel::MultilingualE5Base,
185            Self::MultilingualE5Large => FastEmbedModel::MultilingualE5Large,
186            Self::ParaphraseMiniLmL12V2 => FastEmbedModel::ParaphraseMLMiniLML12V2,
187            Self::ParaphraseMiniLmL12V2Q => FastEmbedModel::ParaphraseMLMiniLML12V2Q,
188            Self::ParaphraseMultilingualMpnetBaseV2 => FastEmbedModel::ParaphraseMLMpnetBaseV2,
189
190            // Chinese
191            Self::BgeSmallZhV15 => FastEmbedModel::BGESmallZHV15,
192            Self::BgeLargeZhV15 => FastEmbedModel::BGELargeZHV15,
193
194            // Long context
195            Self::NomicEmbedTextV1 => FastEmbedModel::NomicEmbedTextV1,
196            Self::NomicEmbedTextV15 => FastEmbedModel::NomicEmbedTextV15,
197            Self::NomicEmbedTextV15Q => FastEmbedModel::NomicEmbedTextV15Q,
198
199            // Specialized
200            Self::MxbaiEmbedLargeV1 => FastEmbedModel::MxbaiEmbedLargeV1,
201            Self::MxbaiEmbedLargeV1Q => FastEmbedModel::MxbaiEmbedLargeV1Q,
202            Self::GteBaseEnV15 => FastEmbedModel::GTEBaseENV15,
203            Self::GteBaseEnV15Q => FastEmbedModel::GTEBaseENV15Q,
204            Self::GteLargeEnV15 => FastEmbedModel::GTELargeENV15,
205            Self::GteLargeEnV15Q => FastEmbedModel::GTELargeENV15Q,
206            Self::ClipVitB32 => FastEmbedModel::ClipVitB32,
207
208            // Code
209            Self::JinaEmbeddingsV2BaseCode => FastEmbedModel::JinaEmbeddingsV2BaseCode,
210
211            // Modern
212            Self::EmbeddingGemma300M => FastEmbedModel::EmbeddingGemma300M,
213            Self::ModernBertEmbedLarge => FastEmbedModel::ModernBertEmbedLarge,
214
215            // Snowflake Arctic
216            Self::SnowflakeArcticEmbedXs => FastEmbedModel::SnowflakeArcticEmbedXS,
217            Self::SnowflakeArcticEmbedXsQ => FastEmbedModel::SnowflakeArcticEmbedXSQ,
218            Self::SnowflakeArcticEmbedS => FastEmbedModel::SnowflakeArcticEmbedS,
219            Self::SnowflakeArcticEmbedSQ => FastEmbedModel::SnowflakeArcticEmbedSQ,
220            Self::SnowflakeArcticEmbedM => FastEmbedModel::SnowflakeArcticEmbedM,
221            Self::SnowflakeArcticEmbedMQ => FastEmbedModel::SnowflakeArcticEmbedMQ,
222            Self::SnowflakeArcticEmbedMLong => FastEmbedModel::SnowflakeArcticEmbedMLong,
223            Self::SnowflakeArcticEmbedMLongQ => FastEmbedModel::SnowflakeArcticEmbedMLongQ,
224            Self::SnowflakeArcticEmbedL => FastEmbedModel::SnowflakeArcticEmbedL,
225            Self::SnowflakeArcticEmbedLQ => FastEmbedModel::SnowflakeArcticEmbedLQ,
226        }
227    }
228
229    /// Get the dimension of the embedding output
230    pub fn dimensions(&self) -> usize {
231        match self {
232            // 384 dimensions
233            Self::BgeSmallEnV15
234            | Self::BgeSmallEnV15Q
235            | Self::AllMiniLmL6V2
236            | Self::AllMiniLmL6V2Q
237            | Self::AllMiniLmL12V2
238            | Self::AllMiniLmL12V2Q
239            | Self::MultilingualE5Small
240            | Self::SnowflakeArcticEmbedXs
241            | Self::SnowflakeArcticEmbedXsQ
242            | Self::SnowflakeArcticEmbedS
243            | Self::SnowflakeArcticEmbedSQ => 384,
244
245            // 512 dimensions
246            Self::BgeSmallZhV15 | Self::ClipVitB32 => 512,
247
248            // 768 dimensions
249            Self::AllMpnetBaseV2
250            | Self::BgeBaseEnV15
251            | Self::BgeBaseEnV15Q
252            | Self::MultilingualE5Base
253            | Self::ParaphraseMiniLmL12V2
254            | Self::ParaphraseMiniLmL12V2Q
255            | Self::ParaphraseMultilingualMpnetBaseV2
256            | Self::NomicEmbedTextV1
257            | Self::NomicEmbedTextV15
258            | Self::NomicEmbedTextV15Q
259            | Self::GteBaseEnV15
260            | Self::GteBaseEnV15Q
261            | Self::JinaEmbeddingsV2BaseCode
262            | Self::EmbeddingGemma300M
263            | Self::SnowflakeArcticEmbedM
264            | Self::SnowflakeArcticEmbedMQ
265            | Self::SnowflakeArcticEmbedMLong
266            | Self::SnowflakeArcticEmbedMLongQ => 768,
267
268            // 1024 dimensions
269            Self::BgeLargeEnV15
270            | Self::BgeLargeEnV15Q
271            | Self::BgeLargeZhV15
272            | Self::MultilingualE5Large
273            | Self::MxbaiEmbedLargeV1
274            | Self::MxbaiEmbedLargeV1Q
275            | Self::GteLargeEnV15
276            | Self::GteLargeEnV15Q
277            | Self::ModernBertEmbedLarge
278            | Self::SnowflakeArcticEmbedL
279            | Self::SnowflakeArcticEmbedLQ => 1024,
280        }
281    }
282
283    /// Check if this is a quantized model
284    pub fn is_quantized(&self) -> bool {
285        matches!(
286            self,
287            Self::BgeSmallEnV15Q
288                | Self::AllMiniLmL6V2Q
289                | Self::AllMiniLmL12V2Q
290                | Self::BgeBaseEnV15Q
291                | Self::BgeLargeEnV15Q
292                | Self::ParaphraseMiniLmL12V2Q
293                | Self::NomicEmbedTextV15Q
294                | Self::MxbaiEmbedLargeV1Q
295                | Self::GteBaseEnV15Q
296                | Self::GteLargeEnV15Q
297                | Self::SnowflakeArcticEmbedXsQ
298                | Self::SnowflakeArcticEmbedSQ
299                | Self::SnowflakeArcticEmbedMQ
300                | Self::SnowflakeArcticEmbedMLongQ
301                | Self::SnowflakeArcticEmbedLQ
302        )
303    }
304
305    /// Check if this model supports multilingual text
306    pub fn is_multilingual(&self) -> bool {
307        matches!(
308            self,
309            Self::MultilingualE5Small
310                | Self::MultilingualE5Base
311                | Self::MultilingualE5Large
312                | Self::ParaphraseMultilingualMpnetBaseV2
313                | Self::BgeSmallZhV15
314                | Self::BgeLargeZhV15
315        )
316    }
317
318    /// Get the maximum context length in tokens
319    pub fn max_context_length(&self) -> usize {
320        match self {
321            Self::NomicEmbedTextV1 | Self::NomicEmbedTextV15 | Self::NomicEmbedTextV15Q => 8192,
322            Self::SnowflakeArcticEmbedMLong | Self::SnowflakeArcticEmbedMLongQ => 2048,
323            _ => 512,
324        }
325    }
326
327    /// List all available models
328    pub fn all() -> Vec<Self> {
329        vec![
330            Self::BgeSmallEnV15,
331            Self::BgeSmallEnV15Q,
332            Self::AllMiniLmL6V2,
333            Self::AllMiniLmL6V2Q,
334            Self::AllMiniLmL12V2,
335            Self::AllMiniLmL12V2Q,
336            Self::AllMpnetBaseV2,
337            Self::BgeBaseEnV15,
338            Self::BgeBaseEnV15Q,
339            Self::BgeLargeEnV15,
340            Self::BgeLargeEnV15Q,
341            Self::MultilingualE5Small,
342            Self::MultilingualE5Base,
343            Self::MultilingualE5Large,
344            Self::ParaphraseMiniLmL12V2,
345            Self::ParaphraseMiniLmL12V2Q,
346            Self::ParaphraseMultilingualMpnetBaseV2,
347            Self::BgeSmallZhV15,
348            Self::BgeLargeZhV15,
349            Self::NomicEmbedTextV1,
350            Self::NomicEmbedTextV15,
351            Self::NomicEmbedTextV15Q,
352            Self::MxbaiEmbedLargeV1,
353            Self::MxbaiEmbedLargeV1Q,
354            Self::GteBaseEnV15,
355            Self::GteBaseEnV15Q,
356            Self::GteLargeEnV15,
357            Self::GteLargeEnV15Q,
358            Self::ClipVitB32,
359            Self::JinaEmbeddingsV2BaseCode,
360            Self::EmbeddingGemma300M,
361            Self::ModernBertEmbedLarge,
362            Self::SnowflakeArcticEmbedXs,
363            Self::SnowflakeArcticEmbedXsQ,
364            Self::SnowflakeArcticEmbedS,
365            Self::SnowflakeArcticEmbedSQ,
366            Self::SnowflakeArcticEmbedM,
367            Self::SnowflakeArcticEmbedMQ,
368            Self::SnowflakeArcticEmbedMLong,
369            Self::SnowflakeArcticEmbedMLongQ,
370            Self::SnowflakeArcticEmbedL,
371            Self::SnowflakeArcticEmbedLQ,
372        ]
373    }
374}
375
376impl Display for EmbeddingModelType {
377    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378        let name = match self {
379            Self::BgeSmallEnV15 => "bge-small-en-v1.5",
380            Self::BgeSmallEnV15Q => "bge-small-en-v1.5-q",
381            Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
382            Self::AllMiniLmL6V2Q => "all-minilm-l6-v2-q",
383            Self::AllMiniLmL12V2 => "all-minilm-l12-v2",
384            Self::AllMiniLmL12V2Q => "all-minilm-l12-v2-q",
385            Self::AllMpnetBaseV2 => "all-mpnet-base-v2",
386            Self::BgeBaseEnV15 => "bge-base-en-v1.5",
387            Self::BgeBaseEnV15Q => "bge-base-en-v1.5-q",
388            Self::BgeLargeEnV15 => "bge-large-en-v1.5",
389            Self::BgeLargeEnV15Q => "bge-large-en-v1.5-q",
390            Self::MultilingualE5Small => "multilingual-e5-small",
391            Self::MultilingualE5Base => "multilingual-e5-base",
392            Self::MultilingualE5Large => "multilingual-e5-large",
393            Self::ParaphraseMiniLmL12V2 => "paraphrase-minilm-l12-v2",
394            Self::ParaphraseMiniLmL12V2Q => "paraphrase-minilm-l12-v2-q",
395            Self::ParaphraseMultilingualMpnetBaseV2 => "paraphrase-multilingual-mpnet-base-v2",
396            Self::BgeSmallZhV15 => "bge-small-zh-v1.5",
397            Self::BgeLargeZhV15 => "bge-large-zh-v1.5",
398            Self::NomicEmbedTextV1 => "nomic-embed-text-v1",
399            Self::NomicEmbedTextV15 => "nomic-embed-text-v1.5",
400            Self::NomicEmbedTextV15Q => "nomic-embed-text-v1.5-q",
401            Self::MxbaiEmbedLargeV1 => "mxbai-embed-large-v1",
402            Self::MxbaiEmbedLargeV1Q => "mxbai-embed-large-v1-q",
403            Self::GteBaseEnV15 => "gte-base-en-v1.5",
404            Self::GteBaseEnV15Q => "gte-base-en-v1.5-q",
405            Self::GteLargeEnV15 => "gte-large-en-v1.5",
406            Self::GteLargeEnV15Q => "gte-large-en-v1.5-q",
407            Self::ClipVitB32 => "clip-vit-b-32",
408            Self::JinaEmbeddingsV2BaseCode => "jina-embeddings-v2-base-code",
409            Self::EmbeddingGemma300M => "embedding-gemma-300m",
410            Self::ModernBertEmbedLarge => "modernbert-embed-large",
411            Self::SnowflakeArcticEmbedXs => "snowflake-arctic-embed-xs",
412            Self::SnowflakeArcticEmbedXsQ => "snowflake-arctic-embed-xs-q",
413            Self::SnowflakeArcticEmbedS => "snowflake-arctic-embed-s",
414            Self::SnowflakeArcticEmbedSQ => "snowflake-arctic-embed-s-q",
415            Self::SnowflakeArcticEmbedM => "snowflake-arctic-embed-m",
416            Self::SnowflakeArcticEmbedMQ => "snowflake-arctic-embed-m-q",
417            Self::SnowflakeArcticEmbedMLong => "snowflake-arctic-embed-m-long",
418            Self::SnowflakeArcticEmbedMLongQ => "snowflake-arctic-embed-m-long-q",
419            Self::SnowflakeArcticEmbedL => "snowflake-arctic-embed-l",
420            Self::SnowflakeArcticEmbedLQ => "snowflake-arctic-embed-l-q",
421        };
422        write!(f, "{}", name)
423    }
424}
425
426impl FromStr for EmbeddingModelType {
427    type Err = AppError;
428
429    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
430        match s.to_lowercase().as_str() {
431            "bge-small-en-v1.5" | "bge-small-en" | "bge-small" => Ok(Self::BgeSmallEnV15),
432            "bge-small-en-v1.5-q" => Ok(Self::BgeSmallEnV15Q),
433            "all-minilm-l6-v2" | "minilm-l6" => Ok(Self::AllMiniLmL6V2),
434            "all-minilm-l6-v2-q" => Ok(Self::AllMiniLmL6V2Q),
435            "all-minilm-l12-v2" | "minilm-l12" => Ok(Self::AllMiniLmL12V2),
436            "all-minilm-l12-v2-q" => Ok(Self::AllMiniLmL12V2Q),
437            "all-mpnet-base-v2" | "mpnet" => Ok(Self::AllMpnetBaseV2),
438            "bge-base-en-v1.5" | "bge-base-en" | "bge-base" => Ok(Self::BgeBaseEnV15),
439            "bge-base-en-v1.5-q" => Ok(Self::BgeBaseEnV15Q),
440            "bge-large-en-v1.5" | "bge-large-en" | "bge-large" => Ok(Self::BgeLargeEnV15),
441            "bge-large-en-v1.5-q" => Ok(Self::BgeLargeEnV15Q),
442            "multilingual-e5-small" | "e5-small" => Ok(Self::MultilingualE5Small),
443            "multilingual-e5-base" | "e5-base" => Ok(Self::MultilingualE5Base),
444            "multilingual-e5-large" | "e5-large" => Ok(Self::MultilingualE5Large),
445            "paraphrase-minilm-l12-v2" => Ok(Self::ParaphraseMiniLmL12V2),
446            "paraphrase-minilm-l12-v2-q" => Ok(Self::ParaphraseMiniLmL12V2Q),
447            "paraphrase-multilingual-mpnet-base-v2" => Ok(Self::ParaphraseMultilingualMpnetBaseV2),
448            "bge-small-zh-v1.5" | "bge-small-zh" => Ok(Self::BgeSmallZhV15),
449            "bge-large-zh-v1.5" | "bge-large-zh" => Ok(Self::BgeLargeZhV15),
450            "nomic-embed-text-v1" | "nomic-v1" => Ok(Self::NomicEmbedTextV1),
451            "nomic-embed-text-v1.5" | "nomic-v1.5" | "nomic" => Ok(Self::NomicEmbedTextV15),
452            "nomic-embed-text-v1.5-q" => Ok(Self::NomicEmbedTextV15Q),
453            "mxbai-embed-large-v1" | "mxbai" => Ok(Self::MxbaiEmbedLargeV1),
454            "mxbai-embed-large-v1-q" => Ok(Self::MxbaiEmbedLargeV1Q),
455            "gte-base-en-v1.5" | "gte-base" => Ok(Self::GteBaseEnV15),
456            "gte-base-en-v1.5-q" => Ok(Self::GteBaseEnV15Q),
457            "gte-large-en-v1.5" | "gte-large" => Ok(Self::GteLargeEnV15),
458            "gte-large-en-v1.5-q" => Ok(Self::GteLargeEnV15Q),
459            "clip-vit-b-32" | "clip" => Ok(Self::ClipVitB32),
460            "jina-embeddings-v2-base-code" | "jina-code" => Ok(Self::JinaEmbeddingsV2BaseCode),
461            "embedding-gemma-300m" | "gemma-300m" | "gemma" => Ok(Self::EmbeddingGemma300M),
462            "modernbert-embed-large" | "modernbert" => Ok(Self::ModernBertEmbedLarge),
463            "snowflake-arctic-embed-xs" => Ok(Self::SnowflakeArcticEmbedXs),
464            "snowflake-arctic-embed-xs-q" => Ok(Self::SnowflakeArcticEmbedXsQ),
465            "snowflake-arctic-embed-s" => Ok(Self::SnowflakeArcticEmbedS),
466            "snowflake-arctic-embed-s-q" => Ok(Self::SnowflakeArcticEmbedSQ),
467            "snowflake-arctic-embed-m" => Ok(Self::SnowflakeArcticEmbedM),
468            "snowflake-arctic-embed-m-q" => Ok(Self::SnowflakeArcticEmbedMQ),
469            "snowflake-arctic-embed-m-long" => Ok(Self::SnowflakeArcticEmbedMLong),
470            "snowflake-arctic-embed-m-long-q" => Ok(Self::SnowflakeArcticEmbedMLongQ),
471            "snowflake-arctic-embed-l" | "snowflake-l" => Ok(Self::SnowflakeArcticEmbedL),
472            "snowflake-arctic-embed-l-q" => Ok(Self::SnowflakeArcticEmbedLQ),
473            _ => Err(AppError::Internal(format!(
474                "Unknown embedding model: {}. Use one of: {}",
475                s,
476                EmbeddingModelType::all()
477                    .iter()
478                    .map(|m| m.to_string())
479                    .collect::<Vec<_>>()
480                    .join(", ")
481            ))),
482        }
483    }
484}
485
486// ============================================================================
487// Sparse Embedding Model Configuration
488// ============================================================================
489
490/// Supported sparse embedding models for hybrid search
491#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
492#[serde(rename_all = "kebab-case")]
493pub enum SparseModelType {
494    /// SPLADE++ v1 - English sparse embeddings
495    #[default]
496    SpladePpV1,
497    // NOTE: BGE-M3 sparse mode is not available in fastembed 5.5.0
498}
499
500impl SparseModelType {
501    /// Convert to fastembed's SparseModel enum
502    pub fn to_fastembed_model(&self) -> SparseModel {
503        match self {
504            Self::SpladePpV1 => SparseModel::SPLADEPPV1,
505        }
506    }
507}
508
509impl Display for SparseModelType {
510    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
511        let name = match self {
512            Self::SpladePpV1 => "splade-pp-v1",
513        };
514        write!(f, "{}", name)
515    }
516}
517
518impl FromStr for SparseModelType {
519    type Err = AppError;
520
521    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
522        match s.to_lowercase().as_str() {
523            "splade-pp-v1" | "splade" => Ok(Self::SpladePpV1),
524            _ => Err(AppError::Internal(format!(
525                "Unknown sparse model: {}. Use: splade-pp-v1",
526                s
527            ))),
528        }
529    }
530}
531
532// ============================================================================
533// Embedding Service Configuration
534// ============================================================================
535
536/// Configuration for the embedding service
537#[derive(Debug, Clone, Serialize, Deserialize)]
538pub struct EmbeddingConfig {
539    /// The embedding model to use
540    #[serde(default)]
541    pub model: EmbeddingModelType,
542
543    /// Batch size for embedding multiple texts
544    #[serde(default = "default_batch_size")]
545    pub batch_size: usize,
546
547    /// Show download progress for first-time model downloads
548    #[serde(default = "default_show_progress")]
549    pub show_download_progress: bool,
550
551    /// Enable sparse embeddings for hybrid search
552    #[serde(default)]
553    pub sparse_enabled: bool,
554
555    /// Sparse embedding model to use
556    #[serde(default)]
557    pub sparse_model: SparseModelType,
558}
559
560fn default_batch_size() -> usize {
561    32
562}
563
564fn default_show_progress() -> bool {
565    true
566}
567
568impl Default for EmbeddingConfig {
569    fn default() -> Self {
570        Self {
571            model: EmbeddingModelType::default(),
572            batch_size: default_batch_size(),
573            show_download_progress: default_show_progress(),
574            sparse_enabled: false,
575            sparse_model: SparseModelType::default(),
576        }
577    }
578}
579
580// ============================================================================
581// Embedding Service
582// ============================================================================
583
584/// Main embedding service for generating text embeddings
585///
586/// Uses `spawn_blocking` to run fastembed's synchronous operations
587/// without blocking the async runtime.
588///
589/// The model is wrapped in `Arc<Mutex<TextEmbedding>>` to allow safe
590/// reuse across async boundaries without recreating the model on each call.
591pub struct EmbeddingService {
592    /// The text embedding model, wrapped for thread-safe access
593    model: Arc<Mutex<TextEmbedding>>,
594    /// Optional sparse embedding model for hybrid search
595    sparse_model: Option<Arc<Mutex<fastembed::SparseTextEmbedding>>>,
596    config: EmbeddingConfig,
597}
598
599impl EmbeddingService {
600    /// Create a new embedding service with the given configuration
601    ///
602    /// Uses a per-model lock to prevent race conditions when multiple threads
603    /// try to download/initialize the same model simultaneously.
604    pub fn new(config: EmbeddingConfig) -> Result<Self> {
605        let model_name = format!("{:?}", config.model.to_fastembed_model());
606        let model_lock = get_model_lock(&model_name);
607
608        // Acquire lock for this specific model to prevent concurrent downloads
609        let _guard = model_lock.lock().map_err(|e| {
610            AppError::Internal(format!(
611                "Failed to acquire model initialization lock: {}",
612                e
613            ))
614        })?;
615
616        let model = TextEmbedding::try_new(
617            InitOptions::new(config.model.to_fastembed_model())
618                .with_show_download_progress(config.show_download_progress),
619        )
620        .map_err(|e| AppError::Internal(format!("Failed to initialize embedding model: {}", e)))?;
621
622        let sparse_model = if config.sparse_enabled {
623            let sparse_model_name = format!("{:?}", config.sparse_model.to_fastembed_model());
624            let sparse_lock = get_model_lock(&sparse_model_name);
625            let _sparse_guard = sparse_lock.lock().map_err(|e| {
626                AppError::Internal(format!("Failed to acquire sparse model lock: {}", e))
627            })?;
628
629            Some(
630                fastembed::SparseTextEmbedding::try_new(
631                    fastembed::SparseInitOptions::new(config.sparse_model.to_fastembed_model())
632                        .with_show_download_progress(config.show_download_progress),
633                )
634                .map_err(|e| {
635                    AppError::Internal(format!(
636                        "Failed to initialize sparse embedding model: {}",
637                        e
638                    ))
639                })?,
640            )
641        } else {
642            None
643        };
644
645        Ok(Self {
646            model: Arc::new(Mutex::new(model)),
647            sparse_model: sparse_model.map(|m| Arc::new(Mutex::new(m))),
648            config,
649        })
650    }
651
652    /// Create a new embedding service with the default model
653    pub fn with_default_model() -> Result<Self> {
654        Self::new(EmbeddingConfig::default())
655    }
656
657    /// Create a new embedding service with a specific model
658    pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
659        Self::new(EmbeddingConfig {
660            model,
661            ..Default::default()
662        })
663    }
664
665    /// Get the current model type
666    pub fn model_type(&self) -> EmbeddingModelType {
667        self.config.model
668    }
669
670    /// Get the embedding dimensions
671    pub fn dimensions(&self) -> usize {
672        self.config.model.dimensions()
673    }
674
675    /// Get the configuration
676    pub fn config(&self) -> &EmbeddingConfig {
677        &self.config
678    }
679
680    /// Embed a single text (async via spawn_blocking)
681    pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
682        let embeddings = self.embed_texts(&[text.to_string()]).await?;
683        embeddings
684            .into_iter()
685            .next()
686            .ok_or_else(|| AppError::Internal("No embedding generated".to_string()))
687    }
688
689    /// Embed multiple texts in batches (async via spawn_blocking)
690    ///
691    /// This is more efficient than calling `embed_text` multiple times
692    /// as it batches the texts and processes them together.
693    ///
694    /// The model is reused across calls via `Arc<Mutex<TextEmbedding>>`.
695    pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
696        &self,
697        texts: &[S],
698    ) -> Result<Vec<Vec<f32>>> {
699        if texts.is_empty() {
700            return Ok(vec![]);
701        }
702
703        // Clone texts to owned strings for the spawn_blocking closure
704        let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
705        let batch_size = self.config.batch_size;
706
707        // Clone the Arc to move into the blocking task
708        let model = Arc::clone(&self.model);
709
710        spawn_blocking(move || {
711            // Lock the model for use
712            let mut model_guard = model
713                .lock()
714                .map_err(|e| AppError::Internal(format!("Failed to acquire model lock: {}", e)))?;
715
716            let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
717            model_guard
718                .embed(refs, Some(batch_size))
719                .map_err(|e| AppError::Internal(format!("Embedding failed: {}", e)))
720        })
721        .await
722        .map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
723    }
724
725    /// Generate sparse embeddings for hybrid search
726    ///
727    /// The sparse model is reused across calls via `Arc<Mutex<SparseTextEmbedding>>`.
728    pub async fn embed_sparse<S: AsRef<str> + Send + Sync + 'static>(
729        &self,
730        texts: &[S],
731    ) -> Result<Vec<fastembed::SparseEmbedding>> {
732        let sparse_model = self.sparse_model.as_ref().ok_or_else(|| {
733            AppError::Internal(
734                "Sparse embeddings not enabled. Set sparse_enabled: true in config.".to_string(),
735            )
736        })?;
737
738        let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
739        let batch_size = self.config.batch_size;
740
741        // Clone the Arc to move into the blocking task
742        let model = Arc::clone(sparse_model);
743
744        spawn_blocking(move || {
745            // Lock the model for use
746            let mut model_guard = model.lock().map_err(|e| {
747                AppError::Internal(format!("Failed to acquire sparse model lock: {}", e))
748            })?;
749
750            let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
751            model_guard
752                .embed(refs, Some(batch_size))
753                .map_err(|e| AppError::Internal(format!("Sparse embedding failed: {}", e)))
754        })
755        .await
756        .map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
757    }
758}
759
760// ============================================================================
761// Cached Embedding Service
762// ============================================================================
763
764use crate::rag::cache::{CacheConfig, CacheStats, EmbeddingCache, LruEmbeddingCache, NoOpCache};
765
766/// An embedding service with integrated caching
767///
768/// Wraps an `EmbeddingService` with an `EmbeddingCache` to avoid recomputing
769/// embeddings for previously seen texts. The cache key is computed as a hash
770/// of the text content and model name.
771///
772/// # Example
773///
774/// ```ignore
775/// use ares::rag::embeddings::{CachedEmbeddingService, EmbeddingConfig};
776/// use ares::rag::cache::CacheConfig;
777///
778/// let service = CachedEmbeddingService::new(
779///     EmbeddingConfig::default(),
780///     CacheConfig::default(),
781/// )?;
782///
783/// // First call computes the embedding
784/// let emb1 = service.embed_text("hello world").await?;
785///
786/// // Second call returns cached result
787/// let emb2 = service.embed_text("hello world").await?;
788/// assert_eq!(emb1, emb2);
789/// ```
790pub struct CachedEmbeddingService {
791    /// The underlying embedding service
792    inner: EmbeddingService,
793    /// The embedding cache
794    cache: Box<dyn EmbeddingCache>,
795}
796
797impl CachedEmbeddingService {
798    /// Create a new cached embedding service
799    pub fn new(embedding_config: EmbeddingConfig, cache_config: CacheConfig) -> Result<Self> {
800        let inner = EmbeddingService::new(embedding_config)?;
801        let cache: Box<dyn EmbeddingCache> = if cache_config.enabled {
802            Box::new(LruEmbeddingCache::new(cache_config))
803        } else {
804            Box::new(NoOpCache::new())
805        };
806
807        Ok(Self { inner, cache })
808    }
809
810    /// Create with default configurations
811    pub fn with_defaults() -> Result<Self> {
812        Self::new(EmbeddingConfig::default(), CacheConfig::default())
813    }
814
815    /// Create with a specific model and default cache
816    pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
817        Self::new(
818            EmbeddingConfig {
819                model,
820                ..Default::default()
821            },
822            CacheConfig::default(),
823        )
824    }
825
826    /// Create with caching disabled
827    pub fn without_cache(embedding_config: EmbeddingConfig) -> Result<Self> {
828        Self::new(
829            embedding_config,
830            CacheConfig {
831                enabled: false,
832                ..Default::default()
833            },
834        )
835    }
836
837    /// Get the model name for cache key computation
838    fn model_name(&self) -> String {
839        self.inner.model_type().to_string()
840    }
841
842    /// Embed a single text with caching
843    pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
844        let cache_key = self.cache.compute_key(text, &self.model_name());
845
846        // Check cache first
847        if let Some(cached) = self.cache.get(&cache_key) {
848            return Ok(cached);
849        }
850
851        // Compute embedding
852        let embedding = self.inner.embed_text(text).await?;
853
854        // Store in cache
855        self.cache.set(&cache_key, embedding.clone(), None)?;
856
857        Ok(embedding)
858    }
859
860    /// Embed multiple texts with caching
861    ///
862    /// Checks cache for each text individually, computes embeddings only
863    /// for uncached texts, and caches the new results.
864    pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
865        &self,
866        texts: &[S],
867    ) -> Result<Vec<Vec<f32>>> {
868        if texts.is_empty() {
869            return Ok(vec![]);
870        }
871
872        let model_name = self.model_name();
873        let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
874        let mut uncached_indices: Vec<usize> = Vec::new();
875        let mut uncached_texts: Vec<String> = Vec::new();
876
877        // Check cache for each text
878        for (i, text) in texts.iter().enumerate() {
879            let text_str = text.as_ref();
880            let cache_key = self.cache.compute_key(text_str, &model_name);
881
882            if let Some(cached) = self.cache.get(&cache_key) {
883                results[i] = Some(cached);
884            } else {
885                uncached_indices.push(i);
886                uncached_texts.push(text_str.to_string());
887            }
888        }
889
890        // Compute embeddings for uncached texts
891        if !uncached_texts.is_empty() {
892            let new_embeddings = self.inner.embed_texts(&uncached_texts).await?;
893
894            // Store results and cache them
895            for (j, embedding) in new_embeddings.into_iter().enumerate() {
896                let idx = uncached_indices[j];
897                let cache_key = self.cache.compute_key(&uncached_texts[j], &model_name);
898                self.cache.set(&cache_key, embedding.clone(), None)?;
899                results[idx] = Some(embedding);
900            }
901        }
902
903        // Unwrap all results (should all be Some at this point)
904        Ok(results.into_iter().flatten().collect())
905    }
906
907    /// Get the current model type
908    pub fn model_type(&self) -> EmbeddingModelType {
909        self.inner.model_type()
910    }
911
912    /// Get the embedding dimensions
913    pub fn dimensions(&self) -> usize {
914        self.inner.dimensions()
915    }
916
917    /// Get the embedding configuration
918    pub fn config(&self) -> &EmbeddingConfig {
919        self.inner.config()
920    }
921
922    /// Get cache statistics
923    pub fn cache_stats(&self) -> CacheStats {
924        self.cache.stats()
925    }
926
927    /// Clear the cache
928    pub fn clear_cache(&self) -> Result<()> {
929        self.cache.clear()
930    }
931
932    /// Invalidate a specific cache entry
933    pub fn invalidate(&self, text: &str) -> Result<()> {
934        let cache_key = self.cache.compute_key(text, &self.model_name());
935        self.cache.invalidate(&cache_key)
936    }
937
938    /// Check if caching is enabled
939    pub fn is_cache_enabled(&self) -> bool {
940        self.cache.is_enabled()
941    }
942}
943
944// ============================================================================
945// GPU Acceleration Stubs (TODO)
946// ============================================================================
947
948/// GPU acceleration backend (STUB - see docs/FUTURE_ENHANCEMENTS.md)
949///
950/// This enum represents potential GPU acceleration options for embedding models.
951/// Currently not implemented - all models run on CPU.
952///
953/// # Future Implementation
954///
955/// - **CUDA**: NVIDIA GPU acceleration via ONNX Runtime CUDA provider
956/// - **Metal**: Apple Silicon GPU acceleration via ONNX Runtime CoreML provider
957/// - **Vulkan**: Cross-platform GPU acceleration via ONNX Runtime Vulkan provider
958/// - **Candle**: GPU support for Qwen3 models via Candle's CUDA backend
959#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
960#[serde(rename_all = "lowercase")]
961#[allow(dead_code)]
962#[derive(Default)]
963pub enum AccelerationBackend {
964    /// CPU execution (default, always available)
965    #[default]
966    Cpu,
967    /// NVIDIA CUDA acceleration
968    Cuda {
969        /// The CUDA device ID to use for computation.
970        device_id: usize,
971    },
972    /// Apple Metal acceleration
973    Metal,
974    /// Vulkan GPU acceleration
975    Vulkan,
976}
977
978// ============================================================================
979// Legacy API Compatibility
980// ============================================================================
981
982/// Legacy embedding service for backward compatibility
983///
984/// This preserves the original API for existing code.
985#[deprecated(note = "Use EmbeddingService instead")]
986pub struct LegacyEmbeddingService {
987    inner: EmbeddingService,
988}
989
990#[allow(deprecated)]
991impl LegacyEmbeddingService {
992    /// Create a new legacy embedding service
993    pub fn new(_model_name: &str) -> Result<Self> {
994        Ok(Self {
995            inner: EmbeddingService::with_default_model()?,
996        })
997    }
998
999    /// Embed texts (synchronous API)
1000    pub fn embed(&mut self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
1001        let model_type = self.inner.config.model.to_fastembed_model();
1002        let mut model =
1003            TextEmbedding::try_new(InitOptions::new(model_type).with_show_download_progress(true))
1004                .map_err(|e| AppError::Internal(e.to_string()))?;
1005
1006        model
1007            .embed(texts, None)
1008            .map_err(|e| AppError::Internal(e.to_string()))
1009    }
1010}
1011
1012// ============================================================================
1013// Tests
1014// ============================================================================
1015
1016#[cfg(test)]
1017mod tests {
1018    use super::*;
1019
1020    #[test]
1021    fn test_model_dimensions() {
1022        assert_eq!(EmbeddingModelType::BgeSmallEnV15.dimensions(), 384);
1023        assert_eq!(EmbeddingModelType::BgeBaseEnV15.dimensions(), 768);
1024        assert_eq!(EmbeddingModelType::BgeLargeEnV15.dimensions(), 1024);
1025        assert_eq!(EmbeddingModelType::MultilingualE5Large.dimensions(), 1024);
1026    }
1027
1028    #[test]
1029    fn test_model_from_str() {
1030        assert_eq!(
1031            "bge-small-en-v1.5".parse::<EmbeddingModelType>().unwrap(),
1032            EmbeddingModelType::BgeSmallEnV15
1033        );
1034        assert_eq!(
1035            "multilingual-e5-large"
1036                .parse::<EmbeddingModelType>()
1037                .unwrap(),
1038            EmbeddingModelType::MultilingualE5Large
1039        );
1040        assert_eq!(
1041            "minilm-l6".parse::<EmbeddingModelType>().unwrap(),
1042            EmbeddingModelType::AllMiniLmL6V2
1043        );
1044    }
1045
1046    #[test]
1047    fn test_model_is_multilingual() {
1048        assert!(EmbeddingModelType::MultilingualE5Small.is_multilingual());
1049        assert!(EmbeddingModelType::MultilingualE5Large.is_multilingual());
1050        assert!(!EmbeddingModelType::BgeSmallEnV15.is_multilingual());
1051    }
1052
1053    #[test]
1054    fn test_model_max_context() {
1055        assert_eq!(
1056            EmbeddingModelType::NomicEmbedTextV15.max_context_length(),
1057            8192
1058        );
1059        assert_eq!(
1060            EmbeddingModelType::NomicEmbedTextV1.max_context_length(),
1061            8192
1062        );
1063        assert_eq!(EmbeddingModelType::BgeSmallEnV15.max_context_length(), 512);
1064    }
1065
1066    #[test]
1067    fn test_default_config() {
1068        let config = EmbeddingConfig::default();
1069        assert_eq!(config.model, EmbeddingModelType::BgeSmallEnV15);
1070        assert_eq!(config.batch_size, 32);
1071        assert!(config.show_download_progress);
1072        assert!(!config.sparse_enabled);
1073    }
1074
1075    #[test]
1076    fn test_all_models_listed() {
1077        let all = EmbeddingModelType::all();
1078        assert!(all.len() >= 38); // We have 38+ models
1079        assert!(all.contains(&EmbeddingModelType::BgeSmallEnV15));
1080        assert!(all.contains(&EmbeddingModelType::MultilingualE5Large));
1081    }
1082}