use parking_lot::Mutex;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
use crate::cache::{EmbeddingCache, EmbeddingCacheConfig};
const BUILTIN_MODEL_VERSION: &str = "all-MiniLM-L6-v2";
struct FastEmbedInner {
model: Mutex<TextEmbedding>,
dim: usize,
cache: EmbeddingCache,
model_version: &'static str,
cache_hits: AtomicU64,
cache_misses: AtomicU64,
}
#[derive(Clone)]
pub struct FastEmbedder {
inner: Arc<FastEmbedInner>,
}
impl FastEmbedder {
pub fn new() -> anyhow::Result<Self> {
Self::with_cache_config(EmbeddingCacheConfig::default())
}
pub fn with_cache_config(cache_cfg: EmbeddingCacheConfig) -> anyhow::Result<Self> {
tracing::info!("loading embedding model (all-MiniLM-L6-v2)...");
let model = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
TextEmbedding::try_new(
InitOptions::new(EmbeddingModel::AllMiniLML6V2)
.with_show_download_progress(true),
)
}))
.map_err(|panic_info| {
let msg = panic_info
.downcast_ref::<String>()
.cloned()
.or_else(|| panic_info.downcast_ref::<&str>().map(|s| s.to_string()))
.unwrap_or_else(|| "unknown panic".to_string());
let hint = if msg.contains("dlopen") || msg.contains("libonnxruntime") || msg.contains("onnxruntime.dll") {
"\n\n\
ONNX Runtime library not found. The built-in embedder requires it.\n\
\n\
Linux: wget https://github.com/microsoft/onnxruntime/releases/download/v1.24.4/onnxruntime-linux-x64-1.24.4.tgz\n\
tar xzf onnxruntime-linux-x64-1.24.4.tgz\n\
sudo cp onnxruntime-linux-x64-1.24.4/lib/libonnxruntime*.so* /usr/local/lib/\n\
sudo ldconfig\n\
export ORT_DYLIB_PATH=/usr/local/lib/libonnxruntime.so.1.24.4\n\
\n\
macOS: brew install onnxruntime\n\
\n\
Windows: Download from https://github.com/microsoft/onnxruntime/releases\n\
and place onnxruntime.dll alongside the binary.\n\
\n\
Or use the Docker image (ghcr.io/yantrikos/yantrikdb) which bundles ONNX Runtime.\n\
\n\
Or skip the built-in embedder by setting [embedding] strategy = \"client_only\"\n\
in your config file (you'll need to provide pre-computed embeddings to remember()).\n"
} else {
""
};
anyhow::anyhow!(
"embedder initialization failed: {}{}",
msg,
hint,
)
})??;
tracing::info!(
"embedding model loaded (384 dim, cache_capacity={})",
cache_cfg.max_entries
);
Ok(Self {
inner: Arc::new(FastEmbedInner {
model: Mutex::new(model),
dim: 384,
cache: EmbeddingCache::new(cache_cfg),
model_version: BUILTIN_MODEL_VERSION,
cache_hits: AtomicU64::new(0),
cache_misses: AtomicU64::new(0),
}),
})
}
pub fn model_version(&self) -> &'static str {
self.inner.model_version
}
pub fn cache_hits(&self) -> u64 {
self.inner.cache_hits.load(Ordering::Relaxed)
}
pub fn cache_misses(&self) -> u64 {
self.inner.cache_misses.load(Ordering::Relaxed)
}
pub fn cache_hit_rate(&self) -> f64 {
let h = self.cache_hits();
let m = self.cache_misses();
let total = h + m;
if total == 0 {
0.0
} else {
h as f64 / total as f64
}
}
pub fn boxed(&self) -> Box<dyn yantrikdb::types::Embedder + Send + Sync> {
Box::new(self.clone())
}
}
impl yantrikdb::types::Embedder for FastEmbedder {
fn embed(
&self,
text: &str,
) -> std::result::Result<Vec<f32>, Box<dyn std::error::Error + Send + Sync>> {
if let Some(v) = self
.inner
.cache
.get_for_text(text, self.inner.model_version)
{
self.inner.cache_hits.fetch_add(1, Ordering::Relaxed);
return Ok(v);
}
self.inner.cache_misses.fetch_add(1, Ordering::Relaxed);
let computed = {
let mut model = self.inner.model.lock();
let mut results = model.embed(vec![text], None)?;
results
.pop()
.ok_or_else(|| -> Box<dyn std::error::Error + Send + Sync> {
"empty embedding result".into()
})?
};
self.inner
.cache
.put_for_text(text, self.inner.model_version, computed.clone());
Ok(computed)
}
fn embed_batch(
&self,
texts: &[&str],
) -> std::result::Result<Vec<Vec<f32>>, Box<dyn std::error::Error + Send + Sync>> {
let model_ver = self.inner.model_version;
let mut results: Vec<Option<Vec<f32>>> = Vec::with_capacity(texts.len());
let mut miss_indices: Vec<usize> = Vec::new();
let mut miss_texts: Vec<String> = Vec::new();
for (idx, &t) in texts.iter().enumerate() {
if let Some(v) = self.inner.cache.get_for_text(t, model_ver) {
results.push(Some(v));
self.inner.cache_hits.fetch_add(1, Ordering::Relaxed);
} else {
results.push(None);
miss_indices.push(idx);
miss_texts.push(t.to_string());
self.inner.cache_misses.fetch_add(1, Ordering::Relaxed);
}
}
if !miss_texts.is_empty() {
let computed = {
let mut model = self.inner.model.lock();
model.embed(miss_texts.clone(), None)?
};
for ((dest_idx, miss_text), vec) in miss_indices
.iter()
.zip(miss_texts.iter())
.zip(computed.into_iter())
{
self.inner
.cache
.put_for_text(miss_text, model_ver, vec.clone());
results[*dest_idx] = Some(vec);
}
}
Ok(results
.into_iter()
.map(|opt| opt.expect("every position populated by either cache hit or batch"))
.collect())
}
fn dim(&self) -> usize {
self.inner.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
fn embedder_with_seeded_cache(entries: &[(&str, Vec<f32>)]) -> Option<FastEmbedder> {
let _ = entries;
None
}
#[test]
fn cache_seeded_helper_compiles() {
let result = embedder_with_seeded_cache(&[]);
assert!(result.is_none()); }
#[test]
fn cache_key_includes_model_version() {
use crate::cache::embedding::EmbeddingCacheKey;
let k1 = EmbeddingCacheKey::for_text("hi", BUILTIN_MODEL_VERSION);
let k2 = EmbeddingCacheKey::for_text("hi", "different-model");
assert_ne!(k1, k2);
}
#[test]
fn model_version_is_stable_string() {
assert_eq!(BUILTIN_MODEL_VERSION, "all-MiniLM-L6-v2");
}
#[test]
fn cache_hit_rate_zero_on_no_calls() {
let h: u64 = 0;
let m: u64 = 0;
let rate = if h + m == 0 {
0.0
} else {
h as f64 / (h + m) as f64
};
assert_eq!(rate, 0.0);
}
#[test]
fn cache_hit_rate_correct_for_mixed() {
let h: u64 = 7;
let m: u64 = 3;
let rate = if h + m == 0 {
0.0
} else {
h as f64 / (h + m) as f64
};
assert!((rate - 0.7).abs() < 1e-9);
}
}