memory_mcp/embedding/
candle.rs1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::path::PathBuf;
3use std::sync::{Arc, Mutex};
4
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config as BertConfig};
8use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
9use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
10
11use super::EmbeddingBackend;
12use crate::error::MemoryError;
13
14pub const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
16
17pub struct CandleEmbeddingEngine {
22 inner: Arc<Mutex<CandleInner>>,
23 dim: usize,
24}
25
26struct CandleInner {
27 model: BertModel,
28 tokenizer: Tokenizer,
29 device: Device,
30}
31
32impl CandleEmbeddingEngine {
33 pub fn new() -> Result<Self, MemoryError> {
38 let device = Device::Cpu;
39
40 let (config, mut tokenizer, weights_path) =
41 load_model_files().map_err(|e| MemoryError::Embedding(e.to_string()))?;
42
43 tokenizer.with_padding(Some(PaddingParams {
45 strategy: tokenizers::PaddingStrategy::BatchLongest,
46 ..Default::default()
47 }));
48 tokenizer
49 .with_truncation(Some(TruncationParams {
50 max_length: 512,
51 ..Default::default()
52 }))
53 .map_err(|e| MemoryError::Embedding(format!("failed to set truncation: {e}")))?;
54
55 let vb = unsafe {
60 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
61 .map_err(|e| MemoryError::Embedding(format!("failed to load weights: {e}")))?
62 };
63
64 let model = BertModel::load(vb, &config)
65 .map_err(|e| MemoryError::Embedding(format!("failed to build BERT model: {e}")))?;
66
67 let dim = config.hidden_size;
68
69 Ok(Self {
70 inner: Arc::new(Mutex::new(CandleInner {
71 model,
72 tokenizer,
73 device,
74 })),
75 dim,
76 })
77 }
78}
79
80#[async_trait::async_trait]
81impl EmbeddingBackend for CandleEmbeddingEngine {
82 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
83 let arc = Arc::clone(&self.inner);
84 let texts = texts.to_vec();
85 let batch_size = texts.len();
86 let dim = self.dim;
87
88 let span = tracing::debug_span!(
91 "embedding.embed",
92 batch_size,
93 dimensions = dim,
94 model = MODEL_ID,
95 );
96
97 tokio::task::spawn_blocking(move || {
98 let _enter = span.enter();
99 let guard = arc.lock().unwrap_or_else(|poisoned| {
100 tracing::warn!("embedding mutex was poisoned — clearing poison and continuing");
101 poisoned.into_inner()
102 });
103 catch_unwind(AssertUnwindSafe(|| embed_batch(&guard, &texts))).unwrap_or_else(
104 |panic_payload| {
105 let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
106 (*s).to_string()
107 } else if let Some(s) = panic_payload.downcast_ref::<String>() {
108 s.clone()
109 } else {
110 "unknown panic in embedding engine".to_string()
111 };
112 Err(MemoryError::Embedding(format!(
113 "embedding engine panicked: {msg}"
114 )))
115 },
116 )
117 })
118 .await
119 .map_err(|e| MemoryError::Join(e.to_string()))?
120 }
121
122 fn dimensions(&self) -> usize {
123 self.dim
124 }
125}
126
127fn load_model_files() -> anyhow::Result<(BertConfig, Tokenizer, PathBuf)> {
138 let _span = tracing::info_span!("embedding.load_model", model = MODEL_ID).entered();
139
140 let cache = Cache::from_env();
141 let hf_repo = Repo::new(MODEL_ID.to_string(), RepoType::Model);
142
143 let cached = cache.repo(hf_repo.clone()).get("model.safetensors");
145 if cached.is_none() {
146 tracing::warn!(
147 model = MODEL_ID,
148 "embedding model not found in cache — downloading from HuggingFace Hub \
149 (this may take a minute on first run; use `memory-mcp warmup` to pre-populate)"
150 );
151 } else {
152 tracing::info!(model = MODEL_ID, "loading embedding model from cache");
153 }
154
155 let api = ApiBuilder::from_env().with_progress(false).build()?;
158 let repo = api.repo(hf_repo);
159
160 let start = std::time::Instant::now();
161 let config_path = repo.get("config.json")?;
162 let tokenizer_path = repo.get("tokenizer.json")?;
163 let weights_path = repo.get("model.safetensors")?;
164 tracing::info!(
165 elapsed_ms = start.elapsed().as_millis(),
166 "model files ready"
167 );
168
169 let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
170 let tokenizer = Tokenizer::from_file(&tokenizer_path)
171 .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
172
173 Ok((config, tokenizer, weights_path))
174}
175
176const MAX_BATCH_SIZE: usize = 64;
184
185fn embed_batch(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
190 let _span = tracing::debug_span!("embedding.embed_batch", batch_size = texts.len()).entered();
191
192 if texts.is_empty() {
193 return Ok(Vec::new());
194 }
195
196 let mut results = Vec::with_capacity(texts.len());
197 for chunk in texts.chunks(MAX_BATCH_SIZE) {
198 results.extend(embed_chunk(inner, chunk)?);
199 }
200 Ok(results)
201}
202
203fn embed_chunk(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
211 let _span = tracing::debug_span!("embedding.embed_chunk", chunk_size = texts.len()).entered();
212 debug_assert!(!texts.is_empty(), "embed_chunk called with empty texts");
213
214 let encodings = inner
215 .tokenizer
216 .encode_batch(texts.to_vec(), true)
217 .map_err(|e| MemoryError::Embedding(format!("tokenization failed: {e}")))?;
218
219 let batch_size = encodings.len();
220 let seq_len = encodings[0].get_ids().len();
221
222 if let Some((i, enc)) = encodings
226 .iter()
227 .enumerate()
228 .find(|(_, e)| e.get_ids().len() != seq_len)
229 {
230 return Err(MemoryError::Embedding(format!(
231 "padding invariant violated: encoding[0] has {seq_len} tokens \
232 but encoding[{i}] has {} — check tokenizer padding config",
233 enc.get_ids().len(),
234 )));
235 }
236
237 let all_ids: Vec<u32> = encodings
238 .iter()
239 .flat_map(|e| e.get_ids().to_vec())
240 .collect();
241 let all_type_ids: Vec<u32> = encodings
242 .iter()
243 .flat_map(|e| e.get_type_ids().to_vec())
244 .collect();
245 let all_masks: Vec<u32> = encodings
246 .iter()
247 .flat_map(|e| e.get_attention_mask().to_vec())
248 .collect();
249
250 let input_ids = Tensor::new(all_ids.as_slice(), &inner.device)
251 .and_then(|t| t.reshape((batch_size, seq_len)))
252 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
253
254 let token_type_ids = Tensor::new(all_type_ids.as_slice(), &inner.device)
255 .and_then(|t| t.reshape((batch_size, seq_len)))
256 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
257
258 let attention_mask = Tensor::new(all_masks.as_slice(), &inner.device)
259 .and_then(|t| t.reshape((batch_size, seq_len)))
260 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
261
262 let embeddings = inner
263 .model
264 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
265 .map_err(|e| MemoryError::Embedding(format!("BERT forward pass failed: {e}")))?;
266
267 let mut results = Vec::with_capacity(batch_size);
269 for i in 0..batch_size {
270 let cls = embeddings
271 .get(i)
272 .and_then(|seq| seq.get(0))
273 .map_err(|e| MemoryError::Embedding(format!("CLS extraction failed: {e}")))?;
274
275 let norm = cls
278 .sqr()
279 .and_then(|s| s.sum_all())
280 .and_then(|s| s.sqrt())
281 .and_then(|n| n.maximum(1e-12))
282 .and_then(|n| cls.broadcast_div(&n))
283 .map_err(|e| MemoryError::Embedding(format!("L2 normalisation failed: {e}")))?;
284
285 let vector: Vec<f32> = norm
286 .to_vec1()
287 .map_err(|e| MemoryError::Embedding(format!("tensor to vec failed: {e}")))?;
288
289 results.push(vector);
290 }
291
292 Ok(results)
293}