Skip to main content

memoir_core/embedding/
onnx.rs

1//! [`EmbeddingModel`] implementation backed by `fastembed`.
2
3use std::sync::{Arc, Mutex};
4
5use super::{EmbeddingError, EmbeddingModel};
6
7const ONNX_DIMENSIONS: usize = 384;
8
9/// Default [`EmbeddingModel`] backed by `fastembed`'s BGE-small-en-v1.5.
10///
11/// Produces 384-dimension vectors. The model file is downloaded on first
12/// construction (~50 MB) and cached locally by `fastembed`.
13pub struct OnnxEmbedding {
14    model: Arc<Mutex<fastembed::TextEmbedding>>,
15}
16
17impl std::fmt::Debug for OnnxEmbedding {
18    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
19        f.debug_struct("OnnxEmbedding").finish_non_exhaustive()
20    }
21}
22
23impl OnnxEmbedding {
24    /// Initializes the embedder, downloading the model file if not cached.
25    ///
26    /// # Errors
27    ///
28    /// Returns [`EmbeddingError::Init`] when the model cannot be loaded —
29    /// typically a download failure on first use or a corrupted cache.
30    pub fn new() -> Result<Self, EmbeddingError> {
31        let options = fastembed::InitOptions::new(fastembed::EmbeddingModel::BGESmallENV15);
32        let model = fastembed::TextEmbedding::try_new(options)
33            .map_err(|e| EmbeddingError::Init(e.to_string()))?;
34        Ok(Self {
35            model: Arc::new(Mutex::new(model)),
36        })
37    }
38}
39
40impl EmbeddingModel for OnnxEmbedding {
41    fn embed(&self, text: &str) -> impl std::future::Future<Output = Result<Vec<f32>, EmbeddingError>> + Send {
42        let model = self.model.clone();
43        let text = text.to_owned();
44        async move {
45            tokio::task::spawn_blocking(move || {
46                let mut guard = model
47                    .lock()
48                    .map_err(|e| EmbeddingError::Embed(format!("model lock poisoned: {e}")))?;
49                let mut results = guard
50                    .embed(vec![&text], None)
51                    .map_err(|e| EmbeddingError::Embed(e.to_string()))?;
52                results
53                    .pop()
54                    .ok_or_else(|| EmbeddingError::Embed("empty result from model".into()))
55            })
56            .await
57            .map_err(|e| EmbeddingError::Embed(format!("join error: {e}")))?
58        }
59    }
60
61    fn dimensions(&self) -> usize {
62        ONNX_DIMENSIONS
63    }
64}
65
66#[cfg(test)]
67mod tests {
68    use super::*;
69
70    #[test]
71    fn should_report_onnx_dimensions_as_384() {
72        assert_eq!(ONNX_DIMENSIONS, 384);
73    }
74}