ares/rag/
embeddings.rs

1//! Embedding Service for RAG
2//!
3//! This module provides a comprehensive embedding service with support for:
4//! - 30+ text embedding models (BGE, Qwen3, Gemma, E5, Jina, etc.)
5//! - Sparse embeddings for hybrid search (SPLADE, BGE-M3)
6//! - Reranking models (BGE, Jina)
7//! - Async embedding via `spawn_blocking`
8//!
9//! # GPU Acceleration (TODO)
10//! GPU acceleration is planned for future iterations. See `docs/FUTURE_ENHANCEMENTS.md`.
11//! Potential approach:
12//! - Add feature flags: `cuda`, `metal`, `vulkan`
13//! - Use ORT execution providers for ONNX models
14//! - Use Candle GPU features for Qwen3 models
15//!
16//! # Embedding Cache (TODO)
17//! Embedding caching is deferred. See `docs/FUTURE_ENHANCEMENTS.md` and `src/rag/cache.rs`.
18
19use crate::types::{AppError, Result};
20use serde::{Deserialize, Serialize};
21use std::fmt::Display;
22use std::str::FromStr;
23use tokio::task::spawn_blocking;
24
25// Re-export fastembed types for convenience
26pub use fastembed::{
27    EmbeddingModel as FastEmbedModel, InitOptions, SparseModel, TextEmbedding,
28};
29
30// ============================================================================
31// Embedding Model Configuration
32// ============================================================================
33
34/// Supported embedding models with their metadata.
35///
36/// This enum wraps fastembed's EmbeddingModel with additional metadata
37/// for easier configuration and selection.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
39#[serde(rename_all = "kebab-case")]
40pub enum EmbeddingModelType {
41    // Fast English models (recommended defaults)
42    /// BAAI/bge-small-en-v1.5 - Fast, 384 dimensions (DEFAULT)
43    #[default]
44    BgeSmallEnV15,
45    /// Quantized BAAI/bge-small-en-v1.5
46    BgeSmallEnV15Q,
47    /// sentence-transformers/all-MiniLM-L6-v2 - Very fast, 384 dimensions
48    AllMiniLmL6V2,
49    /// Quantized all-MiniLM-L6-v2
50    AllMiniLmL6V2Q,
51    /// sentence-transformers/all-MiniLM-L12-v2 - Better quality, 384 dimensions
52    AllMiniLmL12V2,
53    /// Quantized all-MiniLM-L12-v2
54    AllMiniLmL12V2Q,
55    /// sentence-transformers/all-mpnet-base-v2 - 768 dimensions
56    AllMpnetBaseV2,
57
58    // High quality English models
59    /// BAAI/bge-base-en-v1.5 - 768 dimensions
60    BgeBaseEnV15,
61    /// Quantized BAAI/bge-base-en-v1.5
62    BgeBaseEnV15Q,
63    /// BAAI/bge-large-en-v1.5 - 1024 dimensions
64    BgeLargeEnV15,
65    /// Quantized BAAI/bge-large-en-v1.5
66    BgeLargeEnV15Q,
67
68    // Multilingual models
69    // NOTE: BGE-M3 is not available in fastembed 5.5.0, use MultilingualE5 instead
70    /// intfloat/multilingual-e5-small - 384 dimensions
71    MultilingualE5Small,
72    /// intfloat/multilingual-e5-base - 768 dimensions
73    MultilingualE5Base,
74    /// intfloat/multilingual-e5-large - 1024 dimensions
75    MultilingualE5Large,
76    /// sentence-transformers/paraphrase-MiniLM-L12-v2
77    ParaphraseMiniLmL12V2,
78    /// Quantized paraphrase-MiniLM-L12-v2
79    ParaphraseMiniLmL12V2Q,
80    /// sentence-transformers/paraphrase-multilingual-mpnet-base-v2 - 768 dimensions
81    ParaphraseMultilingualMpnetBaseV2,
82
83    // Chinese models
84    /// BAAI/bge-small-zh-v1.5 - 512 dimensions
85    BgeSmallZhV15,
86    /// BAAI/bge-large-zh-v1.5 - 1024 dimensions
87    BgeLargeZhV15,
88
89    // Long context models
90    /// nomic-ai/nomic-embed-text-v1 - 768 dimensions, 8192 context
91    NomicEmbedTextV1,
92    /// nomic-ai/nomic-embed-text-v1.5 - 768 dimensions, 8192 context
93    NomicEmbedTextV15,
94    /// Quantized nomic-embed-text-v1.5
95    NomicEmbedTextV15Q,
96
97    // Specialized models
98    /// mixedbread-ai/mxbai-embed-large-v1 - 1024 dimensions
99    MxbaiEmbedLargeV1,
100    /// Quantized mxbai-embed-large-v1
101    MxbaiEmbedLargeV1Q,
102    /// Alibaba-NLP/gte-base-en-v1.5 - 768 dimensions
103    GteBaseEnV15,
104    /// Quantized gte-base-en-v1.5
105    GteBaseEnV15Q,
106    /// Alibaba-NLP/gte-large-en-v1.5 - 1024 dimensions
107    GteLargeEnV15,
108    /// Quantized gte-large-en-v1.5
109    GteLargeEnV15Q,
110    /// Qdrant/clip-ViT-B-32-text - 512 dimensions, pairs with vision model
111    ClipVitB32,
112
113    // Code models
114    /// jinaai/jina-embeddings-v2-base-code - 768 dimensions
115    JinaEmbeddingsV2BaseCode,
116    // NOTE: JinaEmbeddingsV2BaseEN is not available in fastembed 5.5.0
117
118    // Modern models
119    /// google/embeddinggemma-300m - 768 dimensions
120    EmbeddingGemma300M,
121    /// lightonai/modernbert-embed-large - 1024 dimensions
122    ModernBertEmbedLarge,
123
124    // Snowflake Arctic models
125    /// snowflake/snowflake-arctic-embed-xs - 384 dimensions
126    SnowflakeArcticEmbedXs,
127    /// Quantized snowflake-arctic-embed-xs
128    SnowflakeArcticEmbedXsQ,
129    /// snowflake/snowflake-arctic-embed-s - 384 dimensions
130    SnowflakeArcticEmbedS,
131    /// Quantized snowflake-arctic-embed-s
132    SnowflakeArcticEmbedSQ,
133    /// snowflake/snowflake-arctic-embed-m - 768 dimensions
134    SnowflakeArcticEmbedM,
135    /// Quantized snowflake-arctic-embed-m
136    SnowflakeArcticEmbedMQ,
137    /// snowflake/snowflake-arctic-embed-m-long - 768 dimensions, 2048 context
138    SnowflakeArcticEmbedMLong,
139    /// Quantized snowflake-arctic-embed-m-long
140    SnowflakeArcticEmbedMLongQ,
141    /// snowflake/snowflake-arctic-embed-l - 1024 dimensions
142    SnowflakeArcticEmbedL,
143    /// Quantized snowflake-arctic-embed-l
144    SnowflakeArcticEmbedLQ,
145}
146
147impl EmbeddingModelType {
148    /// Convert to fastembed's EmbeddingModel enum
149    pub fn to_fastembed_model(&self) -> FastEmbedModel {
150        match self {
151            // Fast English
152            Self::BgeSmallEnV15 => FastEmbedModel::BGESmallENV15,
153            Self::BgeSmallEnV15Q => FastEmbedModel::BGESmallENV15Q,
154            Self::AllMiniLmL6V2 => FastEmbedModel::AllMiniLML6V2,
155            Self::AllMiniLmL6V2Q => FastEmbedModel::AllMiniLML6V2Q,
156            Self::AllMiniLmL12V2 => FastEmbedModel::AllMiniLML12V2,
157            Self::AllMiniLmL12V2Q => FastEmbedModel::AllMiniLML12V2Q,
158            Self::AllMpnetBaseV2 => FastEmbedModel::AllMpnetBaseV2,
159
160            // High quality English
161            Self::BgeBaseEnV15 => FastEmbedModel::BGEBaseENV15,
162            Self::BgeBaseEnV15Q => FastEmbedModel::BGEBaseENV15Q,
163            Self::BgeLargeEnV15 => FastEmbedModel::BGELargeENV15,
164            Self::BgeLargeEnV15Q => FastEmbedModel::BGELargeENV15Q,
165
166            // Multilingual
167            Self::MultilingualE5Small => FastEmbedModel::MultilingualE5Small,
168            Self::MultilingualE5Base => FastEmbedModel::MultilingualE5Base,
169            Self::MultilingualE5Large => FastEmbedModel::MultilingualE5Large,
170            Self::ParaphraseMiniLmL12V2 => FastEmbedModel::ParaphraseMLMiniLML12V2,
171            Self::ParaphraseMiniLmL12V2Q => FastEmbedModel::ParaphraseMLMiniLML12V2Q,
172            Self::ParaphraseMultilingualMpnetBaseV2 => FastEmbedModel::ParaphraseMLMpnetBaseV2,
173
174            // Chinese
175            Self::BgeSmallZhV15 => FastEmbedModel::BGESmallZHV15,
176            Self::BgeLargeZhV15 => FastEmbedModel::BGELargeZHV15,
177
178            // Long context
179            Self::NomicEmbedTextV1 => FastEmbedModel::NomicEmbedTextV1,
180            Self::NomicEmbedTextV15 => FastEmbedModel::NomicEmbedTextV15,
181            Self::NomicEmbedTextV15Q => FastEmbedModel::NomicEmbedTextV15Q,
182
183            // Specialized
184            Self::MxbaiEmbedLargeV1 => FastEmbedModel::MxbaiEmbedLargeV1,
185            Self::MxbaiEmbedLargeV1Q => FastEmbedModel::MxbaiEmbedLargeV1Q,
186            Self::GteBaseEnV15 => FastEmbedModel::GTEBaseENV15,
187            Self::GteBaseEnV15Q => FastEmbedModel::GTEBaseENV15Q,
188            Self::GteLargeEnV15 => FastEmbedModel::GTELargeENV15,
189            Self::GteLargeEnV15Q => FastEmbedModel::GTELargeENV15Q,
190            Self::ClipVitB32 => FastEmbedModel::ClipVitB32,
191
192            // Code
193            Self::JinaEmbeddingsV2BaseCode => FastEmbedModel::JinaEmbeddingsV2BaseCode,
194
195            // Modern
196            Self::EmbeddingGemma300M => FastEmbedModel::EmbeddingGemma300M,
197            Self::ModernBertEmbedLarge => FastEmbedModel::ModernBertEmbedLarge,
198
199            // Snowflake Arctic
200            Self::SnowflakeArcticEmbedXs => FastEmbedModel::SnowflakeArcticEmbedXS,
201            Self::SnowflakeArcticEmbedXsQ => FastEmbedModel::SnowflakeArcticEmbedXSQ,
202            Self::SnowflakeArcticEmbedS => FastEmbedModel::SnowflakeArcticEmbedS,
203            Self::SnowflakeArcticEmbedSQ => FastEmbedModel::SnowflakeArcticEmbedSQ,
204            Self::SnowflakeArcticEmbedM => FastEmbedModel::SnowflakeArcticEmbedM,
205            Self::SnowflakeArcticEmbedMQ => FastEmbedModel::SnowflakeArcticEmbedMQ,
206            Self::SnowflakeArcticEmbedMLong => FastEmbedModel::SnowflakeArcticEmbedMLong,
207            Self::SnowflakeArcticEmbedMLongQ => FastEmbedModel::SnowflakeArcticEmbedMLongQ,
208            Self::SnowflakeArcticEmbedL => FastEmbedModel::SnowflakeArcticEmbedL,
209            Self::SnowflakeArcticEmbedLQ => FastEmbedModel::SnowflakeArcticEmbedLQ,
210        }
211    }
212
213    /// Get the dimension of the embedding output
214    pub fn dimensions(&self) -> usize {
215        match self {
216            // 384 dimensions
217            Self::BgeSmallEnV15
218            | Self::BgeSmallEnV15Q
219            | Self::AllMiniLmL6V2
220            | Self::AllMiniLmL6V2Q
221            | Self::AllMiniLmL12V2
222            | Self::AllMiniLmL12V2Q
223            | Self::MultilingualE5Small
224            | Self::SnowflakeArcticEmbedXs
225            | Self::SnowflakeArcticEmbedXsQ
226            | Self::SnowflakeArcticEmbedS
227            | Self::SnowflakeArcticEmbedSQ => 384,
228
229            // 512 dimensions
230            Self::BgeSmallZhV15 | Self::ClipVitB32 => 512,
231
232            // 768 dimensions
233            Self::AllMpnetBaseV2
234            | Self::BgeBaseEnV15
235            | Self::BgeBaseEnV15Q
236            | Self::MultilingualE5Base
237            | Self::ParaphraseMiniLmL12V2
238            | Self::ParaphraseMiniLmL12V2Q
239            | Self::ParaphraseMultilingualMpnetBaseV2
240            | Self::NomicEmbedTextV1
241            | Self::NomicEmbedTextV15
242            | Self::NomicEmbedTextV15Q
243            | Self::GteBaseEnV15
244            | Self::GteBaseEnV15Q
245            | Self::JinaEmbeddingsV2BaseCode
246            | Self::EmbeddingGemma300M
247            | Self::SnowflakeArcticEmbedM
248            | Self::SnowflakeArcticEmbedMQ
249            | Self::SnowflakeArcticEmbedMLong
250            | Self::SnowflakeArcticEmbedMLongQ => 768,
251
252            // 1024 dimensions
253            Self::BgeLargeEnV15
254            | Self::BgeLargeEnV15Q
255            | Self::BgeLargeZhV15
256            | Self::MultilingualE5Large
257            | Self::MxbaiEmbedLargeV1
258            | Self::MxbaiEmbedLargeV1Q
259            | Self::GteLargeEnV15
260            | Self::GteLargeEnV15Q
261            | Self::ModernBertEmbedLarge
262            | Self::SnowflakeArcticEmbedL
263            | Self::SnowflakeArcticEmbedLQ => 1024,
264        }
265    }
266
267    /// Check if this is a quantized model
268    pub fn is_quantized(&self) -> bool {
269        matches!(
270            self,
271            Self::BgeSmallEnV15Q
272                | Self::AllMiniLmL6V2Q
273                | Self::AllMiniLmL12V2Q
274                | Self::BgeBaseEnV15Q
275                | Self::BgeLargeEnV15Q
276                | Self::ParaphraseMiniLmL12V2Q
277                | Self::NomicEmbedTextV15Q
278                | Self::MxbaiEmbedLargeV1Q
279                | Self::GteBaseEnV15Q
280                | Self::GteLargeEnV15Q
281                | Self::SnowflakeArcticEmbedXsQ
282                | Self::SnowflakeArcticEmbedSQ
283                | Self::SnowflakeArcticEmbedMQ
284                | Self::SnowflakeArcticEmbedMLongQ
285                | Self::SnowflakeArcticEmbedLQ
286        )
287    }
288
289    /// Check if this model supports multilingual text
290    pub fn is_multilingual(&self) -> bool {
291        matches!(
292            self,
293            Self::MultilingualE5Small
294                | Self::MultilingualE5Base
295                | Self::MultilingualE5Large
296                | Self::ParaphraseMultilingualMpnetBaseV2
297                | Self::BgeSmallZhV15
298                | Self::BgeLargeZhV15
299        )
300    }
301
302    /// Get the maximum context length in tokens
303    pub fn max_context_length(&self) -> usize {
304        match self {
305            Self::NomicEmbedTextV1 | Self::NomicEmbedTextV15 | Self::NomicEmbedTextV15Q => 8192,
306            Self::SnowflakeArcticEmbedMLong | Self::SnowflakeArcticEmbedMLongQ => 2048,
307            _ => 512,
308        }
309    }
310
311    /// List all available models
312    pub fn all() -> Vec<Self> {
313        vec![
314            Self::BgeSmallEnV15,
315            Self::BgeSmallEnV15Q,
316            Self::AllMiniLmL6V2,
317            Self::AllMiniLmL6V2Q,
318            Self::AllMiniLmL12V2,
319            Self::AllMiniLmL12V2Q,
320            Self::AllMpnetBaseV2,
321            Self::BgeBaseEnV15,
322            Self::BgeBaseEnV15Q,
323            Self::BgeLargeEnV15,
324            Self::BgeLargeEnV15Q,
325            Self::MultilingualE5Small,
326            Self::MultilingualE5Base,
327            Self::MultilingualE5Large,
328            Self::ParaphraseMiniLmL12V2,
329            Self::ParaphraseMiniLmL12V2Q,
330            Self::ParaphraseMultilingualMpnetBaseV2,
331            Self::BgeSmallZhV15,
332            Self::BgeLargeZhV15,
333            Self::NomicEmbedTextV1,
334            Self::NomicEmbedTextV15,
335            Self::NomicEmbedTextV15Q,
336            Self::MxbaiEmbedLargeV1,
337            Self::MxbaiEmbedLargeV1Q,
338            Self::GteBaseEnV15,
339            Self::GteBaseEnV15Q,
340            Self::GteLargeEnV15,
341            Self::GteLargeEnV15Q,
342            Self::ClipVitB32,
343            Self::JinaEmbeddingsV2BaseCode,
344            Self::EmbeddingGemma300M,
345            Self::ModernBertEmbedLarge,
346            Self::SnowflakeArcticEmbedXs,
347            Self::SnowflakeArcticEmbedXsQ,
348            Self::SnowflakeArcticEmbedS,
349            Self::SnowflakeArcticEmbedSQ,
350            Self::SnowflakeArcticEmbedM,
351            Self::SnowflakeArcticEmbedMQ,
352            Self::SnowflakeArcticEmbedMLong,
353            Self::SnowflakeArcticEmbedMLongQ,
354            Self::SnowflakeArcticEmbedL,
355            Self::SnowflakeArcticEmbedLQ,
356        ]
357    }
358}
359
360impl Display for EmbeddingModelType {
361    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362        let name = match self {
363            Self::BgeSmallEnV15 => "bge-small-en-v1.5",
364            Self::BgeSmallEnV15Q => "bge-small-en-v1.5-q",
365            Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
366            Self::AllMiniLmL6V2Q => "all-minilm-l6-v2-q",
367            Self::AllMiniLmL12V2 => "all-minilm-l12-v2",
368            Self::AllMiniLmL12V2Q => "all-minilm-l12-v2-q",
369            Self::AllMpnetBaseV2 => "all-mpnet-base-v2",
370            Self::BgeBaseEnV15 => "bge-base-en-v1.5",
371            Self::BgeBaseEnV15Q => "bge-base-en-v1.5-q",
372            Self::BgeLargeEnV15 => "bge-large-en-v1.5",
373            Self::BgeLargeEnV15Q => "bge-large-en-v1.5-q",
374            Self::MultilingualE5Small => "multilingual-e5-small",
375            Self::MultilingualE5Base => "multilingual-e5-base",
376            Self::MultilingualE5Large => "multilingual-e5-large",
377            Self::ParaphraseMiniLmL12V2 => "paraphrase-minilm-l12-v2",
378            Self::ParaphraseMiniLmL12V2Q => "paraphrase-minilm-l12-v2-q",
379            Self::ParaphraseMultilingualMpnetBaseV2 => "paraphrase-multilingual-mpnet-base-v2",
380            Self::BgeSmallZhV15 => "bge-small-zh-v1.5",
381            Self::BgeLargeZhV15 => "bge-large-zh-v1.5",
382            Self::NomicEmbedTextV1 => "nomic-embed-text-v1",
383            Self::NomicEmbedTextV15 => "nomic-embed-text-v1.5",
384            Self::NomicEmbedTextV15Q => "nomic-embed-text-v1.5-q",
385            Self::MxbaiEmbedLargeV1 => "mxbai-embed-large-v1",
386            Self::MxbaiEmbedLargeV1Q => "mxbai-embed-large-v1-q",
387            Self::GteBaseEnV15 => "gte-base-en-v1.5",
388            Self::GteBaseEnV15Q => "gte-base-en-v1.5-q",
389            Self::GteLargeEnV15 => "gte-large-en-v1.5",
390            Self::GteLargeEnV15Q => "gte-large-en-v1.5-q",
391            Self::ClipVitB32 => "clip-vit-b-32",
392            Self::JinaEmbeddingsV2BaseCode => "jina-embeddings-v2-base-code",
393            Self::EmbeddingGemma300M => "embedding-gemma-300m",
394            Self::ModernBertEmbedLarge => "modernbert-embed-large",
395            Self::SnowflakeArcticEmbedXs => "snowflake-arctic-embed-xs",
396            Self::SnowflakeArcticEmbedXsQ => "snowflake-arctic-embed-xs-q",
397            Self::SnowflakeArcticEmbedS => "snowflake-arctic-embed-s",
398            Self::SnowflakeArcticEmbedSQ => "snowflake-arctic-embed-s-q",
399            Self::SnowflakeArcticEmbedM => "snowflake-arctic-embed-m",
400            Self::SnowflakeArcticEmbedMQ => "snowflake-arctic-embed-m-q",
401            Self::SnowflakeArcticEmbedMLong => "snowflake-arctic-embed-m-long",
402            Self::SnowflakeArcticEmbedMLongQ => "snowflake-arctic-embed-m-long-q",
403            Self::SnowflakeArcticEmbedL => "snowflake-arctic-embed-l",
404            Self::SnowflakeArcticEmbedLQ => "snowflake-arctic-embed-l-q",
405        };
406        write!(f, "{}", name)
407    }
408}
409
410impl FromStr for EmbeddingModelType {
411    type Err = AppError;
412
413    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
414        match s.to_lowercase().as_str() {
415            "bge-small-en-v1.5" | "bge-small-en" | "bge-small" => Ok(Self::BgeSmallEnV15),
416            "bge-small-en-v1.5-q" => Ok(Self::BgeSmallEnV15Q),
417            "all-minilm-l6-v2" | "minilm-l6" => Ok(Self::AllMiniLmL6V2),
418            "all-minilm-l6-v2-q" => Ok(Self::AllMiniLmL6V2Q),
419            "all-minilm-l12-v2" | "minilm-l12" => Ok(Self::AllMiniLmL12V2),
420            "all-minilm-l12-v2-q" => Ok(Self::AllMiniLmL12V2Q),
421            "all-mpnet-base-v2" | "mpnet" => Ok(Self::AllMpnetBaseV2),
422            "bge-base-en-v1.5" | "bge-base-en" | "bge-base" => Ok(Self::BgeBaseEnV15),
423            "bge-base-en-v1.5-q" => Ok(Self::BgeBaseEnV15Q),
424            "bge-large-en-v1.5" | "bge-large-en" | "bge-large" => Ok(Self::BgeLargeEnV15),
425            "bge-large-en-v1.5-q" => Ok(Self::BgeLargeEnV15Q),
426            "multilingual-e5-small" | "e5-small" => Ok(Self::MultilingualE5Small),
427            "multilingual-e5-base" | "e5-base" => Ok(Self::MultilingualE5Base),
428            "multilingual-e5-large" | "e5-large" => Ok(Self::MultilingualE5Large),
429            "paraphrase-minilm-l12-v2" => Ok(Self::ParaphraseMiniLmL12V2),
430            "paraphrase-minilm-l12-v2-q" => Ok(Self::ParaphraseMiniLmL12V2Q),
431            "paraphrase-multilingual-mpnet-base-v2" => Ok(Self::ParaphraseMultilingualMpnetBaseV2),
432            "bge-small-zh-v1.5" | "bge-small-zh" => Ok(Self::BgeSmallZhV15),
433            "bge-large-zh-v1.5" | "bge-large-zh" => Ok(Self::BgeLargeZhV15),
434            "nomic-embed-text-v1" | "nomic-v1" => Ok(Self::NomicEmbedTextV1),
435            "nomic-embed-text-v1.5" | "nomic-v1.5" | "nomic" => Ok(Self::NomicEmbedTextV15),
436            "nomic-embed-text-v1.5-q" => Ok(Self::NomicEmbedTextV15Q),
437            "mxbai-embed-large-v1" | "mxbai" => Ok(Self::MxbaiEmbedLargeV1),
438            "mxbai-embed-large-v1-q" => Ok(Self::MxbaiEmbedLargeV1Q),
439            "gte-base-en-v1.5" | "gte-base" => Ok(Self::GteBaseEnV15),
440            "gte-base-en-v1.5-q" => Ok(Self::GteBaseEnV15Q),
441            "gte-large-en-v1.5" | "gte-large" => Ok(Self::GteLargeEnV15),
442            "gte-large-en-v1.5-q" => Ok(Self::GteLargeEnV15Q),
443            "clip-vit-b-32" | "clip" => Ok(Self::ClipVitB32),
444            "jina-embeddings-v2-base-code" | "jina-code" => Ok(Self::JinaEmbeddingsV2BaseCode),
445            "embedding-gemma-300m" | "gemma-300m" | "gemma" => Ok(Self::EmbeddingGemma300M),
446            "modernbert-embed-large" | "modernbert" => Ok(Self::ModernBertEmbedLarge),
447            "snowflake-arctic-embed-xs" => Ok(Self::SnowflakeArcticEmbedXs),
448            "snowflake-arctic-embed-xs-q" => Ok(Self::SnowflakeArcticEmbedXsQ),
449            "snowflake-arctic-embed-s" => Ok(Self::SnowflakeArcticEmbedS),
450            "snowflake-arctic-embed-s-q" => Ok(Self::SnowflakeArcticEmbedSQ),
451            "snowflake-arctic-embed-m" => Ok(Self::SnowflakeArcticEmbedM),
452            "snowflake-arctic-embed-m-q" => Ok(Self::SnowflakeArcticEmbedMQ),
453            "snowflake-arctic-embed-m-long" => Ok(Self::SnowflakeArcticEmbedMLong),
454            "snowflake-arctic-embed-m-long-q" => Ok(Self::SnowflakeArcticEmbedMLongQ),
455            "snowflake-arctic-embed-l" | "snowflake-l" => Ok(Self::SnowflakeArcticEmbedL),
456            "snowflake-arctic-embed-l-q" => Ok(Self::SnowflakeArcticEmbedLQ),
457            _ => Err(AppError::Internal(format!(
458                "Unknown embedding model: {}. Use one of: {}",
459                s,
460                EmbeddingModelType::all()
461                    .iter()
462                    .map(|m| m.to_string())
463                    .collect::<Vec<_>>()
464                    .join(", ")
465            ))),
466        }
467    }
468}
469
470// ============================================================================
471// Sparse Embedding Model Configuration
472// ============================================================================
473
474/// Supported sparse embedding models for hybrid search
475#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
476#[serde(rename_all = "kebab-case")]
477pub enum SparseModelType {
478    /// SPLADE++ v1 - English sparse embeddings
479    #[default]
480    SpladePpV1,
481    // NOTE: BGE-M3 sparse mode is not available in fastembed 5.5.0
482}
483
484impl SparseModelType {
485    /// Convert to fastembed's SparseModel enum
486    pub fn to_fastembed_model(&self) -> SparseModel {
487        match self {
488            Self::SpladePpV1 => SparseModel::SPLADEPPV1,
489        }
490    }
491}
492
493impl Display for SparseModelType {
494    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495        let name = match self {
496            Self::SpladePpV1 => "splade-pp-v1",
497        };
498        write!(f, "{}", name)
499    }
500}
501
502impl FromStr for SparseModelType {
503    type Err = AppError;
504
505    fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
506        match s.to_lowercase().as_str() {
507            "splade-pp-v1" | "splade" => Ok(Self::SpladePpV1),
508            _ => Err(AppError::Internal(format!(
509                "Unknown sparse model: {}. Use: splade-pp-v1",
510                s
511            ))),
512        }
513    }
514}
515
516// ============================================================================
517// Embedding Service Configuration
518// ============================================================================
519
520/// Configuration for the embedding service
521#[derive(Debug, Clone, Serialize, Deserialize)]
522pub struct EmbeddingConfig {
523    /// The embedding model to use
524    #[serde(default)]
525    pub model: EmbeddingModelType,
526
527    /// Batch size for embedding multiple texts
528    #[serde(default = "default_batch_size")]
529    pub batch_size: usize,
530
531    /// Show download progress for first-time model downloads
532    #[serde(default = "default_show_progress")]
533    pub show_download_progress: bool,
534
535    /// Enable sparse embeddings for hybrid search
536    #[serde(default)]
537    pub sparse_enabled: bool,
538
539    /// Sparse embedding model to use
540    #[serde(default)]
541    pub sparse_model: SparseModelType,
542}
543
544fn default_batch_size() -> usize {
545    32
546}
547
548fn default_show_progress() -> bool {
549    true
550}
551
552impl Default for EmbeddingConfig {
553    fn default() -> Self {
554        Self {
555            model: EmbeddingModelType::default(),
556            batch_size: default_batch_size(),
557            show_download_progress: default_show_progress(),
558            sparse_enabled: false,
559            sparse_model: SparseModelType::default(),
560        }
561    }
562}
563
564// ============================================================================
565// Embedding Service
566// ============================================================================
567
568/// Main embedding service for generating text embeddings
569///
570/// Uses `spawn_blocking` to run fastembed's synchronous operations
571/// without blocking the async runtime.
572pub struct EmbeddingService {
573    #[allow(dead_code)]
574    model: TextEmbedding,
575    #[allow(dead_code)]
576    sparse_model: Option<fastembed::SparseTextEmbedding>,
577    config: EmbeddingConfig,
578}
579
580impl EmbeddingService {
581    /// Create a new embedding service with the given configuration
582    pub fn new(config: EmbeddingConfig) -> Result<Self> {
583        let model = TextEmbedding::try_new(
584            InitOptions::new(config.model.to_fastembed_model())
585                .with_show_download_progress(config.show_download_progress),
586        )
587        .map_err(|e| AppError::Internal(format!("Failed to initialize embedding model: {}", e)))?;
588
589        let sparse_model = if config.sparse_enabled {
590            Some(
591                fastembed::SparseTextEmbedding::try_new(
592                    fastembed::SparseInitOptions::new(config.sparse_model.to_fastembed_model())
593                        .with_show_download_progress(config.show_download_progress),
594                )
595                .map_err(|e| {
596                    AppError::Internal(format!("Failed to initialize sparse embedding model: {}", e))
597                })?,
598            )
599        } else {
600            None
601        };
602
603        Ok(Self {
604            model,
605            sparse_model,
606            config,
607        })
608    }
609
610    /// Create a new embedding service with the default model
611    pub fn with_default_model() -> Result<Self> {
612        Self::new(EmbeddingConfig::default())
613    }
614
615    /// Create a new embedding service with a specific model
616    pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
617        Self::new(EmbeddingConfig {
618            model,
619            ..Default::default()
620        })
621    }
622
623    /// Get the current model type
624    pub fn model_type(&self) -> EmbeddingModelType {
625        self.config.model
626    }
627
628    /// Get the embedding dimensions
629    pub fn dimensions(&self) -> usize {
630        self.config.model.dimensions()
631    }
632
633    /// Get the configuration
634    pub fn config(&self) -> &EmbeddingConfig {
635        &self.config
636    }
637
638    /// Embed a single text (async via spawn_blocking)
639    pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
640        let embeddings = self.embed_texts(&[text.to_string()]).await?;
641        embeddings
642            .into_iter()
643            .next()
644            .ok_or_else(|| AppError::Internal("No embedding generated".to_string()))
645    }
646
647    /// Embed multiple texts in batches (async via spawn_blocking)
648    ///
649    /// This is more efficient than calling `embed_text` multiple times
650    /// as it batches the texts and processes them together.
651    pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
652        &self,
653        texts: &[S],
654    ) -> Result<Vec<Vec<f32>>> {
655        if texts.is_empty() {
656            return Ok(vec![]);
657        }
658
659        // Clone texts to owned strings for the spawn_blocking closure
660        let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
661        let batch_size = self.config.batch_size;
662
663        // Clone the model config for the blocking task
664        let model_type = self.config.model.to_fastembed_model();
665        let show_progress = self.config.show_download_progress;
666
667        spawn_blocking(move || {
668            // Create model in the blocking context
669            let mut model = TextEmbedding::try_new(
670                InitOptions::new(model_type).with_show_download_progress(show_progress),
671            )
672            .map_err(|e| {
673                AppError::Internal(format!("Failed to initialize embedding model: {}", e))
674            })?;
675
676            let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
677            model
678                .embed(refs, Some(batch_size))
679                .map_err(|e| AppError::Internal(format!("Embedding failed: {}", e)))
680        })
681        .await
682        .map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
683    }
684
685    /// Generate sparse embeddings for hybrid search
686    pub async fn embed_sparse<S: AsRef<str> + Send + Sync + 'static>(
687        &self,
688        texts: &[S],
689    ) -> Result<Vec<fastembed::SparseEmbedding>> {
690        if self.sparse_model.is_none() {
691            return Err(AppError::Internal(
692                "Sparse embeddings not enabled. Set sparse_enabled: true in config.".to_string(),
693            ));
694        }
695
696        let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
697        let batch_size = self.config.batch_size;
698        let sparse_model_type = self.config.sparse_model.to_fastembed_model();
699        let show_progress = self.config.show_download_progress;
700
701        spawn_blocking(move || {
702            let mut model = fastembed::SparseTextEmbedding::try_new(
703                fastembed::SparseInitOptions::new(sparse_model_type)
704                    .with_show_download_progress(show_progress),
705            )
706            .map_err(|e| {
707                AppError::Internal(format!("Failed to initialize sparse model: {}", e))
708            })?;
709
710            let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
711            model
712                .embed(refs, Some(batch_size))
713                .map_err(|e| AppError::Internal(format!("Sparse embedding failed: {}", e)))
714        })
715        .await
716        .map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
717    }
718}
719
720// ============================================================================
721// GPU Acceleration Stubs (TODO)
722// ============================================================================
723
724/// GPU acceleration backend (STUB - see docs/FUTURE_ENHANCEMENTS.md)
725///
726/// This enum represents potential GPU acceleration options for embedding models.
727/// Currently not implemented - all models run on CPU.
728///
729/// # Future Implementation
730///
731/// - **CUDA**: NVIDIA GPU acceleration via ONNX Runtime CUDA provider
732/// - **Metal**: Apple Silicon GPU acceleration via ONNX Runtime CoreML provider
733/// - **Vulkan**: Cross-platform GPU acceleration via ONNX Runtime Vulkan provider
734/// - **Candle**: GPU support for Qwen3 models via Candle's CUDA backend
735#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
736#[serde(rename_all = "lowercase")]
737#[allow(dead_code)]
738pub enum AccelerationBackend {
739    /// CPU execution (default, always available)
740    Cpu,
741    /// NVIDIA CUDA acceleration
742    Cuda { device_id: usize },
743    /// Apple Metal acceleration
744    Metal,
745    /// Vulkan GPU acceleration
746    Vulkan,
747}
748
749impl Default for AccelerationBackend {
750    fn default() -> Self {
751        Self::Cpu
752    }
753}
754
755// ============================================================================
756// Legacy API Compatibility
757// ============================================================================
758
759/// Legacy embedding service for backward compatibility
760///
761/// This preserves the original API for existing code.
762#[deprecated(note = "Use EmbeddingService instead")]
763pub struct LegacyEmbeddingService {
764    inner: EmbeddingService,
765}
766
767#[allow(deprecated)]
768impl LegacyEmbeddingService {
769    /// Create a new legacy embedding service
770    pub fn new(_model_name: &str) -> Result<Self> {
771        Ok(Self {
772            inner: EmbeddingService::with_default_model()?,
773        })
774    }
775
776    /// Embed texts (synchronous API)
777    pub fn embed(&mut self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
778        let model_type = self.inner.config.model.to_fastembed_model();
779        let mut model = TextEmbedding::try_new(
780            InitOptions::new(model_type).with_show_download_progress(true),
781        )
782        .map_err(|e| AppError::Internal(e.to_string()))?;
783
784        model
785            .embed(texts, None)
786            .map_err(|e| AppError::Internal(e.to_string()))
787    }
788}
789
790// ============================================================================
791// Tests
792// ============================================================================
793
794#[cfg(test)]
795mod tests {
796    use super::*;
797
798    #[test]
799    fn test_model_dimensions() {
800        assert_eq!(EmbeddingModelType::BgeSmallEnV15.dimensions(), 384);
801        assert_eq!(EmbeddingModelType::BgeBaseEnV15.dimensions(), 768);
802        assert_eq!(EmbeddingModelType::BgeLargeEnV15.dimensions(), 1024);
803        assert_eq!(EmbeddingModelType::MultilingualE5Large.dimensions(), 1024);
804    }
805
806    #[test]
807    fn test_model_from_str() {
808        assert_eq!(
809            "bge-small-en-v1.5".parse::<EmbeddingModelType>().unwrap(),
810            EmbeddingModelType::BgeSmallEnV15
811        );
812        assert_eq!(
813            "multilingual-e5-large".parse::<EmbeddingModelType>().unwrap(),
814            EmbeddingModelType::MultilingualE5Large
815        );
816        assert_eq!(
817            "minilm-l6".parse::<EmbeddingModelType>().unwrap(),
818            EmbeddingModelType::AllMiniLmL6V2
819        );
820    }
821
822    #[test]
823    fn test_model_is_multilingual() {
824        assert!(EmbeddingModelType::MultilingualE5Small.is_multilingual());
825        assert!(EmbeddingModelType::MultilingualE5Large.is_multilingual());
826        assert!(!EmbeddingModelType::BgeSmallEnV15.is_multilingual());
827    }
828
829    #[test]
830    fn test_model_max_context() {
831        assert_eq!(
832            EmbeddingModelType::NomicEmbedTextV15.max_context_length(),
833            8192
834        );
835        assert_eq!(
836            EmbeddingModelType::NomicEmbedTextV1.max_context_length(),
837            8192
838        );
839        assert_eq!(
840            EmbeddingModelType::BgeSmallEnV15.max_context_length(),
841            512
842        );
843    }
844
845    #[test]
846    fn test_default_config() {
847        let config = EmbeddingConfig::default();
848        assert_eq!(config.model, EmbeddingModelType::BgeSmallEnV15);
849        assert_eq!(config.batch_size, 32);
850        assert!(config.show_download_progress);
851        assert!(!config.sparse_enabled);
852    }
853
854    #[test]
855    fn test_all_models_listed() {
856        let all = EmbeddingModelType::all();
857        assert!(all.len() >= 38); // We have 38+ models
858        assert!(all.contains(&EmbeddingModelType::BgeSmallEnV15));
859        assert!(all.contains(&EmbeddingModelType::MultilingualE5Large));
860    }
861}