use std::num::NonZeroUsize;
use std::sync::Arc;
use anyhow::{Context, Result};
use async_trait::async_trait;
use fastembed::{EmbeddingModel, TextEmbedding, TextInitOptions};
use lru::LruCache;
use parking_lot::Mutex;
pub const EMBED_DIM: usize = 384;
pub const DEFAULT_CACHE_CAPACITY: usize = 256;
#[async_trait]
pub trait Embedder: Send + Sync {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>>;
fn dimension(&self) -> usize;
}
pub async fn embed_one(embedder: &dyn Embedder, text: &str) -> Result<Vec<f32>> {
let mut v = embedder.embed_batch(&[text.to_string()]).await?;
v.pop()
.context("embedder returned no embedding for non-empty input")
}
pub struct FastEmbedder {
model: Arc<Mutex<TextEmbedding>>,
cache: Arc<Mutex<LruCache<String, Vec<f32>>>>,
dim: usize,
}
impl FastEmbedder {
pub async fn new() -> Result<Self> {
Self::with_cache_size(DEFAULT_CACHE_CAPACITY).await
}
pub async fn with_cache_size(capacity: usize) -> Result<Self> {
let capacity =
NonZeroUsize::new(capacity.max(1)).expect("capacity.max(1) is always non-zero");
let model = tokio::task::spawn_blocking(|| -> Result<TextEmbedding> {
let mut m =
TextEmbedding::try_new(TextInitOptions::new(EmbeddingModel::AllMiniLML6V2Q))
.or_else(|q_err| {
tracing::warn!(
"AllMiniLML6V2Q init failed ({q_err:#}), falling back to AllMiniLML6V2"
);
TextEmbedding::try_new(TextInitOptions::new(EmbeddingModel::AllMiniLML6V2))
})
.context(
"failed to initialise fastembed (tried AllMiniLML6V2Q and AllMiniLML6V2)",
)?;
let warmup: Vec<&str> = vec![
"hello world",
"the quick brown fox",
"memory palace warmup",
"embedding model ready",
"trusty common warmup",
];
let _ = m
.embed(warmup, None)
.context("fastembed warmup batch failed")?;
Ok(m)
})
.await
.context("spawn_blocking joined with error during embedder init")??;
Ok(Self {
model: Arc::new(Mutex::new(model)),
cache: Arc::new(Mutex::new(LruCache::new(capacity))),
dim: EMBED_DIM,
})
}
}
#[async_trait]
impl Embedder for FastEmbedder {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
if texts.is_empty() {
return Ok(Vec::new());
}
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut to_compute: Vec<(usize, String)> = Vec::new();
{
let mut cache = self.cache.lock();
for (i, t) in texts.iter().enumerate() {
if let Some(v) = cache.get(t) {
results[i] = Some(v.clone());
} else {
to_compute.push((i, t.clone()));
}
}
}
if !to_compute.is_empty() {
let model = Arc::clone(&self.model);
let owned: Vec<String> = to_compute.iter().map(|(_, s)| s.clone()).collect();
let computed = tokio::task::spawn_blocking(move || -> Result<Vec<Vec<f32>>> {
let mut guard = model.lock();
guard
.embed(owned, None)
.context("fastembed embed call failed")
})
.await
.context("spawn_blocking joined with error during embed")??;
if computed.len() != to_compute.len() {
anyhow::bail!(
"fastembed returned {} embeddings, expected {}",
computed.len(),
to_compute.len()
);
}
let mut cache = self.cache.lock();
for ((idx, key), vector) in to_compute.into_iter().zip(computed.into_iter()) {
cache.put(key, vector.clone());
results[idx] = Some(vector);
}
}
results
.into_iter()
.map(|opt| opt.context("missing embedding slot after batch"))
.collect()
}
fn dimension(&self) -> usize {
self.dim
}
}
#[cfg(any(test, feature = "test-support"))]
pub struct MockEmbedder {
dim: usize,
}
#[cfg(any(test, feature = "test-support"))]
impl MockEmbedder {
pub fn new(dim: usize) -> Self {
Self { dim }
}
fn hash_to_vec(&self, text: &str) -> Vec<f32> {
let mut v = vec![0.0_f32; self.dim];
for (i, b) in text.bytes().enumerate() {
let slot = (i + b as usize) % self.dim;
v[slot] += (b as f32) / 255.0;
}
if let Some(first) = text.bytes().next() {
v[0] += first as f32 / 255.0;
}
v
}
}
#[cfg(any(test, feature = "test-support"))]
#[async_trait]
impl Embedder for MockEmbedder {
async fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
Ok(texts.iter().map(|t| self.hash_to_vec(t)).collect())
}
fn dimension(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn mock_embedder_round_trip() {
let e = MockEmbedder::new(EMBED_DIM);
assert_eq!(e.dimension(), EMBED_DIM);
let v = embed_one(&e, "hello").await.unwrap();
assert_eq!(v.len(), EMBED_DIM);
let batch = e
.embed_batch(&["a".to_string(), "b".to_string()])
.await
.unwrap();
assert_eq!(batch.len(), 2);
assert_ne!(batch[0], batch[1]);
}
#[tokio::test]
async fn mock_embedder_empty_input_returns_empty() {
let e = MockEmbedder::new(EMBED_DIM);
let v = e.embed_batch(&[]).await.unwrap();
assert!(v.is_empty());
}
#[tokio::test]
#[ignore]
async fn fastembed_returns_correct_dim() {
let e = FastEmbedder::new().await.unwrap();
assert_eq!(e.dimension(), 384);
let v = embed_one(&e, "fn authenticate(user: &str) -> bool")
.await
.unwrap();
assert_eq!(v.len(), 384);
assert!(v.iter().any(|x| *x != 0.0));
}
#[tokio::test]
#[ignore]
async fn fastembed_cache_hit_is_idempotent() {
let e = FastEmbedder::new().await.unwrap();
let v1 = embed_one(&e, "cached").await.unwrap();
let v2 = embed_one(&e, "cached").await.unwrap();
assert_eq!(v1, v2);
}
}