1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::path::PathBuf;
3use std::sync::mpsc;
4
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config as BertConfig};
8use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
9use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
10use tokio::sync::oneshot;
11use tokio::time::{timeout, Duration};
12
13use super::EmbeddingBackend;
14use crate::error::MemoryError;
15
16pub const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
18
19type EmbedRequest = (
24 Vec<String>,
25 oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
26);
27
28pub struct CandleEmbeddingEngine {
38 tx: Option<mpsc::SyncSender<EmbedRequest>>,
40 worker: Option<std::thread::JoinHandle<()>>,
41 dim: usize,
42 embed_timeout: Duration,
43}
44
45struct CandleInner {
46 model: BertModel,
47 tokenizer: Tokenizer,
48 device: Device,
49}
50
51impl CandleEmbeddingEngine {
52 pub fn new(embed_timeout: Duration, queue_size: usize) -> Result<Self, MemoryError> {
65 let device = Device::Cpu;
66
67 let (config, mut tokenizer, weights_path) =
68 load_model_files().map_err(|e| MemoryError::Embedding(e.to_string()))?;
69
70 tokenizer.with_padding(Some(PaddingParams {
72 strategy: tokenizers::PaddingStrategy::BatchLongest,
73 ..Default::default()
74 }));
75 tokenizer
76 .with_truncation(Some(TruncationParams {
77 max_length: 512,
78 ..Default::default()
79 }))
80 .map_err(|e| MemoryError::Embedding(format!("failed to set truncation: {e}")))?;
81
82 let vb = unsafe {
87 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
88 .map_err(|e| MemoryError::Embedding(format!("failed to load weights: {e}")))?
89 };
90
91 let model = BertModel::load(vb, &config)
92 .map_err(|e| MemoryError::Embedding(format!("failed to build BERT model: {e}")))?;
93
94 let dim = config.hidden_size;
95
96 let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(queue_size);
97
98 let worker = std::thread::Builder::new()
99 .name("embed-worker".into())
100 .spawn(move || {
101 let inner = CandleInner {
102 model,
103 tokenizer,
104 device,
105 };
106 worker_loop(inner, dim, rx);
107 })
108 .map_err(|e| MemoryError::Embedding(format!("failed to spawn embed worker: {e}")))?;
109
110 Ok(Self {
111 tx: Some(tx),
112 worker: Some(worker),
113 dim,
114 embed_timeout,
115 })
116 }
117
118 #[cfg(test)]
125 fn with_worker(
126 tx: mpsc::SyncSender<EmbedRequest>,
127 dim: usize,
128 embed_timeout: Duration,
129 ) -> Self {
130 Self {
131 tx: Some(tx),
132 worker: None,
133 dim,
134 embed_timeout,
135 }
136 }
137}
138
139impl Drop for CandleEmbeddingEngine {
140 fn drop(&mut self) {
141 drop(self.tx.take());
143 if let Some(handle) = self.worker.take() {
144 let _ = handle.join();
145 }
146 }
147}
148
149fn worker_loop(mut inner: CandleInner, dim: usize, rx: mpsc::Receiver<EmbedRequest>) {
156 for (texts, reply_tx) in rx {
157 let span = tracing::debug_span!(
158 "embedding.embed",
159 batch_size = texts.len(),
160 dimensions = dim,
161 model = MODEL_ID,
162 );
163 let _enter = span.enter();
164
165 let mut panicked = false;
166 let result = catch_unwind(AssertUnwindSafe(|| embed_batch(&inner, &texts))).unwrap_or_else(
167 |panic_payload| {
168 panicked = true;
169 let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
170 (*s).to_string()
171 } else if let Some(s) = panic_payload.downcast_ref::<String>() {
172 s.clone()
173 } else {
174 "unknown panic in embedding engine".to_string()
175 };
176 tracing::warn!(error = %msg, "embedding engine panicked — recovering");
177 Err(MemoryError::Embedding(format!(
178 "embedding engine panicked: {msg}"
179 )))
180 },
181 );
182
183 let _ = reply_tx.send(result);
184
185 if panicked {
186 inner.tokenizer.with_padding(Some(PaddingParams {
187 strategy: tokenizers::PaddingStrategy::BatchLongest,
188 ..Default::default()
189 }));
190 let _ = inner.tokenizer.with_truncation(Some(TruncationParams {
191 max_length: 512,
192 ..Default::default()
193 }));
194 }
195 }
196}
197
198#[async_trait::async_trait]
199impl EmbeddingBackend for CandleEmbeddingEngine {
200 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
201 let (reply_tx, reply_rx) = oneshot::channel();
202
203 let tx = self
204 .tx
205 .as_ref()
206 .ok_or_else(|| MemoryError::Embedding("embedding engine has been shut down".into()))?;
207
208 tx.try_send((texts.to_vec(), reply_tx))
209 .map_err(|e| match e {
210 mpsc::TrySendError::Full(_) => {
211 MemoryError::Embedding("embedding worker is busy — try again".into())
212 }
213 mpsc::TrySendError::Disconnected(_) => {
214 MemoryError::Embedding("embedding worker has exited — restart required".into())
215 }
216 })?;
217
218 match timeout(self.embed_timeout, reply_rx).await {
219 Ok(Ok(result)) => result,
220 Ok(Err(_)) => Err(MemoryError::Embedding(
223 "embedding worker dropped the reply channel unexpectedly".into(),
224 )),
225 Err(_elapsed) => Err(MemoryError::Embedding(format!(
226 "embedding timed out after {:.1}s — the worker will recover automatically",
227 self.embed_timeout.as_secs_f64(),
228 ))),
229 }
230 }
231
232 fn dimensions(&self) -> usize {
233 self.dim
234 }
235}
236
237fn load_model_files() -> anyhow::Result<(BertConfig, Tokenizer, PathBuf)> {
248 let _span = tracing::info_span!("embedding.load_model", model = MODEL_ID).entered();
249
250 let cache = Cache::from_env();
251 let hf_repo = Repo::new(MODEL_ID.to_string(), RepoType::Model);
252
253 let cached = cache.repo(hf_repo.clone()).get("model.safetensors");
255 if cached.is_none() {
256 tracing::warn!(
257 model = MODEL_ID,
258 "embedding model not found in cache — downloading from HuggingFace Hub \
259 (this may take a minute on first run; use `memory-mcp warmup` to pre-populate)"
260 );
261 } else {
262 tracing::info!(model = MODEL_ID, "loading embedding model from cache");
263 }
264
265 let api = ApiBuilder::from_env().with_progress(false).build()?;
268 let repo = api.repo(hf_repo);
269
270 let start = std::time::Instant::now();
271 let config_path = repo.get("config.json")?;
272 let tokenizer_path = repo.get("tokenizer.json")?;
273 let weights_path = repo.get("model.safetensors")?;
274 tracing::info!(
275 elapsed_ms = start.elapsed().as_millis(),
276 "model files ready"
277 );
278
279 let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
280 let tokenizer = Tokenizer::from_file(&tokenizer_path)
281 .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
282
283 Ok((config, tokenizer, weights_path))
284}
285
286const MAX_BATCH_SIZE: usize = 64;
294
295fn embed_batch(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
300 let _span = tracing::debug_span!("embedding.embed_batch", batch_size = texts.len()).entered();
301
302 if texts.is_empty() {
303 return Ok(Vec::new());
304 }
305
306 let mut results = Vec::with_capacity(texts.len());
307 for chunk in texts.chunks(MAX_BATCH_SIZE) {
308 results.extend(embed_chunk(inner, chunk)?);
309 }
310 Ok(results)
311}
312
313fn embed_chunk(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
321 let _span = tracing::debug_span!("embedding.embed_chunk", chunk_size = texts.len()).entered();
322 debug_assert!(!texts.is_empty(), "embed_chunk called with empty texts");
323
324 let encodings = inner
325 .tokenizer
326 .encode_batch(texts.to_vec(), true)
327 .map_err(|e| MemoryError::Embedding(format!("tokenization failed: {e}")))?;
328
329 let batch_size = encodings.len();
330 let seq_len = encodings[0].get_ids().len();
331
332 if let Some((i, enc)) = encodings
336 .iter()
337 .enumerate()
338 .find(|(_, e)| e.get_ids().len() != seq_len)
339 {
340 return Err(MemoryError::Embedding(format!(
341 "padding invariant violated: encoding[0] has {seq_len} tokens \
342 but encoding[{i}] has {} — check tokenizer padding config",
343 enc.get_ids().len(),
344 )));
345 }
346
347 let all_ids: Vec<u32> = encodings
348 .iter()
349 .flat_map(|e| e.get_ids().to_vec())
350 .collect();
351 let all_type_ids: Vec<u32> = encodings
352 .iter()
353 .flat_map(|e| e.get_type_ids().to_vec())
354 .collect();
355 let all_masks: Vec<u32> = encodings
356 .iter()
357 .flat_map(|e| e.get_attention_mask().to_vec())
358 .collect();
359
360 let input_ids = Tensor::new(all_ids.as_slice(), &inner.device)
361 .and_then(|t| t.reshape((batch_size, seq_len)))
362 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
363
364 let token_type_ids = Tensor::new(all_type_ids.as_slice(), &inner.device)
365 .and_then(|t| t.reshape((batch_size, seq_len)))
366 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
367
368 let attention_mask = Tensor::new(all_masks.as_slice(), &inner.device)
369 .and_then(|t| t.reshape((batch_size, seq_len)))
370 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
371
372 let embeddings = inner
373 .model
374 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
375 .map_err(|e| MemoryError::Embedding(format!("BERT forward pass failed: {e}")))?;
376
377 let mut results = Vec::with_capacity(batch_size);
379 for i in 0..batch_size {
380 let cls = embeddings
381 .get(i)
382 .and_then(|seq| seq.get(0))
383 .map_err(|e| MemoryError::Embedding(format!("CLS extraction failed: {e}")))?;
384
385 let norm = cls
388 .sqr()
389 .and_then(|s| s.sum_all())
390 .and_then(|s| s.sqrt())
391 .and_then(|n| n.maximum(1e-12))
392 .and_then(|n| cls.broadcast_div(&n))
393 .map_err(|e| MemoryError::Embedding(format!("L2 normalisation failed: {e}")))?;
394
395 let vector: Vec<f32> = norm
396 .to_vec1()
397 .map_err(|e| MemoryError::Embedding(format!("tensor to vec failed: {e}")))?;
398
399 results.push(vector);
400 }
401
402 Ok(results)
403}
404
405#[cfg(test)]
410mod tests {
411 use std::sync::{Arc, Barrier};
412 use std::time::Duration;
413
414 use super::*;
415
416 fn fake_engine<F>(timeout: Duration, handler: F) -> CandleEmbeddingEngine
422 where
423 F: Fn(Vec<String>, oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>) + Send + 'static,
424 {
425 let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
426 std::thread::spawn(move || {
427 for (texts, reply_tx) in rx {
428 handler(texts, reply_tx);
429 }
430 });
431 CandleEmbeddingEngine::with_worker(tx, 4, timeout)
432 }
433
434 fn ok_handler(
436 texts: Vec<String>,
437 reply_tx: oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
438 ) {
439 let vecs = texts.iter().map(|_| vec![0.0f32; 4]).collect();
440 let _ = reply_tx.send(Ok(vecs));
441 }
442
443 #[tokio::test]
444 async fn happy_path_returns_vectors() {
445 let engine = fake_engine(Duration::from_secs(5), ok_handler);
446 let result = engine
447 .embed(&["hello".to_string(), "world".to_string()])
448 .await;
449 let vecs = result.expect("embed should succeed");
450 assert_eq!(vecs.len(), 2);
451 assert_eq!(vecs[0].len(), 4);
452 }
453
454 #[tokio::test]
455 async fn timeout_returns_error_and_worker_recovers() {
456 let barrier = Arc::new(Barrier::new(2));
459 let barrier2 = Arc::clone(&barrier);
460
461 let engine = fake_engine(Duration::from_millis(50), move |texts, reply_tx| {
462 if texts[0] == "slow" {
463 barrier2.wait();
465 let _ = reply_tx.send(Ok(vec![vec![0.0; 4]]));
468 barrier2.wait();
470 } else {
471 ok_handler(texts, reply_tx);
472 }
473 });
474
475 let err = engine
477 .embed(&["slow".to_string()])
478 .await
479 .expect_err("slow embed should time out");
480 assert!(
481 err.to_string().contains("timed out"),
482 "expected timeout error, got: {err}"
483 );
484
485 barrier.wait();
487 barrier.wait();
488
489 let result = engine.embed(&["fast".to_string()]).await;
491 assert!(
492 result.is_ok(),
493 "engine should recover after timeout: {result:?}"
494 );
495 }
496
497 #[tokio::test]
498 async fn disconnected_worker_returns_error() {
499 let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
500 drop(rx);
502 let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
503
504 let err = engine
505 .embed(&["anything".to_string()])
506 .await
507 .expect_err("disconnected worker should error");
508 assert!(
509 err.to_string().contains("exited"),
510 "expected 'exited' in error, got: {err}"
511 );
512 }
513
514 #[tokio::test]
515 async fn busy_worker_returns_error() {
516 let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
520
521 let (filler_tx, _filler_rx) = oneshot::channel::<Result<Vec<Vec<f32>>, MemoryError>>();
523 tx.send((vec!["fill".to_string()], filler_tx)).unwrap();
524
525 let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
527 let err = engine
528 .embed(&["overflow".to_string()])
529 .await
530 .expect_err("full channel should error");
531 assert!(
532 err.to_string().contains("busy"),
533 "expected 'busy' in error, got: {err}"
534 );
535
536 drop(rx); }
538}