1use crate::{EmbedError, EmbeddingBackend};
2use async_trait::async_trait;
3use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
4use std::sync::Arc;
5use tokio::sync::Mutex;
6
7pub struct FastembedBackend {
16 id: String,
17 dimension: u16,
18 inner: Arc<Mutex<TextEmbedding>>,
19}
20
21impl std::fmt::Debug for FastembedBackend {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 f.debug_struct("FastembedBackend")
24 .field("id", &self.id)
25 .field("dimension", &self.dimension)
26 .finish()
27 }
28}
29
30impl FastembedBackend {
31 pub async fn new(model_id: &str, model_path: Option<&str>) -> Result<Self, EmbedError> {
35 let em = pick_model(model_id)?;
36 let dimension = em_dimension(&em);
37 let mut opts = InitOptions::new(em);
38 if let Some(path) = model_path {
39 opts = opts.with_cache_dir(path.into());
40 }
41 let model = tokio::task::spawn_blocking(move || TextEmbedding::try_new(opts))
42 .await
43 .map_err(|e| EmbedError::ModelLoad(e.to_string()))?
44 .map_err(|e| EmbedError::ModelLoad(e.to_string()))?;
45 Ok(Self {
46 id: format!("fastembed/{model_id}"),
47 dimension,
48 inner: Arc::new(Mutex::new(model)),
49 })
50 }
51}
52
53#[async_trait]
54impl EmbeddingBackend for FastembedBackend {
55 fn id(&self) -> &str {
56 &self.id
57 }
58 fn dimension(&self) -> u16 {
59 self.dimension
60 }
61 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbedError> {
62 if texts.is_empty() {
63 return Ok(vec![]);
64 }
65 let inner = self.inner.clone();
66 let owned: Vec<String> = texts.to_vec();
67 let dim = self.dimension;
68 tokio::task::spawn_blocking(move || {
69 let guard = inner.blocking_lock();
70 let vecs = guard
71 .embed(owned, None)
72 .map_err(|e| EmbedError::Request(e.to_string()))?;
73 for v in &vecs {
74 if v.len() != dim as usize {
75 return Err(EmbedError::DimensionMismatch {
76 got: v.len() as u16,
77 expected: dim,
78 });
79 }
80 }
81 Ok(vecs)
82 })
83 .await
84 .map_err(|e| EmbedError::Internal(e.to_string()))?
85 }
86}
87
88fn pick_model(id: &str) -> Result<EmbeddingModel, EmbedError> {
89 match id {
90 "bge-small-en-v1.5" => Ok(EmbeddingModel::BGESmallENV15),
91 "bge-base-en-v1.5" => Ok(EmbeddingModel::BGEBaseENV15),
92 "all-MiniLM-L6-v2" => Ok(EmbeddingModel::AllMiniLML6V2),
93 other => Err(EmbedError::NotConfigured(format!(
94 "unknown fastembed model: {other}"
95 ))),
96 }
97}
98
99fn em_dimension(em: &EmbeddingModel) -> u16 {
100 match em {
101 EmbeddingModel::BGESmallENV15 => 384,
102 EmbeddingModel::BGEBaseENV15 => 768,
103 EmbeddingModel::AllMiniLML6V2 => 384,
104 other => unreachable!(
105 "em_dimension: pick_model accepted model {other:?} but em_dimension has no arm. Add the dimension here."),
106 }
107}