use std::sync::Arc;
use std::sync::OnceLock;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::{Duration, Instant};
use anyhow::{Context, Result, anyhow};
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use candle_transformers::models::xlm_roberta::{Config, XLMRobertaModel};
use tokenizers::Tokenizer;
use tokio::sync::Mutex;
use tokio_stream::StreamExt;
use crate::sessions::{EmbeddedMessage, PendingMessage, Store, embedding_dim};
pub(crate) const MAX_TOKENS: usize = 512;
pub struct CandleEmbedder {
model: XLMRobertaModel,
tokenizer: Tokenizer,
device: Device,
}
impl CandleEmbedder {
pub fn load() -> Result<Self> {
let device = select_device();
let id = model_id();
let api = hf_hub::api::sync::Api::new().context("init HuggingFace hub client")?;
let repo = api.model(id.to_owned());
let fetch = |file: &str| {
repo.get(file)
.with_context(|| format!("fetch {file} for {id}"))
};
let config: Config =
serde_json::from_str(&std::fs::read_to_string(fetch("config.json")?)?)?;
if config.hidden_size != embedding_dim() {
return Err(anyhow!(
"[embeddings].dim = {} but model {id:?} reports hidden_size = {}; \
set [embeddings].dim to match the model's output width.",
embedding_dim(),
config.hidden_size,
));
}
let model_path = fetch("model.safetensors")?;
#[allow(unsafe_code)]
let vb =
unsafe { VarBuilder::from_mmaped_safetensors(&[model_path], DType::F16, &device)? };
let model = XLMRobertaModel::new(&config, vb)
.map_err(|error| anyhow!("load {id} weights: {error}"))?;
let mut tokenizer = Tokenizer::from_file(fetch("tokenizer.json")?)
.map_err(|error| anyhow!("load e5 tokenizer: {error}"))?;
tokenizer.with_padding(Some(tokenizers::PaddingParams {
strategy: tokenizers::PaddingStrategy::BatchLongest,
pad_id: config.pad_token_id,
..Default::default()
}));
tokenizer
.with_truncation(Some(tokenizers::TruncationParams {
max_length: MAX_TOKENS,
..Default::default()
}))
.map_err(|error| anyhow!("configure e5 tokenizer: {error}"))?;
tracing::info!(model = %id, device = device_label(&device), "loaded embedding model");
Ok(Self {
model,
tokenizer,
device,
})
}
}
impl Embedder for CandleEmbedder {
fn device(&self) -> &str {
device_label(&self.device)
}
fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let encodings = self
.tokenizer
.encode_batch(texts.to_vec(), true)
.map_err(|error| anyhow!("tokenize embedding batch: {error}"))?;
let mut ids = Vec::with_capacity(encodings.len());
let mut masks = Vec::with_capacity(encodings.len());
for encoding in &encodings {
ids.push(Tensor::new(encoding.get_ids(), &self.device)?);
masks.push(Tensor::new(encoding.get_attention_mask(), &self.device)?);
}
let input_ids = Tensor::stack(&ids, 0)?;
let attention_mask = Tensor::stack(&masks, 0)?;
let token_type_ids = input_ids.zeros_like()?;
let hidden = self
.model
.forward(
&input_ids,
&attention_mask,
&token_type_ids,
None,
None,
None,
)?
.to_dtype(DType::F32)?;
let mask = attention_mask.to_dtype(DType::F32)?.unsqueeze(2)?;
let summed = hidden.broadcast_mul(&mask)?.sum(1)?;
let counts = mask.sum(1)?;
let mean = summed.broadcast_div(&counts)?;
let norm = mean.sqr()?.sum_keepdim(1)?.sqrt()?;
mean.broadcast_div(&norm)?
.to_vec2::<f32>()
.map_err(|error| anyhow!("read embedding vectors: {error}"))
}
}
fn select_device() -> Device {
#[cfg(target_os = "macos")]
let device = Device::metal_if_available(0);
#[cfg(not(target_os = "macos"))]
let device = Device::cuda_if_available(0);
device.unwrap_or_else(|error| {
tracing::warn!(%error, "GPU device unavailable, falling back to CPU");
Device::Cpu
})
}
fn device_label(device: &Device) -> &'static str {
match device {
Device::Cpu => "cpu",
Device::Cuda(_) => "cuda",
Device::Metal(_) => "metal",
}
}
type EmbedLoader = Arc<dyn Fn() -> Result<Arc<dyn Embedder>> + Send + Sync>;
pub const DEFAULT_IDLE_EVICTION: Duration = Duration::from_secs(300);
struct CachedBackend {
backend: Arc<dyn Embedder>,
last_use: Instant,
}
pub struct LazyEmbedder {
loader: EmbedLoader,
state: Mutex<Option<CachedBackend>>,
idle_threshold: Duration,
}
impl LazyEmbedder {
pub fn candle() -> Self {
Self::with_loader(Arc::new(|| {
Ok(Arc::new(CandleEmbedder::load()?) as Arc<dyn Embedder>)
}))
}
pub fn with_loader(loader: EmbedLoader) -> Self {
Self {
loader,
state: Mutex::new(None),
idle_threshold: DEFAULT_IDLE_EVICTION,
}
}
#[must_use]
pub fn with_idle_threshold(mut self, threshold: Duration) -> Self {
self.idle_threshold = threshold;
self
}
pub fn from_loaded(backend: Arc<dyn Embedder>) -> Self {
let preloaded = Arc::clone(&backend);
let loader: EmbedLoader = Arc::new(move || Ok(Arc::clone(&preloaded)));
Self {
loader,
state: Mutex::new(Some(CachedBackend {
backend,
last_use: Instant::now(),
})),
idle_threshold: Duration::MAX,
}
}
pub async fn get(&self) -> Result<Arc<dyn Embedder>> {
let mut state = self.state.lock().await;
let now = Instant::now();
if let Some(cached) = &*state
&& now.duration_since(cached.last_use) > self.idle_threshold
{
tracing::info!(
idle_secs = self.idle_threshold.as_secs(),
"evicting idle embedder",
);
*state = None;
}
if let Some(cached) = state.as_mut() {
cached.last_use = now;
return Ok(Arc::clone(&cached.backend));
}
let loader = Arc::clone(&self.loader);
let backend = tokio::task::spawn_blocking(move || loader())
.await
.map_err(|join_error| anyhow!("embedder load panicked: {join_error}"))??;
*state = Some(CachedBackend {
backend: Arc::clone(&backend),
last_use: now,
});
Ok(backend)
}
}
pub const DEFAULT_MODEL_ID: &str = "intfloat/multilingual-e5-small";
static MODEL_ID_RUNTIME: OnceLock<String> = OnceLock::new();
pub fn model_id() -> &'static str {
MODEL_ID_RUNTIME
.get()
.map(String::as_str)
.unwrap_or(DEFAULT_MODEL_ID)
}
pub fn init_model_id(id: String) {
MODEL_ID_RUNTIME.get_or_init(|| id);
}
pub const DEFAULT_BATCH_SIZE: usize = 32;
pub const DEFAULT_SORT_WINDOW: usize = 2048;
pub fn format_query(query: &str) -> String {
format!("query: {query}")
}
pub fn format_passage(text: &str) -> String {
format!("passage: {text}")
}
pub trait Embedder: Send + Sync {
fn device(&self) -> &str;
fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
}
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
pub struct EmbedSummary {
pub messages: usize,
pub batches: usize,
pub cancelled: bool,
}
#[derive(Debug, Clone, Copy)]
pub struct BatchProgress {
pub batch_messages: usize,
pub total_messages: usize,
pub total_batches: usize,
}
type ProgressFn = Box<dyn Fn(BatchProgress) + Send + Sync>;
pub struct EmbedWorker<'a, B: Embedder> {
store: &'a Store,
backend: &'a B,
include_stale: bool,
limit: Option<usize>,
sort_window: usize,
progress: Option<ProgressFn>,
cancel: Option<Arc<AtomicBool>>,
}
impl<'a, B: Embedder> EmbedWorker<'a, B> {
pub fn new(store: &'a Store, backend: &'a B) -> Self {
Self {
store,
backend,
include_stale: false,
limit: None,
sort_window: DEFAULT_SORT_WINDOW,
progress: None,
cancel: None,
}
}
pub fn with_cancel(mut self, flag: Arc<AtomicBool>) -> Self {
self.cancel = Some(flag);
self
}
fn cancelled(&self) -> bool {
self.cancel
.as_ref()
.is_some_and(|f| f.load(Ordering::Relaxed))
}
pub fn with_sort_window(mut self, window: usize) -> Self {
self.sort_window = window.max(DEFAULT_BATCH_SIZE);
self
}
pub fn with_progress(
mut self,
callback: impl Fn(BatchProgress) + Send + Sync + 'static,
) -> Self {
self.progress = Some(Box::new(callback));
self
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = Some(limit.max(1));
self
}
pub fn include_stale(mut self) -> Self {
self.include_stale = true;
self
}
pub async fn run(&self) -> Result<EmbedSummary> {
let mut summary = EmbedSummary::default();
let mut window: Vec<PendingMessage> = Vec::with_capacity(self.sort_window);
let mut pulled = 0usize;
let mut stream = if self.include_stale {
Box::pin(self.store.pending_or_stale_messages())
as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
} else {
Box::pin(self.store.pending_embedding_messages())
as std::pin::Pin<Box<dyn tokio_stream::Stream<Item = Result<PendingMessage>> + '_>>
};
while let Some(pending) = stream.next().await {
if self.limit.is_some_and(|limit| pulled >= limit) || self.cancelled() {
break;
}
window.push(pending?);
pulled += 1;
if window.len() >= self.sort_window {
self.drain_window(&mut window, &mut summary).await?;
}
}
self.drain_window(&mut window, &mut summary).await?;
summary.cancelled = self.cancelled();
tracing::info!(
model = model_id(),
messages = summary.messages,
batches = summary.batches,
cancelled = summary.cancelled,
"embed worker finished",
);
Ok(summary)
}
async fn drain_window(
&self,
window: &mut Vec<PendingMessage>,
summary: &mut EmbedSummary,
) -> Result<()> {
if window.is_empty() {
return Ok(());
}
window.sort_unstable_by_key(|message| message.search_text.len());
let mut batch: Vec<PendingMessage> = Vec::with_capacity(DEFAULT_BATCH_SIZE);
let mut accumulator: Vec<EmbeddedMessage> = Vec::with_capacity(window.len());
for message in window.drain(..) {
batch.push(message);
if batch.len() >= DEFAULT_BATCH_SIZE {
accumulator.extend(self.embed_batch(&mut batch, summary).await?);
}
}
accumulator.extend(self.embed_batch(&mut batch, summary).await?);
if !accumulator.is_empty() {
self.store.write_embeddings(&accumulator).await?;
}
Ok(())
}
async fn embed_batch(
&self,
batch: &mut Vec<PendingMessage>,
summary: &mut EmbedSummary,
) -> Result<Vec<EmbeddedMessage>> {
if batch.is_empty() {
return Ok(Vec::new());
}
let pending = std::mem::take(batch);
let texts = pending
.iter()
.map(|message| format_passage(&message.search_text))
.collect::<Vec<_>>();
let vectors = self.backend.embed(&texts)?;
if vectors.len() != pending.len() {
return Err(anyhow!(
"backend returned {} vectors for {} messages",
vectors.len(),
pending.len(),
));
}
let rows = pending
.into_iter()
.zip(vectors)
.map(|(message, vector)| EmbeddedMessage {
session_id: message.session_id,
id: message.id,
vector,
})
.collect::<Vec<_>>();
let batch_messages = rows.len();
summary.messages += batch_messages;
summary.batches += 1;
if let Some(progress) = &self.progress {
progress(BatchProgress {
batch_messages,
total_messages: summary.messages,
total_batches: summary.batches,
});
}
Ok(rows)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use std::sync::atomic::{AtomicUsize, Ordering as AtomicOrdering};
#[test]
fn e5_prefixes_apply_the_asymmetric_retrieval_pair() {
assert_eq!(
format_query("how does retry backoff work"),
"query: how does retry backoff work",
);
assert_eq!(
format_passage("retry uses exponential backoff"),
"passage: retry uses exponential backoff",
);
}
struct CountingEmbedder;
impl Embedder for CountingEmbedder {
fn device(&self) -> &str {
"test"
}
fn embed(&self, _texts: &[String]) -> Result<Vec<Vec<f32>>> {
Ok(vec![])
}
}
#[tokio::test(flavor = "multi_thread")]
async fn lazy_embedder_evicts_after_idle_threshold() {
let loads = Arc::new(AtomicUsize::new(0));
let counter = Arc::clone(&loads);
let loader: EmbedLoader = Arc::new(move || {
counter.fetch_add(1, AtomicOrdering::SeqCst);
Ok(Arc::new(CountingEmbedder) as Arc<dyn Embedder>)
});
let embedder =
LazyEmbedder::with_loader(loader).with_idle_threshold(Duration::from_millis(20));
embedder.get().await.unwrap();
assert_eq!(
loads.load(AtomicOrdering::SeqCst),
1,
"first get loads once"
);
embedder.get().await.unwrap();
assert_eq!(
loads.load(AtomicOrdering::SeqCst),
1,
"back-to-back get reuses the cached backend",
);
tokio::time::sleep(Duration::from_millis(60)).await;
embedder.get().await.unwrap();
assert_eq!(
loads.load(AtomicOrdering::SeqCst),
2,
"get after the idle threshold triggers a reload",
);
}
#[tokio::test(flavor = "multi_thread")]
async fn lazy_embedder_from_loaded_never_evicts() {
let preloaded = LazyEmbedder::from_loaded(Arc::new(CountingEmbedder));
preloaded.get().await.unwrap();
tokio::time::sleep(Duration::from_millis(60)).await;
preloaded.get().await.unwrap();
}
}