Skip to main content

blazen_embed_fastembed/
provider.rs

1//! The [`FastEmbedModel`] type providing local embeddings via fastembed.
2
3use std::fmt;
4use std::sync::{Arc, Mutex};
5
6use crate::FastEmbedOptions;
7
8/// Error type for fastembed operations.
9#[derive(Debug)]
10pub enum FastEmbedError {
11    /// The model name was not recognised by fastembed.
12    UnknownModel(String),
13    /// The fastembed model failed to initialise.
14    Init(String),
15    /// An embedding operation failed.
16    Embed(String),
17    /// The internal mutex was poisoned.
18    MutexPoisoned(String),
19    /// A blocking task panicked.
20    TaskPanicked(String),
21}
22
23impl fmt::Display for FastEmbedError {
24    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25        match self {
26            Self::UnknownModel(msg) => write!(f, "unknown fastembed model: {msg}"),
27            Self::Init(msg) => write!(f, "fastembed init failed: {msg}"),
28            Self::Embed(msg) => write!(f, "fastembed embed failed: {msg}"),
29            Self::MutexPoisoned(msg) => write!(f, "fastembed mutex poisoned: {msg}"),
30            Self::TaskPanicked(msg) => write!(f, "fastembed blocking task panicked: {msg}"),
31        }
32    }
33}
34
35impl std::error::Error for FastEmbedError {}
36
37/// Response from a fastembed embedding operation.
38#[derive(Debug, Clone)]
39pub struct FastEmbedResponse {
40    /// The embedding vectors -- one per input text.
41    pub embeddings: Vec<Vec<f32>>,
42    /// The model identifier that produced these embeddings.
43    pub model: String,
44}
45
46/// A local embedding model backed by [`fastembed`] (ONNX Runtime).
47///
48/// Constructed via [`FastEmbedModel::from_options`]. The underlying
49/// `fastembed::TextEmbedding` is synchronous, so all calls to
50/// [`FastEmbedModel::embed`] are dispatched onto Tokio's blocking thread pool
51/// via [`tokio::task::spawn_blocking`].
52pub struct FastEmbedModel {
53    /// The fastembed model handle. Wrapped in `Arc<Mutex<...>>` because
54    /// `TextEmbedding::embed` takes `&mut self` and we need to move the
55    /// handle into `spawn_blocking` closures.
56    model: Arc<Mutex<fastembed::TextEmbedding>>,
57    /// The model identifier string returned by `model_id()`.
58    model_id: String,
59    /// Embedding dimensionality for this model.
60    dims: usize,
61    /// Optional batch size override.
62    batch_size: Option<usize>,
63}
64
65// `fastembed::TextEmbedding` is `Send` (it contains `ort::Session` which is
66// `Send`). `Arc<Mutex<T: Send>>` is `Send + Sync`, so `FastEmbedModel`
67// auto-derives both traits.
68
69impl FastEmbedModel {
70    /// Construct a new [`FastEmbedModel`] from the given options.
71    ///
72    /// This is a blocking operation that may download model weights from
73    /// `HuggingFace` on first use. Call from a context where blocking is
74    /// acceptable (e.g. application startup), or wrap in
75    /// [`tokio::task::spawn_blocking`].
76    ///
77    /// # Errors
78    ///
79    /// Returns [`FastEmbedError`] if the fastembed model fails to initialise
80    /// (e.g. unknown model name, network error during download).
81    pub fn from_options(opts: FastEmbedOptions) -> Result<Self, FastEmbedError> {
82        // Resolve the fastembed EmbeddingModel enum variant.
83        let fe_model = if let Some(ref name) = opts.model_name {
84            name.parse::<fastembed::EmbeddingModel>()
85                .map_err(|e| FastEmbedError::UnknownModel(format!("\"{name}\": {e}")))?
86        } else {
87            fastembed::EmbeddingModel::default()
88        };
89
90        // Look up the model info to get dimensions.
91        let model_info =
92            <fastembed::EmbeddingModel as fastembed::ModelTrait>::get_model_info(&fe_model)
93                .ok_or_else(|| {
94                    FastEmbedError::Init(format!("no model info found for {fe_model:?}"))
95                })?;
96        let dims = model_info.dim;
97        let model_code = model_info.model_code.clone();
98
99        // Build init options.
100        let mut init_opts = fastembed::TextInitOptions::new(fe_model);
101        if let Some(cache_dir) = opts.cache_dir {
102            init_opts = init_opts.with_cache_dir(cache_dir);
103        }
104        if let Some(show) = opts.show_download_progress {
105            init_opts = init_opts.with_show_download_progress(show);
106        }
107
108        let te = fastembed::TextEmbedding::try_new(init_opts)
109            .map_err(|e| FastEmbedError::Init(e.to_string()))?;
110
111        Ok(Self {
112            model: Arc::new(Mutex::new(te)),
113            model_id: model_code,
114            dims,
115            batch_size: opts.max_batch_size,
116        })
117    }
118
119    /// The model identifier (e.g. `"Xenova/bge-small-en-v1.5"`).
120    #[must_use]
121    pub fn model_id(&self) -> &str {
122        &self.model_id
123    }
124
125    /// Embedding vector dimensionality for this model.
126    #[must_use]
127    pub fn dimensions(&self) -> usize {
128        self.dims
129    }
130
131    /// Embed one or more texts, returning one vector per input text.
132    ///
133    /// The fastembed crate's embed method is synchronous -- this function
134    /// dispatches the work onto Tokio's blocking thread pool via
135    /// [`tokio::task::spawn_blocking`] to avoid starving the async runtime.
136    ///
137    /// # Errors
138    ///
139    /// Returns [`FastEmbedError`] if the underlying fastembed call fails or
140    /// the blocking task panics.
141    pub async fn embed(&self, texts: &[String]) -> Result<FastEmbedResponse, FastEmbedError> {
142        if texts.is_empty() {
143            return Ok(FastEmbedResponse {
144                embeddings: vec![],
145                model: self.model_id.clone(),
146            });
147        }
148
149        // Clone inputs and the Arc handle so we can move them into the
150        // blocking closure.
151        let texts_owned: Vec<String> = texts.to_vec();
152        let batch_size = self.batch_size;
153        let model_id = self.model_id.clone();
154        let model_handle = Arc::clone(&self.model);
155
156        let embeddings = tokio::task::spawn_blocking(move || {
157            let mut model = model_handle
158                .lock()
159                .map_err(|e| FastEmbedError::MutexPoisoned(e.to_string()))?;
160            let result: Vec<Vec<f32>> = model
161                .embed(&texts_owned, batch_size)
162                .map_err(|e| FastEmbedError::Embed(e.to_string()))?;
163            Ok::<Vec<Vec<f32>>, FastEmbedError>(result)
164        })
165        .await
166        .map_err(|e| FastEmbedError::TaskPanicked(e.to_string()))??;
167
168        Ok(FastEmbedResponse {
169            embeddings,
170            model: model_id,
171        })
172    }
173}
174
175#[cfg(test)]
176mod tests {
177    use super::*;
178
179    #[test]
180    #[ignore = "requires model download from HuggingFace"]
181    fn from_options_default_loads_model() {
182        let model = FastEmbedModel::from_options(FastEmbedOptions::default())
183            .expect("should create model with default options");
184        assert!(model.dimensions() > 0);
185        assert!(!model.model_id().is_empty());
186    }
187
188    #[tokio::test]
189    #[ignore = "requires model download from HuggingFace"]
190    async fn embed_returns_correct_count() {
191        let model = FastEmbedModel::from_options(FastEmbedOptions::default())
192            .expect("should create model with default options");
193        let response = model
194            .embed(&["hello".into(), "world".into()])
195            .await
196            .expect("embedding should succeed");
197        assert_eq!(response.embeddings.len(), 2);
198        assert!(!response.embeddings[0].is_empty());
199        assert_eq!(response.embeddings[0].len(), model.dimensions());
200    }
201
202    #[tokio::test]
203    async fn embed_empty_input_returns_empty() {
204        // This test does NOT require model download because we short-circuit
205        // on empty input. But we still need a model instance, so we skip if
206        // the model is not cached locally.
207        let Ok(model) = FastEmbedModel::from_options(FastEmbedOptions::default()) else {
208            eprintln!("skipping embed_empty_input_returns_empty: model not available");
209            return;
210        };
211        let response = model.embed(&[]).await.expect("empty embed should succeed");
212        assert!(response.embeddings.is_empty());
213    }
214}