memory_mcp/embedding/
candle.rs1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::path::PathBuf;
3use std::sync::{Arc, Mutex};
4
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config as BertConfig};
8use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
9use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
10
11use super::EmbeddingBackend;
12use crate::error::MemoryError;
13
14pub const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
16
17pub struct CandleEmbeddingEngine {
22 inner: Arc<Mutex<CandleInner>>,
23 dim: usize,
24}
25
26struct CandleInner {
27 model: BertModel,
28 tokenizer: Tokenizer,
29 device: Device,
30}
31
32impl CandleEmbeddingEngine {
33 pub fn new() -> Result<Self, MemoryError> {
38 let device = Device::Cpu;
39
40 let (config, mut tokenizer, weights_path) =
41 load_model_files().map_err(|e| MemoryError::Embedding(e.to_string()))?;
42
43 tokenizer.with_padding(Some(PaddingParams {
45 strategy: tokenizers::PaddingStrategy::BatchLongest,
46 ..Default::default()
47 }));
48 tokenizer
49 .with_truncation(Some(TruncationParams {
50 max_length: 512,
51 ..Default::default()
52 }))
53 .map_err(|e| MemoryError::Embedding(format!("failed to set truncation: {e}")))?;
54
55 let vb = unsafe {
60 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
61 .map_err(|e| MemoryError::Embedding(format!("failed to load weights: {e}")))?
62 };
63
64 let model = BertModel::load(vb, &config)
65 .map_err(|e| MemoryError::Embedding(format!("failed to build BERT model: {e}")))?;
66
67 let dim = config.hidden_size;
68
69 Ok(Self {
70 inner: Arc::new(Mutex::new(CandleInner {
71 model,
72 tokenizer,
73 device,
74 })),
75 dim,
76 })
77 }
78}
79
80#[async_trait::async_trait]
81impl EmbeddingBackend for CandleEmbeddingEngine {
82 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
83 let arc = Arc::clone(&self.inner);
84 let texts = texts.to_vec();
85 tokio::task::spawn_blocking(move || {
86 let guard = arc.lock().unwrap_or_else(|poisoned| {
87 tracing::warn!("embedding mutex was poisoned — clearing poison and continuing");
88 poisoned.into_inner()
89 });
90 catch_unwind(AssertUnwindSafe(|| embed_batch(&guard, &texts))).unwrap_or_else(
91 |panic_payload| {
92 let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
93 (*s).to_string()
94 } else if let Some(s) = panic_payload.downcast_ref::<String>() {
95 s.clone()
96 } else {
97 "unknown panic in embedding engine".to_string()
98 };
99 Err(MemoryError::Embedding(format!(
100 "embedding engine panicked: {msg}"
101 )))
102 },
103 )
104 })
105 .await
106 .map_err(|e| MemoryError::Join(e.to_string()))?
107 }
108
109 fn dimensions(&self) -> usize {
110 self.dim
111 }
112}
113
114fn load_model_files() -> anyhow::Result<(BertConfig, Tokenizer, PathBuf)> {
125 let cache = Cache::from_env();
126 let hf_repo = Repo::new(MODEL_ID.to_string(), RepoType::Model);
127
128 let cached = cache.repo(hf_repo.clone()).get("model.safetensors");
130 if cached.is_none() {
131 tracing::warn!(
132 model = MODEL_ID,
133 "embedding model not found in cache — downloading from HuggingFace Hub \
134 (this may take a minute on first run; use `memory-mcp warmup` to pre-populate)"
135 );
136 } else {
137 tracing::info!(model = MODEL_ID, "loading embedding model from cache");
138 }
139
140 let api = ApiBuilder::from_env().with_progress(false).build()?;
143 let repo = api.repo(hf_repo);
144
145 let start = std::time::Instant::now();
146 let config_path = repo.get("config.json")?;
147 let tokenizer_path = repo.get("tokenizer.json")?;
148 let weights_path = repo.get("model.safetensors")?;
149 tracing::info!(
150 elapsed_ms = start.elapsed().as_millis(),
151 "model files ready"
152 );
153
154 let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
155 let tokenizer = Tokenizer::from_file(&tokenizer_path)
156 .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
157
158 Ok((config, tokenizer, weights_path))
159}
160
161const MAX_BATCH_SIZE: usize = 64;
169
170fn embed_batch(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
175 if texts.is_empty() {
176 return Ok(Vec::new());
177 }
178
179 let mut results = Vec::with_capacity(texts.len());
180 for chunk in texts.chunks(MAX_BATCH_SIZE) {
181 results.extend(embed_chunk(inner, chunk)?);
182 }
183 Ok(results)
184}
185
186fn embed_chunk(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
194 debug_assert!(!texts.is_empty(), "embed_chunk called with empty texts");
195
196 let encodings = inner
197 .tokenizer
198 .encode_batch(texts.to_vec(), true)
199 .map_err(|e| MemoryError::Embedding(format!("tokenization failed: {e}")))?;
200
201 let batch_size = encodings.len();
202 let seq_len = encodings[0].get_ids().len();
203
204 if let Some((i, enc)) = encodings
208 .iter()
209 .enumerate()
210 .find(|(_, e)| e.get_ids().len() != seq_len)
211 {
212 return Err(MemoryError::Embedding(format!(
213 "padding invariant violated: encoding[0] has {seq_len} tokens \
214 but encoding[{i}] has {} — check tokenizer padding config",
215 enc.get_ids().len(),
216 )));
217 }
218
219 let all_ids: Vec<u32> = encodings
220 .iter()
221 .flat_map(|e| e.get_ids().to_vec())
222 .collect();
223 let all_type_ids: Vec<u32> = encodings
224 .iter()
225 .flat_map(|e| e.get_type_ids().to_vec())
226 .collect();
227 let all_masks: Vec<u32> = encodings
228 .iter()
229 .flat_map(|e| e.get_attention_mask().to_vec())
230 .collect();
231
232 let input_ids = Tensor::new(all_ids.as_slice(), &inner.device)
233 .and_then(|t| t.reshape((batch_size, seq_len)))
234 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
235
236 let token_type_ids = Tensor::new(all_type_ids.as_slice(), &inner.device)
237 .and_then(|t| t.reshape((batch_size, seq_len)))
238 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
239
240 let attention_mask = Tensor::new(all_masks.as_slice(), &inner.device)
241 .and_then(|t| t.reshape((batch_size, seq_len)))
242 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
243
244 let embeddings = inner
245 .model
246 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
247 .map_err(|e| MemoryError::Embedding(format!("BERT forward pass failed: {e}")))?;
248
249 let mut results = Vec::with_capacity(batch_size);
251 for i in 0..batch_size {
252 let cls = embeddings
253 .get(i)
254 .and_then(|seq| seq.get(0))
255 .map_err(|e| MemoryError::Embedding(format!("CLS extraction failed: {e}")))?;
256
257 let norm = cls
260 .sqr()
261 .and_then(|s| s.sum_all())
262 .and_then(|s| s.sqrt())
263 .and_then(|n| n.maximum(1e-12))
264 .and_then(|n| cls.broadcast_div(&n))
265 .map_err(|e| MemoryError::Embedding(format!("L2 normalisation failed: {e}")))?;
266
267 let vector: Vec<f32> = norm
268 .to_vec1()
269 .map_err(|e| MemoryError::Embedding(format!("tensor to vec failed: {e}")))?;
270
271 results.push(vector);
272 }
273
274 Ok(results)
275}