Skip to main content

memory_mcp/embedding/
candle.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::path::PathBuf;
3use std::sync::mpsc;
4
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config as BertConfig};
8use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
9use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
10use tokio::sync::oneshot;
11use tokio::time::{timeout, Duration};
12
13use super::EmbeddingBackend;
14use crate::error::MemoryError;
15use crate::health::SubsystemReporter;
16
17/// HuggingFace model ID. Only BGE-small-en-v1.5 is supported currently.
18pub const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
19
20// ---------------------------------------------------------------------------
21// Worker thread
22// ---------------------------------------------------------------------------
23
24type EmbedRequest = (
25    Vec<String>,
26    oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
27);
28
29/// Pure-Rust embedding engine using candle for BERT inference.
30///
31/// Uses candle-transformers' BERT implementation with tokenizers for
32/// tokenisation. No C/C++ FFI dependencies — compiles on all platforms.
33///
34/// A dedicated OS thread owns the model exclusively. Async callers send work
35/// via a channel and await a oneshot reply. If a call times out, the caller
36/// gets an error immediately; the worker finishes its current task, discards
37/// the stale reply channel, and picks up the next request — no restart needed.
38pub struct CandleEmbeddingEngine {
39    // Option so Drop can take ownership of tx to close it before joining.
40    tx: Option<mpsc::SyncSender<EmbedRequest>>,
41    worker: Option<std::thread::JoinHandle<()>>,
42    dim: usize,
43    embed_timeout: Duration,
44    reporter: SubsystemReporter,
45}
46
47struct CandleInner {
48    model: BertModel,
49    tokenizer: Tokenizer,
50    device: Device,
51}
52
53impl CandleEmbeddingEngine {
54    /// Initialise the candle embedding engine.
55    ///
56    /// Downloads model weights from HuggingFace Hub on first use (cached
57    /// in the standard HF cache directory, respects `HF_HOME`).
58    ///
59    /// `embed_timeout` caps how long a single [`embed`](Self::embed) call may
60    /// block. If the worker is still running when the timeout fires, the caller
61    /// gets an error but the engine recovers automatically — no restart needed.
62    ///
63    /// `queue_size` sets the bounded channel capacity — how many requests can
64    /// queue behind the one being processed. Extra callers get an immediate
65    /// "busy" error.
66    ///
67    /// `reporter` receives `report_ok`/`report_err` after each embed call so
68    /// the `/readyz` handler can reflect the engine's operational state without
69    /// active probing.
70    pub fn new(
71        embed_timeout: Duration,
72        queue_size: usize,
73        reporter: SubsystemReporter,
74    ) -> Result<Self, MemoryError> {
75        let device = Device::Cpu;
76
77        let (config, mut tokenizer, weights_path) =
78            load_model_files().map_err(|e| MemoryError::Embedding(e.to_string()))?;
79
80        // Enable padding so encode_batch produces equal-length sequences.
81        tokenizer.with_padding(Some(PaddingParams {
82            strategy: tokenizers::PaddingStrategy::BatchLongest,
83            ..Default::default()
84        }));
85        tokenizer
86            .with_truncation(Some(TruncationParams {
87                max_length: 512,
88                ..Default::default()
89            }))
90            .map_err(|e| MemoryError::Embedding(format!("failed to set truncation: {e}")))?;
91
92        // SAFETY: `from_mmaped_safetensors` memory-maps the weights file. The
93        // caller must ensure the file is not modified for the lifetime of the
94        // resulting tensors. HuggingFace Hub writes cache files atomically and
95        // never modifies them in-place, so the mapping is stable.
96        let vb = unsafe {
97            VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
98                .map_err(|e| MemoryError::Embedding(format!("failed to load weights: {e}")))?
99        };
100
101        let model = BertModel::load(vb, &config)
102            .map_err(|e| MemoryError::Embedding(format!("failed to build BERT model: {e}")))?;
103
104        let dim = config.hidden_size;
105
106        let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(queue_size);
107
108        let worker = std::thread::Builder::new()
109            .name("embed-worker".into())
110            .spawn(move || {
111                let inner = CandleInner {
112                    model,
113                    tokenizer,
114                    device,
115                };
116                worker_loop(inner, dim, rx);
117            })
118            .map_err(|e| MemoryError::Embedding(format!("failed to spawn embed worker: {e}")))?;
119
120        Ok(Self {
121            tx: Some(tx),
122            worker: Some(worker),
123            dim,
124            embed_timeout,
125            reporter,
126        })
127    }
128
129    /// Construct an engine backed by a caller-supplied worker sender.
130    ///
131    /// Bypasses model loading entirely. The caller is responsible for spawning
132    /// a thread that reads from the other end of the channel. Used only in
133    /// tests to exercise the channel mechanics (timeout, disconnect, busy)
134    /// without needing the HuggingFace model cache.
135    #[cfg(test)]
136    fn with_worker(
137        tx: mpsc::SyncSender<EmbedRequest>,
138        dim: usize,
139        embed_timeout: Duration,
140    ) -> Self {
141        Self {
142            tx: Some(tx),
143            worker: None,
144            dim,
145            embed_timeout,
146            reporter: SubsystemReporter::new(),
147        }
148    }
149}
150
151impl Drop for CandleEmbeddingEngine {
152    fn drop(&mut self) {
153        // Close the channel first so the worker's `for ... in rx` loop exits.
154        drop(self.tx.take());
155        if let Some(handle) = self.worker.take() {
156            let _ = handle.join();
157        }
158    }
159}
160
161/// Main loop for the dedicated embedding worker thread.
162///
163/// Receives `(texts, reply_tx)` pairs, runs inference, and sends the result
164/// back on `reply_tx`. If the receiver was dropped (the async caller timed
165/// out), the send fails silently and the loop continues — this is the
166/// self-healing path.
167fn worker_loop(mut inner: CandleInner, dim: usize, rx: mpsc::Receiver<EmbedRequest>) {
168    for (texts, reply_tx) in rx {
169        let span = tracing::debug_span!(
170            "embedding.embed",
171            batch_size = texts.len(),
172            dimensions = dim,
173            model = MODEL_ID,
174        );
175        let _enter = span.enter();
176
177        let mut panicked = false;
178        let result = catch_unwind(AssertUnwindSafe(|| embed_batch(&inner, &texts))).unwrap_or_else(
179            |panic_payload| {
180                panicked = true;
181                let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
182                    (*s).to_string()
183                } else if let Some(s) = panic_payload.downcast_ref::<String>() {
184                    s.clone()
185                } else {
186                    "unknown panic in embedding engine".to_string()
187                };
188                tracing::warn!(error = %msg, "embedding engine panicked — recovering");
189                Err(MemoryError::Embedding(format!(
190                    "embedding engine panicked: {msg}"
191                )))
192            },
193        );
194
195        let _ = reply_tx.send(result);
196
197        if panicked {
198            inner.tokenizer.with_padding(Some(PaddingParams {
199                strategy: tokenizers::PaddingStrategy::BatchLongest,
200                ..Default::default()
201            }));
202            let _ = inner.tokenizer.with_truncation(Some(TruncationParams {
203                max_length: 512,
204                ..Default::default()
205            }));
206        }
207    }
208}
209
210#[async_trait::async_trait]
211impl EmbeddingBackend for CandleEmbeddingEngine {
212    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
213        let (reply_tx, reply_rx) = oneshot::channel();
214
215        let tx = self
216            .tx
217            .as_ref()
218            .ok_or_else(|| MemoryError::Embedding("embedding engine has been shut down".into()))?;
219
220        tx.try_send((texts.to_vec(), reply_tx))
221            .map_err(|e| match e {
222                mpsc::TrySendError::Full(_) => {
223                    MemoryError::Embedding("embedding worker is busy — try again".into())
224                }
225                mpsc::TrySendError::Disconnected(_) => {
226                    MemoryError::Embedding("embedding worker has exited — restart required".into())
227                }
228            })?;
229
230        let result = match timeout(self.embed_timeout, reply_rx).await {
231            Ok(Ok(result)) => result,
232            // Fires if the worker drops reply_tx without sending (e.g. a
233            // double-panic that escapes catch_unwind, or a panic in span setup).
234            Ok(Err(_)) => Err(MemoryError::Embedding(
235                "embedding worker dropped the reply channel unexpectedly".into(),
236            )),
237            Err(_elapsed) => Err(MemoryError::Embedding(format!(
238                "embedding timed out after {:.1}s — the worker will recover automatically",
239                self.embed_timeout.as_secs_f64(),
240            ))),
241        };
242
243        // Report operational state passively so /readyz reflects reality without probing.
244        match &result {
245            Ok(_) => self.reporter.report_ok(),
246            Err(_) => self.reporter.report_err("embed failed"),
247        }
248
249        result
250    }
251
252    fn dimensions(&self) -> usize {
253        self.dim
254    }
255}
256
257// ---------------------------------------------------------------------------
258// Model loading
259// ---------------------------------------------------------------------------
260
261/// Download (or retrieve from cache) the model files from HuggingFace Hub.
262///
263/// On first run (cold start), this downloads ~130 MB of model files from
264/// HuggingFace Hub. Subsequent starts use the local cache (`HF_HOME`).
265/// Use the `warmup` subcommand or a k8s init container to pre-populate the
266/// cache and avoid blocking the first server startup.
267fn load_model_files() -> anyhow::Result<(BertConfig, Tokenizer, PathBuf)> {
268    let _span = tracing::info_span!("embedding.load_model", model = MODEL_ID).entered();
269
270    let cache = Cache::from_env();
271    let hf_repo = Repo::new(MODEL_ID.to_string(), RepoType::Model);
272
273    // Check whether the heaviest file (model weights) is already cached.
274    let cached = cache.repo(hf_repo.clone()).get("model.safetensors");
275    if cached.is_none() {
276        tracing::warn!(
277            model = MODEL_ID,
278            "embedding model not found in cache — downloading from HuggingFace Hub \
279             (this may take a minute on first run; use `memory-mcp warmup` to pre-populate)"
280        );
281    } else {
282        tracing::info!(model = MODEL_ID, "loading embedding model from cache");
283    }
284
285    // Respect HF_HOME and HF_ENDPOINT env vars; disable indicatif progress
286    // bars since we are a headless server.
287    let api = ApiBuilder::from_env().with_progress(false).build()?;
288    let repo = api.repo(hf_repo);
289
290    let start = std::time::Instant::now();
291    let config_path = repo.get("config.json")?;
292    let tokenizer_path = repo.get("tokenizer.json")?;
293    let weights_path = repo.get("model.safetensors")?;
294    tracing::info!(
295        elapsed_ms = start.elapsed().as_millis(),
296        "model files ready"
297    );
298
299    let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
300    let tokenizer = Tokenizer::from_file(&tokenizer_path)
301        .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
302
303    Ok((config, tokenizer, weights_path))
304}
305
306// ---------------------------------------------------------------------------
307// Inference
308// ---------------------------------------------------------------------------
309
310/// Maximum texts per forward pass. BERT attention is O(batch × seq²) in memory;
311/// capping the batch avoids OOM on large reindex operations. 64 is conservative
312/// enough for CPU inference while still amortising per-batch overhead.
313const MAX_BATCH_SIZE: usize = 64;
314
315/// Embed texts through the BERT model, chunking into bounded forward passes.
316///
317/// Splits the input into chunks of at most [`MAX_BATCH_SIZE`] texts and runs
318/// each chunk through [`embed_chunk`], concatenating the results.
319fn embed_batch(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
320    let _span = tracing::debug_span!("embedding.embed_batch", batch_size = texts.len()).entered();
321
322    if texts.is_empty() {
323        return Ok(Vec::new());
324    }
325
326    let mut results = Vec::with_capacity(texts.len());
327    for chunk in texts.chunks(MAX_BATCH_SIZE) {
328        results.extend(embed_chunk(inner, chunk)?);
329    }
330    Ok(results)
331}
332
333/// Embed a single chunk of texts through the BERT model in one forward pass.
334///
335/// Texts are tokenised with padding (to the longest sequence in the chunk)
336/// and truncation (to 512 tokens), then passed through BERT together.
337/// An attention mask ensures padding tokens do not affect the output.
338/// CLS pooling extracts the first token's hidden state, which is then
339/// L2-normalised to produce unit vectors.
340fn embed_chunk(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
341    let _span = tracing::debug_span!("embedding.embed_chunk", chunk_size = texts.len()).entered();
342    debug_assert!(!texts.is_empty(), "embed_chunk called with empty texts");
343
344    let encodings = inner
345        .tokenizer
346        .encode_batch(texts.to_vec(), true)
347        .map_err(|e| MemoryError::Embedding(format!("tokenization failed: {e}")))?;
348
349    let batch_size = encodings.len();
350    let seq_len = encodings[0].get_ids().len();
351
352    // Verify padding produced uniform sequence lengths before allocating
353    // the flat token vectors. A mismatch here means the tokenizer's
354    // padding config was not applied (e.g. silently reset).
355    if let Some((i, enc)) = encodings
356        .iter()
357        .enumerate()
358        .find(|(_, e)| e.get_ids().len() != seq_len)
359    {
360        return Err(MemoryError::Embedding(format!(
361            "padding invariant violated: encoding[0] has {seq_len} tokens \
362             but encoding[{i}] has {} — check tokenizer padding config",
363            enc.get_ids().len(),
364        )));
365    }
366
367    let all_ids: Vec<u32> = encodings
368        .iter()
369        .flat_map(|e| e.get_ids().to_vec())
370        .collect();
371    let all_type_ids: Vec<u32> = encodings
372        .iter()
373        .flat_map(|e| e.get_type_ids().to_vec())
374        .collect();
375    let all_masks: Vec<u32> = encodings
376        .iter()
377        .flat_map(|e| e.get_attention_mask().to_vec())
378        .collect();
379
380    let input_ids = Tensor::new(all_ids.as_slice(), &inner.device)
381        .and_then(|t| t.reshape((batch_size, seq_len)))
382        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
383
384    let token_type_ids = Tensor::new(all_type_ids.as_slice(), &inner.device)
385        .and_then(|t| t.reshape((batch_size, seq_len)))
386        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
387
388    let attention_mask = Tensor::new(all_masks.as_slice(), &inner.device)
389        .and_then(|t| t.reshape((batch_size, seq_len)))
390        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
391
392    let embeddings = inner
393        .model
394        .forward(&input_ids, &token_type_ids, Some(&attention_mask))
395        .map_err(|e| MemoryError::Embedding(format!("BERT forward pass failed: {e}")))?;
396
397    // CLS pooling + L2 normalise each vector in the batch.
398    let mut results = Vec::with_capacity(batch_size);
399    for i in 0..batch_size {
400        let cls = embeddings
401            .get(i)
402            .and_then(|seq| seq.get(0))
403            .map_err(|e| MemoryError::Embedding(format!("CLS extraction failed: {e}")))?;
404
405        // L2 normalise with epsilon guard against division by zero
406        // (e.g. malformed model weights producing an all-zero CLS vector).
407        let norm = cls
408            .sqr()
409            .and_then(|s| s.sum_all())
410            .and_then(|s| s.sqrt())
411            .and_then(|n| n.maximum(1e-12))
412            .and_then(|n| cls.broadcast_div(&n))
413            .map_err(|e| MemoryError::Embedding(format!("L2 normalisation failed: {e}")))?;
414
415        let vector: Vec<f32> = norm
416            .to_vec1()
417            .map_err(|e| MemoryError::Embedding(format!("tensor to vec failed: {e}")))?;
418
419        results.push(vector);
420    }
421
422    Ok(results)
423}
424
425// ---------------------------------------------------------------------------
426// Tests
427// ---------------------------------------------------------------------------
428
429#[cfg(test)]
430mod tests {
431    use std::sync::{Arc, Barrier};
432    use std::time::Duration;
433
434    use super::*;
435
436    /// Build a fake engine whose worker applies `handler` to every request.
437    ///
438    /// `handler` receives the texts and the reply sender; it can sleep, panic,
439    /// drop the sender, or send any result — enabling controlled fault injection
440    /// without touching the model loading path.
441    fn fake_engine<F>(timeout: Duration, handler: F) -> CandleEmbeddingEngine
442    where
443        F: Fn(Vec<String>, oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>) + Send + 'static,
444    {
445        let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
446        std::thread::spawn(move || {
447            for (texts, reply_tx) in rx {
448                handler(texts, reply_tx);
449            }
450        });
451        CandleEmbeddingEngine::with_worker(tx, 4, timeout)
452    }
453
454    /// Worker that immediately returns a fixed-size zero vector per input.
455    fn ok_handler(
456        texts: Vec<String>,
457        reply_tx: oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
458    ) {
459        let vecs = texts.iter().map(|_| vec![0.0f32; 4]).collect();
460        let _ = reply_tx.send(Ok(vecs));
461    }
462
463    #[tokio::test]
464    async fn happy_path_returns_vectors() {
465        let engine = fake_engine(Duration::from_secs(5), ok_handler);
466        let result = engine
467            .embed(&["hello".to_string(), "world".to_string()])
468            .await;
469        let vecs = result.expect("embed should succeed");
470        assert_eq!(vecs.len(), 2);
471        assert_eq!(vecs[0].len(), 4);
472    }
473
474    #[tokio::test]
475    async fn timeout_returns_error_and_worker_recovers() {
476        // Barrier lets us prove the worker is still alive after the timeout
477        // fires on the first request.
478        let barrier = Arc::new(Barrier::new(2));
479        let barrier2 = Arc::clone(&barrier);
480
481        let engine = fake_engine(Duration::from_millis(50), move |texts, reply_tx| {
482            if texts[0] == "slow" {
483                // Block until the test signals us to proceed (after timeout fires).
484                barrier2.wait();
485                // Reply arrives after the caller's receiver was dropped — send
486                // fails silently, which is the self-healing path.
487                let _ = reply_tx.send(Ok(vec![vec![0.0; 4]]));
488                // Signal the test that we've finished processing the stale request.
489                barrier2.wait();
490            } else {
491                ok_handler(texts, reply_tx);
492            }
493        });
494
495        // First call times out.
496        let err = engine
497            .embed(&["slow".to_string()])
498            .await
499            .expect_err("slow embed should time out");
500        assert!(
501            err.to_string().contains("timed out"),
502            "expected timeout error, got: {err}"
503        );
504
505        // Unblock the worker and wait for it to finish the stale request.
506        barrier.wait();
507        barrier.wait();
508
509        // Second call should succeed — the worker recovered.
510        let result = engine.embed(&["fast".to_string()]).await;
511        assert!(
512            result.is_ok(),
513            "engine should recover after timeout: {result:?}"
514        );
515    }
516
517    #[tokio::test]
518    async fn disconnected_worker_returns_error() {
519        let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
520        // Drop the receiver immediately — worker is "dead".
521        drop(rx);
522        let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
523
524        let err = engine
525            .embed(&["anything".to_string()])
526            .await
527            .expect_err("disconnected worker should error");
528        assert!(
529            err.to_string().contains("exited"),
530            "expected 'exited' in error, got: {err}"
531        );
532    }
533
534    #[tokio::test]
535    async fn busy_worker_returns_error() {
536        // Channel capacity 0 is not allowed by SyncSender; use capacity 1 but
537        // send two requests without the worker consuming either.
538        // Easier: use a zero-sleep worker that we pre-fill the channel for.
539        let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
540
541        // Pre-fill the single channel slot by sending directly, bypassing embed().
542        let (filler_tx, _filler_rx) = oneshot::channel::<Result<Vec<Vec<f32>>, MemoryError>>();
543        tx.send((vec!["fill".to_string()], filler_tx)).unwrap();
544
545        // Now embed() hits try_send on a full channel.
546        let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
547        let err = engine
548            .embed(&["overflow".to_string()])
549            .await
550            .expect_err("full channel should error");
551        assert!(
552            err.to_string().contains("busy"),
553            "expected 'busy' in error, got: {err}"
554        );
555
556        drop(rx); // clean up
557    }
558}