1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::path::PathBuf;
3use std::sync::mpsc;
4
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config as BertConfig};
8use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
9use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
10use tokio::sync::oneshot;
11use tokio::time::{timeout, Duration};
12
13use super::EmbeddingBackend;
14use crate::error::MemoryError;
15use crate::health::SubsystemReporter;
16
17pub const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
19
20type EmbedRequest = (
25 Vec<String>,
26 oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
27);
28
29pub struct CandleEmbeddingEngine {
39 tx: Option<mpsc::SyncSender<EmbedRequest>>,
41 worker: Option<std::thread::JoinHandle<()>>,
42 dim: usize,
43 embed_timeout: Duration,
44 reporter: SubsystemReporter,
45}
46
47struct CandleInner {
48 model: BertModel,
49 tokenizer: Tokenizer,
50 device: Device,
51}
52
53impl CandleEmbeddingEngine {
54 pub fn new(
71 embed_timeout: Duration,
72 queue_size: usize,
73 reporter: SubsystemReporter,
74 ) -> Result<Self, MemoryError> {
75 let device = Device::Cpu;
76
77 let (config, mut tokenizer, weights_path) =
78 load_model_files().map_err(|e| MemoryError::Embedding(e.to_string()))?;
79
80 tokenizer.with_padding(Some(PaddingParams {
82 strategy: tokenizers::PaddingStrategy::BatchLongest,
83 ..Default::default()
84 }));
85 tokenizer
86 .with_truncation(Some(TruncationParams {
87 max_length: 512,
88 ..Default::default()
89 }))
90 .map_err(|e| MemoryError::Embedding(format!("failed to set truncation: {e}")))?;
91
92 let vb = unsafe {
97 VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
98 .map_err(|e| MemoryError::Embedding(format!("failed to load weights: {e}")))?
99 };
100
101 let model = BertModel::load(vb, &config)
102 .map_err(|e| MemoryError::Embedding(format!("failed to build BERT model: {e}")))?;
103
104 let dim = config.hidden_size;
105
106 let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(queue_size);
107
108 let worker = std::thread::Builder::new()
109 .name("embed-worker".into())
110 .spawn(move || {
111 let inner = CandleInner {
112 model,
113 tokenizer,
114 device,
115 };
116 worker_loop(inner, dim, rx);
117 })
118 .map_err(|e| MemoryError::Embedding(format!("failed to spawn embed worker: {e}")))?;
119
120 Ok(Self {
121 tx: Some(tx),
122 worker: Some(worker),
123 dim,
124 embed_timeout,
125 reporter,
126 })
127 }
128
129 #[cfg(test)]
136 fn with_worker(
137 tx: mpsc::SyncSender<EmbedRequest>,
138 dim: usize,
139 embed_timeout: Duration,
140 ) -> Self {
141 Self {
142 tx: Some(tx),
143 worker: None,
144 dim,
145 embed_timeout,
146 reporter: SubsystemReporter::new(),
147 }
148 }
149}
150
151impl Drop for CandleEmbeddingEngine {
152 fn drop(&mut self) {
153 drop(self.tx.take());
155 if let Some(handle) = self.worker.take() {
156 let _ = handle.join();
157 }
158 }
159}
160
161fn worker_loop(mut inner: CandleInner, dim: usize, rx: mpsc::Receiver<EmbedRequest>) {
168 for (texts, reply_tx) in rx {
169 let span = tracing::debug_span!(
170 "embedding.embed",
171 batch_size = texts.len(),
172 dimensions = dim,
173 model = MODEL_ID,
174 );
175 let _enter = span.enter();
176
177 let mut panicked = false;
178 let result = catch_unwind(AssertUnwindSafe(|| embed_batch(&inner, &texts))).unwrap_or_else(
179 |panic_payload| {
180 panicked = true;
181 let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
182 (*s).to_string()
183 } else if let Some(s) = panic_payload.downcast_ref::<String>() {
184 s.clone()
185 } else {
186 "unknown panic in embedding engine".to_string()
187 };
188 tracing::warn!(error = %msg, "embedding engine panicked — recovering");
189 Err(MemoryError::Embedding(format!(
190 "embedding engine panicked: {msg}"
191 )))
192 },
193 );
194
195 let _ = reply_tx.send(result);
196
197 if panicked {
198 inner.tokenizer.with_padding(Some(PaddingParams {
199 strategy: tokenizers::PaddingStrategy::BatchLongest,
200 ..Default::default()
201 }));
202 let _ = inner.tokenizer.with_truncation(Some(TruncationParams {
203 max_length: 512,
204 ..Default::default()
205 }));
206 }
207 }
208}
209
210#[async_trait::async_trait]
211impl EmbeddingBackend for CandleEmbeddingEngine {
212 async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
213 let (reply_tx, reply_rx) = oneshot::channel();
214
215 let tx = self
216 .tx
217 .as_ref()
218 .ok_or_else(|| MemoryError::Embedding("embedding engine has been shut down".into()))?;
219
220 tx.try_send((texts.to_vec(), reply_tx))
221 .map_err(|e| match e {
222 mpsc::TrySendError::Full(_) => {
223 MemoryError::Embedding("embedding worker is busy — try again".into())
224 }
225 mpsc::TrySendError::Disconnected(_) => {
226 MemoryError::Embedding("embedding worker has exited — restart required".into())
227 }
228 })?;
229
230 let result = match timeout(self.embed_timeout, reply_rx).await {
231 Ok(Ok(result)) => result,
232 Ok(Err(_)) => Err(MemoryError::Embedding(
235 "embedding worker dropped the reply channel unexpectedly".into(),
236 )),
237 Err(_elapsed) => Err(MemoryError::Embedding(format!(
238 "embedding timed out after {:.1}s — the worker will recover automatically",
239 self.embed_timeout.as_secs_f64(),
240 ))),
241 };
242
243 match &result {
245 Ok(_) => self.reporter.report_ok(),
246 Err(_) => self.reporter.report_err("embed failed"),
247 }
248
249 result
250 }
251
252 fn dimensions(&self) -> usize {
253 self.dim
254 }
255}
256
257fn load_model_files() -> anyhow::Result<(BertConfig, Tokenizer, PathBuf)> {
268 let _span = tracing::info_span!("embedding.load_model", model = MODEL_ID).entered();
269
270 let cache = Cache::from_env();
271 let hf_repo = Repo::new(MODEL_ID.to_string(), RepoType::Model);
272
273 let cached = cache.repo(hf_repo.clone()).get("model.safetensors");
275 if cached.is_none() {
276 tracing::warn!(
277 model = MODEL_ID,
278 "embedding model not found in cache — downloading from HuggingFace Hub \
279 (this may take a minute on first run; use `memory-mcp warmup` to pre-populate)"
280 );
281 } else {
282 tracing::info!(model = MODEL_ID, "loading embedding model from cache");
283 }
284
285 let api = ApiBuilder::from_env().with_progress(false).build()?;
288 let repo = api.repo(hf_repo);
289
290 let start = std::time::Instant::now();
291 let config_path = repo.get("config.json")?;
292 let tokenizer_path = repo.get("tokenizer.json")?;
293 let weights_path = repo.get("model.safetensors")?;
294 tracing::info!(
295 elapsed_ms = start.elapsed().as_millis(),
296 "model files ready"
297 );
298
299 let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
300 let tokenizer = Tokenizer::from_file(&tokenizer_path)
301 .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
302
303 Ok((config, tokenizer, weights_path))
304}
305
306const MAX_BATCH_SIZE: usize = 64;
314
315fn embed_batch(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
320 let _span = tracing::debug_span!("embedding.embed_batch", batch_size = texts.len()).entered();
321
322 if texts.is_empty() {
323 return Ok(Vec::new());
324 }
325
326 let mut results = Vec::with_capacity(texts.len());
327 for chunk in texts.chunks(MAX_BATCH_SIZE) {
328 results.extend(embed_chunk(inner, chunk)?);
329 }
330 Ok(results)
331}
332
333fn embed_chunk(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
341 let _span = tracing::debug_span!("embedding.embed_chunk", chunk_size = texts.len()).entered();
342 debug_assert!(!texts.is_empty(), "embed_chunk called with empty texts");
343
344 let encodings = inner
345 .tokenizer
346 .encode_batch(texts.to_vec(), true)
347 .map_err(|e| MemoryError::Embedding(format!("tokenization failed: {e}")))?;
348
349 let batch_size = encodings.len();
350 let seq_len = encodings[0].get_ids().len();
351
352 if let Some((i, enc)) = encodings
356 .iter()
357 .enumerate()
358 .find(|(_, e)| e.get_ids().len() != seq_len)
359 {
360 return Err(MemoryError::Embedding(format!(
361 "padding invariant violated: encoding[0] has {seq_len} tokens \
362 but encoding[{i}] has {} — check tokenizer padding config",
363 enc.get_ids().len(),
364 )));
365 }
366
367 let all_ids: Vec<u32> = encodings
368 .iter()
369 .flat_map(|e| e.get_ids().to_vec())
370 .collect();
371 let all_type_ids: Vec<u32> = encodings
372 .iter()
373 .flat_map(|e| e.get_type_ids().to_vec())
374 .collect();
375 let all_masks: Vec<u32> = encodings
376 .iter()
377 .flat_map(|e| e.get_attention_mask().to_vec())
378 .collect();
379
380 let input_ids = Tensor::new(all_ids.as_slice(), &inner.device)
381 .and_then(|t| t.reshape((batch_size, seq_len)))
382 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
383
384 let token_type_ids = Tensor::new(all_type_ids.as_slice(), &inner.device)
385 .and_then(|t| t.reshape((batch_size, seq_len)))
386 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
387
388 let attention_mask = Tensor::new(all_masks.as_slice(), &inner.device)
389 .and_then(|t| t.reshape((batch_size, seq_len)))
390 .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
391
392 let embeddings = inner
393 .model
394 .forward(&input_ids, &token_type_ids, Some(&attention_mask))
395 .map_err(|e| MemoryError::Embedding(format!("BERT forward pass failed: {e}")))?;
396
397 let mut results = Vec::with_capacity(batch_size);
399 for i in 0..batch_size {
400 let cls = embeddings
401 .get(i)
402 .and_then(|seq| seq.get(0))
403 .map_err(|e| MemoryError::Embedding(format!("CLS extraction failed: {e}")))?;
404
405 let norm = cls
408 .sqr()
409 .and_then(|s| s.sum_all())
410 .and_then(|s| s.sqrt())
411 .and_then(|n| n.maximum(1e-12))
412 .and_then(|n| cls.broadcast_div(&n))
413 .map_err(|e| MemoryError::Embedding(format!("L2 normalisation failed: {e}")))?;
414
415 let vector: Vec<f32> = norm
416 .to_vec1()
417 .map_err(|e| MemoryError::Embedding(format!("tensor to vec failed: {e}")))?;
418
419 results.push(vector);
420 }
421
422 Ok(results)
423}
424
425#[cfg(test)]
430mod tests {
431 use std::sync::{Arc, Barrier};
432 use std::time::Duration;
433
434 use super::*;
435
436 fn fake_engine<F>(timeout: Duration, handler: F) -> CandleEmbeddingEngine
442 where
443 F: Fn(Vec<String>, oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>) + Send + 'static,
444 {
445 let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
446 std::thread::spawn(move || {
447 for (texts, reply_tx) in rx {
448 handler(texts, reply_tx);
449 }
450 });
451 CandleEmbeddingEngine::with_worker(tx, 4, timeout)
452 }
453
454 fn ok_handler(
456 texts: Vec<String>,
457 reply_tx: oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
458 ) {
459 let vecs = texts.iter().map(|_| vec![0.0f32; 4]).collect();
460 let _ = reply_tx.send(Ok(vecs));
461 }
462
463 #[tokio::test]
464 async fn happy_path_returns_vectors() {
465 let engine = fake_engine(Duration::from_secs(5), ok_handler);
466 let result = engine
467 .embed(&["hello".to_string(), "world".to_string()])
468 .await;
469 let vecs = result.expect("embed should succeed");
470 assert_eq!(vecs.len(), 2);
471 assert_eq!(vecs[0].len(), 4);
472 }
473
474 #[tokio::test]
475 async fn timeout_returns_error_and_worker_recovers() {
476 let barrier = Arc::new(Barrier::new(2));
479 let barrier2 = Arc::clone(&barrier);
480
481 let engine = fake_engine(Duration::from_millis(50), move |texts, reply_tx| {
482 if texts[0] == "slow" {
483 barrier2.wait();
485 let _ = reply_tx.send(Ok(vec![vec![0.0; 4]]));
488 barrier2.wait();
490 } else {
491 ok_handler(texts, reply_tx);
492 }
493 });
494
495 let err = engine
497 .embed(&["slow".to_string()])
498 .await
499 .expect_err("slow embed should time out");
500 assert!(
501 err.to_string().contains("timed out"),
502 "expected timeout error, got: {err}"
503 );
504
505 barrier.wait();
507 barrier.wait();
508
509 let result = engine.embed(&["fast".to_string()]).await;
511 assert!(
512 result.is_ok(),
513 "engine should recover after timeout: {result:?}"
514 );
515 }
516
517 #[tokio::test]
518 async fn disconnected_worker_returns_error() {
519 let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
520 drop(rx);
522 let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
523
524 let err = engine
525 .embed(&["anything".to_string()])
526 .await
527 .expect_err("disconnected worker should error");
528 assert!(
529 err.to_string().contains("exited"),
530 "expected 'exited' in error, got: {err}"
531 );
532 }
533
534 #[tokio::test]
535 async fn busy_worker_returns_error() {
536 let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
540
541 let (filler_tx, _filler_rx) = oneshot::channel::<Result<Vec<Vec<f32>>, MemoryError>>();
543 tx.send((vec!["fill".to_string()], filler_tx)).unwrap();
544
545 let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
547 let err = engine
548 .embed(&["overflow".to_string()])
549 .await
550 .expect_err("full channel should error");
551 assert!(
552 err.to_string().contains("busy"),
553 "expected 'busy' in error, got: {err}"
554 );
555
556 drop(rx); }
558}