use std::path::PathBuf;
use std::sync::{Condvar, Mutex};
use std::time::{Duration, Instant};
use indexmap::IndexMap;
use super::catalog::EmbeddingModel;
use super::fastembed::FastembedProvider;
use super::types::text_preview;
use super::types::{EmbeddingProvider, EmbeddingResult};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ModelState {
NotLoaded,
Loading,
Cached,
Ready,
Disabled,
Failed(String),
}
#[derive(Debug, Clone)]
pub struct LazyOpts {
pub idle_timeout_secs: u64,
pub load_timeout_secs: u64,
pub query_cache_capacity: usize,
}
impl Default for LazyOpts {
fn default() -> Self {
Self {
idle_timeout_secs: 600,
load_timeout_secs: 300,
query_cache_capacity: 64,
}
}
}
pub struct EmbeddingCache {
capacity: usize,
map: IndexMap<String, Vec<f32>>,
}
impl EmbeddingCache {
pub fn new(capacity: usize) -> Self {
Self {
capacity,
map: IndexMap::new(),
}
}
pub fn get(&mut self, key: &str) -> Option<&Vec<f32>> {
if let Some(idx) = self.map.get_index_of(key) {
self.map.move_index(idx, self.map.len() - 1);
Some(&self.map[key])
} else {
None
}
}
pub fn insert(&mut self, key: String, value: Vec<f32>) {
if self.capacity == 0 {
return;
}
if let Some(idx) = self.map.get_index_of(&key) {
self.map.move_index(idx, self.map.len() - 1);
*self.map.get_mut(&key).unwrap() = value;
} else {
if self.map.len() >= self.capacity {
self.map.shift_remove_index(0);
}
self.map.insert(key, value);
}
}
}
struct Inner {
model: EmbeddingModel,
cache_dir: PathBuf,
state: ModelState,
provider: Option<FastembedProvider>,
last_used: Option<Instant>,
}
pub struct LazyFastembedProvider {
inner: Mutex<Inner>,
cvar: Condvar,
opts: LazyOpts,
query_cache: Mutex<EmbeddingCache>,
}
impl LazyFastembedProvider {
pub fn new(model: EmbeddingModel, cache_dir: PathBuf, opts: LazyOpts) -> Self {
let initial_state = if is_model_cached(model, &cache_dir) {
ModelState::Cached
} else {
ModelState::NotLoaded
};
Self {
inner: Mutex::new(Inner {
model,
cache_dir,
state: initial_state,
provider: None,
last_used: None,
}),
cvar: Condvar::new(),
opts: opts.clone(),
query_cache: Mutex::new(EmbeddingCache::new(opts.query_cache_capacity)),
}
}
pub fn state(&self) -> ModelState {
let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
guard.state.clone()
}
pub fn is_ready(&self) -> bool {
let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
matches!(guard.state, ModelState::Ready)
}
pub fn ensure_model(&self) -> Result<(), String> {
let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
if matches!(guard.state, ModelState::Ready)
&& let Some(last) = guard.last_used
&& self.opts.idle_timeout_secs > 0
&& last.elapsed().as_secs() > self.opts.idle_timeout_secs
{
guard.provider = None;
guard.state = ModelState::Cached;
guard.last_used = None;
}
match &guard.state {
ModelState::Ready => Ok(()),
ModelState::Disabled => Err("model is disabled".into()),
ModelState::Failed(e) => Err(e.clone()),
ModelState::Loading => {
let timeout = Duration::from_secs(self.opts.load_timeout_secs);
let result = self
.cvar
.wait_timeout_while(guard, timeout, |g| matches!(g.state, ModelState::Loading));
match result {
Ok((mut g, timeout_result)) => {
if timeout_result.timed_out() {
if matches!(g.state, ModelState::Loading) {
g.state = ModelState::Failed("model loading timed out".into());
self.cvar.notify_all();
}
return Err("model loading timed out".into());
}
guard = g;
}
Err(e) => {
(guard, _) = e.into_inner();
}
}
match &guard.state {
ModelState::Ready => Ok(()),
ModelState::Failed(e) => Err(e.clone()),
other => Err(format!("unexpected state after wait: {other:?}")),
}
}
ModelState::NotLoaded | ModelState::Cached => {
guard.state = ModelState::Loading;
let model = guard.model;
let cache_dir = guard.cache_dir.clone();
self.cvar.notify_all();
drop(guard);
let result = FastembedProvider::new(model, Some(cache_dir));
let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
match result {
Ok(provider) => {
guard.state = ModelState::Ready;
guard.provider = Some(provider);
guard.last_used = Some(Instant::now());
self.cvar.notify_all();
Ok(())
}
Err(e) => {
guard.state = ModelState::Failed(e.to_string());
self.cvar.notify_all();
Err(e.to_string())
}
}
}
}
}
}
impl EmbeddingProvider for LazyFastembedProvider {
fn dim(&self) -> usize {
let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
guard.model.dimension()
}
fn name(&self) -> &str {
let guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
guard.model.as_str()
}
fn embed(&self, text: &str) -> anyhow::Result<EmbeddingResult> {
{
let mut cache = self.query_cache.lock().unwrap_or_else(|e| e.into_inner());
if let Some(cached) = cache.get(text) {
return Ok(EmbeddingResult {
vector: cached.clone(),
text_preview: text_preview(text),
});
}
}
self.ensure_model().map_err(|e| anyhow::anyhow!("{e}"))?;
let result = {
let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
guard.last_used = Some(Instant::now());
match &guard.provider {
Some(p) => p.embed(text),
None => Err(anyhow::anyhow!("provider not available")),
}
}?;
{
let mut cache = self.query_cache.lock().unwrap_or_else(|e| e.into_inner());
cache.insert(text.to_string(), result.vector.clone());
}
Ok(result)
}
fn embed_batch(&self, texts: &[&str]) -> anyhow::Result<Vec<EmbeddingResult>> {
if texts.is_empty() {
return Ok(vec![]);
}
let mut results: Vec<Option<Vec<f32>>> = vec![None; texts.len()];
let mut miss_indices: Vec<usize> = Vec::with_capacity(texts.len());
let mut miss_texts: Vec<&str> = Vec::with_capacity(texts.len());
{
let mut cache = self.query_cache.lock().unwrap_or_else(|e| e.into_inner());
for (i, text) in texts.iter().enumerate() {
if let Some(cached) = cache.get(text) {
results[i] = Some(cached.clone());
} else {
miss_indices.push(i);
miss_texts.push(*text);
}
}
}
if miss_indices.is_empty() {
return Ok(results
.into_iter()
.zip(texts.iter())
.map(|(opt, text)| EmbeddingResult {
vector: opt.unwrap(),
text_preview: text_preview(text),
})
.collect());
}
self.ensure_model().map_err(|e| anyhow::anyhow!("{e}"))?;
let batch_results = {
let mut guard = self.inner.lock().unwrap_or_else(|e| e.into_inner());
guard.last_used = Some(Instant::now());
match &guard.provider {
Some(p) => p.embed_batch(&miss_texts),
None => Err(anyhow::anyhow!("provider not available")),
}
}?;
{
let mut cache = self.query_cache.lock().unwrap_or_else(|e| e.into_inner());
for (batch_idx, &result_idx) in miss_indices.iter().enumerate() {
let vector = batch_results[batch_idx].vector.clone();
cache.insert(texts[result_idx].to_string(), vector.clone());
results[result_idx] = Some(vector);
}
}
Ok(results
.into_iter()
.zip(texts.iter())
.map(|(opt, text)| EmbeddingResult {
vector: opt.unwrap(),
text_preview: text_preview(text),
})
.collect())
}
}
pub fn is_model_cached(model: EmbeddingModel, cache_dir: &std::path::Path) -> bool {
let model_code = model.model_code();
let mut parts = model_code.splitn(2, '/');
let org = parts.next().unwrap_or("");
let repo = parts.next().unwrap_or("");
if org.is_empty() || repo.is_empty() {
return false;
}
let folder_name = format!("models--{org}--{repo}");
cache_dir.join(&folder_name).is_dir()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn model_state_debug() {
assert_eq!(format!("{:?}", ModelState::NotLoaded), "NotLoaded");
assert_eq!(format!("{:?}", ModelState::Ready), "Ready");
}
#[test]
fn lazy_opts_default() {
let opts = LazyOpts::default();
assert_eq!(opts.idle_timeout_secs, 600);
assert_eq!(opts.load_timeout_secs, 300);
assert_eq!(opts.query_cache_capacity, 64);
}
#[test]
fn cache_hit_miss_eviction() {
let mut cache = EmbeddingCache::new(2);
assert!(cache.get("a").is_none());
cache.insert("a".into(), vec![1.0]);
assert!(cache.get("a").is_some());
assert!(cache.get("b").is_none());
cache.insert("b".into(), vec![2.0]);
assert!(cache.get("a").is_some()); cache.insert("c".into(), vec![3.0]); assert!(cache.get("b").is_none()); assert!(cache.get("a").is_some());
assert!(cache.get("c").is_some());
}
#[test]
fn cache_zero_capacity() {
let mut cache = EmbeddingCache::new(0);
cache.insert("a".into(), vec![1.0]);
assert!(cache.get("a").is_none());
}
#[test]
fn cache_update_existing() {
let mut cache = EmbeddingCache::new(2);
cache.insert("a".into(), vec![1.0]);
cache.insert("a".into(), vec![2.0]);
assert_eq!(cache.get("a").unwrap(), &vec![2.0]);
}
#[test]
fn is_model_cached_nonexistent() {
let dir = tempfile::tempdir().unwrap();
assert!(!is_model_cached(EmbeddingModel::BGESmallENV15, dir.path()));
}
#[test]
fn is_model_cached_existing() {
let dir = tempfile::tempdir().unwrap();
let model_code = EmbeddingModel::BGESmallENV15.model_code();
let parts: Vec<&str> = model_code.splitn(2, '/').collect();
let folder = dir
.path()
.join(format!("models--{}--{}", parts[0], parts[1]));
std::fs::create_dir_all(&folder).unwrap();
assert!(is_model_cached(EmbeddingModel::BGESmallENV15, dir.path()));
}
#[test]
fn constructor_is_instant_no_model() {
let dir = tempfile::tempdir().unwrap();
let provider = LazyFastembedProvider::new(
EmbeddingModel::BGESmallENV15,
dir.path().to_path_buf(),
LazyOpts::default(),
);
assert_eq!(provider.state(), ModelState::NotLoaded);
}
#[test]
fn constructor_detects_cached() {
let dir = tempfile::tempdir().unwrap();
let model_code = EmbeddingModel::BGESmallENV15.model_code();
let parts: Vec<&str> = model_code.splitn(2, '/').collect();
let folder = dir
.path()
.join(format!("models--{}--{}", parts[0], parts[1]));
std::fs::create_dir_all(&folder).unwrap();
let provider = LazyFastembedProvider::new(
EmbeddingModel::BGESmallENV15,
dir.path().to_path_buf(),
LazyOpts::default(),
);
assert_eq!(provider.state(), ModelState::Cached);
}
}