anamnesis_embedder/
local.rs1use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8
9use anamnesis_core::embedding::{EmbeddingProvider, EmbeddingTask, ModelId};
10use anamnesis_core::error::{Error, Result};
11use async_trait::async_trait;
12
13use crate::registry::CuratedModel;
14
15pub struct LocalFastembedProvider {
17 model_info: &'static CuratedModel,
18 model_id: ModelId,
19 cache_dir: PathBuf,
20 inner: Mutex<fastembed::TextEmbedding>,
24}
25
26impl std::fmt::Debug for LocalFastembedProvider {
27 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
28 f.debug_struct("LocalFastembedProvider")
29 .field("model_id", &self.model_id)
30 .field("dim", &self.model_info.dim)
31 .field("cache_dir", &self.cache_dir)
32 .finish()
33 }
34}
35
36impl LocalFastembedProvider {
37 pub fn new(key: &str, cache_dir: impl AsRef<Path>) -> Result<Self> {
40 let info = crate::registry::by_key(key).ok_or_else(|| {
41 Error::Other(format!(
42 "unknown curated model: {key} (try one of: {})",
43 crate::registry::available().join(", ")
44 ))
45 })?;
46 if !info.is_local {
47 return Err(Error::Other(format!(
48 "model {key} is a cloud provider; use the cloud provider instead"
49 )));
50 }
51 let cache_dir = cache_dir.as_ref().to_path_buf();
52 std::fs::create_dir_all(&cache_dir).map_err(Error::Io)?;
53 let fast_model = map_to_fastembed(info)?;
54 let opts = fastembed::InitOptions::new(fast_model).with_cache_dir(cache_dir.clone());
55 let inner = fastembed::TextEmbedding::try_new(opts)
56 .map_err(|e| Error::Other(format!("fastembed init {key}: {e}")))?;
57 Ok(Self {
58 model_info: info,
59 model_id: ModelId::new("local", info.key, 1),
60 cache_dir,
61 inner: Mutex::new(inner),
62 })
63 }
64
65 pub fn cache_dir(&self) -> &Path {
67 &self.cache_dir
68 }
69
70 pub fn model_info(&self) -> &'static CuratedModel {
72 self.model_info
73 }
74
75 fn prefixed(&self, texts: &[&str], task: EmbeddingTask) -> Vec<String> {
76 let prefix = match task {
77 EmbeddingTask::Query => self.model_info.query_prefix,
78 EmbeddingTask::Document => self.model_info.doc_prefix,
79 };
80 match prefix {
81 Some(p) => texts.iter().map(|t| format!("{p}{t}")).collect(),
82 None => texts.iter().map(|t| (*t).to_owned()).collect(),
83 }
84 }
85}
86
87#[async_trait]
88impl EmbeddingProvider for LocalFastembedProvider {
89 fn model_id(&self) -> ModelId {
90 self.model_id.clone()
91 }
92
93 fn dim(&self) -> u16 {
94 self.model_info.dim
95 }
96
97 async fn embed_batch(&self, texts: &[&str], task: EmbeddingTask) -> Result<Vec<Vec<f32>>> {
98 let inputs = self.prefixed(texts, task);
99 let guard = self.inner.lock().expect("provider inner mutex poisoned");
103 guard
104 .embed(inputs, None)
105 .map_err(|e| Error::Other(format!("fastembed embed: {e}")))
106 }
107}
108
109fn map_to_fastembed(info: &CuratedModel) -> Result<fastembed::EmbeddingModel> {
110 use fastembed::EmbeddingModel as FE;
111 Ok(match info.key {
112 "default" => FE::MultilingualE5Small,
113 "tiny" => FE::AllMiniLML6V2Q,
114 "en" => FE::BGESmallENV15,
115 "multi-strong" => FE::MultilingualE5Base,
116 other => {
117 return Err(Error::Other(format!(
118 "no fastembed mapping for curated model: {other}"
119 )))
120 }
121 })
122}
123
124#[cfg(test)]
125mod tests {
126 use super::*;
127 use std::sync::atomic::{AtomicU64, Ordering};
128
129 static FE_CACHE_TMP_NONCE: AtomicU64 = AtomicU64::new(0);
130
131 fn tmp_cache() -> PathBuf {
132 let nonce = std::time::SystemTime::now()
133 .duration_since(std::time::UNIX_EPOCH)
134 .unwrap()
135 .as_nanos();
136 let seq = FE_CACHE_TMP_NONCE.fetch_add(1, Ordering::Relaxed);
137 let p = std::env::temp_dir().join(format!(
138 "anamnesis-fe-cache-{nonce}-{pid}-{seq}",
139 pid = std::process::id()
140 ));
141 std::fs::create_dir_all(&p).unwrap();
142 p
143 }
144
145 #[test]
146 fn unknown_key_errors() {
147 let r = LocalFastembedProvider::new("nope-not-a-model", tmp_cache());
148 assert!(r.is_err());
149 let msg = format!("{}", r.unwrap_err());
150 assert!(msg.contains("unknown curated model"));
151 assert!(msg.contains("default")); }
153
154 #[test]
155 fn cloud_voyage_rejected_by_local_provider() {
156 let r = LocalFastembedProvider::new("cloud-voyage", tmp_cache());
157 let err = r.unwrap_err();
158 assert!(format!("{err}").contains("cloud provider"));
159 }
160
161 #[test]
162 fn every_local_key_has_a_fastembed_mapping() {
163 for m in crate::registry::local_only() {
164 assert!(
165 map_to_fastembed(m).is_ok(),
166 "missing fastembed mapping for {}",
167 m.key
168 );
169 }
170 }
171
172 fn allow_download() -> bool {
176 std::env::var("FASTEMBED_DOWNLOAD").ok().as_deref() == Some("1")
177 }
178
179 #[tokio::test]
180 async fn end_to_end_embed_with_real_model() {
181 if !allow_download() {
182 eprintln!("skipping: FASTEMBED_DOWNLOAD != 1");
183 return;
184 }
185 let provider = LocalFastembedProvider::new("default", tmp_cache()).unwrap();
186 assert_eq!(provider.dim(), 384);
187 assert_eq!(provider.model_id().as_str(), "local:default:1");
188 let v = provider
189 .embed_batch(&["hello", "用户偏好"], EmbeddingTask::Document)
190 .await
191 .unwrap();
192 assert_eq!(v.len(), 2);
193 assert_eq!(v[0].len(), 384);
194 assert_eq!(v[1].len(), 384);
195 let mag = (v[0].iter().map(|x| x * x).sum::<f32>()).sqrt();
197 assert!(
198 (mag - 1.0).abs() < 0.1,
199 "expected ~L2-normalized vector, got mag {mag}"
200 );
201 }
202}