use crate::error::{MemvidError, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
pub enum ModelType {
SentenceTransformer,
Bert,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelInfo {
pub name: String,
pub model_type: ModelType,
pub local_path: Option<PathBuf>,
pub hub_id: Option<String>,
pub config: ModelConfig,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub dimension: usize,
pub max_length: usize,
pub cached: bool,
pub params: HashMap<String, String>,
}
pub struct ModelManager {
cache_dir: PathBuf,
models: HashMap<String, ModelInfo>,
}
impl ModelManager {
pub fn new(cache_dir: Option<PathBuf>) -> Result<Self> {
let cache_dir = cache_dir.unwrap_or_else(|| {
let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
PathBuf::from(home)
.join(".cache")
.join("memvid-rs")
.join("models")
});
std::fs::create_dir_all(&cache_dir)?;
let mut manager = Self {
cache_dir,
models: HashMap::new(),
};
manager.register_default_models()?;
Ok(manager)
}
fn register_default_models(&mut self) -> Result<()> {
let mini_lm = ModelInfo {
name: "all-MiniLM-L6-v2".to_string(),
model_type: ModelType::SentenceTransformer,
local_path: None,
hub_id: Some("sentence-transformers/all-MiniLM-L6-v2".to_string()),
config: ModelConfig {
dimension: 384,
max_length: 384,
cached: false,
params: HashMap::new(),
},
};
self.models
.insert("all-MiniLM-L6-v2".to_string(), mini_lm.clone());
self.models.insert(
"sentence-transformers/all-MiniLM-L6-v2".to_string(),
mini_lm,
);
let bert_base = ModelInfo {
name: "bert-base-uncased".to_string(),
model_type: ModelType::Bert,
local_path: None,
hub_id: Some("bert-base-uncased".to_string()),
config: ModelConfig {
dimension: 768,
max_length: 512,
cached: false,
params: HashMap::new(),
},
};
self.models
.insert("bert-base-uncased".to_string(), bert_base);
Ok(())
}
pub fn get_model(&self, name: &str) -> Option<&ModelInfo> {
self.models.get(name)
}
pub fn list_models(&self) -> Vec<&ModelInfo> {
self.models.values().collect()
}
pub fn is_cached(&self, name: &str) -> bool {
if let Some(model) = self.models.get(name) {
model.config.cached && model.local_path.is_some()
} else {
false
}
}
pub fn cache_dir(&self) -> &PathBuf {
&self.cache_dir
}
pub async fn download_model(&mut self, name: &str) -> Result<PathBuf> {
let model = self
.models
.get_mut(name)
.ok_or_else(|| MemvidError::MachineLearning(format!("Model '{}' not found", name)))?;
if let Some(local_path) = &model.local_path {
if local_path.exists() && Self::validate_model_files_static(local_path)? {
log::info!("Model '{}' already cached at {:?}", name, local_path);
return Ok(local_path.clone());
}
}
let model_dir = self.cache_dir.join(name);
std::fs::create_dir_all(&model_dir)?;
if let Some(hub_id) = &model.hub_id {
log::info!(
"Downloading model '{}' from HuggingFace Hub: {}",
name,
hub_id
);
let files_to_download = vec![
"config.json",
"tokenizer.json",
"tokenizer_config.json",
"model.safetensors",
"vocab.txt", ];
let mut downloaded_any = false;
for file_name in files_to_download {
match Self::download_file_static(hub_id, file_name, &model_dir) {
Ok(_) => {
downloaded_any = true;
log::debug!("Downloaded {}/{}", hub_id, file_name);
}
Err(e) => {
log::warn!("Failed to download {}/{}: {}", hub_id, file_name, e);
}
}
}
if downloaded_any {
log::info!("Successfully downloaded model files for '{}'", name);
model.local_path = Some(model_dir.clone());
model.config.cached = true;
} else {
log::error!("Failed to download any files for model '{}'", name);
return Err(MemvidError::MachineLearning(format!(
"Failed to download model '{}'",
name
)));
}
} else {
log::warn!(
"No HuggingFace Hub ID for model '{}', creating placeholder",
name
);
model.local_path = Some(model_dir.clone());
model.config.cached = true;
}
Ok(model_dir)
}
fn download_file_static(repo_id: &str, filename: &str, target_dir: &Path) -> Result<()> {
use hf_hub::api::sync::Api;
let api = Api::new()
.map_err(|e| MemvidError::MachineLearning(format!("Failed to create HF API: {}", e)))?;
let repo = api.model(repo_id.to_string());
let target_path = target_dir.join(filename);
if target_path.exists() && target_path.metadata()?.len() > 0 {
return Ok(());
}
match repo.get(filename) {
Ok(downloaded_path) => {
std::fs::copy(&downloaded_path, &target_path).map_err(|e| {
MemvidError::MachineLearning(format!("Failed to copy file: {}", e))
})?;
log::debug!("Downloaded and copied {} to {:?}", filename, target_path);
Ok(())
}
Err(e) => Err(MemvidError::MachineLearning(format!(
"Failed to download {}: {}",
filename, e
))),
}
}
fn validate_model_files_static(model_dir: &Path) -> Result<bool> {
let essential_files = vec!["config.json"];
for file_name in essential_files {
let file_path = model_dir.join(file_name);
if !file_path.exists() || file_path.metadata()?.len() == 0 {
return Ok(false);
}
}
Ok(true)
}
pub fn add_model(&mut self, model_info: ModelInfo) {
self.models.insert(model_info.name.clone(), model_info);
}
pub fn remove_model(&mut self, name: &str) -> Result<()> {
if let Some(model) = self.models.get_mut(name) {
if let Some(local_path) = &model.local_path {
if local_path.exists() {
std::fs::remove_dir_all(local_path)?;
}
}
model.local_path = None;
model.config.cached = false;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
#[tokio::test]
async fn test_model_manager_creation() {
let temp_dir = TempDir::new().unwrap();
let manager = ModelManager::new(Some(temp_dir.path().to_path_buf())).unwrap();
assert!(manager.cache_dir().exists());
assert!(manager.get_model("all-MiniLM-L6-v2").is_some());
assert!(manager.get_model("bert-base-uncased").is_some());
}
#[tokio::test]
async fn test_model_listing() {
let temp_dir = TempDir::new().unwrap();
let manager = ModelManager::new(Some(temp_dir.path().to_path_buf())).unwrap();
let models = manager.list_models();
assert!(models.len() >= 2);
let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
assert!(model_names.contains(&"all-MiniLM-L6-v2"));
assert!(model_names.contains(&"bert-base-uncased"));
}
#[tokio::test]
async fn test_model_caching() {
let temp_dir = TempDir::new().unwrap();
let mut manager = ModelManager::new(Some(temp_dir.path().to_path_buf())).unwrap();
assert!(!manager.is_cached("all-MiniLM-L6-v2"));
let model_path = manager.download_model("all-MiniLM-L6-v2").await.unwrap();
assert!(model_path.exists());
assert!(manager.is_cached("all-MiniLM-L6-v2"));
}
}