use std::panic::{catch_unwind, AssertUnwindSafe};
use std::path::PathBuf;
use std::sync::mpsc;
use candle_core::{Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::bert::{BertModel, Config as BertConfig};
use hf_hub::{api::sync::ApiBuilder, Cache, Repo, RepoType};
use tokenizers::{PaddingParams, Tokenizer, TruncationParams};
use tokio::sync::oneshot;
use tokio::time::{timeout, Duration};
use super::EmbeddingBackend;
use crate::error::MemoryError;
use crate::health::SubsystemReporter;
pub const MODEL_ID: &str = "BAAI/bge-small-en-v1.5";
type EmbedRequest = (
Vec<String>,
oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
);
pub struct CandleEmbeddingEngine {
tx: Option<mpsc::SyncSender<EmbedRequest>>,
worker: Option<std::thread::JoinHandle<()>>,
dim: usize,
embed_timeout: Duration,
reporter: SubsystemReporter,
}
struct CandleInner {
model: BertModel,
tokenizer: Tokenizer,
device: Device,
}
impl CandleEmbeddingEngine {
pub fn new(
embed_timeout: Duration,
queue_size: usize,
reporter: SubsystemReporter,
) -> Result<Self, MemoryError> {
let device = Device::Cpu;
let (config, mut tokenizer, weights_path) =
load_model_files().map_err(|e| MemoryError::Embedding(e.to_string()))?;
tokenizer.with_padding(Some(PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
}));
tokenizer
.with_truncation(Some(TruncationParams {
max_length: 512,
..Default::default()
}))
.map_err(|e| MemoryError::Embedding(format!("failed to set truncation: {e}")))?;
let vb = unsafe {
VarBuilder::from_mmaped_safetensors(&[weights_path], candle_core::DType::F32, &device)
.map_err(|e| MemoryError::Embedding(format!("failed to load weights: {e}")))?
};
let model = BertModel::load(vb, &config)
.map_err(|e| MemoryError::Embedding(format!("failed to build BERT model: {e}")))?;
let dim = config.hidden_size;
let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(queue_size);
let worker = std::thread::Builder::new()
.name("embed-worker".into())
.spawn(move || {
let inner = CandleInner {
model,
tokenizer,
device,
};
worker_loop(inner, dim, rx);
})
.map_err(|e| MemoryError::Embedding(format!("failed to spawn embed worker: {e}")))?;
Ok(Self {
tx: Some(tx),
worker: Some(worker),
dim,
embed_timeout,
reporter,
})
}
#[cfg(test)]
fn with_worker(
tx: mpsc::SyncSender<EmbedRequest>,
dim: usize,
embed_timeout: Duration,
) -> Self {
Self {
tx: Some(tx),
worker: None,
dim,
embed_timeout,
reporter: SubsystemReporter::new(),
}
}
}
impl Drop for CandleEmbeddingEngine {
fn drop(&mut self) {
drop(self.tx.take());
if let Some(handle) = self.worker.take() {
let _ = handle.join();
}
}
}
fn worker_loop(mut inner: CandleInner, dim: usize, rx: mpsc::Receiver<EmbedRequest>) {
for (texts, reply_tx) in rx {
let span = tracing::debug_span!(
"embedding.embed",
batch_size = texts.len(),
dimensions = dim,
model = MODEL_ID,
);
let _enter = span.enter();
let mut panicked = false;
let result = catch_unwind(AssertUnwindSafe(|| embed_batch(&inner, &texts))).unwrap_or_else(
|panic_payload| {
panicked = true;
let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
(*s).to_string()
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
s.clone()
} else {
"unknown panic in embedding engine".to_string()
};
tracing::warn!(error = %msg, "embedding engine panicked — recovering");
Err(MemoryError::Embedding(format!(
"embedding engine panicked: {msg}"
)))
},
);
let _ = reply_tx.send(result);
if panicked {
inner.tokenizer.with_padding(Some(PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
..Default::default()
}));
let _ = inner.tokenizer.with_truncation(Some(TruncationParams {
max_length: 512,
..Default::default()
}));
}
}
}
#[async_trait::async_trait]
impl EmbeddingBackend for CandleEmbeddingEngine {
async fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
let (reply_tx, reply_rx) = oneshot::channel();
let tx = self
.tx
.as_ref()
.ok_or_else(|| MemoryError::Embedding("embedding engine has been shut down".into()))?;
tx.try_send((texts.to_vec(), reply_tx))
.map_err(|e| match e {
mpsc::TrySendError::Full(_) => {
MemoryError::Embedding("embedding worker is busy — try again".into())
}
mpsc::TrySendError::Disconnected(_) => {
MemoryError::Embedding("embedding worker has exited — restart required".into())
}
})?;
let result = match timeout(self.embed_timeout, reply_rx).await {
Ok(Ok(result)) => result,
Ok(Err(_)) => Err(MemoryError::Embedding(
"embedding worker dropped the reply channel unexpectedly".into(),
)),
Err(_elapsed) => Err(MemoryError::Embedding(format!(
"embedding timed out after {:.1}s — the worker will recover automatically",
self.embed_timeout.as_secs_f64(),
))),
};
match &result {
Ok(_) => self.reporter.report_ok(),
Err(_) => self.reporter.report_err("embed failed"),
}
result
}
fn dimensions(&self) -> usize {
self.dim
}
}
fn load_model_files() -> anyhow::Result<(BertConfig, Tokenizer, PathBuf)> {
let _span = tracing::info_span!("embedding.load_model", model = MODEL_ID).entered();
let cache = Cache::from_env();
let hf_repo = Repo::new(MODEL_ID.to_string(), RepoType::Model);
let cached = cache.repo(hf_repo.clone()).get("model.safetensors");
if cached.is_none() {
tracing::warn!(
model = MODEL_ID,
"embedding model not found in cache — downloading from HuggingFace Hub \
(this may take a minute on first run; use `memory-mcp warmup` to pre-populate)"
);
} else {
tracing::info!(model = MODEL_ID, "loading embedding model from cache");
}
let api = ApiBuilder::from_env().with_progress(false).build()?;
let repo = api.repo(hf_repo);
let start = std::time::Instant::now();
let config_path = repo.get("config.json")?;
let tokenizer_path = repo.get("tokenizer.json")?;
let weights_path = repo.get("model.safetensors")?;
tracing::info!(
elapsed_ms = start.elapsed().as_millis(),
"model files ready"
);
let config: BertConfig = serde_json::from_str(&std::fs::read_to_string(&config_path)?)?;
let tokenizer = Tokenizer::from_file(&tokenizer_path)
.map_err(|e| anyhow::anyhow!("failed to load tokenizer: {e}"))?;
Ok((config, tokenizer, weights_path))
}
const MAX_BATCH_SIZE: usize = 64;
fn embed_batch(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
let _span = tracing::debug_span!("embedding.embed_batch", batch_size = texts.len()).entered();
if texts.is_empty() {
return Ok(Vec::new());
}
let mut results = Vec::with_capacity(texts.len());
for chunk in texts.chunks(MAX_BATCH_SIZE) {
results.extend(embed_chunk(inner, chunk)?);
}
Ok(results)
}
fn embed_chunk(inner: &CandleInner, texts: &[String]) -> Result<Vec<Vec<f32>>, MemoryError> {
let _span = tracing::debug_span!("embedding.embed_chunk", chunk_size = texts.len()).entered();
debug_assert!(!texts.is_empty(), "embed_chunk called with empty texts");
let encodings = inner
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|e| MemoryError::Embedding(format!("tokenization failed: {e}")))?;
let batch_size = encodings.len();
let seq_len = encodings[0].get_ids().len();
if let Some((i, enc)) = encodings
.iter()
.enumerate()
.find(|(_, e)| e.get_ids().len() != seq_len)
{
return Err(MemoryError::Embedding(format!(
"padding invariant violated: encoding[0] has {seq_len} tokens \
but encoding[{i}] has {} — check tokenizer padding config",
enc.get_ids().len(),
)));
}
let all_ids: Vec<u32> = encodings
.iter()
.flat_map(|e| e.get_ids().to_vec())
.collect();
let all_type_ids: Vec<u32> = encodings
.iter()
.flat_map(|e| e.get_type_ids().to_vec())
.collect();
let all_masks: Vec<u32> = encodings
.iter()
.flat_map(|e| e.get_attention_mask().to_vec())
.collect();
let input_ids = Tensor::new(all_ids.as_slice(), &inner.device)
.and_then(|t| t.reshape((batch_size, seq_len)))
.map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
let token_type_ids = Tensor::new(all_type_ids.as_slice(), &inner.device)
.and_then(|t| t.reshape((batch_size, seq_len)))
.map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
let attention_mask = Tensor::new(all_masks.as_slice(), &inner.device)
.and_then(|t| t.reshape((batch_size, seq_len)))
.map_err(|e| MemoryError::Embedding(format!("tensor creation failed: {e}")))?;
let embeddings = inner
.model
.forward(&input_ids, &token_type_ids, Some(&attention_mask))
.map_err(|e| MemoryError::Embedding(format!("BERT forward pass failed: {e}")))?;
let mut results = Vec::with_capacity(batch_size);
for i in 0..batch_size {
let cls = embeddings
.get(i)
.and_then(|seq| seq.get(0))
.map_err(|e| MemoryError::Embedding(format!("CLS extraction failed: {e}")))?;
let norm = cls
.sqr()
.and_then(|s| s.sum_all())
.and_then(|s| s.sqrt())
.and_then(|n| n.maximum(1e-12))
.and_then(|n| cls.broadcast_div(&n))
.map_err(|e| MemoryError::Embedding(format!("L2 normalisation failed: {e}")))?;
let vector: Vec<f32> = norm
.to_vec1()
.map_err(|e| MemoryError::Embedding(format!("tensor to vec failed: {e}")))?;
results.push(vector);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use std::sync::{Arc, Barrier};
use std::time::Duration;
use super::*;
fn fake_engine<F>(timeout: Duration, handler: F) -> CandleEmbeddingEngine
where
F: Fn(Vec<String>, oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>) + Send + 'static,
{
let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
std::thread::spawn(move || {
for (texts, reply_tx) in rx {
handler(texts, reply_tx);
}
});
CandleEmbeddingEngine::with_worker(tx, 4, timeout)
}
fn ok_handler(
texts: Vec<String>,
reply_tx: oneshot::Sender<Result<Vec<Vec<f32>>, MemoryError>>,
) {
let vecs = texts.iter().map(|_| vec![0.0f32; 4]).collect();
let _ = reply_tx.send(Ok(vecs));
}
#[tokio::test]
async fn happy_path_returns_vectors() {
let engine = fake_engine(Duration::from_secs(5), ok_handler);
let result = engine
.embed(&["hello".to_string(), "world".to_string()])
.await;
let vecs = result.expect("embed should succeed");
assert_eq!(vecs.len(), 2);
assert_eq!(vecs[0].len(), 4);
}
#[tokio::test]
async fn timeout_returns_error_and_worker_recovers() {
let barrier = Arc::new(Barrier::new(2));
let barrier2 = Arc::clone(&barrier);
let engine = fake_engine(Duration::from_millis(50), move |texts, reply_tx| {
if texts[0] == "slow" {
barrier2.wait();
let _ = reply_tx.send(Ok(vec![vec![0.0; 4]]));
barrier2.wait();
} else {
ok_handler(texts, reply_tx);
}
});
let err = engine
.embed(&["slow".to_string()])
.await
.expect_err("slow embed should time out");
assert!(
err.to_string().contains("timed out"),
"expected timeout error, got: {err}"
);
barrier.wait();
barrier.wait();
let result = engine.embed(&["fast".to_string()]).await;
assert!(
result.is_ok(),
"engine should recover after timeout: {result:?}"
);
}
#[tokio::test]
async fn disconnected_worker_returns_error() {
let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
drop(rx);
let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
let err = engine
.embed(&["anything".to_string()])
.await
.expect_err("disconnected worker should error");
assert!(
err.to_string().contains("exited"),
"expected 'exited' in error, got: {err}"
);
}
#[tokio::test]
async fn busy_worker_returns_error() {
let (tx, rx) = mpsc::sync_channel::<EmbedRequest>(1);
let (filler_tx, _filler_rx) = oneshot::channel::<Result<Vec<Vec<f32>>, MemoryError>>();
tx.send((vec!["fill".to_string()], filler_tx)).unwrap();
let engine = CandleEmbeddingEngine::with_worker(tx, 4, Duration::from_secs(5));
let err = engine
.embed(&["overflow".to_string()])
.await
.expect_err("full channel should error");
assert!(
err.to_string().contains("busy"),
"expected 'busy' in error, got: {err}"
);
drop(rx); }
}