Skip to main content

memvid_rs/ml/
models.rs

1//! Model management for memvid-rs ML system
2//!
3//! This module handles downloading, caching, and managing ML models
4//! using pure Rust implementations.
5
6use crate::error::{MemvidError, Result};
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11/// Types of models supported
12#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
13pub enum ModelType {
14    /// Sentence transformer for embeddings
15    SentenceTransformer,
16    /// BERT-based models
17    Bert,
18    /// Custom models
19    Custom(String),
20}
21
22/// Model metadata
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ModelInfo {
25    /// Model name/identifier
26    pub name: String,
27    /// Model type
28    pub model_type: ModelType,
29    /// Local path to model files
30    pub local_path: Option<PathBuf>,
31    /// HuggingFace model hub identifier
32    pub hub_id: Option<String>,
33    /// Model configuration
34    pub config: ModelConfig,
35}
36
37/// Model configuration
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ModelConfig {
40    /// Model dimension
41    pub dimension: usize,
42    /// Maximum sequence length
43    pub max_length: usize,
44    /// Whether model is cached locally
45    pub cached: bool,
46    /// Additional parameters
47    pub params: HashMap<String, String>,
48}
49
50/// Model manager for downloading and caching models
51pub struct ModelManager {
52    /// Cache directory for models
53    cache_dir: PathBuf,
54    /// Available models
55    models: HashMap<String, ModelInfo>,
56}
57
58impl ModelManager {
59    /// Create new model manager
60    pub fn new(cache_dir: Option<PathBuf>) -> Result<Self> {
61        let cache_dir = cache_dir.unwrap_or_else(|| {
62            let home = std::env::var("HOME").unwrap_or_else(|_| ".".to_string());
63            PathBuf::from(home)
64                .join(".cache")
65                .join("memvid-rs")
66                .join("models")
67        });
68
69        // Create cache directory if it doesn't exist
70        std::fs::create_dir_all(&cache_dir)?;
71
72        let mut manager = Self {
73            cache_dir,
74            models: HashMap::new(),
75        };
76
77        // Register default models
78        manager.register_default_models()?;
79
80        Ok(manager)
81    }
82
83    /// Register default models
84    fn register_default_models(&mut self) -> Result<()> {
85        // Register all-MiniLM-L6-v2 model
86        let mini_lm = ModelInfo {
87            name: "all-MiniLM-L6-v2".to_string(),
88            model_type: ModelType::SentenceTransformer,
89            local_path: None,
90            hub_id: Some("sentence-transformers/all-MiniLM-L6-v2".to_string()),
91            config: ModelConfig {
92                dimension: 384,
93                max_length: 384,
94                cached: false,
95                params: HashMap::new(),
96            },
97        };
98
99        // Register with both short and full names for compatibility
100        self.models
101            .insert("all-MiniLM-L6-v2".to_string(), mini_lm.clone());
102        self.models.insert(
103            "sentence-transformers/all-MiniLM-L6-v2".to_string(),
104            mini_lm,
105        );
106
107        // Register other common models
108        let bert_base = ModelInfo {
109            name: "bert-base-uncased".to_string(),
110            model_type: ModelType::Bert,
111            local_path: None,
112            hub_id: Some("bert-base-uncased".to_string()),
113            config: ModelConfig {
114                dimension: 768,
115                max_length: 512,
116                cached: false,
117                params: HashMap::new(),
118            },
119        };
120
121        self.models
122            .insert("bert-base-uncased".to_string(), bert_base);
123
124        Ok(())
125    }
126
127    /// Get model info by name
128    pub fn get_model(&self, name: &str) -> Option<&ModelInfo> {
129        self.models.get(name)
130    }
131
132    /// List available models
133    pub fn list_models(&self) -> Vec<&ModelInfo> {
134        self.models.values().collect()
135    }
136
137    /// Check if model is cached locally
138    pub fn is_cached(&self, name: &str) -> bool {
139        if let Some(model) = self.models.get(name) {
140            model.config.cached && model.local_path.is_some()
141        } else {
142            false
143        }
144    }
145
146    /// Get cache directory
147    pub fn cache_dir(&self) -> &PathBuf {
148        &self.cache_dir
149    }
150
151    /// Download and cache model from HuggingFace Hub
152    pub async fn download_model(&mut self, name: &str) -> Result<PathBuf> {
153        let model = self
154            .models
155            .get_mut(name)
156            .ok_or_else(|| MemvidError::MachineLearning(format!("Model '{}' not found", name)))?;
157
158        if let Some(local_path) = &model.local_path {
159            if local_path.exists() && Self::validate_model_files_static(local_path)? {
160                log::info!("Model '{}' already cached at {:?}", name, local_path);
161                return Ok(local_path.clone());
162            }
163        }
164
165        let model_dir = self.cache_dir.join(name);
166        std::fs::create_dir_all(&model_dir)?;
167
168        if let Some(hub_id) = &model.hub_id {
169            log::info!(
170                "Downloading model '{}' from HuggingFace Hub: {}",
171                name,
172                hub_id
173            );
174
175            // Download essential model files
176            let files_to_download = vec![
177                "config.json",
178                "tokenizer.json",
179                "tokenizer_config.json",
180                "model.safetensors",
181                "vocab.txt", // For BERT-based models
182            ];
183
184            let mut downloaded_any = false;
185            for file_name in files_to_download {
186                match Self::download_file_static(hub_id, file_name, &model_dir) {
187                    Ok(_) => {
188                        downloaded_any = true;
189                        log::debug!("Downloaded {}/{}", hub_id, file_name);
190                    }
191                    Err(e) => {
192                        // Some files are optional, only warn
193                        log::warn!("Failed to download {}/{}: {}", hub_id, file_name, e);
194                    }
195                }
196            }
197
198            if downloaded_any {
199                log::info!("Successfully downloaded model files for '{}'", name);
200                model.local_path = Some(model_dir.clone());
201                model.config.cached = true;
202            } else {
203                log::error!("Failed to download any files for model '{}'", name);
204                return Err(MemvidError::MachineLearning(format!(
205                    "Failed to download model '{}'",
206                    name
207                )));
208            }
209        } else {
210            // Create placeholder for models without hub_id
211            log::warn!(
212                "No HuggingFace Hub ID for model '{}', creating placeholder",
213                name
214            );
215            model.local_path = Some(model_dir.clone());
216            model.config.cached = true;
217        }
218
219        Ok(model_dir)
220    }
221
222    /// Download a single file from HuggingFace Hub (static method to avoid borrowing issues)
223    fn download_file_static(repo_id: &str, filename: &str, target_dir: &Path) -> Result<()> {
224        use hf_hub::api::sync::Api;
225
226        let api = Api::new()
227            .map_err(|e| MemvidError::MachineLearning(format!("Failed to create HF API: {}", e)))?;
228
229        let repo = api.model(repo_id.to_string());
230        let target_path = target_dir.join(filename);
231
232        // Skip if file already exists and is valid
233        if target_path.exists() && target_path.metadata()?.len() > 0 {
234            return Ok(());
235        }
236
237        match repo.get(filename) {
238            Ok(downloaded_path) => {
239                // Copy from downloaded path to target path
240                std::fs::copy(&downloaded_path, &target_path).map_err(|e| {
241                    MemvidError::MachineLearning(format!("Failed to copy file: {}", e))
242                })?;
243                log::debug!("Downloaded and copied {} to {:?}", filename, target_path);
244                Ok(())
245            }
246            Err(e) => Err(MemvidError::MachineLearning(format!(
247                "Failed to download {}: {}",
248                filename, e
249            ))),
250        }
251    }
252
253    /// Validate that essential model files exist (static method)
254    fn validate_model_files_static(model_dir: &Path) -> Result<bool> {
255        let essential_files = vec!["config.json"];
256
257        for file_name in essential_files {
258            let file_path = model_dir.join(file_name);
259            if !file_path.exists() || file_path.metadata()?.len() == 0 {
260                return Ok(false);
261            }
262        }
263
264        Ok(true)
265    }
266
267    /// Add custom model
268    pub fn add_model(&mut self, model_info: ModelInfo) {
269        self.models.insert(model_info.name.clone(), model_info);
270    }
271
272    /// Remove model from cache
273    pub fn remove_model(&mut self, name: &str) -> Result<()> {
274        if let Some(model) = self.models.get_mut(name) {
275            if let Some(local_path) = &model.local_path {
276                if local_path.exists() {
277                    std::fs::remove_dir_all(local_path)?;
278                }
279            }
280            model.local_path = None;
281            model.config.cached = false;
282        }
283        Ok(())
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use tempfile::TempDir;
291
292    #[tokio::test]
293    async fn test_model_manager_creation() {
294        let temp_dir = TempDir::new().unwrap();
295        let manager = ModelManager::new(Some(temp_dir.path().to_path_buf())).unwrap();
296
297        assert!(manager.cache_dir().exists());
298        assert!(manager.get_model("all-MiniLM-L6-v2").is_some());
299        assert!(manager.get_model("bert-base-uncased").is_some());
300    }
301
302    #[tokio::test]
303    async fn test_model_listing() {
304        let temp_dir = TempDir::new().unwrap();
305        let manager = ModelManager::new(Some(temp_dir.path().to_path_buf())).unwrap();
306
307        let models = manager.list_models();
308        assert!(models.len() >= 2); // at least our default models
309
310        let model_names: Vec<&str> = models.iter().map(|m| m.name.as_str()).collect();
311        assert!(model_names.contains(&"all-MiniLM-L6-v2"));
312        assert!(model_names.contains(&"bert-base-uncased"));
313    }
314
315    #[tokio::test]
316    async fn test_model_caching() {
317        let temp_dir = TempDir::new().unwrap();
318        let mut manager = ModelManager::new(Some(temp_dir.path().to_path_buf())).unwrap();
319
320        assert!(!manager.is_cached("all-MiniLM-L6-v2"));
321
322        let model_path = manager.download_model("all-MiniLM-L6-v2").await.unwrap();
323        assert!(model_path.exists());
324        assert!(manager.is_cached("all-MiniLM-L6-v2"));
325    }
326}