1use crate::types::{AppError, Result};
20use serde::{Deserialize, Serialize};
21use std::fmt::Display;
22use std::str::FromStr;
23use tokio::task::spawn_blocking;
24
25pub use fastembed::{
27 EmbeddingModel as FastEmbedModel, InitOptions, SparseModel, TextEmbedding,
28};
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
39#[serde(rename_all = "kebab-case")]
40pub enum EmbeddingModelType {
41 #[default]
44 BgeSmallEnV15,
45 BgeSmallEnV15Q,
47 AllMiniLmL6V2,
49 AllMiniLmL6V2Q,
51 AllMiniLmL12V2,
53 AllMiniLmL12V2Q,
55 AllMpnetBaseV2,
57
58 BgeBaseEnV15,
61 BgeBaseEnV15Q,
63 BgeLargeEnV15,
65 BgeLargeEnV15Q,
67
68 MultilingualE5Small,
72 MultilingualE5Base,
74 MultilingualE5Large,
76 ParaphraseMiniLmL12V2,
78 ParaphraseMiniLmL12V2Q,
80 ParaphraseMultilingualMpnetBaseV2,
82
83 BgeSmallZhV15,
86 BgeLargeZhV15,
88
89 NomicEmbedTextV1,
92 NomicEmbedTextV15,
94 NomicEmbedTextV15Q,
96
97 MxbaiEmbedLargeV1,
100 MxbaiEmbedLargeV1Q,
102 GteBaseEnV15,
104 GteBaseEnV15Q,
106 GteLargeEnV15,
108 GteLargeEnV15Q,
110 ClipVitB32,
112
113 JinaEmbeddingsV2BaseCode,
116 EmbeddingGemma300M,
121 ModernBertEmbedLarge,
123
124 SnowflakeArcticEmbedXs,
127 SnowflakeArcticEmbedXsQ,
129 SnowflakeArcticEmbedS,
131 SnowflakeArcticEmbedSQ,
133 SnowflakeArcticEmbedM,
135 SnowflakeArcticEmbedMQ,
137 SnowflakeArcticEmbedMLong,
139 SnowflakeArcticEmbedMLongQ,
141 SnowflakeArcticEmbedL,
143 SnowflakeArcticEmbedLQ,
145}
146
147impl EmbeddingModelType {
148 pub fn to_fastembed_model(&self) -> FastEmbedModel {
150 match self {
151 Self::BgeSmallEnV15 => FastEmbedModel::BGESmallENV15,
153 Self::BgeSmallEnV15Q => FastEmbedModel::BGESmallENV15Q,
154 Self::AllMiniLmL6V2 => FastEmbedModel::AllMiniLML6V2,
155 Self::AllMiniLmL6V2Q => FastEmbedModel::AllMiniLML6V2Q,
156 Self::AllMiniLmL12V2 => FastEmbedModel::AllMiniLML12V2,
157 Self::AllMiniLmL12V2Q => FastEmbedModel::AllMiniLML12V2Q,
158 Self::AllMpnetBaseV2 => FastEmbedModel::AllMpnetBaseV2,
159
160 Self::BgeBaseEnV15 => FastEmbedModel::BGEBaseENV15,
162 Self::BgeBaseEnV15Q => FastEmbedModel::BGEBaseENV15Q,
163 Self::BgeLargeEnV15 => FastEmbedModel::BGELargeENV15,
164 Self::BgeLargeEnV15Q => FastEmbedModel::BGELargeENV15Q,
165
166 Self::MultilingualE5Small => FastEmbedModel::MultilingualE5Small,
168 Self::MultilingualE5Base => FastEmbedModel::MultilingualE5Base,
169 Self::MultilingualE5Large => FastEmbedModel::MultilingualE5Large,
170 Self::ParaphraseMiniLmL12V2 => FastEmbedModel::ParaphraseMLMiniLML12V2,
171 Self::ParaphraseMiniLmL12V2Q => FastEmbedModel::ParaphraseMLMiniLML12V2Q,
172 Self::ParaphraseMultilingualMpnetBaseV2 => FastEmbedModel::ParaphraseMLMpnetBaseV2,
173
174 Self::BgeSmallZhV15 => FastEmbedModel::BGESmallZHV15,
176 Self::BgeLargeZhV15 => FastEmbedModel::BGELargeZHV15,
177
178 Self::NomicEmbedTextV1 => FastEmbedModel::NomicEmbedTextV1,
180 Self::NomicEmbedTextV15 => FastEmbedModel::NomicEmbedTextV15,
181 Self::NomicEmbedTextV15Q => FastEmbedModel::NomicEmbedTextV15Q,
182
183 Self::MxbaiEmbedLargeV1 => FastEmbedModel::MxbaiEmbedLargeV1,
185 Self::MxbaiEmbedLargeV1Q => FastEmbedModel::MxbaiEmbedLargeV1Q,
186 Self::GteBaseEnV15 => FastEmbedModel::GTEBaseENV15,
187 Self::GteBaseEnV15Q => FastEmbedModel::GTEBaseENV15Q,
188 Self::GteLargeEnV15 => FastEmbedModel::GTELargeENV15,
189 Self::GteLargeEnV15Q => FastEmbedModel::GTELargeENV15Q,
190 Self::ClipVitB32 => FastEmbedModel::ClipVitB32,
191
192 Self::JinaEmbeddingsV2BaseCode => FastEmbedModel::JinaEmbeddingsV2BaseCode,
194
195 Self::EmbeddingGemma300M => FastEmbedModel::EmbeddingGemma300M,
197 Self::ModernBertEmbedLarge => FastEmbedModel::ModernBertEmbedLarge,
198
199 Self::SnowflakeArcticEmbedXs => FastEmbedModel::SnowflakeArcticEmbedXS,
201 Self::SnowflakeArcticEmbedXsQ => FastEmbedModel::SnowflakeArcticEmbedXSQ,
202 Self::SnowflakeArcticEmbedS => FastEmbedModel::SnowflakeArcticEmbedS,
203 Self::SnowflakeArcticEmbedSQ => FastEmbedModel::SnowflakeArcticEmbedSQ,
204 Self::SnowflakeArcticEmbedM => FastEmbedModel::SnowflakeArcticEmbedM,
205 Self::SnowflakeArcticEmbedMQ => FastEmbedModel::SnowflakeArcticEmbedMQ,
206 Self::SnowflakeArcticEmbedMLong => FastEmbedModel::SnowflakeArcticEmbedMLong,
207 Self::SnowflakeArcticEmbedMLongQ => FastEmbedModel::SnowflakeArcticEmbedMLongQ,
208 Self::SnowflakeArcticEmbedL => FastEmbedModel::SnowflakeArcticEmbedL,
209 Self::SnowflakeArcticEmbedLQ => FastEmbedModel::SnowflakeArcticEmbedLQ,
210 }
211 }
212
213 pub fn dimensions(&self) -> usize {
215 match self {
216 Self::BgeSmallEnV15
218 | Self::BgeSmallEnV15Q
219 | Self::AllMiniLmL6V2
220 | Self::AllMiniLmL6V2Q
221 | Self::AllMiniLmL12V2
222 | Self::AllMiniLmL12V2Q
223 | Self::MultilingualE5Small
224 | Self::SnowflakeArcticEmbedXs
225 | Self::SnowflakeArcticEmbedXsQ
226 | Self::SnowflakeArcticEmbedS
227 | Self::SnowflakeArcticEmbedSQ => 384,
228
229 Self::BgeSmallZhV15 | Self::ClipVitB32 => 512,
231
232 Self::AllMpnetBaseV2
234 | Self::BgeBaseEnV15
235 | Self::BgeBaseEnV15Q
236 | Self::MultilingualE5Base
237 | Self::ParaphraseMiniLmL12V2
238 | Self::ParaphraseMiniLmL12V2Q
239 | Self::ParaphraseMultilingualMpnetBaseV2
240 | Self::NomicEmbedTextV1
241 | Self::NomicEmbedTextV15
242 | Self::NomicEmbedTextV15Q
243 | Self::GteBaseEnV15
244 | Self::GteBaseEnV15Q
245 | Self::JinaEmbeddingsV2BaseCode
246 | Self::EmbeddingGemma300M
247 | Self::SnowflakeArcticEmbedM
248 | Self::SnowflakeArcticEmbedMQ
249 | Self::SnowflakeArcticEmbedMLong
250 | Self::SnowflakeArcticEmbedMLongQ => 768,
251
252 Self::BgeLargeEnV15
254 | Self::BgeLargeEnV15Q
255 | Self::BgeLargeZhV15
256 | Self::MultilingualE5Large
257 | Self::MxbaiEmbedLargeV1
258 | Self::MxbaiEmbedLargeV1Q
259 | Self::GteLargeEnV15
260 | Self::GteLargeEnV15Q
261 | Self::ModernBertEmbedLarge
262 | Self::SnowflakeArcticEmbedL
263 | Self::SnowflakeArcticEmbedLQ => 1024,
264 }
265 }
266
267 pub fn is_quantized(&self) -> bool {
269 matches!(
270 self,
271 Self::BgeSmallEnV15Q
272 | Self::AllMiniLmL6V2Q
273 | Self::AllMiniLmL12V2Q
274 | Self::BgeBaseEnV15Q
275 | Self::BgeLargeEnV15Q
276 | Self::ParaphraseMiniLmL12V2Q
277 | Self::NomicEmbedTextV15Q
278 | Self::MxbaiEmbedLargeV1Q
279 | Self::GteBaseEnV15Q
280 | Self::GteLargeEnV15Q
281 | Self::SnowflakeArcticEmbedXsQ
282 | Self::SnowflakeArcticEmbedSQ
283 | Self::SnowflakeArcticEmbedMQ
284 | Self::SnowflakeArcticEmbedMLongQ
285 | Self::SnowflakeArcticEmbedLQ
286 )
287 }
288
289 pub fn is_multilingual(&self) -> bool {
291 matches!(
292 self,
293 Self::MultilingualE5Small
294 | Self::MultilingualE5Base
295 | Self::MultilingualE5Large
296 | Self::ParaphraseMultilingualMpnetBaseV2
297 | Self::BgeSmallZhV15
298 | Self::BgeLargeZhV15
299 )
300 }
301
302 pub fn max_context_length(&self) -> usize {
304 match self {
305 Self::NomicEmbedTextV1 | Self::NomicEmbedTextV15 | Self::NomicEmbedTextV15Q => 8192,
306 Self::SnowflakeArcticEmbedMLong | Self::SnowflakeArcticEmbedMLongQ => 2048,
307 _ => 512,
308 }
309 }
310
311 pub fn all() -> Vec<Self> {
313 vec![
314 Self::BgeSmallEnV15,
315 Self::BgeSmallEnV15Q,
316 Self::AllMiniLmL6V2,
317 Self::AllMiniLmL6V2Q,
318 Self::AllMiniLmL12V2,
319 Self::AllMiniLmL12V2Q,
320 Self::AllMpnetBaseV2,
321 Self::BgeBaseEnV15,
322 Self::BgeBaseEnV15Q,
323 Self::BgeLargeEnV15,
324 Self::BgeLargeEnV15Q,
325 Self::MultilingualE5Small,
326 Self::MultilingualE5Base,
327 Self::MultilingualE5Large,
328 Self::ParaphraseMiniLmL12V2,
329 Self::ParaphraseMiniLmL12V2Q,
330 Self::ParaphraseMultilingualMpnetBaseV2,
331 Self::BgeSmallZhV15,
332 Self::BgeLargeZhV15,
333 Self::NomicEmbedTextV1,
334 Self::NomicEmbedTextV15,
335 Self::NomicEmbedTextV15Q,
336 Self::MxbaiEmbedLargeV1,
337 Self::MxbaiEmbedLargeV1Q,
338 Self::GteBaseEnV15,
339 Self::GteBaseEnV15Q,
340 Self::GteLargeEnV15,
341 Self::GteLargeEnV15Q,
342 Self::ClipVitB32,
343 Self::JinaEmbeddingsV2BaseCode,
344 Self::EmbeddingGemma300M,
345 Self::ModernBertEmbedLarge,
346 Self::SnowflakeArcticEmbedXs,
347 Self::SnowflakeArcticEmbedXsQ,
348 Self::SnowflakeArcticEmbedS,
349 Self::SnowflakeArcticEmbedSQ,
350 Self::SnowflakeArcticEmbedM,
351 Self::SnowflakeArcticEmbedMQ,
352 Self::SnowflakeArcticEmbedMLong,
353 Self::SnowflakeArcticEmbedMLongQ,
354 Self::SnowflakeArcticEmbedL,
355 Self::SnowflakeArcticEmbedLQ,
356 ]
357 }
358}
359
360impl Display for EmbeddingModelType {
361 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
362 let name = match self {
363 Self::BgeSmallEnV15 => "bge-small-en-v1.5",
364 Self::BgeSmallEnV15Q => "bge-small-en-v1.5-q",
365 Self::AllMiniLmL6V2 => "all-minilm-l6-v2",
366 Self::AllMiniLmL6V2Q => "all-minilm-l6-v2-q",
367 Self::AllMiniLmL12V2 => "all-minilm-l12-v2",
368 Self::AllMiniLmL12V2Q => "all-minilm-l12-v2-q",
369 Self::AllMpnetBaseV2 => "all-mpnet-base-v2",
370 Self::BgeBaseEnV15 => "bge-base-en-v1.5",
371 Self::BgeBaseEnV15Q => "bge-base-en-v1.5-q",
372 Self::BgeLargeEnV15 => "bge-large-en-v1.5",
373 Self::BgeLargeEnV15Q => "bge-large-en-v1.5-q",
374 Self::MultilingualE5Small => "multilingual-e5-small",
375 Self::MultilingualE5Base => "multilingual-e5-base",
376 Self::MultilingualE5Large => "multilingual-e5-large",
377 Self::ParaphraseMiniLmL12V2 => "paraphrase-minilm-l12-v2",
378 Self::ParaphraseMiniLmL12V2Q => "paraphrase-minilm-l12-v2-q",
379 Self::ParaphraseMultilingualMpnetBaseV2 => "paraphrase-multilingual-mpnet-base-v2",
380 Self::BgeSmallZhV15 => "bge-small-zh-v1.5",
381 Self::BgeLargeZhV15 => "bge-large-zh-v1.5",
382 Self::NomicEmbedTextV1 => "nomic-embed-text-v1",
383 Self::NomicEmbedTextV15 => "nomic-embed-text-v1.5",
384 Self::NomicEmbedTextV15Q => "nomic-embed-text-v1.5-q",
385 Self::MxbaiEmbedLargeV1 => "mxbai-embed-large-v1",
386 Self::MxbaiEmbedLargeV1Q => "mxbai-embed-large-v1-q",
387 Self::GteBaseEnV15 => "gte-base-en-v1.5",
388 Self::GteBaseEnV15Q => "gte-base-en-v1.5-q",
389 Self::GteLargeEnV15 => "gte-large-en-v1.5",
390 Self::GteLargeEnV15Q => "gte-large-en-v1.5-q",
391 Self::ClipVitB32 => "clip-vit-b-32",
392 Self::JinaEmbeddingsV2BaseCode => "jina-embeddings-v2-base-code",
393 Self::EmbeddingGemma300M => "embedding-gemma-300m",
394 Self::ModernBertEmbedLarge => "modernbert-embed-large",
395 Self::SnowflakeArcticEmbedXs => "snowflake-arctic-embed-xs",
396 Self::SnowflakeArcticEmbedXsQ => "snowflake-arctic-embed-xs-q",
397 Self::SnowflakeArcticEmbedS => "snowflake-arctic-embed-s",
398 Self::SnowflakeArcticEmbedSQ => "snowflake-arctic-embed-s-q",
399 Self::SnowflakeArcticEmbedM => "snowflake-arctic-embed-m",
400 Self::SnowflakeArcticEmbedMQ => "snowflake-arctic-embed-m-q",
401 Self::SnowflakeArcticEmbedMLong => "snowflake-arctic-embed-m-long",
402 Self::SnowflakeArcticEmbedMLongQ => "snowflake-arctic-embed-m-long-q",
403 Self::SnowflakeArcticEmbedL => "snowflake-arctic-embed-l",
404 Self::SnowflakeArcticEmbedLQ => "snowflake-arctic-embed-l-q",
405 };
406 write!(f, "{}", name)
407 }
408}
409
410impl FromStr for EmbeddingModelType {
411 type Err = AppError;
412
413 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
414 match s.to_lowercase().as_str() {
415 "bge-small-en-v1.5" | "bge-small-en" | "bge-small" => Ok(Self::BgeSmallEnV15),
416 "bge-small-en-v1.5-q" => Ok(Self::BgeSmallEnV15Q),
417 "all-minilm-l6-v2" | "minilm-l6" => Ok(Self::AllMiniLmL6V2),
418 "all-minilm-l6-v2-q" => Ok(Self::AllMiniLmL6V2Q),
419 "all-minilm-l12-v2" | "minilm-l12" => Ok(Self::AllMiniLmL12V2),
420 "all-minilm-l12-v2-q" => Ok(Self::AllMiniLmL12V2Q),
421 "all-mpnet-base-v2" | "mpnet" => Ok(Self::AllMpnetBaseV2),
422 "bge-base-en-v1.5" | "bge-base-en" | "bge-base" => Ok(Self::BgeBaseEnV15),
423 "bge-base-en-v1.5-q" => Ok(Self::BgeBaseEnV15Q),
424 "bge-large-en-v1.5" | "bge-large-en" | "bge-large" => Ok(Self::BgeLargeEnV15),
425 "bge-large-en-v1.5-q" => Ok(Self::BgeLargeEnV15Q),
426 "multilingual-e5-small" | "e5-small" => Ok(Self::MultilingualE5Small),
427 "multilingual-e5-base" | "e5-base" => Ok(Self::MultilingualE5Base),
428 "multilingual-e5-large" | "e5-large" => Ok(Self::MultilingualE5Large),
429 "paraphrase-minilm-l12-v2" => Ok(Self::ParaphraseMiniLmL12V2),
430 "paraphrase-minilm-l12-v2-q" => Ok(Self::ParaphraseMiniLmL12V2Q),
431 "paraphrase-multilingual-mpnet-base-v2" => Ok(Self::ParaphraseMultilingualMpnetBaseV2),
432 "bge-small-zh-v1.5" | "bge-small-zh" => Ok(Self::BgeSmallZhV15),
433 "bge-large-zh-v1.5" | "bge-large-zh" => Ok(Self::BgeLargeZhV15),
434 "nomic-embed-text-v1" | "nomic-v1" => Ok(Self::NomicEmbedTextV1),
435 "nomic-embed-text-v1.5" | "nomic-v1.5" | "nomic" => Ok(Self::NomicEmbedTextV15),
436 "nomic-embed-text-v1.5-q" => Ok(Self::NomicEmbedTextV15Q),
437 "mxbai-embed-large-v1" | "mxbai" => Ok(Self::MxbaiEmbedLargeV1),
438 "mxbai-embed-large-v1-q" => Ok(Self::MxbaiEmbedLargeV1Q),
439 "gte-base-en-v1.5" | "gte-base" => Ok(Self::GteBaseEnV15),
440 "gte-base-en-v1.5-q" => Ok(Self::GteBaseEnV15Q),
441 "gte-large-en-v1.5" | "gte-large" => Ok(Self::GteLargeEnV15),
442 "gte-large-en-v1.5-q" => Ok(Self::GteLargeEnV15Q),
443 "clip-vit-b-32" | "clip" => Ok(Self::ClipVitB32),
444 "jina-embeddings-v2-base-code" | "jina-code" => Ok(Self::JinaEmbeddingsV2BaseCode),
445 "embedding-gemma-300m" | "gemma-300m" | "gemma" => Ok(Self::EmbeddingGemma300M),
446 "modernbert-embed-large" | "modernbert" => Ok(Self::ModernBertEmbedLarge),
447 "snowflake-arctic-embed-xs" => Ok(Self::SnowflakeArcticEmbedXs),
448 "snowflake-arctic-embed-xs-q" => Ok(Self::SnowflakeArcticEmbedXsQ),
449 "snowflake-arctic-embed-s" => Ok(Self::SnowflakeArcticEmbedS),
450 "snowflake-arctic-embed-s-q" => Ok(Self::SnowflakeArcticEmbedSQ),
451 "snowflake-arctic-embed-m" => Ok(Self::SnowflakeArcticEmbedM),
452 "snowflake-arctic-embed-m-q" => Ok(Self::SnowflakeArcticEmbedMQ),
453 "snowflake-arctic-embed-m-long" => Ok(Self::SnowflakeArcticEmbedMLong),
454 "snowflake-arctic-embed-m-long-q" => Ok(Self::SnowflakeArcticEmbedMLongQ),
455 "snowflake-arctic-embed-l" | "snowflake-l" => Ok(Self::SnowflakeArcticEmbedL),
456 "snowflake-arctic-embed-l-q" => Ok(Self::SnowflakeArcticEmbedLQ),
457 _ => Err(AppError::Internal(format!(
458 "Unknown embedding model: {}. Use one of: {}",
459 s,
460 EmbeddingModelType::all()
461 .iter()
462 .map(|m| m.to_string())
463 .collect::<Vec<_>>()
464 .join(", ")
465 ))),
466 }
467 }
468}
469
470#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
476#[serde(rename_all = "kebab-case")]
477pub enum SparseModelType {
478 #[default]
480 SpladePpV1,
481 }
483
484impl SparseModelType {
485 pub fn to_fastembed_model(&self) -> SparseModel {
487 match self {
488 Self::SpladePpV1 => SparseModel::SPLADEPPV1,
489 }
490 }
491}
492
493impl Display for SparseModelType {
494 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
495 let name = match self {
496 Self::SpladePpV1 => "splade-pp-v1",
497 };
498 write!(f, "{}", name)
499 }
500}
501
502impl FromStr for SparseModelType {
503 type Err = AppError;
504
505 fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
506 match s.to_lowercase().as_str() {
507 "splade-pp-v1" | "splade" => Ok(Self::SpladePpV1),
508 _ => Err(AppError::Internal(format!(
509 "Unknown sparse model: {}. Use: splade-pp-v1",
510 s
511 ))),
512 }
513 }
514}
515
516#[derive(Debug, Clone, Serialize, Deserialize)]
522pub struct EmbeddingConfig {
523 #[serde(default)]
525 pub model: EmbeddingModelType,
526
527 #[serde(default = "default_batch_size")]
529 pub batch_size: usize,
530
531 #[serde(default = "default_show_progress")]
533 pub show_download_progress: bool,
534
535 #[serde(default)]
537 pub sparse_enabled: bool,
538
539 #[serde(default)]
541 pub sparse_model: SparseModelType,
542}
543
544fn default_batch_size() -> usize {
545 32
546}
547
548fn default_show_progress() -> bool {
549 true
550}
551
552impl Default for EmbeddingConfig {
553 fn default() -> Self {
554 Self {
555 model: EmbeddingModelType::default(),
556 batch_size: default_batch_size(),
557 show_download_progress: default_show_progress(),
558 sparse_enabled: false,
559 sparse_model: SparseModelType::default(),
560 }
561 }
562}
563
564pub struct EmbeddingService {
573 #[allow(dead_code)]
574 model: TextEmbedding,
575 #[allow(dead_code)]
576 sparse_model: Option<fastembed::SparseTextEmbedding>,
577 config: EmbeddingConfig,
578}
579
580impl EmbeddingService {
581 pub fn new(config: EmbeddingConfig) -> Result<Self> {
583 let model = TextEmbedding::try_new(
584 InitOptions::new(config.model.to_fastembed_model())
585 .with_show_download_progress(config.show_download_progress),
586 )
587 .map_err(|e| AppError::Internal(format!("Failed to initialize embedding model: {}", e)))?;
588
589 let sparse_model = if config.sparse_enabled {
590 Some(
591 fastembed::SparseTextEmbedding::try_new(
592 fastembed::SparseInitOptions::new(config.sparse_model.to_fastembed_model())
593 .with_show_download_progress(config.show_download_progress),
594 )
595 .map_err(|e| {
596 AppError::Internal(format!("Failed to initialize sparse embedding model: {}", e))
597 })?,
598 )
599 } else {
600 None
601 };
602
603 Ok(Self {
604 model,
605 sparse_model,
606 config,
607 })
608 }
609
610 pub fn with_default_model() -> Result<Self> {
612 Self::new(EmbeddingConfig::default())
613 }
614
615 pub fn with_model(model: EmbeddingModelType) -> Result<Self> {
617 Self::new(EmbeddingConfig {
618 model,
619 ..Default::default()
620 })
621 }
622
623 pub fn model_type(&self) -> EmbeddingModelType {
625 self.config.model
626 }
627
628 pub fn dimensions(&self) -> usize {
630 self.config.model.dimensions()
631 }
632
633 pub fn config(&self) -> &EmbeddingConfig {
635 &self.config
636 }
637
638 pub async fn embed_text(&self, text: &str) -> Result<Vec<f32>> {
640 let embeddings = self.embed_texts(&[text.to_string()]).await?;
641 embeddings
642 .into_iter()
643 .next()
644 .ok_or_else(|| AppError::Internal("No embedding generated".to_string()))
645 }
646
647 pub async fn embed_texts<S: AsRef<str> + Send + Sync + 'static>(
652 &self,
653 texts: &[S],
654 ) -> Result<Vec<Vec<f32>>> {
655 if texts.is_empty() {
656 return Ok(vec![]);
657 }
658
659 let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
661 let batch_size = self.config.batch_size;
662
663 let model_type = self.config.model.to_fastembed_model();
665 let show_progress = self.config.show_download_progress;
666
667 spawn_blocking(move || {
668 let mut model = TextEmbedding::try_new(
670 InitOptions::new(model_type).with_show_download_progress(show_progress),
671 )
672 .map_err(|e| {
673 AppError::Internal(format!("Failed to initialize embedding model: {}", e))
674 })?;
675
676 let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
677 model
678 .embed(refs, Some(batch_size))
679 .map_err(|e| AppError::Internal(format!("Embedding failed: {}", e)))
680 })
681 .await
682 .map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
683 }
684
685 pub async fn embed_sparse<S: AsRef<str> + Send + Sync + 'static>(
687 &self,
688 texts: &[S],
689 ) -> Result<Vec<fastembed::SparseEmbedding>> {
690 if self.sparse_model.is_none() {
691 return Err(AppError::Internal(
692 "Sparse embeddings not enabled. Set sparse_enabled: true in config.".to_string(),
693 ));
694 }
695
696 let texts_owned: Vec<String> = texts.iter().map(|s| s.as_ref().to_string()).collect();
697 let batch_size = self.config.batch_size;
698 let sparse_model_type = self.config.sparse_model.to_fastembed_model();
699 let show_progress = self.config.show_download_progress;
700
701 spawn_blocking(move || {
702 let mut model = fastembed::SparseTextEmbedding::try_new(
703 fastembed::SparseInitOptions::new(sparse_model_type)
704 .with_show_download_progress(show_progress),
705 )
706 .map_err(|e| {
707 AppError::Internal(format!("Failed to initialize sparse model: {}", e))
708 })?;
709
710 let refs: Vec<&str> = texts_owned.iter().map(|s| s.as_str()).collect();
711 model
712 .embed(refs, Some(batch_size))
713 .map_err(|e| AppError::Internal(format!("Sparse embedding failed: {}", e)))
714 })
715 .await
716 .map_err(|e| AppError::Internal(format!("Blocking task failed: {}", e)))?
717 }
718}
719
720#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
736#[serde(rename_all = "lowercase")]
737#[allow(dead_code)]
738pub enum AccelerationBackend {
739 Cpu,
741 Cuda { device_id: usize },
743 Metal,
745 Vulkan,
747}
748
749impl Default for AccelerationBackend {
750 fn default() -> Self {
751 Self::Cpu
752 }
753}
754
755#[deprecated(note = "Use EmbeddingService instead")]
763pub struct LegacyEmbeddingService {
764 inner: EmbeddingService,
765}
766
767#[allow(deprecated)]
768impl LegacyEmbeddingService {
769 pub fn new(_model_name: &str) -> Result<Self> {
771 Ok(Self {
772 inner: EmbeddingService::with_default_model()?,
773 })
774 }
775
776 pub fn embed(&mut self, texts: Vec<&str>) -> Result<Vec<Vec<f32>>> {
778 let model_type = self.inner.config.model.to_fastembed_model();
779 let mut model = TextEmbedding::try_new(
780 InitOptions::new(model_type).with_show_download_progress(true),
781 )
782 .map_err(|e| AppError::Internal(e.to_string()))?;
783
784 model
785 .embed(texts, None)
786 .map_err(|e| AppError::Internal(e.to_string()))
787 }
788}
789
790#[cfg(test)]
795mod tests {
796 use super::*;
797
798 #[test]
799 fn test_model_dimensions() {
800 assert_eq!(EmbeddingModelType::BgeSmallEnV15.dimensions(), 384);
801 assert_eq!(EmbeddingModelType::BgeBaseEnV15.dimensions(), 768);
802 assert_eq!(EmbeddingModelType::BgeLargeEnV15.dimensions(), 1024);
803 assert_eq!(EmbeddingModelType::MultilingualE5Large.dimensions(), 1024);
804 }
805
806 #[test]
807 fn test_model_from_str() {
808 assert_eq!(
809 "bge-small-en-v1.5".parse::<EmbeddingModelType>().unwrap(),
810 EmbeddingModelType::BgeSmallEnV15
811 );
812 assert_eq!(
813 "multilingual-e5-large".parse::<EmbeddingModelType>().unwrap(),
814 EmbeddingModelType::MultilingualE5Large
815 );
816 assert_eq!(
817 "minilm-l6".parse::<EmbeddingModelType>().unwrap(),
818 EmbeddingModelType::AllMiniLmL6V2
819 );
820 }
821
822 #[test]
823 fn test_model_is_multilingual() {
824 assert!(EmbeddingModelType::MultilingualE5Small.is_multilingual());
825 assert!(EmbeddingModelType::MultilingualE5Large.is_multilingual());
826 assert!(!EmbeddingModelType::BgeSmallEnV15.is_multilingual());
827 }
828
829 #[test]
830 fn test_model_max_context() {
831 assert_eq!(
832 EmbeddingModelType::NomicEmbedTextV15.max_context_length(),
833 8192
834 );
835 assert_eq!(
836 EmbeddingModelType::NomicEmbedTextV1.max_context_length(),
837 8192
838 );
839 assert_eq!(
840 EmbeddingModelType::BgeSmallEnV15.max_context_length(),
841 512
842 );
843 }
844
845 #[test]
846 fn test_default_config() {
847 let config = EmbeddingConfig::default();
848 assert_eq!(config.model, EmbeddingModelType::BgeSmallEnV15);
849 assert_eq!(config.batch_size, 32);
850 assert!(config.show_download_progress);
851 assert!(!config.sparse_enabled);
852 }
853
854 #[test]
855 fn test_all_models_listed() {
856 let all = EmbeddingModelType::all();
857 assert!(all.len() >= 38); assert!(all.contains(&EmbeddingModelType::BgeSmallEnV15));
859 assert!(all.contains(&EmbeddingModelType::MultilingualE5Large));
860 }
861}