use anyhow::{anyhow, Result};
use fastembed::{InitOptions, TextEmbedding};
use parking_lot::Mutex;
use std::path::PathBuf;
use std::sync::Arc;
pub use fastembed::EmbeddingModel as Model;
pub type Embedding = Vec<f32>;
#[derive(Clone)]
pub struct EmbeddingEngine {
model: Model,
cache_dir: Option<PathBuf>,
inner: Arc<Mutex<Option<TextEmbedding>>>,
}
impl EmbeddingEngine {
pub fn new(model: Model, cache_dir: Option<PathBuf>) -> Result<Self> {
Ok(Self {
model,
cache_dir,
inner: Arc::new(Mutex::new(None)),
})
}
#[allow(dead_code)]
pub fn is_loaded(&self) -> bool {
self.inner.lock().is_some()
}
fn with_model<R>(
&self,
f: impl FnOnce(&mut TextEmbedding) -> Result<R>,
) -> Result<R> {
let mut guard = self.inner.lock();
if guard.is_none() {
let options = {
let opts = InitOptions::new(self.model.clone());
match &self.cache_dir {
Some(dir) => opts.with_cache_dir(dir.clone()),
None => opts,
}
};
let model = TextEmbedding::try_new(options).map_err(|e| {
anyhow!("failed to initialise embedding model: {e}")
})?;
*guard = Some(model);
}
let model = guard
.as_mut()
.expect("embedding model set in the block immediately above");
f(model)
}
pub fn embed(&self, text: &str) -> Result<Embedding> {
self.with_model(|model| {
model
.embed(vec![text], None)
.map_err(|e| anyhow!("embedding failed: {e}"))?
.into_iter()
.next()
.ok_or_else(|| anyhow!("model returned no embedding"))
})
}
pub fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
if texts.is_empty() {
return Ok(vec![]);
}
self.with_model(|model| {
model
.embed(texts.to_vec(), None)
.map_err(|e| anyhow!("batch embedding failed: {e}"))
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn new_does_not_load_the_model() {
let engine =
EmbeddingEngine::new(Model::MultilingualE5Small, None).unwrap();
assert!(
!engine.is_loaded(),
"new() must defer the ONNX model load (I.1.4) — \
the engine reported the model as already resident",
);
}
#[test]
fn empty_batch_does_not_load_the_model() {
let engine =
EmbeddingEngine::new(Model::MultilingualE5Small, None).unwrap();
let out = engine.embed_batch(&[]).unwrap();
assert!(out.is_empty());
assert!(
!engine.is_loaded(),
"embed_batch([]) must short-circuit before the model load",
);
}
}