Skip to main content

memory_mcp/embedding/
candle.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::path::PathBuf;
3use std::sync::mpsc;
4
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config as BertConfig};
8use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
9use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
10use tokio::sync::oneshot;
11use tokio::time::{timeout, Duration};
12
13use super::EmbeddingBackend;
14use crate::error::MemoryError;
15
16/// HuggingFace model ID. Only BGE-small-en-v1.5 is supported currently.
17pub const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
18
19// ---------------------------------------------------------------------------
20// Worker thread
21// ---------------------------------------------------------------------------
22
23type EmbedRequest = (
24    Vec<String>,
25    oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
26);
27
28/// Pure-Rust embedding engine using candle for BERT inference.
29///
30/// Uses candle-transformers' BERT implementation with tokenizers for
31/// tokenisation. No C/C++ FFI dependencies — compiles on all platforms.
32///
33/// A dedicated OS thread owns the model exclusively. Async callers send work
34/// via a channel and await a oneshot reply. If a call times out, the caller
35/// gets an error immediately; the worker finishes its current task, discards
36/// the stale reply channel, and picks up the next request — no restart needed.
37pub struct CandleEmbeddingEngine {
38    // Option so Drop can take ownership of tx to close it before joining.
39    tx: Option<mpsc::SyncSender<EmbedRequest>>,
40    worker: Option<std::thread::JoinHandle<()>>,
41    dim: usize,
42    embed_timeout: Duration,
43}
44
45struct CandleInner {
46    model: BertModel,
47    tokenizer: Tokenizer,
48    device: Device,
49}
50
51impl CandleEmbeddingEngine {
52    /// Initialise the candle embedding engine.
53    ///
54    /// Downloads model weights from HuggingFace Hub on first use (cached
55    /// in the standard HF cache directory, respects `HF_HOME`).
56    ///
57    /// `embed_timeout` caps how long a single [`embed`](Self::embed) call may
58    /// block. If the worker is still running when the timeout fires, the caller
59    /// gets an error but the engine recovers automatically — no restart needed.
60    ///
61    /// `queue_size` sets the bounded channel capacity — how many requests can
62    /// queue behind the one being processed. Extra callers get an immediate
63    /// "busy" error.
64    pub fn new(embed_timeout: Duration, queue_size: usize) -> Result<Self, MemoryError> {
65        let device = Device::Cpu;
66
67        let (config, mut tokenizer, weights_path) =
68            load_model_files().map_err(|e| MemoryError::Embedding(e.to_string()))?;
69
70        // Enable padding so encode_batch produces equal-length sequences.
71        tokenizer.with_padding(Some(PaddingParams {
72            strategy: tokenizers::PaddingStrategy::BatchLongest,
73            ..Default::default()
74        }));
75        tokenizer
76            .with_truncation(Some(TruncationParams {
77                max_length: 512,
78                ..Default::default()
79            }))
80            .map_err(|e| MemoryError::Embedding(format!("failed to set truncation: {e}")))?;
81
82        // SAFETY: `from_mmaped_safetensors` memory-maps the weights file. The
83        // caller must ensure the file is not modified for the lifetime of the
84        // resulting tensors. HuggingFace Hub writes cache files atomically and
85        // never modifies them in-place, so the mapping is stable.
86        let vb = unsafe {
87            VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
88                .map_err(|e| MemoryError::Embedding(format!("failed to load weights: {e}")))?
89        };
90
91        let model = BertModel::load(vb, &config)
92            .map_err(|e| MemoryError::Embedding(format!("failed to build BERT model: {e}")))?;
93
94        let dim = config.hidden_size;
95
96        let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(queue_size);
97
98        let worker = std::thread::Builder::new()
99            .name("embed-worker".into())
100            .spawn(move || {
101                let inner = CandleInner {
102                    model,
103                    tokenizer,
104                    device,
105                };
106                worker_loop(inner, dim, rx);
107            })
108            .map_err(|e| MemoryError::Embedding(format!("failed to spawn embed worker: {e}")))?;
109
110        Ok(Self {
111            tx: Some(tx),
112            worker: Some(worker),
113            dim,
114            embed_timeout,
115        })
116    }
117
118    /// Construct an engine backed by a caller-supplied worker sender.
119    ///
120    /// Bypasses model loading entirely. The caller is responsible for spawning
121    /// a thread that reads from the other end of the channel. Used only in
122    /// tests to exercise the channel mechanics (timeout, disconnect, busy)
123    /// without needing the HuggingFace model cache.
124    #[cfg(test)]
125    fn with_worker(
126        tx: mpsc::SyncSender<EmbedRequest>,
127        dim: usize,
128        embed_timeout: Duration,
129    ) -> Self {
130        Self {
131            tx: Some(tx),
132            worker: None,
133            dim,
134            embed_timeout,
135        }
136    }
137}
138
139impl Drop for CandleEmbeddingEngine {
140    fn drop(&mut self) {
141        // Close the channel first so the worker's `for ... in rx` loop exits.
142        drop(self.tx.take());
143        if let Some(handle) = self.worker.take() {
144            let _ = handle.join();
145        }
146    }
147}
148
149/// Main loop for the dedicated embedding worker thread.
150///
151/// Receives `(texts, reply_tx)` pairs, runs inference, and sends the result
152/// back on `reply_tx`. If the receiver was dropped (the async caller timed
153/// out), the send fails silently and the loop continues — this is the
154/// self-healing path.
155fn worker_loop(mut inner: CandleInner, dim: usize, rx: mpsc::Receiver<EmbedRequest>) {
156    for (texts, reply_tx) in rx {
157        let span = tracing::debug_span!(
158            "embedding.embed",
159            batch_size = texts.len(),
160            dimensions = dim,
161            model = MODEL_ID,
162        );
163        let _enter = span.enter();
164
165        let mut panicked = false;
166        let result = catch_unwind(AssertUnwindSafe(|| embed_batch(&inner, &texts))).unwrap_or_else(
167            |panic_payload| {
168                panicked = true;
169                let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
170                    (*s).to_string()
171                } else if let Some(s) = panic_payload.downcast_ref::<String>() {
172                    s.clone()
173                } else {
174                    "unknown panic in embedding engine".to_string()
175                };
176                tracing::warn!(error = %msg, "embedding engine panicked — recovering");
177                Err(MemoryError::Embedding(format!(
178                    "embedding engine panicked: {msg}"
179                )))
180            },
181        );
182
183        let _ = reply_tx.send(result);
184
185        if panicked {
186            inner.tokenizer.with_padding(Some(PaddingParams {
187                strategy: tokenizers::PaddingStrategy::BatchLongest,
188                ..Default::default()
189            }));
190            let _ = inner.tokenizer.with_truncation(Some(TruncationParams {
191                max_length: 512,
192                ..Default::default()
193            }));
194        }
195    }
196}
197
198#[async_trait::async_trait]
199impl EmbeddingBackend for CandleEmbeddingEngine {
200    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
201        let (reply_tx, reply_rx) = oneshot::channel();
202
203        let tx = self
204            .tx
205            .as_ref()
206            .ok_or_else(|| MemoryError::Embedding("embedding engine has been shut down".into()))?;
207
208        tx.try_send((texts.to_vec(), reply_tx))
209            .map_err(|e| match e {
210                mpsc::TrySendError::Full(_) => {
211                    MemoryError::Embedding("embedding worker is busy — try again".into())
212                }
213                mpsc::TrySendError::Disconnected(_) => {
214                    MemoryError::Embedding("embedding worker has exited — restart required".into())
215                }
216            })?;
217
218        match timeout(self.embed_timeout, reply_rx).await {
219            Ok(Ok(result)) => result,
220            // Fires if the worker drops reply_tx without sending (e.g. a
221            // double-panic that escapes catch_unwind, or a panic in span setup).
222            Ok(Err(_)) => Err(MemoryError::Embedding(
223                "embedding worker dropped the reply channel unexpectedly".into(),
224            )),
225            Err(_elapsed) => Err(MemoryError::Embedding(format!(
226                "embedding timed out after {:.1}s — the worker will recover automatically",
227                self.embed_timeout.as_secs_f64(),
228            ))),
229        }
230    }
231
232    fn dimensions(&self) -> usize {
233        self.dim
234    }
235}
236
237// ---------------------------------------------------------------------------
238// Model loading
239// ---------------------------------------------------------------------------
240
241/// Download (or retrieve from cache) the model files from HuggingFace Hub.
242///
243/// On first run (cold start), this downloads ~130 MB of model files from
244/// HuggingFace Hub. Subsequent starts use the local cache (`HF_HOME`).
245/// Use the `warmup` subcommand or a k8s init container to pre-populate the
246/// cache and avoid blocking the first server startup.
247fn load_model_files() -> anyhow::Result<(BertConfig, Tokenizer, PathBuf)> {
248    let _span = tracing::info_span!("embedding.load_model", model = MODEL_ID).entered();
249
250    let cache = Cache::from_env();
251    let hf_repo = Repo::new(MODEL_ID.to_string(), RepoType::Model);
252
253    // Check whether the heaviest file (model weights) is already cached.
254    let cached = cache.repo(hf_repo.clone()).get("model.safetensors");
255    if cached.is_none() {
256        tracing::warn!(
257            model = MODEL_ID,
258            "embedding model not found in cache — downloading from HuggingFace Hub \
259             (this may take a minute on first run; use `memory-mcp warmup` to pre-populate)"
260        );
261    } else {
262        tracing::info!(model = MODEL_ID, "loading embedding model from cache");
263    }
264
265    // Respect HF_HOME and HF_ENDPOINT env vars; disable indicatif progress
266    // bars since we are a headless server.
267    let api = ApiBuilder::from_env().with_progress(false).build()?;
268    let repo = api.repo(hf_repo);
269
270    let start = std::time::Instant::now();
271    let config_path = repo.get("config.json")?;
272    let tokenizer_path = repo.get("tokenizer.json")?;
273    let weights_path = repo.get("model.safetensors")?;
274    tracing::info!(
275        elapsed_ms = start.elapsed().as_millis(),
276        "model files ready"
277    );
278
279    let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
280    let tokenizer = Tokenizer::from_file(&tokenizer_path)
281        .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
282
283    Ok((config, tokenizer, weights_path))
284}
285
286// ---------------------------------------------------------------------------
287// Inference
288// ---------------------------------------------------------------------------
289
290/// Maximum texts per forward pass. BERT attention is O(batch × seq²) in memory;
291/// capping the batch avoids OOM on large reindex operations. 64 is conservative
292/// enough for CPU inference while still amortising per-batch overhead.
293const MAX_BATCH_SIZE: usize = 64;
294
295/// Embed texts through the BERT model, chunking into bounded forward passes.
296///
297/// Splits the input into chunks of at most [`MAX_BATCH_SIZE`] texts and runs
298/// each chunk through [`embed_chunk`], concatenating the results.
299fn embed_batch(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
300    let _span = tracing::debug_span!("embedding.embed_batch", batch_size = texts.len()).entered();
301
302    if texts.is_empty() {
303        return Ok(Vec::new());
304    }
305
306    let mut results = Vec::with_capacity(texts.len());
307    for chunk in texts.chunks(MAX_BATCH_SIZE) {
308        results.extend(embed_chunk(inner, chunk)?);
309    }
310    Ok(results)
311}
312
313/// Embed a single chunk of texts through the BERT model in one forward pass.
314///
315/// Texts are tokenised with padding (to the longest sequence in the chunk)
316/// and truncation (to 512 tokens), then passed through BERT together.
317/// An attention mask ensures padding tokens do not affect the output.
318/// CLS pooling extracts the first token's hidden state, which is then
319/// L2-normalised to produce unit vectors.
320fn embed_chunk(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
321    let _span = tracing::debug_span!("embedding.embed_chunk", chunk_size = texts.len()).entered();
322    debug_assert!(!texts.is_empty(), "embed_chunk called with empty texts");
323
324    let encodings = inner
325        .tokenizer
326        .encode_batch(texts.to_vec(), true)
327        .map_err(|e| MemoryError::Embedding(format!("tokenization failed: {e}")))?;
328
329    let batch_size = encodings.len();
330    let seq_len = encodings[0].get_ids().len();
331
332    // Verify padding produced uniform sequence lengths before allocating
333    // the flat token vectors. A mismatch here means the tokenizer's
334    // padding config was not applied (e.g. silently reset).
335    if let Some((i, enc)) = encodings
336        .iter()
337        .enumerate()
338        .find(|(_, e)| e.get_ids().len() != seq_len)
339    {
340        return Err(MemoryError::Embedding(format!(
341            "padding invariant violated: encoding[0] has {seq_len} tokens \
342             but encoding[{i}] has {} — check tokenizer padding config",
343            enc.get_ids().len(),
344        )));
345    }
346
347    let all_ids: Vec<u32> = encodings
348        .iter()
349        .flat_map(|e| e.get_ids().to_vec())
350        .collect();
351    let all_type_ids: Vec<u32> = encodings
352        .iter()
353        .flat_map(|e| e.get_type_ids().to_vec())
354        .collect();
355    let all_masks: Vec<u32> = encodings
356        .iter()
357        .flat_map(|e| e.get_attention_mask().to_vec())
358        .collect();
359
360    let input_ids = Tensor::new(all_ids.as_slice(), &inner.device)
361        .and_then(|t| t.reshape((batch_size, seq_len)))
362        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
363
364    let token_type_ids = Tensor::new(all_type_ids.as_slice(), &inner.device)
365        .and_then(|t| t.reshape((batch_size, seq_len)))
366        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
367
368    let attention_mask = Tensor::new(all_masks.as_slice(), &inner.device)
369        .and_then(|t| t.reshape((batch_size, seq_len)))
370        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
371
372    let embeddings = inner
373        .model
374        .forward(&input_ids, &token_type_ids, Some(&attention_mask))
375        .map_err(|e| MemoryError::Embedding(format!("BERT forward pass failed: {e}")))?;
376
377    // CLS pooling + L2 normalise each vector in the batch.
378    let mut results = Vec::with_capacity(batch_size);
379    for i in 0..batch_size {
380        let cls = embeddings
381            .get(i)
382            .and_then(|seq| seq.get(0))
383            .map_err(|e| MemoryError::Embedding(format!("CLS extraction failed: {e}")))?;
384
385        // L2 normalise with epsilon guard against division by zero
386        // (e.g. malformed model weights producing an all-zero CLS vector).
387        let norm = cls
388            .sqr()
389            .and_then(|s| s.sum_all())
390            .and_then(|s| s.sqrt())
391            .and_then(|n| n.maximum(1e-12))
392            .and_then(|n| cls.broadcast_div(&n))
393            .map_err(|e| MemoryError::Embedding(format!("L2 normalisation failed: {e}")))?;
394
395        let vector: Vec<f32> = norm
396            .to_vec1()
397            .map_err(|e| MemoryError::Embedding(format!("tensor to vec failed: {e}")))?;
398
399        results.push(vector);
400    }
401
402    Ok(results)
403}
404
405// ---------------------------------------------------------------------------
406// Tests
407// ---------------------------------------------------------------------------
408
409#[cfg(test)]
410mod tests {
411    use std::sync::{Arc, Barrier};
412    use std::time::Duration;
413
414    use super::*;
415
416    /// Build a fake engine whose worker applies `handler` to every request.
417    ///
418    /// `handler` receives the texts and the reply sender; it can sleep, panic,
419    /// drop the sender, or send any result — enabling controlled fault injection
420    /// without touching the model loading path.
421    fn fake_engine<F>(timeout: Duration, handler: F) -> CandleEmbeddingEngine
422    where
423        F: Fn(Vec<String>, oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>) + Send + 'static,
424    {
425        let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
426        std::thread::spawn(move || {
427            for (texts, reply_tx) in rx {
428                handler(texts, reply_tx);
429            }
430        });
431        CandleEmbeddingEngine::with_worker(tx, 4, timeout)
432    }
433
434    /// Worker that immediately returns a fixed-size zero vector per input.
435    fn ok_handler(
436        texts: Vec<String>,
437        reply_tx: oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
438    ) {
439        let vecs = texts.iter().map(|_| vec![0.0f32; 4]).collect();
440        let _ = reply_tx.send(Ok(vecs));
441    }
442
443    #[tokio::test]
444    async fn happy_path_returns_vectors() {
445        let engine = fake_engine(Duration::from_secs(5), ok_handler);
446        let result = engine
447            .embed(&["hello".to_string(), "world".to_string()])
448            .await;
449        let vecs = result.expect("embed should succeed");
450        assert_eq!(vecs.len(), 2);
451        assert_eq!(vecs[0].len(), 4);
452    }
453
454    #[tokio::test]
455    async fn timeout_returns_error_and_worker_recovers() {
456        // Barrier lets us prove the worker is still alive after the timeout
457        // fires on the first request.
458        let barrier = Arc::new(Barrier::new(2));
459        let barrier2 = Arc::clone(&barrier);
460
461        let engine = fake_engine(Duration::from_millis(50), move |texts, reply_tx| {
462            if texts[0] == "slow" {
463                // Block until the test signals us to proceed (after timeout fires).
464                barrier2.wait();
465                // Reply arrives after the caller's receiver was dropped — send
466                // fails silently, which is the self-healing path.
467                let _ = reply_tx.send(Ok(vec![vec![0.0; 4]]));
468                // Signal the test that we've finished processing the stale request.
469                barrier2.wait();
470            } else {
471                ok_handler(texts, reply_tx);
472            }
473        });
474
475        // First call times out.
476        let err = engine
477            .embed(&["slow".to_string()])
478            .await
479            .expect_err("slow embed should time out");
480        assert!(
481            err.to_string().contains("timed out"),
482            "expected timeout error, got: {err}"
483        );
484
485        // Unblock the worker and wait for it to finish the stale request.
486        barrier.wait();
487        barrier.wait();
488
489        // Second call should succeed — the worker recovered.
490        let result = engine.embed(&["fast".to_string()]).await;
491        assert!(
492            result.is_ok(),
493            "engine should recover after timeout: {result:?}"
494        );
495    }
496
497    #[tokio::test]
498    async fn disconnected_worker_returns_error() {
499        let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
500        // Drop the receiver immediately — worker is "dead".
501        drop(rx);
502        let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
503
504        let err = engine
505            .embed(&["anything".to_string()])
506            .await
507            .expect_err("disconnected worker should error");
508        assert!(
509            err.to_string().contains("exited"),
510            "expected 'exited' in error, got: {err}"
511        );
512    }
513
514    #[tokio::test]
515    async fn busy_worker_returns_error() {
516        // Channel capacity 0 is not allowed by SyncSender; use capacity 1 but
517        // send two requests without the worker consuming either.
518        // Easier: use a zero-sleep worker that we pre-fill the channel for.
519        let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
520
521        // Pre-fill the single channel slot by sending directly, bypassing embed().
522        let (filler_tx, _filler_rx) = oneshot::channel::<Result<Vec<Vec<f32>>, MemoryError>>();
523        tx.send((vec!["fill".to_string()], filler_tx)).unwrap();
524
525        // Now embed() hits try_send on a full channel.
526        let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
527        let err = engine
528            .embed(&["overflow".to_string()])
529            .await
530            .expect_err("full channel should error");
531        assert!(
532            err.to_string().contains("busy"),
533            "expected 'busy' in error, got: {err}"
534        );
535
536        drop(rx); // clean up
537    }
538}