1use crate::error::{MemvidError, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
13pub enum ModelType {
14 SentenceTransformer,
16 Bert,
18 Custom(String),
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ModelInfo {
25 pub name: String,
27 pub model_type: ModelType,
29 pub local_path: Option<PathBuf>,
31 pub hub_id: Option<String>,
33 pub config: ModelConfig,
35}
36
37#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ModelConfig {
40 pub dimension: usize,
42 pub max_length: usize,
44 pub cached: bool,
46 pub params: HashMap<String, String>,
48}
49
50pub struct ModelManager {
52 cache_dir: PathBuf,
54 models: HashMap<String, ModelInfo>,
56}
57
58impl ModelManager {
59 pub fn new(cache_dir: Option<PathBuf>) -> Result<Self> {
61 let cache_dir = cache_dir.unwrap_or_else(|| {
62 let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
63 PathBuf::from(home)
64 .join(".cache")
65 .join("memvid-rs")
66 .join("models")
67 });
68
69 std::fs::create_dir_all(&cache_dir)?;
71
72 let mut manager = Self {
73 cache_dir,
74 models: HashMap::new(),
75 };
76
77 manager.register_default_models()?;
79
80 Ok(manager)
81 }
82
83 fn register_default_models(&mut self) -> Result<()> {
85 let mini_lm = ModelInfo {
87 name: "all-MiniLM-L6-v2".to_string(),
88 model_type: ModelType::SentenceTransformer,
89 local_path: None,
90 hub_id: Some("sentence-transformers/all-MiniLM-L6-v2".to_string()),
91 config: ModelConfig {
92 dimension: 384,
93 max_length: 384,
94 cached: false,
95 params: HashMap::new(),
96 },
97 };
98
99 self.models
101 .insert("all-MiniLM-L6-v2".to_string(), mini_lm.clone());
102 self.models.insert(
103 "sentence-transformers/all-MiniLM-L6-v2".to_string(),
104 mini_lm,
105 );
106
107 let bert_base = ModelInfo {
109 name: "bert-base-uncased".to_string(),
110 model_type: ModelType::Bert,
111 local_path: None,
112 hub_id: Some("bert-base-uncased".to_string()),
113 config: ModelConfig {
114 dimension: 768,
115 max_length: 512,
116 cached: false,
117 params: HashMap::new(),
118 },
119 };
120
121 self.models
122 .insert("bert-base-uncased".to_string(), bert_base);
123
124 Ok(())
125 }
126
127 pub fn get_model(&self, name: &str) -> Option<&ModelInfo> {
129 self.models.get(name)
130 }
131
132 pub fn list_models(&self) -> Vec<&ModelInfo> {
134 self.models.values().collect()
135 }
136
137 pub fn is_cached(&self, name: &str) -> bool {
139 if let Some(model) = self.models.get(name) {
140 model.config.cached && model.local_path.is_some()
141 } else {
142 false
143 }
144 }
145
146 pub fn cache_dir(&self) -> &PathBuf {
148 &self.cache_dir
149 }
150
151 pub async fn download_model(&mut self, name: &str) -> Result<PathBuf> {
153 let model = self
154 .models
155 .get_mut(name)
156 .ok_or_else(|| MemvidError::MachineLearning(format!("Model '{}' not found", name)))?;
157
158 if let Some(local_path) = &model.local_path {
159 if local_path.exists() && Self::validate_model_files_static(local_path)? {
160 log::info!("Model '{}' already cached at {:?}", name, local_path);
161 return Ok(local_path.clone());
162 }
163 }
164
165 let model_dir = self.cache_dir.join(name);
166 std::fs::create_dir_all(&model_dir)?;
167
168 if let Some(hub_id) = &model.hub_id {
169 log::info!(
170 "Downloading model '{}' from HuggingFace Hub: {}",
171 name,
172 hub_id
173 );
174
175 let files_to_download = vec![
177 "config.json",
178 "tokenizer.json",
179 "tokenizer_config.json",
180 "model.safetensors",
181 "vocab.txt", ];
183
184 let mut downloaded_any = false;
185 for file_name in files_to_download {
186 match Self::download_file_static(hub_id, file_name, &model_dir) {
187 Ok(_) => {
188 downloaded_any = true;
189 log::debug!("Downloaded {}/{}", hub_id, file_name);
190 }
191 Err(e) => {
192 log::warn!("Failed to download {}/{}: {}", hub_id, file_name, e);
194 }
195 }
196 }
197
198 if downloaded_any {
199 log::info!("Successfully downloaded model files for '{}'", name);
200 model.local_path = Some(model_dir.clone());
201 model.config.cached = true;
202 } else {
203 log::error!("Failed to download any files for model '{}'", name);
204 return Err(MemvidError::MachineLearning(format!(
205 "Failed to download model '{}'",
206 name
207 )));
208 }
209 } else {
210 log::warn!(
212 "No HuggingFace Hub ID for model '{}', creating placeholder",
213 name
214 );
215 model.local_path = Some(model_dir.clone());
216 model.config.cached = true;
217 }
218
219 Ok(model_dir)
220 }
221
222 fn download_file_static(repo_id: &str, filename: &str, target_dir: &Path) -> Result<()> {
224 use hf_hub::api::sync::Api;
225
226 let api = Api::new()
227 .map_err(|e| MemvidError::MachineLearning(format!("Failed to create HF API: {}", e)))?;
228
229 let repo = api.model(repo_id.to_string());
230 let target_path = target_dir.join(filename);
231
232 if target_path.exists() && target_path.metadata()?.len() > 0 {
234 return Ok(());
235 }
236
237 match repo.get(filename) {
238 Ok(downloaded_path) => {
239 std::fs::copy(&downloaded_path, &target_path).map_err(|e| {
241 MemvidError::MachineLearning(format!("Failed to copy file: {}", e))
242 })?;
243 log::debug!("Downloaded and copied {} to {:?}", filename, target_path);
244 Ok(())
245 }
246 Err(e) => Err(MemvidError::MachineLearning(format!(
247 "Failed to download {}: {}",
248 filename, e
249 ))),
250 }
251 }
252
253 fn validate_model_files_static(model_dir: &Path) -> Result<bool> {
255 let essential_files = vec!["config.json"];
256
257 for file_name in essential_files {
258 let file_path = model_dir.join(file_name);
259 if !file_path.exists() || file_path.metadata()?.len() == 0 {
260 return Ok(false);
261 }
262 }
263
264 Ok(true)
265 }
266
267 pub fn add_model(&mut self, model_info: ModelInfo) {
269 self.models.insert(model_info.name.clone(), model_info);
270 }
271
272 pub fn remove_model(&mut self, name: &str) -> Result<()> {
274 if let Some(model) = self.models.get_mut(name) {
275 if let Some(local_path) = &model.local_path {
276 if local_path.exists() {
277 std::fs::remove_dir_all(local_path)?;
278 }
279 }
280 model.local_path = None;
281 model.config.cached = false;
282 }
283 Ok(())
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use tempfile::TempDir;
291
292 #[tokio::test]
293 async fn test_model_manager_creation() {
294 let temp_dir = TempDir::new().unwrap();
295 let manager = ModelManager::new(Some(temp_dir.path().to_path_buf())).unwrap();
296
297 assert!(manager.cache_dir().exists());
298 assert!(manager.get_model("all-MiniLM-L6-v2").is_some());
299 assert!(manager.get_model("bert-base-uncased").is_some());
300 }
301
302 #[tokio::test]
303 async fn test_model_listing() {
304 let temp_dir = TempDir::new().unwrap();
305 let manager = ModelManager::new(Some(temp_dir.path().to_path_buf())).unwrap();
306
307 let models = manager.list_models();
308 assert!(models.len() >= 2); let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
311 assert!(model_names.contains(&"all-MiniLM-L6-v2"));
312 assert!(model_names.contains(&"bert-base-uncased"));
313 }
314
315 #[tokio::test]
316 async fn test_model_caching() {
317 let temp_dir = TempDir::new().unwrap();
318 let mut manager = ModelManager::new(Some(temp_dir.path().to_path_buf())).unwrap();
319
320 assert!(!manager.is_cached("all-MiniLM-L6-v2"));
321
322 let model_path = manager.download_model("all-MiniLM-L6-v2").await.unwrap();
323 assert!(model_path.exists());
324 assert!(manager.is_cached("all-MiniLM-L6-v2"));
325 }
326}