Skip to main content

memory_mcp/embedding/
candle.rs

1use std::panic::{catch_unwind, AssertUnwindSafe};
2use std::path::PathBuf;
3use std::sync::{Arc, Mutex};
4
5use candle_core::{Device, Tensor};
6use candle_nn::VarBuilder;
7use candle_transformers::models::bert::{BertModel, Config as BertConfig};
8use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
9use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
10
11use super::EmbeddingBackend;
12use crate::error::MemoryError;
13
14/// HuggingFace model ID. Only BGE-small-en-v1.5 is supported currently.
15pub const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
16
17/// Pure-Rust embedding engine using candle for BERT inference.
18///
19/// Uses candle-transformers' BERT implementation with tokenizers for
20/// tokenisation. No C/C++ FFI dependencies — compiles on all platforms.
21pub struct CandleEmbeddingEngine {
22    inner: Arc<Mutex<CandleInner>>,
23    dim: usize,
24}
25
26struct CandleInner {
27    model: BertModel,
28    tokenizer: Tokenizer,
29    device: Device,
30}
31
32impl CandleEmbeddingEngine {
33    /// Initialise the candle embedding engine.
34    ///
35    /// Downloads model weights from HuggingFace Hub on first use (cached
36    /// in the standard HF cache directory, respects `HF_HOME`).
37    pub fn new() -> Result<Self, MemoryError> {
38        let device = Device::Cpu;
39
40        let (config, mut tokenizer, weights_path) =
41            load_model_files().map_err(|e| MemoryError::Embedding(e.to_string()))?;
42
43        // Enable padding so encode_batch produces equal-length sequences.
44        tokenizer.with_padding(Some(PaddingParams {
45            strategy: tokenizers::PaddingStrategy::BatchLongest,
46            ..Default::default()
47        }));
48        tokenizer
49            .with_truncation(Some(TruncationParams {
50                max_length: 512,
51                ..Default::default()
52            }))
53            .map_err(|e| MemoryError::Embedding(format!("failed to set truncation: {e}")))?;
54
55        // SAFETY: `from_mmaped_safetensors` memory-maps the weights file. The
56        // caller must ensure the file is not modified for the lifetime of the
57        // resulting tensors. HuggingFace Hub writes cache files atomically and
58        // never modifies them in-place, so the mapping is stable.
59        let vb = unsafe {
60            VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
61                .map_err(|e| MemoryError::Embedding(format!("failed to load weights: {e}")))?
62        };
63
64        let model = BertModel::load(vb, &config)
65            .map_err(|e| MemoryError::Embedding(format!("failed to build BERT model: {e}")))?;
66
67        let dim = config.hidden_size;
68
69        Ok(Self {
70            inner: Arc::new(Mutex::new(CandleInner {
71                model,
72                tokenizer,
73                device,
74            })),
75            dim,
76        })
77    }
78}
79
80#[async_trait::async_trait]
81impl EmbeddingBackend for CandleEmbeddingEngine {
82    async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
83        let arc = Arc::clone(&self.inner);
84        let texts = texts.to_vec();
85        let batch_size = texts.len();
86        let dim = self.dim;
87
88        // Capture current span before entering spawn_blocking so the
89        // child thread can enter it and keep the span hierarchy intact.
90        let span = tracing::debug_span!(
91            "embedding.embed",
92            batch_size,
93            dimensions = dim,
94            model = MODEL_ID,
95        );
96
97        tokio::task::spawn_blocking(move || {
98            let _enter = span.enter();
99            let guard = arc.lock().unwrap_or_else(|poisoned| {
100                tracing::warn!("embedding mutex was poisoned — clearing poison and continuing");
101                poisoned.into_inner()
102            });
103            catch_unwind(AssertUnwindSafe(|| embed_batch(&guard, &texts))).unwrap_or_else(
104                |panic_payload| {
105                    let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
106                        (*s).to_string()
107                    } else if let Some(s) = panic_payload.downcast_ref::<String>() {
108                        s.clone()
109                    } else {
110                        "unknown panic in embedding engine".to_string()
111                    };
112                    Err(MemoryError::Embedding(format!(
113                        "embedding engine panicked: {msg}"
114                    )))
115                },
116            )
117        })
118        .await
119        .map_err(|e| MemoryError::Join(e.to_string()))?
120    }
121
122    fn dimensions(&self) -> usize {
123        self.dim
124    }
125}
126
127// ---------------------------------------------------------------------------
128// Model loading
129// ---------------------------------------------------------------------------
130
131/// Download (or retrieve from cache) the model files from HuggingFace Hub.
132///
133/// On first run (cold start), this downloads ~130 MB of model files from
134/// HuggingFace Hub. Subsequent starts use the local cache (`HF_HOME`).
135/// Use the `warmup` subcommand or a k8s init container to pre-populate the
136/// cache and avoid blocking the first server startup.
137fn load_model_files() -> anyhow::Result<(BertConfig, Tokenizer, PathBuf)> {
138    let _span = tracing::info_span!("embedding.load_model", model = MODEL_ID).entered();
139
140    let cache = Cache::from_env();
141    let hf_repo = Repo::new(MODEL_ID.to_string(), RepoType::Model);
142
143    // Check whether the heaviest file (model weights) is already cached.
144    let cached = cache.repo(hf_repo.clone()).get("model.safetensors");
145    if cached.is_none() {
146        tracing::warn!(
147            model = MODEL_ID,
148            "embedding model not found in cache — downloading from HuggingFace Hub \
149             (this may take a minute on first run; use `memory-mcp warmup` to pre-populate)"
150        );
151    } else {
152        tracing::info!(model = MODEL_ID, "loading embedding model from cache");
153    }
154
155    // Respect HF_HOME and HF_ENDPOINT env vars; disable indicatif progress
156    // bars since we are a headless server.
157    let api = ApiBuilder::from_env().with_progress(false).build()?;
158    let repo = api.repo(hf_repo);
159
160    let start = std::time::Instant::now();
161    let config_path = repo.get("config.json")?;
162    let tokenizer_path = repo.get("tokenizer.json")?;
163    let weights_path = repo.get("model.safetensors")?;
164    tracing::info!(
165        elapsed_ms = start.elapsed().as_millis(),
166        "model files ready"
167    );
168
169    let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
170    let tokenizer = Tokenizer::from_file(&tokenizer_path)
171        .map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
172
173    Ok((config, tokenizer, weights_path))
174}
175
176// ---------------------------------------------------------------------------
177// Inference
178// ---------------------------------------------------------------------------
179
180/// Maximum texts per forward pass. BERT attention is O(batch × seq²) in memory;
181/// capping the batch avoids OOM on large reindex operations. 64 is conservative
182/// enough for CPU inference while still amortising per-batch overhead.
183const MAX_BATCH_SIZE: usize = 64;
184
185/// Embed texts through the BERT model, chunking into bounded forward passes.
186///
187/// Splits the input into chunks of at most [`MAX_BATCH_SIZE`] texts and runs
188/// each chunk through [`embed_chunk`], concatenating the results.
189fn embed_batch(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
190    let _span = tracing::debug_span!("embedding.embed_batch", batch_size = texts.len()).entered();
191
192    if texts.is_empty() {
193        return Ok(Vec::new());
194    }
195
196    let mut results = Vec::with_capacity(texts.len());
197    for chunk in texts.chunks(MAX_BATCH_SIZE) {
198        results.extend(embed_chunk(inner, chunk)?);
199    }
200    Ok(results)
201}
202
203/// Embed a single chunk of texts through the BERT model in one forward pass.
204///
205/// Texts are tokenised with padding (to the longest sequence in the chunk)
206/// and truncation (to 512 tokens), then passed through BERT together.
207/// An attention mask ensures padding tokens do not affect the output.
208/// CLS pooling extracts the first token's hidden state, which is then
209/// L2-normalised to produce unit vectors.
210fn embed_chunk(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
211    let _span = tracing::debug_span!("embedding.embed_chunk", chunk_size = texts.len()).entered();
212    debug_assert!(!texts.is_empty(), "embed_chunk called with empty texts");
213
214    let encodings = inner
215        .tokenizer
216        .encode_batch(texts.to_vec(), true)
217        .map_err(|e| MemoryError::Embedding(format!("tokenization failed: {e}")))?;
218
219    let batch_size = encodings.len();
220    let seq_len = encodings[0].get_ids().len();
221
222    // Verify padding produced uniform sequence lengths before allocating
223    // the flat token vectors. A mismatch here means the tokenizer's
224    // padding config was not applied (e.g. silently reset).
225    if let Some((i, enc)) = encodings
226        .iter()
227        .enumerate()
228        .find(|(_, e)| e.get_ids().len() != seq_len)
229    {
230        return Err(MemoryError::Embedding(format!(
231            "padding invariant violated: encoding[0] has {seq_len} tokens \
232             but encoding[{i}] has {} — check tokenizer padding config",
233            enc.get_ids().len(),
234        )));
235    }
236
237    let all_ids: Vec<u32> = encodings
238        .iter()
239        .flat_map(|e| e.get_ids().to_vec())
240        .collect();
241    let all_type_ids: Vec<u32> = encodings
242        .iter()
243        .flat_map(|e| e.get_type_ids().to_vec())
244        .collect();
245    let all_masks: Vec<u32> = encodings
246        .iter()
247        .flat_map(|e| e.get_attention_mask().to_vec())
248        .collect();
249
250    let input_ids = Tensor::new(all_ids.as_slice(), &inner.device)
251        .and_then(|t| t.reshape((batch_size, seq_len)))
252        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
253
254    let token_type_ids = Tensor::new(all_type_ids.as_slice(), &inner.device)
255        .and_then(|t| t.reshape((batch_size, seq_len)))
256        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
257
258    let attention_mask = Tensor::new(all_masks.as_slice(), &inner.device)
259        .and_then(|t| t.reshape((batch_size, seq_len)))
260        .map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
261
262    let embeddings = inner
263        .model
264        .forward(&input_ids, &token_type_ids, Some(&attention_mask))
265        .map_err(|e| MemoryError::Embedding(format!("BERT forward pass failed: {e}")))?;
266
267    // CLS pooling + L2 normalise each vector in the batch.
268    let mut results = Vec::with_capacity(batch_size);
269    for i in 0..batch_size {
270        let cls = embeddings
271            .get(i)
272            .and_then(|seq| seq.get(0))
273            .map_err(|e| MemoryError::Embedding(format!("CLS extraction failed: {e}")))?;
274
275        // L2 normalise with epsilon guard against division by zero
276        // (e.g. malformed model weights producing an all-zero CLS vector).
277        let norm = cls
278            .sqr()
279            .and_then(|s| s.sum_all())
280            .and_then(|s| s.sqrt())
281            .and_then(|n| n.maximum(1e-12))
282            .and_then(|n| cls.broadcast_div(&n))
283            .map_err(|e| MemoryError::Embedding(format!("L2 normalisation failed: {e}")))?;
284
285        let vector: Vec<f32> = norm
286            .to_vec1()
287            .map_err(|e| MemoryError::Embedding(format!("tensor to vec failed: {e}")))?;
288
289        results.push(vector);
290    }
291
292    Ok(results)
293}