use std::path::{Path, PathBuf};
use std::sync::Arc;
use parking_lot::RwLock;
use tracing::{info, warn};
use crate::search::embedder::{Embedder, EmbedderError, EmbedderResult};
use crate::search::fastembed_embedder::FastEmbedder;
use crate::search::fastembed_reranker::FastEmbedReranker;
use crate::search::hash_embedder::HashEmbedder;
use crate::search::reranker::{Reranker, RerankerError, RerankerResult, rerank_texts};
pub struct ModelManager {
data_dir: PathBuf,
embedder: RwLock<Option<Arc<dyn Embedder>>>,
reranker: RwLock<Option<Arc<dyn Reranker>>>,
embedder_name: RwLock<String>,
reranker_name: RwLock<String>,
fallback_embedder: Arc<HashEmbedder>,
}
impl ModelManager {
pub fn new(data_dir: &Path) -> Self {
Self {
data_dir: data_dir.to_path_buf(),
embedder: RwLock::new(None),
reranker: RwLock::new(None),
embedder_name: RwLock::new("not-loaded".to_string()),
reranker_name: RwLock::new("not-loaded".to_string()),
fallback_embedder: Arc::new(HashEmbedder::new(384)),
}
}
pub fn is_ready(&self) -> bool {
self.embedder.read().is_some()
}
pub fn embedder_id(&self) -> String {
self.embedder
.read()
.as_ref()
.map(|e| e.id().to_string())
.unwrap_or_else(|| "hash-384".to_string())
}
pub fn embedder_name(&self) -> String {
self.embedder_name.read().clone()
}
pub fn embedder_dimension(&self) -> usize {
self.embedder
.read()
.as_ref()
.map(|e| e.dimension())
.unwrap_or(384)
}
pub fn embedder_loaded(&self) -> bool {
self.embedder.read().is_some()
}
pub fn reranker_id(&self) -> String {
self.reranker
.read()
.as_ref()
.map(|r| r.id().to_string())
.unwrap_or_else(|| "none".to_string())
}
pub fn reranker_name(&self) -> String {
self.reranker_name.read().clone()
}
pub fn reranker_loaded(&self) -> bool {
self.reranker.read().is_some()
}
pub fn warm_embedder(&self) -> EmbedderResult<()> {
if self.embedder.read().is_some() {
return Ok(());
}
let mut embedder_guard = self.embedder.write();
if embedder_guard.is_some() {
return Ok(());
}
let model_dir = FastEmbedder::default_model_dir(&self.data_dir);
info!(model_dir = %model_dir.display(), "Loading embedder");
match FastEmbedder::load_from_dir(&model_dir) {
Ok(embedder) => {
let id = embedder.id().to_string();
let dimension = embedder.dimension();
*embedder_guard = Some(Arc::new(embedder));
*self.embedder_name.write() = "MiniLM-L6-v2".to_string();
info!(id = %id, dimension = dimension, "Embedder loaded");
Ok(())
}
Err(e) => {
warn!(error = %e, "Failed to load embedder, using hash fallback");
*embedder_guard = Some(self.fallback_embedder.clone());
*self.embedder_name.write() = "hash-fallback".to_string();
Ok(())
}
}
}
pub fn warm_reranker(&self) -> RerankerResult<()> {
if self.reranker.read().is_some() {
return Ok(());
}
let mut reranker_guard = self.reranker.write();
if reranker_guard.is_some() {
return Ok(());
}
let model_dir = FastEmbedReranker::default_model_dir(&self.data_dir);
info!(model_dir = %model_dir.display(), "Loading reranker");
match FastEmbedReranker::load_from_dir(&model_dir) {
Ok(reranker) => {
let id = reranker.id().to_string();
*reranker_guard = Some(Arc::new(reranker));
*self.reranker_name.write() = "ms-marco-MiniLM-L-6-v2".to_string();
info!(id = %id, "Reranker loaded");
Ok(())
}
Err(e) => {
warn!(error = %e, "Failed to load reranker, reranking unavailable");
Err(e)
}
}
}
pub fn embed_batch(&self, texts: &[String]) -> EmbedderResult<Vec<Vec<f32>>> {
if self.embedder.read().is_none() {
self.warm_embedder()?;
}
let embedder = self.embedder.read();
let embedder = embedder
.as_ref()
.ok_or_else(|| EmbedderError::EmbedderUnavailable {
model: "unknown".to_string(),
reason: "embedder not loaded".to_string(),
})?;
let text_refs: Vec<&str> = texts.iter().map(|s| s.as_str()).collect();
embedder.embed_batch_sync(&text_refs)
}
pub fn embed(&self, text: &str) -> EmbedderResult<Vec<f32>> {
if self.embedder.read().is_none() {
self.warm_embedder()?;
}
let embedder = self.embedder.read();
let embedder = embedder
.as_ref()
.ok_or_else(|| EmbedderError::EmbedderUnavailable {
model: "unknown".to_string(),
reason: "embedder not loaded".to_string(),
})?;
embedder.embed_sync(text)
}
pub fn rerank(&self, query: &str, documents: &[String]) -> RerankerResult<Vec<f32>> {
if self.reranker.read().is_none() {
self.warm_reranker()?;
}
let reranker = self.reranker.read();
let reranker = reranker
.as_ref()
.ok_or_else(|| RerankerError::RerankerUnavailable {
model: "reranker".to_string(),
})?;
let doc_refs: Vec<&str> = documents.iter().map(|s| s.as_str()).collect();
rerank_texts(&**reranker, query, &doc_refs)
}
pub fn unload_all(&self) {
*self.embedder.write() = None;
*self.reranker.write() = None;
*self.embedder_name.write() = "not-loaded".to_string();
*self.reranker_name.write() = "not-loaded".to_string();
info!("All models unloaded");
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_data_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures")
}
#[allow(dead_code)]
fn model_fixture_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("tests/fixtures/models/xenova-paraphrase-minilm-l3-v2-int8")
}
#[test]
fn test_model_manager_creation() {
let manager = ModelManager::new(&test_data_dir());
assert!(!manager.is_ready());
assert!(!manager.embedder_loaded());
assert!(!manager.reranker_loaded());
}
#[test]
fn test_embedder_fallback_on_missing_model() {
let manager = ModelManager::new(&PathBuf::from("/tmp/nonexistent"));
let result = manager.warm_embedder();
assert!(result.is_ok());
assert!(manager.embedder_loaded());
assert_eq!(manager.embedder_name(), "hash-fallback");
}
#[test]
fn test_embedder_dimension() {
let manager = ModelManager::new(&test_data_dir());
assert_eq!(manager.embedder_dimension(), 384);
}
#[test]
fn test_unload_all() {
let manager = ModelManager::new(&test_data_dir());
let _ = manager.warm_embedder();
assert!(manager.embedder_loaded());
manager.unload_all();
assert!(!manager.embedder_loaded());
assert!(!manager.reranker_loaded());
}
#[test]
fn test_embed_with_fallback() {
let manager = ModelManager::new(&PathBuf::from("/tmp/nonexistent"));
let result = manager.embed("test text");
assert!(result.is_ok());
let embedding = result.unwrap();
assert_eq!(embedding.len(), 384);
}
}