Skip to main content

anamnesis_embedder/
local.rs

1//! Local `EmbeddingProvider` implemented on top of `fastembed-rs`.
2//!
3//! Gated behind the `local-fastembed` cargo feature so dev iterations
4//! that don't need the ONNX runtime can compile fast.
5
6use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9use anamnesis_core::embedding::{EmbeddingProvider, EmbeddingTask, ModelId};
10use anamnesis_core::error::{Error, Result};
11use async_trait::async_trait;
12
13use crate::registry::CuratedModel;
14
15/// `EmbeddingProvider` backed by a fastembed-managed ONNX model.
16pub struct LocalFastembedProvider {
17    model_info: &'static CuratedModel,
18    model_id: ModelId,
19    cache_dir: PathBuf,
20    // `TextEmbedding::embed` is `&mut`, so wrap in Mutex for `&self` access
21    // from the async trait. The queue worker calls one batch at a time;
22    // contention is a non-issue.
23    inner: Mutex<fastembed::TextEmbedding>,
24}
25
26impl std::fmt::Debug for LocalFastembedProvider {
27    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28        f.debug_struct("LocalFastembedProvider")
29            .field("model_id", &self.model_id)
30            .field("dim", &self.model_info.dim)
31            .field("cache_dir", &self.cache_dir)
32            .finish()
33    }
34}
35
36impl LocalFastembedProvider {
37    /// Build a provider for the curated `key` (see `registry::REGISTRY`).
38    /// Downloads the model on first use; subsequent runs read from cache.
39    pub fn new(key: &str, cache_dir: impl AsRef<Path>) -> Result<Self> {
40        let info = crate::registry::by_key(key).ok_or_else(|| {
41            Error::Other(format!(
42                "unknown curated model: {key} (try one of: {})",
43                crate::registry::available().join(", ")
44            ))
45        })?;
46        if !info.is_local {
47            return Err(Error::Other(format!(
48                "model {key} is a cloud provider; use the cloud provider instead"
49            )));
50        }
51        let cache_dir = cache_dir.as_ref().to_path_buf();
52        std::fs::create_dir_all(&cache_dir).map_err(Error::Io)?;
53        let fast_model = map_to_fastembed(info)?;
54        let opts = fastembed::InitOptions::new(fast_model).with_cache_dir(cache_dir.clone());
55        let inner = fastembed::TextEmbedding::try_new(opts)
56            .map_err(|e| Error::Other(format!("fastembed init {key}: {e}")))?;
57        Ok(Self {
58            model_info: info,
59            model_id: ModelId::new("local", info.key, 1),
60            cache_dir,
61            inner: Mutex::new(inner),
62        })
63    }
64
65    /// Where the model files are cached.
66    pub fn cache_dir(&self) -> &Path {
67        &self.cache_dir
68    }
69
70    /// The curated model entry this provider serves.
71    pub fn model_info(&self) -> &'static CuratedModel {
72        self.model_info
73    }
74
75    fn prefixed(&self, texts: &[&str], task: EmbeddingTask) -> Vec<String> {
76        let prefix = match task {
77            EmbeddingTask::Query => self.model_info.query_prefix,
78            EmbeddingTask::Document => self.model_info.doc_prefix,
79        };
80        match prefix {
81            Some(p) => texts.iter().map(|t| format!("{p}{t}")).collect(),
82            None => texts.iter().map(|t| (*t).to_owned()).collect(),
83        }
84    }
85}
86
87#[async_trait]
88impl EmbeddingProvider for LocalFastembedProvider {
89    fn model_id(&self) -> ModelId {
90        self.model_id.clone()
91    }
92
93    fn dim(&self) -> u16 {
94        self.model_info.dim
95    }
96
97    async fn embed_batch(&self, texts: &[&str], task: EmbeddingTask) -> Result<Vec<Vec<f32>>> {
98        let inputs = self.prefixed(texts, task);
99        // Synchronous CPU-bound call. The embedding worker is single-batch
100        // at a time so blocking the runtime is acceptable; users who run
101        // many parallel embedders should drive each on its own runtime.
102        let guard = self.inner.lock().expect("provider inner mutex poisoned");
103        guard
104            .embed(inputs, None)
105            .map_err(|e| Error::Other(format!("fastembed embed: {e}")))
106    }
107}
108
109fn map_to_fastembed(info: &CuratedModel) -> Result<fastembed::EmbeddingModel> {
110    use fastembed::EmbeddingModel as FE;
111    Ok(match info.key {
112        "default" => FE::MultilingualE5Small,
113        "tiny" => FE::AllMiniLML6V2Q,
114        "en" => FE::BGESmallENV15,
115        "multi-strong" => FE::MultilingualE5Base,
116        other => {
117            return Err(Error::Other(format!(
118                "no fastembed mapping for curated model: {other}"
119            )))
120        }
121    })
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use std::sync::atomic::{AtomicU64, Ordering};
128
129    static FE_CACHE_TMP_NONCE: AtomicU64 = AtomicU64::new(0);
130
131    fn tmp_cache() -> PathBuf {
132        let nonce = std::time::SystemTime::now()
133            .duration_since(std::time::UNIX_EPOCH)
134            .unwrap()
135            .as_nanos();
136        let seq = FE_CACHE_TMP_NONCE.fetch_add(1, Ordering::Relaxed);
137        let p = std::env::temp_dir().join(format!(
138            "anamnesis-fe-cache-{nonce}-{pid}-{seq}",
139            pid = std::process::id()
140        ));
141        std::fs::create_dir_all(&p).unwrap();
142        p
143    }
144
145    #[test]
146    fn unknown_key_errors() {
147        let r = LocalFastembedProvider::new("nope-not-a-model", tmp_cache());
148        assert!(r.is_err());
149        let msg = format!("{}", r.unwrap_err());
150        assert!(msg.contains("unknown curated model"));
151        assert!(msg.contains("default")); // suggestion list rendered
152    }
153
154    #[test]
155    fn cloud_voyage_rejected_by_local_provider() {
156        let r = LocalFastembedProvider::new("cloud-voyage", tmp_cache());
157        let err = r.unwrap_err();
158        assert!(format!("{err}").contains("cloud provider"));
159    }
160
161    #[test]
162    fn every_local_key_has_a_fastembed_mapping() {
163        for m in crate::registry::local_only() {
164            assert!(
165                map_to_fastembed(m).is_ok(),
166                "missing fastembed mapping for {}",
167                m.key
168            );
169        }
170    }
171
172    // The instantiation + embed tests actually download the model
173    // (~120 MB for `default`). They're gated behind FASTEMBED_DOWNLOAD=1
174    // so plain `cargo test` stays fast and CI can opt in.
175    fn allow_download() -> bool {
176        std::env::var("FASTEMBED_DOWNLOAD").ok().as_deref() == Some("1")
177    }
178
179    #[tokio::test]
180    async fn end_to_end_embed_with_real_model() {
181        if !allow_download() {
182            eprintln!("skipping: FASTEMBED_DOWNLOAD != 1");
183            return;
184        }
185        let provider = LocalFastembedProvider::new("default", tmp_cache()).unwrap();
186        assert_eq!(provider.dim(), 384);
187        assert_eq!(provider.model_id().as_str(), "local:default:1");
188        let v = provider
189            .embed_batch(&["hello", "用户偏好"], EmbeddingTask::Document)
190            .await
191            .unwrap();
192        assert_eq!(v.len(), 2);
193        assert_eq!(v[0].len(), 384);
194        assert_eq!(v[1].len(), 384);
195        // E5 returns L2-normalized vectors → magnitude ~1.0
196        let mag = (v[0].iter().map(|x| x * x).sum::<f32>()).sqrt();
197        assert!(
198            (mag - 1.0).abs() < 0.1,
199            "expected ~L2-normalized vector, got mag {mag}"
200        );
201    }
202}