Skip to main content

entrenar/sovereign/registry/
offline.rs

1//! Offline model registry implementation.
2
3use sha2::{Digest, Sha256};
4use std::fs;
5use std::io::Read;
6use std::path::{Path, PathBuf};
7
8use crate::error::{Error, Result};
9
10use super::manifest::RegistryManifest;
11use super::types::{ModelEntry, ModelSource};
12
13/// Offline model registry
14#[derive(Debug)]
15pub struct OfflineModelRegistry {
16    /// Root path for model storage
17    pub root_path: PathBuf,
18    /// Registry manifest
19    pub manifest: RegistryManifest,
20    /// Manifest file path
21    manifest_path: PathBuf,
22}
23
24impl OfflineModelRegistry {
25    /// Create a new registry at the given root path
26    pub fn new(root: PathBuf) -> Self {
27        let manifest_path = root.join("manifest.json");
28        let manifest = if manifest_path.exists() {
29            Self::load_manifest(&manifest_path).unwrap_or_default()
30        } else {
31            RegistryManifest::new()
32        };
33
34        Self { root_path: root, manifest, manifest_path }
35    }
36
37    /// Create registry at default location (~/.entrenar/models/)
38    pub fn default_location() -> Self {
39        let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
40        Self::new(home.join(".entrenar").join("models"))
41    }
42
43    /// Load manifest from file
44    fn load_manifest(path: &Path) -> Result<RegistryManifest> {
45        let content = fs::read_to_string(path)?;
46        serde_json::from_str(&content).map_err(|e| Error::Io(format!("Invalid manifest data: {e}")))
47    }
48
49    /// Save manifest to file
50    pub fn save_manifest(&self) -> Result<()> {
51        // Ensure parent directory exists
52        if let Some(parent) = self.manifest_path.parent() {
53            fs::create_dir_all(parent)?;
54        }
55
56        let content = serde_json::to_string_pretty(&self.manifest)
57            .map_err(|e| Error::Io(format!("Failed to serialize manifest: {e}")))?;
58        fs::write(&self.manifest_path, content)?;
59        Ok(())
60    }
61
62    /// Add a model entry to the registry
63    pub fn add_model(&mut self, entry: ModelEntry) {
64        self.manifest.add(entry);
65    }
66
67    /// Mirror a model from HuggingFace Hub (simulated for offline scenarios)
68    ///
69    /// In a real implementation, this would download the model.
70    /// For air-gapped scenarios, models are pre-downloaded and registered.
71    pub fn mirror_from_hub(&mut self, repo_id: &str) -> Result<ModelEntry> {
72        // Create model entry with HuggingFace source
73        let name = repo_id.split('/').next_back().unwrap_or(repo_id);
74        let local_path = self.root_path.join(name);
75
76        let entry = ModelEntry::new(
77            name,
78            "1.0",
79            "", // Checksum computed after download
80            0,  // Size computed after download
81            ModelSource::huggingface(repo_id),
82        )
83        .with_local_path(&local_path);
84
85        self.manifest.add(entry.clone());
86        Ok(entry)
87    }
88
89    /// Register a local model file
90    pub fn register_local(&mut self, name: &str, path: &Path) -> Result<ModelEntry> {
91        if !path.exists() {
92            return Err(Error::ConfigError(format!("Model file not found: {}", path.display())));
93        }
94
95        let metadata = fs::metadata(path)?;
96        let size_bytes = metadata.len();
97
98        // Compute SHA-256
99        let sha256 = Self::compute_file_sha256(path)?;
100
101        // Determine format from extension
102        let format = path.extension().and_then(|e| e.to_str()).map(String::from);
103
104        let entry = ModelEntry::new(name, "local", sha256, size_bytes, ModelSource::local(path))
105            .with_local_path(path);
106
107        let entry = if let Some(fmt) = format { entry.with_format(fmt) } else { entry };
108
109        self.manifest.add(entry.clone());
110        self.manifest.mark_synced();
111        self.save_manifest()?;
112
113        Ok(entry)
114    }
115
116    /// Compute SHA-256 hash of a file
117    fn compute_file_sha256(path: &Path) -> Result<String> {
118        let mut file = fs::File::open(path)?;
119        let mut hasher = Sha256::new();
120        let mut buffer = [0u8; 8192];
121
122        loop {
123            let bytes_read = file.read(&mut buffer)?;
124            if bytes_read == 0 {
125                break;
126            }
127            hasher.update(&buffer[..bytes_read]);
128        }
129
130        Ok(format!("{:x}", hasher.finalize()))
131    }
132
133    /// Load a model by name, returning its local path
134    pub fn load(&self, name: &str) -> Result<PathBuf> {
135        let entry = self
136            .manifest
137            .find(name)
138            .ok_or_else(|| Error::ConfigError(format!("Model not found: {name}")))?;
139
140        let path = entry
141            .local_path
142            .as_ref()
143            .ok_or_else(|| Error::ConfigError(format!("Model not available locally: {name}")))?;
144
145        if !path.exists() {
146            return Err(Error::ConfigError(format!("Model file missing: {}", path.display())));
147        }
148
149        Ok(path.clone())
150    }
151
152    /// Verify a model entry's checksum
153    pub fn verify(&self, entry: &ModelEntry) -> Result<bool> {
154        let path = entry
155            .local_path
156            .as_ref()
157            .ok_or_else(|| Error::ConfigError("Model has no local path".into()))?;
158
159        if !path.exists() {
160            return Ok(false);
161        }
162
163        if entry.sha256.is_empty() {
164            // No checksum to verify against
165            return Ok(true);
166        }
167
168        let computed = Self::compute_file_sha256(path)?;
169        Ok(computed == entry.sha256)
170    }
171
172    /// List all available (locally cached) models
173    pub fn list_available(&self) -> Vec<&ModelEntry> {
174        self.manifest.available()
175    }
176
177    /// List all models in registry
178    pub fn list_all(&self) -> &[ModelEntry] {
179        &self.manifest.models
180    }
181
182    /// Get a model entry by name
183    pub fn get(&self, name: &str) -> Option<&ModelEntry> {
184        self.manifest.find(name)
185    }
186
187    /// Remove a model from registry (does not delete files)
188    pub fn remove(&mut self, name: &str) -> Option<ModelEntry> {
189        let pos = self.manifest.models.iter().position(|m| m.name == name)?;
190        Some(self.manifest.models.remove(pos))
191    }
192
193    /// Get total size of all models
194    pub fn total_size(&self) -> u64 {
195        self.manifest.total_size_bytes()
196    }
197
198    /// Get root path
199    pub fn root(&self) -> &Path {
200        &self.root_path
201    }
202}
203
204#[cfg(test)]
205#[allow(clippy::unwrap_used)]
206mod tests {
207    use super::*;
208    use std::io::Write;
209    use std::sync::atomic::{AtomicU64, Ordering};
210
211    static TEST_COUNTER: AtomicU64 = AtomicU64::new(0);
212
213    fn temp_registry_dir() -> PathBuf {
214        let id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
215        let dir =
216            std::env::temp_dir().join(format!("entrenar_offline_test_{}_{id}", std::process::id()));
217        let _ = std::fs::remove_dir_all(&dir);
218        std::fs::create_dir_all(&dir).unwrap();
219        dir
220    }
221
222    #[test]
223    fn test_new_empty_registry() {
224        let dir = temp_registry_dir();
225        let reg = OfflineModelRegistry::new(dir.clone());
226        assert!(reg.manifest.models.is_empty());
227        assert_eq!(reg.root(), dir.as_path());
228        assert_eq!(reg.total_size(), 0);
229        assert!(reg.list_all().is_empty());
230        assert!(reg.list_available().is_empty());
231        let _ = std::fs::remove_dir_all(&dir);
232    }
233
234    #[test]
235    fn test_new_loads_existing_manifest() {
236        let dir = temp_registry_dir();
237        let manifest_path = dir.join("manifest.json");
238        let manifest = RegistryManifest::new();
239        let content = serde_json::to_string_pretty(&manifest).unwrap();
240        std::fs::write(&manifest_path, content).unwrap();
241
242        let reg = OfflineModelRegistry::new(dir.clone());
243        assert!(reg.manifest.models.is_empty());
244        let _ = std::fs::remove_dir_all(&dir);
245    }
246
247    #[test]
248    fn test_new_with_corrupted_manifest_falls_back() {
249        let dir = temp_registry_dir();
250        let manifest_path = dir.join("manifest.json");
251        std::fs::write(&manifest_path, "not valid json").unwrap();
252
253        let reg = OfflineModelRegistry::new(dir.clone());
254        assert!(reg.manifest.models.is_empty()); // falls back to default
255        let _ = std::fs::remove_dir_all(&dir);
256    }
257
258    #[test]
259    fn test_add_model() {
260        let dir = temp_registry_dir();
261        let mut reg = OfflineModelRegistry::new(dir.clone());
262        let entry =
263            ModelEntry::new("test-model", "1.0", "abc123", 1024, ModelSource::local("/tmp/model"));
264        reg.add_model(entry);
265        assert_eq!(reg.list_all().len(), 1);
266        assert_eq!(reg.list_all()[0].name, "test-model");
267        let _ = std::fs::remove_dir_all(&dir);
268    }
269
270    #[test]
271    fn test_get_model() {
272        let dir = temp_registry_dir();
273        let mut reg = OfflineModelRegistry::new(dir.clone());
274        let entry = ModelEntry::new("mymodel", "2.0", "sha", 2048, ModelSource::local("/tmp/m"));
275        reg.add_model(entry);
276
277        assert!(reg.get("mymodel").is_some());
278        assert_eq!(reg.get("mymodel").unwrap().version, "2.0");
279        assert!(reg.get("nonexistent").is_none());
280        let _ = std::fs::remove_dir_all(&dir);
281    }
282
283    #[test]
284    fn test_remove_model() {
285        let dir = temp_registry_dir();
286        let mut reg = OfflineModelRegistry::new(dir.clone());
287        let entry = ModelEntry::new("removeme", "1.0", "hash", 512, ModelSource::local("/tmp"));
288        reg.add_model(entry);
289        assert_eq!(reg.list_all().len(), 1);
290
291        let removed = reg.remove("removeme");
292        assert!(removed.is_some());
293        assert_eq!(removed.unwrap().name, "removeme");
294        assert!(reg.list_all().is_empty());
295
296        // Remove nonexistent
297        assert!(reg.remove("nonexistent").is_none());
298        let _ = std::fs::remove_dir_all(&dir);
299    }
300
301    #[test]
302    fn test_save_manifest() {
303        let dir = temp_registry_dir();
304        let mut reg = OfflineModelRegistry::new(dir.clone());
305        let entry = ModelEntry::new("saved", "1.0", "sha256", 100, ModelSource::local("/tmp"));
306        reg.add_model(entry);
307        reg.save_manifest().unwrap();
308
309        // Verify file was written
310        let manifest_path = dir.join("manifest.json");
311        assert!(manifest_path.exists());
312
313        // Load it back
314        let reg2 = OfflineModelRegistry::new(dir.clone());
315        assert_eq!(reg2.list_all().len(), 1);
316        assert_eq!(reg2.list_all()[0].name, "saved");
317        let _ = std::fs::remove_dir_all(&dir);
318    }
319
320    #[test]
321    fn test_mirror_from_hub() {
322        let dir = temp_registry_dir();
323        let mut reg = OfflineModelRegistry::new(dir.clone());
324        let entry = reg.mirror_from_hub("org/my-model").unwrap();
325        assert_eq!(entry.name, "my-model");
326        assert_eq!(reg.list_all().len(), 1);
327        let _ = std::fs::remove_dir_all(&dir);
328    }
329
330    #[test]
331    fn test_mirror_from_hub_no_slash() {
332        let dir = temp_registry_dir();
333        let mut reg = OfflineModelRegistry::new(dir.clone());
334        let entry = reg.mirror_from_hub("simple-model").unwrap();
335        assert_eq!(entry.name, "simple-model");
336        let _ = std::fs::remove_dir_all(&dir);
337    }
338
339    #[test]
340    fn test_register_local_file() {
341        let dir = temp_registry_dir();
342        let model_file = dir.join("model.safetensors");
343        let mut f = std::fs::File::create(&model_file).unwrap();
344        f.write_all(b"fake model data for testing").unwrap();
345
346        let mut reg = OfflineModelRegistry::new(dir.clone());
347        let entry = reg.register_local("local-model", &model_file).unwrap();
348        assert_eq!(entry.name, "local-model");
349        assert_eq!(entry.version, "local");
350        assert!(!entry.sha256.is_empty());
351        assert!(entry.size_bytes > 0);
352        assert_eq!(entry.format, Some("safetensors".to_string()));
353        assert!(reg.list_all().len() == 1);
354        let _ = std::fs::remove_dir_all(&dir);
355    }
356
357    #[test]
358    fn test_register_local_file_not_found() {
359        let dir = temp_registry_dir();
360        let mut reg = OfflineModelRegistry::new(dir.clone());
361        let result = reg.register_local("missing", Path::new("/tmp/nonexistent_model_xyz"));
362        assert!(result.is_err());
363        let _ = std::fs::remove_dir_all(&dir);
364    }
365
366    #[test]
367    fn test_register_local_no_extension() {
368        let dir = temp_registry_dir();
369        let model_file = dir.join("model_no_ext");
370        std::fs::write(&model_file, b"data").unwrap();
371
372        let mut reg = OfflineModelRegistry::new(dir.clone());
373        let entry = reg.register_local("noext", &model_file).unwrap();
374        assert!(entry.format.is_none());
375        let _ = std::fs::remove_dir_all(&dir);
376    }
377
378    #[test]
379    fn test_load_model_found() {
380        let dir = temp_registry_dir();
381        let model_file = dir.join("loadable.bin");
382        std::fs::write(&model_file, b"model content").unwrap();
383
384        let mut reg = OfflineModelRegistry::new(dir.clone());
385        let entry = ModelEntry::new("loadable", "1.0", "", 100, ModelSource::local(&model_file))
386            .with_local_path(&model_file);
387        reg.add_model(entry);
388
389        let path = reg.load("loadable").unwrap();
390        assert_eq!(path, model_file);
391        let _ = std::fs::remove_dir_all(&dir);
392    }
393
394    #[test]
395    fn test_load_model_not_found() {
396        let dir = temp_registry_dir();
397        let reg = OfflineModelRegistry::new(dir.clone());
398        assert!(reg.load("nonexistent").is_err());
399        let _ = std::fs::remove_dir_all(&dir);
400    }
401
402    #[test]
403    fn test_load_model_no_local_path() {
404        let dir = temp_registry_dir();
405        let mut reg = OfflineModelRegistry::new(dir.clone());
406        let entry = ModelEntry::new("no-path", "1.0", "", 0, ModelSource::huggingface("org/model"));
407        reg.add_model(entry);
408        assert!(reg.load("no-path").is_err());
409        let _ = std::fs::remove_dir_all(&dir);
410    }
411
412    #[test]
413    fn test_load_model_file_missing() {
414        let dir = temp_registry_dir();
415        let mut reg = OfflineModelRegistry::new(dir.clone());
416        let entry = ModelEntry::new("gone", "1.0", "", 0, ModelSource::local("/tmp/gone_xyz"))
417            .with_local_path("/tmp/gone_xyz");
418        reg.add_model(entry);
419        assert!(reg.load("gone").is_err());
420        let _ = std::fs::remove_dir_all(&dir);
421    }
422
423    #[test]
424    fn test_verify_no_local_path() {
425        let dir = temp_registry_dir();
426        let reg = OfflineModelRegistry::new(dir.clone());
427        let entry = ModelEntry::new("no-path", "1.0", "sha", 0, ModelSource::huggingface("org/m"));
428        assert!(reg.verify(&entry).is_err());
429        let _ = std::fs::remove_dir_all(&dir);
430    }
431
432    #[test]
433    fn test_verify_file_missing() {
434        let dir = temp_registry_dir();
435        let reg = OfflineModelRegistry::new(dir.clone());
436        let entry = ModelEntry::new("missing", "1.0", "sha", 0, ModelSource::local("/tmp/nope"))
437            .with_local_path("/tmp/nope_xyz_verify");
438        let result = reg.verify(&entry).unwrap();
439        assert!(!result); // file doesn't exist
440        let _ = std::fs::remove_dir_all(&dir);
441    }
442
443    #[test]
444    fn test_verify_empty_checksum() {
445        let dir = temp_registry_dir();
446        let model_file = dir.join("verify_empty.bin");
447        std::fs::write(&model_file, b"data").unwrap();
448
449        let reg = OfflineModelRegistry::new(dir.clone());
450        let entry = ModelEntry::new("verify-empty", "1.0", "", 0, ModelSource::local(&model_file))
451            .with_local_path(&model_file);
452        let result = reg.verify(&entry).unwrap();
453        assert!(result); // empty checksum always passes
454        let _ = std::fs::remove_dir_all(&dir);
455    }
456
457    #[test]
458    fn test_verify_checksum_match() {
459        let dir = temp_registry_dir();
460        let model_file = dir.join("verify_match.bin");
461        std::fs::write(&model_file, b"test content for sha256").unwrap();
462
463        // Compute actual sha256
464        let computed = OfflineModelRegistry::compute_file_sha256(&model_file).unwrap();
465
466        let reg = OfflineModelRegistry::new(dir.clone());
467        let entry =
468            ModelEntry::new("verify-match", "1.0", &computed, 0, ModelSource::local(&model_file))
469                .with_local_path(&model_file);
470        let result = reg.verify(&entry).unwrap();
471        assert!(result);
472        let _ = std::fs::remove_dir_all(&dir);
473    }
474
475    #[test]
476    fn test_verify_checksum_mismatch() {
477        let dir = temp_registry_dir();
478        let model_file = dir.join("verify_mismatch.bin");
479        std::fs::write(&model_file, b"some data").unwrap();
480
481        let reg = OfflineModelRegistry::new(dir.clone());
482        let entry =
483            ModelEntry::new("mismatch", "1.0", "wrong_hash", 0, ModelSource::local(&model_file))
484                .with_local_path(&model_file);
485        let result = reg.verify(&entry).unwrap();
486        assert!(!result);
487        let _ = std::fs::remove_dir_all(&dir);
488    }
489
490    #[test]
491    fn test_total_size() {
492        let dir = temp_registry_dir();
493        let mut reg = OfflineModelRegistry::new(dir.clone());
494        reg.add_model(ModelEntry::new("m1", "1.0", "", 100, ModelSource::local("/tmp")));
495        reg.add_model(ModelEntry::new("m2", "1.0", "", 200, ModelSource::local("/tmp")));
496        reg.add_model(ModelEntry::new("m3", "1.0", "", 300, ModelSource::local("/tmp")));
497        assert_eq!(reg.total_size(), 600);
498        let _ = std::fs::remove_dir_all(&dir);
499    }
500}