1use crate::types::{AppError, Result};
22use serde::{Deserialize, Serialize};
23use std::collections::HashMap;
24use std::fmt::Display;
25use std::str::FromStr;
26use std::sync::{Arc, Mutex, OnceLock};
27use tokio::task::spawn_blocking;
29
30pub use fastembed::{EmbeddingModel as FastEmbedModel, InitOptions, SparseModel, TextEmbedding};
32
33static MODEL_INIT_LOCKS: OnceLock<Mutex<HashMap<String, Arc<Mutex<()>>>>> = OnceLock::new();
36
37fn get_model_lock(model_name: &str) -> Arc<Mutex<()>> {
39 let locks = MODEL_INIT_LOCKS.get_or_init(|| Mutex::new(HashMap::new()));
40 let mut map = locks.lock().unwrap();
41 map.entry(model_name.to_string())
42 .or_insert_with(|| Arc::new(Mutex::new(())))
43 .clone()
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
55#[serde(rename_all = "kebab-case")]
56pub enum EmbeddingModelType {
57 #[default]
60 BgeSmallEnV15,
61 BgeSmallEnV15Q,
63 AllMiniLmL6V2,
65 AllMiniLmL6V2Q,
67 AllMiniLmL12V2,
69 AllMiniLmL12V2Q,
71 AllMpnetBaseV2,
73
74 BgeBaseEnV15,
77 BgeBaseEnV15Q,
79 BgeLargeEnV15,
81 BgeLargeEnV15Q,
83
84 MultilingualE5Small,
88 MultilingualE5Base,
90 MultilingualE5Large,
92 ParaphraseMiniLmL12V2,
94 ParaphraseMiniLmL12V2Q,
96 ParaphraseMultilingualMpnetBaseV2,
98
99 BgeSmallZhV15,
102 BgeLargeZhV15,
104
105 NomicEmbedTextV1,
108 NomicEmbedTextV15,
110 NomicEmbedTextV15Q,
112
113 MxbaiEmbedLargeV1,
116 MxbaiEmbedLargeV1Q,
118 GteBaseEnV15,
120 GteBaseEnV15Q,
122 GteLargeEnV15,
124 GteLargeEnV15Q,
126 ClipVitB32,
128
129 JinaEmbeddingsV2BaseCode,
132 EmbeddingGemma300M,
137 ModernBertEmbedLarge,
139
140 SnowflakeArcticEmbedXs,
143 SnowflakeArcticEmbedXsQ,
145 SnowflakeArcticEmbedS,
147 SnowflakeArcticEmbedSQ,
149 SnowflakeArcticEmbedM,
151 SnowflakeArcticEmbedMQ,
153 SnowflakeArcticEmbedMLong,
155 SnowflakeArcticEmbedMLongQ,
157 SnowflakeArcticEmbedL,
159 SnowflakeArcticEmbedLQ,
161}
162
163impl EmbeddingModelType {
164 pub fn to_fastembed_model(&self) -> FastEmbedModel {
166 match self {
167 Self::BgeSmallEnV15 => FastEmbedModel::BGESmallENV15,
169 Self::BgeSmallEnV15Q => FastEmbedModel::BGESmallENV15Q,
170 Self::AllMiniLmL6V2 => FastEmbedModel::AllMiniLML6V2,
171 Self::AllMiniLmL6V2Q => FastEmbedModel::AllMiniLML6V2Q,
172 Self::AllMiniLmL12V2 => FastEmbedModel::AllMiniLML12V2,
173 Self::AllMiniLmL12V2Q => FastEmbedModel::AllMiniLML12V2Q,
174 Self::AllMpnetBaseV2 => FastEmbedModel::AllMpnetBaseV2,
175
176 Self::BgeBaseEnV15 => FastEmbedModel::BGEBaseENV15,
178 Self::BgeBaseEnV15Q => FastEmbedModel::BGEBaseENV15Q,
179 Self::BgeLargeEnV15 => FastEmbedModel::BGELargeENV15,
180 Self::BgeLargeEnV15Q => FastEmbedModel::BGELargeENV15Q,
181
182 Self::MultilingualE5Small => FastEmbedModel::MultilingualE5Small,
184 Self::MultilingualE5Base => FastEmbedModel::MultilingualE5Base,
185 Self::MultilingualE5Large => FastEmbedModel::MultilingualE5Large,
186 Self::ParaphraseMiniLmL12V2 => FastEmbedModel::ParaphraseMLMiniLML12V2,
187 Self::ParaphraseMiniLmL12V2Q => FastEmbedModel::ParaphraseMLMiniLML12V2Q,
188 Self::ParaphraseMultilingualMpnetBaseV2 => FastEmbedModel::ParaphraseMLMpnetBaseV2,
189
190 Self::BgeSmallZhV15 => FastEmbedModel::BGESmallZHV15,
192 Self::BgeLargeZhV15 => FastEmbedModel::BGELargeZHV15,
193
194 Self::NomicEmbedTextV1 => FastEmbedModel::NomicEmbedTextV1,
196 Self::NomicEmbedTextV15 => FastEmbedModel::NomicEmbedTextV15,
197 Self::NomicEmbedTextV15Q => FastEmbedModel::NomicEmbedTextV15Q,
198
199 Self::MxbaiEmbedLargeV1 => FastEmbedModel::MxbaiEmbedLargeV1,
201 Self::MxbaiEmbedLargeV1Q => FastEmbedModel::MxbaiEmbedLargeV1Q,
202 Self::GteBaseEnV15 => FastEmbedModel::GTEBaseENV15,
203 Self::GteBaseEnV15Q => FastEmbedModel::GTEBaseENV15Q,
204 Self::GteLargeEnV15 => FastEmbedModel::GTELargeENV15,
205 Self::GteLargeEnV15Q => FastEmbedModel::GTELargeENV15Q,
206 Self::ClipVitB32 => FastEmbedModel::ClipVitB32,
207
208 Self::JinaEmbeddingsV2BaseCode => FastEmbedModel::JinaEmbeddingsV2BaseCode,
210
211 Self::EmbeddingGemma300M => FastEmbedModel::EmbeddingGemma300M,
213 Self::ModernBertEmbedLarge => FastEmbedModel::ModernBertEmbedLarge,
214
215 Self::SnowflakeArcticEmbedXs => FastEmbedModel::SnowflakeArcticEmbedXS,
217 Self::SnowflakeArcticEmbedXsQ => FastEmbedModel::SnowflakeArcticEmbedXSQ,
218 Self::SnowflakeArcticEmbedS => FastEmbedModel::SnowflakeArcticEmbedS,
219 Self::SnowflakeArcticEmbedSQ => FastEmbedModel::SnowflakeArcticEmbedSQ,
220 Self::SnowflakeArcticEmbedM => FastEmbedModel::SnowflakeArcticEmbedM,
221 Self::SnowflakeArcticEmbedMQ => FastEmbedModel::SnowflakeArcticEmbedMQ,
222 Self::SnowflakeArcticEmbedMLong => FastEmbedModel::SnowflakeArcticEmbedMLong,
223 Self::SnowflakeArcticEmbedMLongQ => FastEmbedModel::SnowflakeArcticEmbedMLongQ,
224 Self::SnowflakeArcticEmbedL => FastEmbedModel::SnowflakeArcticEmbedL,
225 Self::SnowflakeArcticEmbedLQ => FastEmbedModel::SnowflakeArcticEmbedLQ,
226 }
227 }
228
229 pub fn dimensions(&self) -> usize {
231 match self {
232 Self::BgeSmallEnV15
234 | Self::BgeSmallEnV15Q
235 | Self::AllMiniLmL6V2
236 | Self::AllMiniLmL6V2Q
237 | Self::AllMiniLmL12V2
238 | Self::AllMiniLmL12V2Q
239 | Self::MultilingualE5Small
240 | Self::SnowflakeArcticEmbedXs
241 | Self::SnowflakeArcticEmbedXsQ
242 | Self::SnowflakeArcticEmbedS
243 | Self::SnowflakeArcticEmbedSQ => 384,
244
245 Self::BgeSmallZhV15 | Self::ClipVitB32 => 512,
247
248 Self::AllMpnetBaseV2
250 | Self::BgeBaseEnV15
251 | Self::BgeBaseEnV15Q
252 | Self::MultilingualE5Base
253 | Self::ParaphraseMiniLmL12V2
254 | Self::ParaphraseMiniLmL12V2Q
255 | Self::ParaphraseMultilingualMpnetBaseV2
256 | Self::NomicEmbedTextV1
257 | Self::NomicEmbedTextV15
258 | Self::NomicEmbedTextV15Q
259 | Self::GteBaseEnV15
260 | Self::GteBaseEnV15Q
261 | Self::JinaEmbeddingsV2BaseCode
262 | Self::EmbeddingGemma300M
263 | Self::SnowflakeArcticEmbedM
264 | Self::SnowflakeArcticEmbedMQ
265 | Self::SnowflakeArcticEmbedMLong
266 | Self::SnowflakeArcticEmbedMLongQ => 768,
267
268 Self::BgeLargeEnV15
270 | Self::BgeLargeEnV15Q
271 | Self::BgeLargeZhV15
272 | Self::MultilingualE5Large
273 | Self::MxbaiEmbedLargeV1
274 | Self::MxbaiEmbedLargeV1Q
275 | Self::GteLargeEnV15
276 | Self::GteLargeEnV15Q
277 | Self::ModernBertEmbedLarge
278 | Self::SnowflakeArcticEmbedL
279 | Self::SnowflakeArcticEmbedLQ => 1024,
280 }
281 }
282
283 pub fn is_quantized(&self) -> bool {
285 matches!(
286 self,
287 Self::BgeSmallEnV15Q
288 | Self::AllMiniLmL6V2Q
289 | Self::AllMiniLmL12V2Q
290 | Self::BgeBaseEnV15Q
291 | Self::BgeLargeEnV15Q
292 | Self::ParaphraseMiniLmL12V2Q
293 | Self::NomicEmbedTextV15Q
294 | Self::MxbaiEmbedLargeV1Q
295 | Self::GteBaseEnV15Q
296 | Self::GteLargeEnV15Q
297 | Self::SnowflakeArcticEmbedXsQ
298 | Self::SnowflakeArcticEmbedSQ
299 | Self::SnowflakeArcticEmbedMQ
300 | Self::SnowflakeArcticEmbedMLongQ
301 | Self::SnowflakeArcticEmbedLQ
302 )
303 }
304
305 pub fn is_multilingual(&self) -> bool {
307 matches!(
308 self,
309 Self::MultilingualE5Small
310 | Self::MultilingualE5Base
311 | Self::MultilingualE5Large
312 | Self::ParaphraseMultilingualMpnetBaseV2
313 | Self::BgeSmallZhV15
314 | Self::BgeLargeZhV15
315 )
316 }
317
318 pub fn max_context_length(&self) -> usize {
320 match self {
321 Self::NomicEmbedTextV1 | Self::NomicEmbedTextV15 | Self::NomicEmbedTextV15Q => 8192,
322 Self::SnowflakeArcticEmbedMLong | Self::SnowflakeArcticEmbedMLongQ => 2048,
323 _ => 512,
324 }
325 }
326
327 pub fn all() -> Vec<Self> {
329 vec![
330 Self::BgeSmallEnV15,
331 Self::BgeSmallEnV15Q,
332 Self::AllMiniLmL6V2,
333 Self::AllMiniLmL6V2Q,
334 Self::AllMiniLmL12V2,
335 Self::AllMiniLmL12V2Q,
336 Self::AllMpnetBaseV2,
337 Self::BgeBaseEnV15,
338 Self::BgeBaseEnV15Q,
339 Self::BgeLargeEnV15,
340 Self::BgeLargeEnV15Q,
341 Self::MultilingualE5Small,
342 Self::MultilingualE5Base,
343 Self::MultilingualE5Large,
344 Self::ParaphraseMiniLmL12V2,
345 Self::ParaphraseMiniLmL12V2Q,
346 Self::ParaphraseMultilingualMpnetBaseV2,
347 Self::BgeSmallZhV15,
348 Self::BgeLargeZhV15,
349 Self::NomicEmbedTextV1,
350 Self::NomicEmbedTextV15,
351 Self::NomicEmbedTextV15Q,
352 Self::MxbaiEmbedLargeV1,
353 Self::MxbaiEmbedLargeV1Q,
354 Self::GteBaseEnV15,
355 Self::GteBaseEnV15Q,
356 Self::GteLargeEnV15,
357 Self::GteLargeEnV15Q,
358 Self::ClipVitB32,
359 Self::JinaEmbeddingsV2BaseCode,
360 Self::EmbeddingGemma300M,
361 Self::ModernBertEmbedLarge,
362 Self::SnowflakeArcticEmbedXs,
363 Self::SnowflakeArcticEmbedXsQ,
364 Self::SnowflakeArcticEmbedS,
365 Self::SnowflakeArcticEmbedSQ,
366 Self::SnowflakeArcticEmbedM,
367 Self::SnowflakeArcticEmbedMQ,
368 Self::SnowflakeArcticEmbedMLong,
369 Self::SnowflakeArcticEmbedMLongQ,
370 Self::SnowflakeArcticEmbedL,
371 Self::SnowflakeArcticEmbedLQ,
372 ]
373 }
374}
375
376impl Display for EmbeddingModelType {
377 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
378 let name = match self {
379 Self::BgeSmallEnV15 => "bge-small-en-v1.5",
380 Self::BgeSmallEnV15Q => "bge-small-en-v1.5-q",
381 Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
382 Self::AllMiniLmL6V2Q => "all-minilm-l6-v2-q",
383 Self::AllMiniLmL12V2 => "all-minilm-l12-v2",
384 Self::AllMiniLmL12V2Q => "all-minilm-l12-v2-q",
385 Self::AllMpnetBaseV2 => "all-mpnet-base-v2",
386 Self::BgeBaseEnV15 => "bge-base-en-v1.5",
387 Self::BgeBaseEnV15Q => "bge-base-en-v1.5-q",
388 Self::BgeLargeEnV15 => "bge-large-en-v1.5",
389 Self::BgeLargeEnV15Q => "bge-large-en-v1.5-q",
390 Self::MultilingualE5Small => "multilingual-e5-small",
391 Self::MultilingualE5Base => "multilingual-e5-base",
392 Self::MultilingualE5Large => "multilingual-e5-large",
393 Self::ParaphraseMiniLmL12V2 => "paraphrase-minilm-l12-v2",
394 Self::ParaphraseMiniLmL12V2Q => "paraphrase-minilm-l12-v2-q",
395 Self::ParaphraseMultilingualMpnetBaseV2 => "paraphrase-multilingual-mpnet-base-v2",
396 Self::BgeSmallZhV15 => "bge-small-zh-v1.5",
397 Self::BgeLargeZhV15 => "bge-large-zh-v1.5",
398 Self::NomicEmbedTextV1 => "nomic-embed-text-v1",
399 Self::NomicEmbedTextV15 => "nomic-embed-text-v1.5",
400 Self::NomicEmbedTextV15Q => "nomic-embed-text-v1.5-q",
401 Self::MxbaiEmbedLargeV1 => "mxbai-embed-large-v1",
402 Self::MxbaiEmbedLargeV1Q => "mxbai-embed-large-v1-q",
403 Self::GteBaseEnV15 => "gte-base-en-v1.5",
404 Self::GteBaseEnV15Q => "gte-base-en-v1.5-q",
405 Self::GteLargeEnV15 => "gte-large-en-v1.5",
406 Self::GteLargeEnV15Q => "gte-large-en-v1.5-q",
407 Self::ClipVitB32 => "clip-vit-b-32",
408 Self::JinaEmbeddingsV2BaseCode => "jina-embeddings-v2-base-code",
409 Self::EmbeddingGemma300M => "embedding-gemma-300m",
410 Self::ModernBertEmbedLarge => "modernbert-embed-large",
411 Self::SnowflakeArcticEmbedXs => "snowflake-arctic-embed-xs",
412 Self::SnowflakeArcticEmbedXsQ => "snowflake-arctic-embed-xs-q",
413 Self::SnowflakeArcticEmbedS => "snowflake-arctic-embed-s",
414 Self::SnowflakeArcticEmbedSQ => "snowflake-arctic-embed-s-q",
415 Self::SnowflakeArcticEmbedM => "snowflake-arctic-embed-m",
416 Self::SnowflakeArcticEmbedMQ => "snowflake-arctic-embed-m-q",
417 Self::SnowflakeArcticEmbedMLong => "snowflake-arctic-embed-m-long",
418 Self::SnowflakeArcticEmbedMLongQ => "snowflake-arctic-embed-m-long-q",
419 Self::SnowflakeArcticEmbedL => "snowflake-arctic-embed-l",
420 Self::SnowflakeArcticEmbedLQ => "snowflake-arctic-embed-l-q",
421 };
422 write!(f, "{}", name)
423 }
424}
425
426impl FromStr for EmbeddingModelType {
427 type Err = AppError;
428
429 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
430 match s.to_lowercase().as_str() {
431 "bge-small-en-v1.5" | "bge-small-en" | "bge-small" => Ok(Self::BgeSmallEnV15),
432 "bge-small-en-v1.5-q" => Ok(Self::BgeSmallEnV15Q),
433 "all-minilm-l6-v2" | "minilm-l6" => Ok(Self::AllMiniLmL6V2),
434 "all-minilm-l6-v2-q" => Ok(Self::AllMiniLmL6V2Q),
435 "all-minilm-l12-v2" | "minilm-l12" => Ok(Self::AllMiniLmL12V2),
436 "all-minilm-l12-v2-q" => Ok(Self::AllMiniLmL12V2Q),
437 "all-mpnet-base-v2" | "mpnet" => Ok(Self::AllMpnetBaseV2),
438 "bge-base-en-v1.5" | "bge-base-en" | "bge-base" => Ok(Self::BgeBaseEnV15),
439 "bge-base-en-v1.5-q" => Ok(Self::BgeBaseEnV15Q),
440 "bge-large-en-v1.5" | "bge-large-en" | "bge-large" => Ok(Self::BgeLargeEnV15),
441 "bge-large-en-v1.5-q" => Ok(Self::BgeLargeEnV15Q),
442 "multilingual-e5-small" | "e5-small" => Ok(Self::MultilingualE5Small),
443 "multilingual-e5-base" | "e5-base" => Ok(Self::MultilingualE5Base),
444 "multilingual-e5-large" | "e5-large" => Ok(Self::MultilingualE5Large),
445 "paraphrase-minilm-l12-v2" => Ok(Self::ParaphraseMiniLmL12V2),
446 "paraphrase-minilm-l12-v2-q" => Ok(Self::ParaphraseMiniLmL12V2Q),
447 "paraphrase-multilingual-mpnet-base-v2" => Ok(Self::ParaphraseMultilingualMpnetBaseV2),
448 "bge-small-zh-v1.5" | "bge-small-zh" => Ok(Self::BgeSmallZhV15),
449 "bge-large-zh-v1.5" | "bge-large-zh" => Ok(Self::BgeLargeZhV15),
450 "nomic-embed-text-v1" | "nomic-v1" => Ok(Self::NomicEmbedTextV1),
451 "nomic-embed-text-v1.5" | "nomic-v1.5" | "nomic" => Ok(Self::NomicEmbedTextV15),
452 "nomic-embed-text-v1.5-q" => Ok(Self::NomicEmbedTextV15Q),
453 "mxbai-embed-large-v1" | "mxbai" => Ok(Self::MxbaiEmbedLargeV1),
454 "mxbai-embed-large-v1-q" => Ok(Self::MxbaiEmbedLargeV1Q),
455 "gte-base-en-v1.5" | "gte-base" => Ok(Self::GteBaseEnV15),
456 "gte-base-en-v1.5-q" => Ok(Self::GteBaseEnV15Q),
457 "gte-large-en-v1.5" | "gte-large" => Ok(Self::GteLargeEnV15),
458 "gte-large-en-v1.5-q" => Ok(Self::GteLargeEnV15Q),
459 "clip-vit-b-32" | "clip" => Ok(Self::ClipVitB32),
460 "jina-embeddings-v2-base-code" | "jina-code" => Ok(Self::JinaEmbeddingsV2BaseCode),
461 "embedding-gemma-300m" | "gemma-300m" | "gemma" => Ok(Self::EmbeddingGemma300M),
462 "modernbert-embed-large" | "modernbert" => Ok(Self::ModernBertEmbedLarge),
463 "snowflake-arctic-embed-xs" => Ok(Self::SnowflakeArcticEmbedXs),
464 "snowflake-arctic-embed-xs-q" => Ok(Self::SnowflakeArcticEmbedXsQ),
465 "snowflake-arctic-embed-s" => Ok(Self::SnowflakeArcticEmbedS),
466 "snowflake-arctic-embed-s-q" => Ok(Self::SnowflakeArcticEmbedSQ),
467 "snowflake-arctic-embed-m" => Ok(Self::SnowflakeArcticEmbedM),
468 "snowflake-arctic-embed-m-q" => Ok(Self::SnowflakeArcticEmbedMQ),
469 "snowflake-arctic-embed-m-long" => Ok(Self::SnowflakeArcticEmbedMLong),
470 "snowflake-arctic-embed-m-long-q" => Ok(Self::SnowflakeArcticEmbedMLongQ),
471 "snowflake-arctic-embed-l" | "snowflake-l" => Ok(Self::SnowflakeArcticEmbedL),
472 "snowflake-arctic-embed-l-q" => Ok(Self::SnowflakeArcticEmbedLQ),
473 _ => Err(AppError::Internal(format!(
474 "Unknown embedding model: {}. Use one of: {}",
475 s,
476 EmbeddingModelType::all()
477 .iter()
478 .map(|m| m.to_string())
479 .collect::<Vec<_>>()
480 .join(", ")
481 ))),
482 }
483 }
484}
485
486#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
492#[serde(rename_all = "kebab-case")]
493pub enum SparseModelType {
494 #[default]
496 SpladePpV1,
497 }
499
500impl SparseModelType {
501 pub fn to_fastembed_model(&self) -> SparseModel {
503 match self {
504 Self::SpladePpV1 => SparseModel::SPLADEPPV1,
505 }
506 }
507}
508
509impl Display for SparseModelType {
510 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
511 let name = match self {
512 Self::SpladePpV1 => "splade-pp-v1",
513 };
514 write!(f, "{}", name)
515 }
516}
517
518impl FromStr for SparseModelType {
519 type Err = AppError;
520
521 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
522 match s.to_lowercase().as_str() {
523 "splade-pp-v1" | "splade" => Ok(Self::SpladePpV1),
524 _ => Err(AppError::Internal(format!(
525 "Unknown sparse model: {}. Use: splade-pp-v1",
526 s
527 ))),
528 }
529 }
530}
531
532#[derive(Debug, Clone, Serialize, Deserialize)]
538pub struct EmbeddingConfig {
539 #[serde(default)]
541 pub model: EmbeddingModelType,
542
543 #[serde(default = "default_batch_size")]
545 pub batch_size: usize,
546
547 #[serde(default = "default_show_progress")]
549 pub show_download_progress: bool,
550
551 #[serde(default)]
553 pub sparse_enabled: bool,
554
555 #[serde(default)]
557 pub sparse_model: SparseModelType,
558}
559
560fn default_batch_size() -> usize {
561 32
562}
563
564fn default_show_progress() -> bool {
565 true
566}
567
568impl Default for EmbeddingConfig {
569 fn default() -> Self {
570 Self {
571 model: EmbeddingModelType::default(),
572 batch_size: default_batch_size(),
573 show_download_progress: default_show_progress(),
574 sparse_enabled: false,
575 sparse_model: SparseModelType::default(),
576 }
577 }
578}
579
580pub struct EmbeddingService {
592 model: Arc<Mutex<TextEmbedding>>,
594 sparse_model: Option<Arc<Mutex<fastembed::SparseTextEmbedding>>>,
596 config: EmbeddingConfig,
597}
598
599impl EmbeddingService {
600 pub fn new(config: EmbeddingConfig) -> Result<Self> {
605 let model_name = format!("{:?}", config.model.to_fastembed_model());
606 let model_lock = get_model_lock(&model_name);
607
608 let _guard = model_lock.lock().map_err(|e| {
610 AppError::Internal(format!(
611 "Failed to acquire model initialization lock: {}",
612 e
613 ))
614 })?;
615
616 let model = TextEmbedding::try_new(
617 InitOptions::new(config.model.to_fastembed_model())
618 .with_show_download_progress(config.show_download_progress),
619 )
620 .map_err(|e| AppError::Internal(format!("Failed to initialize embedding model: {}", e)))?;
621
622 let sparse_model = if config.sparse_enabled {
623 let sparse_model_name = format!("{:?}", config.sparse_model.to_fastembed_model());
624 let sparse_lock = get_model_lock(&sparse_model_name);
625 let _sparse_guard = sparse_lock.lock().map_err(|e| {
626 AppError::Internal(format!("Failed to acquire sparse model lock: {}", e))
627 })?;
628
629 Some(
630 fastembed::SparseTextEmbedding::try_new(
631 fastembed::SparseInitOptions::new(config.sparse_model.to_fastembed_model())
632 .with_show_download_progress(config.show_download_progress),
633 )
634 .map_err(|e| {
635 AppError::Internal(format!(
636 "Failed to initialize sparse embedding model: {}",
637 e
638 ))
639 })?,
640 )
641 } else {
642 None
643 };
644
645 Ok(Self {
646 model: Arc::new(Mutex::new(model)),
647 sparse_model: sparse_model.map(|m| Arc::new(Mutex::new(m))),
648 config,
649 })
650 }
651
652 pub fn with_default_model() -> Result<Self> {
654 Self::new(EmbeddingConfig::default())
655 }
656
657 pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
659 Self::new(EmbeddingConfig {
660 model,
661 ..Default::default()
662 })
663 }
664
665 pub fn model_type(&self) -> EmbeddingModelType {
667 self.config.model
668 }
669
670 pub fn dimensions(&self) -> usize {
672 self.config.model.dimensions()
673 }
674
675 pub fn config(&self) -> &EmbeddingConfig {
677 &self.config
678 }
679
680 pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
682 let embeddings = self.embed_texts(&[text.to_string()]).await?;
683 embeddings
684 .into_iter()
685 .next()
686 .ok_or_else(|| AppError::Internal("No embedding generated".to_string()))
687 }
688
689 pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
696 &self,
697 texts: &[S],
698 ) -> Result<Vec<Vec<f32>>> {
699 if texts.is_empty() {
700 return Ok(vec![]);
701 }
702
703 let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
705 let batch_size = self.config.batch_size;
706
707 let model = Arc::clone(&self.model);
709
710 spawn_blocking(move || {
711 let mut model_guard = model
713 .lock()
714 .map_err(|e| AppError::Internal(format!("Failed to acquire model lock: {}", e)))?;
715
716 let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
717 model_guard
718 .embed(refs, Some(batch_size))
719 .map_err(|e| AppError::Internal(format!("Embedding failed: {}", e)))
720 })
721 .await
722 .map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
723 }
724
725 pub async fn embed_sparse<S: AsRef<str> + Send + Sync + 'static>(
729 &self,
730 texts: &[S],
731 ) -> Result<Vec<fastembed::SparseEmbedding>> {
732 let sparse_model = self.sparse_model.as_ref().ok_or_else(|| {
733 AppError::Internal(
734 "Sparse embeddings not enabled. Set sparse_enabled: true in config.".to_string(),
735 )
736 })?;
737
738 let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
739 let batch_size = self.config.batch_size;
740
741 let model = Arc::clone(sparse_model);
743
744 spawn_blocking(move || {
745 let mut model_guard = model.lock().map_err(|e| {
747 AppError::Internal(format!("Failed to acquire sparse model lock: {}", e))
748 })?;
749
750 let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
751 model_guard
752 .embed(refs, Some(batch_size))
753 .map_err(|e| AppError::Internal(format!("Sparse embedding failed: {}", e)))
754 })
755 .await
756 .map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
757 }
758}
759
760use crate::rag::cache::{CacheConfig, CacheStats, EmbeddingCache, LruEmbeddingCache, NoOpCache};
765
766pub struct CachedEmbeddingService {
791 inner: EmbeddingService,
793 cache: Box<dyn EmbeddingCache>,
795}
796
797impl CachedEmbeddingService {
798 pub fn new(embedding_config: EmbeddingConfig, cache_config: CacheConfig) -> Result<Self> {
800 let inner = EmbeddingService::new(embedding_config)?;
801 let cache: Box<dyn EmbeddingCache> = if cache_config.enabled {
802 Box::new(LruEmbeddingCache::new(cache_config))
803 } else {
804 Box::new(NoOpCache::new())
805 };
806
807 Ok(Self { inner, cache })
808 }
809
810 pub fn with_defaults() -> Result<Self> {
812 Self::new(EmbeddingConfig::default(), CacheConfig::default())
813 }
814
815 pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
817 Self::new(
818 EmbeddingConfig {
819 model,
820 ..Default::default()
821 },
822 CacheConfig::default(),
823 )
824 }
825
826 pub fn without_cache(embedding_config: EmbeddingConfig) -> Result<Self> {
828 Self::new(
829 embedding_config,
830 CacheConfig {
831 enabled: false,
832 ..Default::default()
833 },
834 )
835 }
836
837 fn model_name(&self) -> String {
839 self.inner.model_type().to_string()
840 }
841
842 pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
844 let cache_key = self.cache.compute_key(text, &self.model_name());
845
846 if let Some(cached) = self.cache.get(&cache_key) {
848 return Ok(cached);
849 }
850
851 let embedding = self.inner.embed_text(text).await?;
853
854 self.cache.set(&cache_key, embedding.clone(), None)?;
856
857 Ok(embedding)
858 }
859
860 pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
865 &self,
866 texts: &[S],
867 ) -> Result<Vec<Vec<f32>>> {
868 if texts.is_empty() {
869 return Ok(vec![]);
870 }
871
872 let model_name = self.model_name();
873 let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
874 let mut uncached_indices: Vec<usize> = Vec::new();
875 let mut uncached_texts: Vec<String> = Vec::new();
876
877 for (i, text) in texts.iter().enumerate() {
879 let text_str = text.as_ref();
880 let cache_key = self.cache.compute_key(text_str, &model_name);
881
882 if let Some(cached) = self.cache.get(&cache_key) {
883 results[i] = Some(cached);
884 } else {
885 uncached_indices.push(i);
886 uncached_texts.push(text_str.to_string());
887 }
888 }
889
890 if !uncached_texts.is_empty() {
892 let new_embeddings = self.inner.embed_texts(&uncached_texts).await?;
893
894 for (j, embedding) in new_embeddings.into_iter().enumerate() {
896 let idx = uncached_indices[j];
897 let cache_key = self.cache.compute_key(&uncached_texts[j], &model_name);
898 self.cache.set(&cache_key, embedding.clone(), None)?;
899 results[idx] = Some(embedding);
900 }
901 }
902
903 Ok(results.into_iter().flatten().collect())
905 }
906
907 pub fn model_type(&self) -> EmbeddingModelType {
909 self.inner.model_type()
910 }
911
912 pub fn dimensions(&self) -> usize {
914 self.inner.dimensions()
915 }
916
917 pub fn config(&self) -> &EmbeddingConfig {
919 self.inner.config()
920 }
921
922 pub fn cache_stats(&self) -> CacheStats {
924 self.cache.stats()
925 }
926
927 pub fn clear_cache(&self) -> Result<()> {
929 self.cache.clear()
930 }
931
932 pub fn invalidate(&self, text: &str) -> Result<()> {
934 let cache_key = self.cache.compute_key(text, &self.model_name());
935 self.cache.invalidate(&cache_key)
936 }
937
938 pub fn is_cache_enabled(&self) -> bool {
940 self.cache.is_enabled()
941 }
942}
943
944#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
960#[serde(rename_all = "lowercase")]
961#[allow(dead_code)]
962#[derive(Default)]
963pub enum AccelerationBackend {
964 #[default]
966 Cpu,
967 Cuda {
969 device_id: usize,
971 },
972 Metal,
974 Vulkan,
976}
977
978#[deprecated(note = "Use EmbeddingService instead")]
986pub struct LegacyEmbeddingService {
987 inner: EmbeddingService,
988}
989
990#[allow(deprecated)]
991impl LegacyEmbeddingService {
992 pub fn new(_model_name: &str) -> Result<Self> {
994 Ok(Self {
995 inner: EmbeddingService::with_default_model()?,
996 })
997 }
998
999 pub fn embed(&mut self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
1001 let model_type = self.inner.config.model.to_fastembed_model();
1002 let mut model =
1003 TextEmbedding::try_new(InitOptions::new(model_type).with_show_download_progress(true))
1004 .map_err(|e| AppError::Internal(e.to_string()))?;
1005
1006 model
1007 .embed(texts, None)
1008 .map_err(|e| AppError::Internal(e.to_string()))
1009 }
1010}
1011
1012#[cfg(test)]
1017mod tests {
1018 use super::*;
1019
1020 #[test]
1021 fn test_model_dimensions() {
1022 assert_eq!(EmbeddingModelType::BgeSmallEnV15.dimensions(), 384);
1023 assert_eq!(EmbeddingModelType::BgeBaseEnV15.dimensions(), 768);
1024 assert_eq!(EmbeddingModelType::BgeLargeEnV15.dimensions(), 1024);
1025 assert_eq!(EmbeddingModelType::MultilingualE5Large.dimensions(), 1024);
1026 }
1027
1028 #[test]
1029 fn test_model_from_str() {
1030 assert_eq!(
1031 "bge-small-en-v1.5".parse::<EmbeddingModelType>().unwrap(),
1032 EmbeddingModelType::BgeSmallEnV15
1033 );
1034 assert_eq!(
1035 "multilingual-e5-large"
1036 .parse::<EmbeddingModelType>()
1037 .unwrap(),
1038 EmbeddingModelType::MultilingualE5Large
1039 );
1040 assert_eq!(
1041 "minilm-l6".parse::<EmbeddingModelType>().unwrap(),
1042 EmbeddingModelType::AllMiniLmL6V2
1043 );
1044 }
1045
1046 #[test]
1047 fn test_model_is_multilingual() {
1048 assert!(EmbeddingModelType::MultilingualE5Small.is_multilingual());
1049 assert!(EmbeddingModelType::MultilingualE5Large.is_multilingual());
1050 assert!(!EmbeddingModelType::BgeSmallEnV15.is_multilingual());
1051 }
1052
1053 #[test]
1054 fn test_model_max_context() {
1055 assert_eq!(
1056 EmbeddingModelType::NomicEmbedTextV15.max_context_length(),
1057 8192
1058 );
1059 assert_eq!(
1060 EmbeddingModelType::NomicEmbedTextV1.max_context_length(),
1061 8192
1062 );
1063 assert_eq!(EmbeddingModelType::BgeSmallEnV15.max_context_length(), 512);
1064 }
1065
1066 #[test]
1067 fn test_default_config() {
1068 let config = EmbeddingConfig::default();
1069 assert_eq!(config.model, EmbeddingModelType::BgeSmallEnV15);
1070 assert_eq!(config.batch_size, 32);
1071 assert!(config.show_download_progress);
1072 assert!(!config.sparse_enabled);
1073 }
1074
1075 #[test]
1076 fn test_all_models_listed() {
1077 let all = EmbeddingModelType::all();
1078 assert!(all.len() >= 38); assert!(all.contains(&EmbeddingModelType::BgeSmallEnV15));
1080 assert!(all.contains(&EmbeddingModelType::MultilingualE5Large));
1081 }
1082}