Skip to main content

memory_mcp/embedding/
candle.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::path::PathBuf;
3use std::sync::{Arc, Mutex};
4
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config as BertConfig};
8use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
9use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
10
11use super::EmbeddingBackend;
12use crate::error::MemoryError;
13
14/// HuggingFace model ID. Only BGE-small-en-v1.5 is supported currently.
15pub const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
16
17/// Pure-Rust embedding engine using candle for BERT inference.
18///
19/// Uses candle-transformers' BERT implementation with tokenizers for
20/// tokenisation. No C/C++ FFI dependencies — compiles on all platforms.
21pub struct CandleEmbeddingEngine {
22    inner: Arc<Mutex<CandleInner>>,
23    dim: usize,
24}
25
26struct CandleInner {
27    model: BertModel,
28    tokenizer: Tokenizer,
29    device: Device,
30}
31
32impl CandleEmbeddingEngine {
33    /// Initialise the candle embedding engine.
34    ///
35    /// Downloads model weights from HuggingFace Hub on first use (cached
36    /// in the standard HF cache directory, respects `HF_HOME`).
37    pub fn new() -> Result<Self, MemoryError> {
38        let device = Device::Cpu;
39
40        let (config, mut tokenizer, weights_path) =
41            load_model_files().map_err(|e| MemoryError::Embedding(e.to_string()))?;
42
43        // Enable padding so encode_batch produces equal-length sequences.
44        tokenizer.with_padding(Some(PaddingParams {
45            strategy: tokenizers::PaddingStrategy::BatchLongest,
46            ..Default::default()
47        }));
48        tokenizer
49            .with_truncation(Some(TruncationParams {
50                max_length: 512,
51                ..Default::default()
52            }))
53            .map_err(|e| MemoryError::Embedding(format!("failed to set truncation: {e}")))?;
54
55        // SAFETY: `from_mmaped_safetensors` memory-maps the weights file. The
56        // caller must ensure the file is not modified for the lifetime of the
57        // resulting tensors. HuggingFace Hub writes cache files atomically and
58        // never modifies them in-place, so the mapping is stable.
59        let vb = unsafe {
60            VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
61                .map_err(|e| MemoryError::Embedding(format!("failed to load weights: {e}")))?
62        };
63
64        let model = BertModel::load(vb, &config)
65            .map_err(|e| MemoryError::Embedding(format!("failed to build BERT model: {e}")))?;
66
67        let dim = config.hidden_size;
68
69        Ok(Self {
70            inner: Arc::new(Mutex::new(CandleInner {
71                model,
72                tokenizer,
73                device,
74            })),
75            dim,
76        })
77    }
78}
79
80#[async_trait::async_trait]
81impl EmbeddingBackend for CandleEmbeddingEngine {
82    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
83        let arc = Arc::clone(&self.inner);
84        let texts = texts.to_vec();
85        tokio::task::spawn_blocking(move || {
86            let guard = arc.lock().unwrap_or_else(|poisoned| {
87                tracing::warn!("embedding mutex was poisoned — clearing poison and continuing");
88                poisoned.into_inner()
89            });
90            catch_unwind(AssertUnwindSafe(|| embed_batch(&guard, &texts))).unwrap_or_else(
91                |panic_payload| {
92                    let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
93                        (*s).to_string()
94                    } else if let Some(s) = panic_payload.downcast_ref::<String>() {
95                        s.clone()
96                    } else {
97                        "unknown panic in embedding engine".to_string()
98                    };
99                    Err(MemoryError::Embedding(format!(
100                        "embedding engine panicked: {msg}"
101                    )))
102                },
103            )
104        })
105        .await
106        .map_err(|e| MemoryError::Join(e.to_string()))?
107    }
108
109    fn dimensions(&self) -> usize {
110        self.dim
111    }
112}
113
114// ---------------------------------------------------------------------------
115// Model loading
116// ---------------------------------------------------------------------------
117
118/// Download (or retrieve from cache) the model files from HuggingFace Hub.
119///
120/// On first run (cold start), this downloads ~130 MB of model files from
121/// HuggingFace Hub. Subsequent starts use the local cache (`HF_HOME`).
122/// Use the `warmup` subcommand or a k8s init container to pre-populate the
123/// cache and avoid blocking the first server startup.
124fn load_model_files() -> anyhow::Result<(BertConfig, Tokenizer, PathBuf)> {
125    let cache = Cache::from_env();
126    let hf_repo = Repo::new(MODEL_ID.to_string(), RepoType::Model);
127
128    // Check whether the heaviest file (model weights) is already cached.
129    let cached = cache.repo(hf_repo.clone()).get("model.safetensors");
130    if cached.is_none() {
131        tracing::warn!(
132            model = MODEL_ID,
133            "embedding model not found in cache — downloading from HuggingFace Hub \
134             (this may take a minute on first run; use `memory-mcp warmup` to pre-populate)"
135        );
136    } else {
137        tracing::info!(model = MODEL_ID, "loading embedding model from cache");
138    }
139
140    // Respect HF_HOME and HF_ENDPOINT env vars; disable indicatif progress
141    // bars since we are a headless server.
142    let api = ApiBuilder::from_env().with_progress(false).build()?;
143    let repo = api.repo(hf_repo);
144
145    let start = std::time::Instant::now();
146    let config_path = repo.get("config.json")?;
147    let tokenizer_path = repo.get("tokenizer.json")?;
148    let weights_path = repo.get("model.safetensors")?;
149    tracing::info!(
150        elapsed_ms = start.elapsed().as_millis(),
151        "model files ready"
152    );
153
154    let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
155    let tokenizer = Tokenizer::from_file(&tokenizer_path)
156        .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
157
158    Ok((config, tokenizer, weights_path))
159}
160
161// ---------------------------------------------------------------------------
162// Inference
163// ---------------------------------------------------------------------------
164
165/// Maximum texts per forward pass. BERT attention is O(batch × seq²) in memory;
166/// capping the batch avoids OOM on large reindex operations. 64 is conservative
167/// enough for CPU inference while still amortising per-batch overhead.
168const MAX_BATCH_SIZE: usize = 64;
169
170/// Embed texts through the BERT model, chunking into bounded forward passes.
171///
172/// Splits the input into chunks of at most [`MAX_BATCH_SIZE`] texts and runs
173/// each chunk through [`embed_chunk`], concatenating the results.
174fn embed_batch(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
175    if texts.is_empty() {
176        return Ok(Vec::new());
177    }
178
179    let mut results = Vec::with_capacity(texts.len());
180    for chunk in texts.chunks(MAX_BATCH_SIZE) {
181        results.extend(embed_chunk(inner, chunk)?);
182    }
183    Ok(results)
184}
185
186/// Embed a single chunk of texts through the BERT model in one forward pass.
187///
188/// Texts are tokenised with padding (to the longest sequence in the chunk)
189/// and truncation (to 512 tokens), then passed through BERT together.
190/// An attention mask ensures padding tokens do not affect the output.
191/// CLS pooling extracts the first token's hidden state, which is then
192/// L2-normalised to produce unit vectors.
193fn embed_chunk(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
194    debug_assert!(!texts.is_empty(), "embed_chunk called with empty texts");
195
196    let encodings = inner
197        .tokenizer
198        .encode_batch(texts.to_vec(), true)
199        .map_err(|e| MemoryError::Embedding(format!("tokenization failed: {e}")))?;
200
201    let batch_size = encodings.len();
202    let seq_len = encodings[0].get_ids().len();
203
204    // Verify padding produced uniform sequence lengths before allocating
205    // the flat token vectors. A mismatch here means the tokenizer's
206    // padding config was not applied (e.g. silently reset).
207    if let Some((i, enc)) = encodings
208        .iter()
209        .enumerate()
210        .find(|(_, e)| e.get_ids().len() != seq_len)
211    {
212        return Err(MemoryError::Embedding(format!(
213            "padding invariant violated: encoding[0] has {seq_len} tokens \
214             but encoding[{i}] has {} — check tokenizer padding config",
215            enc.get_ids().len(),
216        )));
217    }
218
219    let all_ids: Vec<u32> = encodings
220        .iter()
221        .flat_map(|e| e.get_ids().to_vec())
222        .collect();
223    let all_type_ids: Vec<u32> = encodings
224        .iter()
225        .flat_map(|e| e.get_type_ids().to_vec())
226        .collect();
227    let all_masks: Vec<u32> = encodings
228        .iter()
229        .flat_map(|e| e.get_attention_mask().to_vec())
230        .collect();
231
232    let input_ids = Tensor::new(all_ids.as_slice(), &inner.device)
233        .and_then(|t| t.reshape((batch_size, seq_len)))
234        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
235
236    let token_type_ids = Tensor::new(all_type_ids.as_slice(), &inner.device)
237        .and_then(|t| t.reshape((batch_size, seq_len)))
238        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
239
240    let attention_mask = Tensor::new(all_masks.as_slice(), &inner.device)
241        .and_then(|t| t.reshape((batch_size, seq_len)))
242        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
243
244    let embeddings = inner
245        .model
246        .forward(&input_ids, &token_type_ids, Some(&attention_mask))
247        .map_err(|e| MemoryError::Embedding(format!("BERT forward pass failed: {e}")))?;
248
249    // CLS pooling + L2 normalise each vector in the batch.
250    let mut results = Vec::with_capacity(batch_size);
251    for i in 0..batch_size {
252        let cls = embeddings
253            .get(i)
254            .and_then(|seq| seq.get(0))
255            .map_err(|e| MemoryError::Embedding(format!("CLS extraction failed: {e}")))?;
256
257        // L2 normalise with epsilon guard against division by zero
258        // (e.g. malformed model weights producing an all-zero CLS vector).
259        let norm = cls
260            .sqr()
261            .and_then(|s| s.sum_all())
262            .and_then(|s| s.sqrt())
263            .and_then(|n| n.maximum(1e-12))
264            .and_then(|n| cls.broadcast_div(&n))
265            .map_err(|e| MemoryError::Embedding(format!("L2 normalisation failed: {e}")))?;
266
267        let vector: Vec<f32> = norm
268            .to_vec1()
269            .map_err(|e| MemoryError::Embedding(format!("tensor to vec failed: {e}")))?;
270
271        results.push(vector);
272    }
273
274    Ok(results)
275}