use dashmap::DashMap;
use fastembed::{EmbeddingModel as FastEmbedModel, InitOptions, TextEmbedding};
use parking_lot::RwLock;
use super::models::EmbeddingModel;
pub struct ModelCache {
models: DashMap<EmbeddingModel, RwLock<TextEmbedding>>,
default_model: RwLock<EmbeddingModel>,
}
impl ModelCache {
pub fn new() -> Self {
Self {
models: DashMap::new(),
default_model: RwLock::new(EmbeddingModel::default()),
}
}
pub fn embed(&self, model: EmbeddingModel, texts: Vec<&str>) -> Result<Vec<Vec<f32>>, String> {
if let Some(cached) = self.models.get(&model) {
let mut embedding = cached.write();
return embedding
.embed(texts, None)
.map_err(|e| format!("Embedding failed: {}", e));
}
let embedding = self.load_model(model)?;
let mut embedding_model = embedding;
let result = embedding_model
.embed(texts, None)
.map_err(|e| format!("Embedding failed: {}", e));
self.models.insert(model, RwLock::new(embedding_model));
result
}
fn load_model(&self, model: EmbeddingModel) -> Result<TextEmbedding, String> {
let fastembed_model = match model {
EmbeddingModel::AllMiniLmL6V2 => FastEmbedModel::AllMiniLML6V2,
EmbeddingModel::BgeSmallEnV15 => FastEmbedModel::BGESmallENV15,
EmbeddingModel::BgeBaseEnV15 => FastEmbedModel::BGEBaseENV15,
EmbeddingModel::BgeLargeEnV15 => FastEmbedModel::BGELargeENV15,
EmbeddingModel::AllMpnetBaseV2 => FastEmbedModel::AllMiniLML6V2, EmbeddingModel::NomicEmbedTextV15 => FastEmbedModel::NomicEmbedTextV15,
};
let options = InitOptions::new(fastembed_model).with_show_download_progress(false);
TextEmbedding::try_new(options)
.map_err(|e| format!("Failed to load model '{}': {}", model.name(), e))
}
pub fn preload(&self, model: EmbeddingModel) -> Result<(), String> {
if self.models.contains_key(&model) {
return Ok(());
}
let embedding = self.load_model(model)?;
self.models.insert(model, RwLock::new(embedding));
Ok(())
}
pub fn is_loaded(&self, model: EmbeddingModel) -> bool {
self.models.contains_key(&model)
}
pub fn loaded_models(&self) -> Vec<EmbeddingModel> {
self.models.iter().map(|r| *r.key()).collect()
}
pub fn unload(&self, model: EmbeddingModel) -> bool {
self.models.remove(&model).is_some()
}
pub fn clear(&self) {
self.models.clear();
}
pub fn default_model(&self) -> EmbeddingModel {
*self.default_model.read()
}
pub fn set_default_model(&self, model: EmbeddingModel) {
*self.default_model.write() = model;
}
pub fn estimated_memory_usage(&self) -> usize {
self.models
.iter()
.map(|r| r.key().memory_mb() * 1024 * 1024)
.sum()
}
}
impl Default for ModelCache {
fn default() -> Self {
Self::new()
}
}
lazy_static::lazy_static! {
pub static ref GLOBAL_CACHE: ModelCache = ModelCache::new();
}
pub fn global_cache() -> &'static ModelCache {
&GLOBAL_CACHE
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_cache_creation() {
let cache = ModelCache::new();
assert!(!cache.is_loaded(EmbeddingModel::AllMiniLmL6V2));
assert!(cache.loaded_models().is_empty());
}
#[test]
fn test_default_model() {
let cache = ModelCache::new();
assert_eq!(cache.default_model(), EmbeddingModel::AllMiniLmL6V2);
cache.set_default_model(EmbeddingModel::BgeSmallEnV15);
assert_eq!(cache.default_model(), EmbeddingModel::BgeSmallEnV15);
}
}