entrenar/sovereign/registry/
offline.rs1use sha2::{Digest, Sha256};
4use std::fs;
5use std::io::Read;
6use std::path::{Path, PathBuf};
7
8use crate::error::{Error, Result};
9
10use super::manifest::RegistryManifest;
11use super::types::{ModelEntry, ModelSource};
12
13#[derive(Debug)]
15pub struct OfflineModelRegistry {
16 pub root_path: PathBuf,
18 pub manifest: RegistryManifest,
20 manifest_path: PathBuf,
22}
23
24impl OfflineModelRegistry {
25 pub fn new(root: PathBuf) -> Self {
27 let manifest_path = root.join("manifest.json");
28 let manifest = if manifest_path.exists() {
29 Self::load_manifest(&manifest_path).unwrap_or_default()
30 } else {
31 RegistryManifest::new()
32 };
33
34 Self { root_path: root, manifest, manifest_path }
35 }
36
37 pub fn default_location() -> Self {
39 let home = dirs::home_dir().unwrap_or_else(|| PathBuf::from("."));
40 Self::new(home.join(".entrenar").join("models"))
41 }
42
43 fn load_manifest(path: &Path) -> Result<RegistryManifest> {
45 let content = fs::read_to_string(path)?;
46 serde_json::from_str(&content).map_err(|e| Error::Io(format!("Invalid manifest data: {e}")))
47 }
48
49 pub fn save_manifest(&self) -> Result<()> {
51 if let Some(parent) = self.manifest_path.parent() {
53 fs::create_dir_all(parent)?;
54 }
55
56 let content = serde_json::to_string_pretty(&self.manifest)
57 .map_err(|e| Error::Io(format!("Failed to serialize manifest: {e}")))?;
58 fs::write(&self.manifest_path, content)?;
59 Ok(())
60 }
61
62 pub fn add_model(&mut self, entry: ModelEntry) {
64 self.manifest.add(entry);
65 }
66
67 pub fn mirror_from_hub(&mut self, repo_id: &str) -> Result<ModelEntry> {
72 let name = repo_id.split('/').next_back().unwrap_or(repo_id);
74 let local_path = self.root_path.join(name);
75
76 let entry = ModelEntry::new(
77 name,
78 "1.0",
79 "", 0, ModelSource::huggingface(repo_id),
82 )
83 .with_local_path(&local_path);
84
85 self.manifest.add(entry.clone());
86 Ok(entry)
87 }
88
89 pub fn register_local(&mut self, name: &str, path: &Path) -> Result<ModelEntry> {
91 if !path.exists() {
92 return Err(Error::ConfigError(format!("Model file not found: {}", path.display())));
93 }
94
95 let metadata = fs::metadata(path)?;
96 let size_bytes = metadata.len();
97
98 let sha256 = Self::compute_file_sha256(path)?;
100
101 let format = path.extension().and_then(|e| e.to_str()).map(String::from);
103
104 let entry = ModelEntry::new(name, "local", sha256, size_bytes, ModelSource::local(path))
105 .with_local_path(path);
106
107 let entry = if let Some(fmt) = format { entry.with_format(fmt) } else { entry };
108
109 self.manifest.add(entry.clone());
110 self.manifest.mark_synced();
111 self.save_manifest()?;
112
113 Ok(entry)
114 }
115
116 fn compute_file_sha256(path: &Path) -> Result<String> {
118 let mut file = fs::File::open(path)?;
119 let mut hasher = Sha256::new();
120 let mut buffer = [0u8; 8192];
121
122 loop {
123 let bytes_read = file.read(&mut buffer)?;
124 if bytes_read == 0 {
125 break;
126 }
127 hasher.update(&buffer[..bytes_read]);
128 }
129
130 Ok(format!("{:x}", hasher.finalize()))
131 }
132
133 pub fn load(&self, name: &str) -> Result<PathBuf> {
135 let entry = self
136 .manifest
137 .find(name)
138 .ok_or_else(|| Error::ConfigError(format!("Model not found: {name}")))?;
139
140 let path = entry
141 .local_path
142 .as_ref()
143 .ok_or_else(|| Error::ConfigError(format!("Model not available locally: {name}")))?;
144
145 if !path.exists() {
146 return Err(Error::ConfigError(format!("Model file missing: {}", path.display())));
147 }
148
149 Ok(path.clone())
150 }
151
152 pub fn verify(&self, entry: &ModelEntry) -> Result<bool> {
154 let path = entry
155 .local_path
156 .as_ref()
157 .ok_or_else(|| Error::ConfigError("Model has no local path".into()))?;
158
159 if !path.exists() {
160 return Ok(false);
161 }
162
163 if entry.sha256.is_empty() {
164 return Ok(true);
166 }
167
168 let computed = Self::compute_file_sha256(path)?;
169 Ok(computed == entry.sha256)
170 }
171
172 pub fn list_available(&self) -> Vec<&ModelEntry> {
174 self.manifest.available()
175 }
176
177 pub fn list_all(&self) -> &[ModelEntry] {
179 &self.manifest.models
180 }
181
182 pub fn get(&self, name: &str) -> Option<&ModelEntry> {
184 self.manifest.find(name)
185 }
186
187 pub fn remove(&mut self, name: &str) -> Option<ModelEntry> {
189 let pos = self.manifest.models.iter().position(|m| m.name == name)?;
190 Some(self.manifest.models.remove(pos))
191 }
192
193 pub fn total_size(&self) -> u64 {
195 self.manifest.total_size_bytes()
196 }
197
198 pub fn root(&self) -> &Path {
200 &self.root_path
201 }
202}
203
204#[cfg(test)]
205#[allow(clippy::unwrap_used)]
206mod tests {
207 use super::*;
208 use std::io::Write;
209 use std::sync::atomic::{AtomicU64, Ordering};
210
211 static TEST_COUNTER: AtomicU64 = AtomicU64::new(0);
212
213 fn temp_registry_dir() -> PathBuf {
214 let id = TEST_COUNTER.fetch_add(1, Ordering::SeqCst);
215 let dir =
216 std::env::temp_dir().join(format!("entrenar_offline_test_{}_{id}", std::process::id()));
217 let _ = std::fs::remove_dir_all(&dir);
218 std::fs::create_dir_all(&dir).unwrap();
219 dir
220 }
221
222 #[test]
223 fn test_new_empty_registry() {
224 let dir = temp_registry_dir();
225 let reg = OfflineModelRegistry::new(dir.clone());
226 assert!(reg.manifest.models.is_empty());
227 assert_eq!(reg.root(), dir.as_path());
228 assert_eq!(reg.total_size(), 0);
229 assert!(reg.list_all().is_empty());
230 assert!(reg.list_available().is_empty());
231 let _ = std::fs::remove_dir_all(&dir);
232 }
233
234 #[test]
235 fn test_new_loads_existing_manifest() {
236 let dir = temp_registry_dir();
237 let manifest_path = dir.join("manifest.json");
238 let manifest = RegistryManifest::new();
239 let content = serde_json::to_string_pretty(&manifest).unwrap();
240 std::fs::write(&manifest_path, content).unwrap();
241
242 let reg = OfflineModelRegistry::new(dir.clone());
243 assert!(reg.manifest.models.is_empty());
244 let _ = std::fs::remove_dir_all(&dir);
245 }
246
247 #[test]
248 fn test_new_with_corrupted_manifest_falls_back() {
249 let dir = temp_registry_dir();
250 let manifest_path = dir.join("manifest.json");
251 std::fs::write(&manifest_path, "not valid json").unwrap();
252
253 let reg = OfflineModelRegistry::new(dir.clone());
254 assert!(reg.manifest.models.is_empty()); let _ = std::fs::remove_dir_all(&dir);
256 }
257
258 #[test]
259 fn test_add_model() {
260 let dir = temp_registry_dir();
261 let mut reg = OfflineModelRegistry::new(dir.clone());
262 let entry =
263 ModelEntry::new("test-model", "1.0", "abc123", 1024, ModelSource::local("/tmp/model"));
264 reg.add_model(entry);
265 assert_eq!(reg.list_all().len(), 1);
266 assert_eq!(reg.list_all()[0].name, "test-model");
267 let _ = std::fs::remove_dir_all(&dir);
268 }
269
270 #[test]
271 fn test_get_model() {
272 let dir = temp_registry_dir();
273 let mut reg = OfflineModelRegistry::new(dir.clone());
274 let entry = ModelEntry::new("mymodel", "2.0", "sha", 2048, ModelSource::local("/tmp/m"));
275 reg.add_model(entry);
276
277 assert!(reg.get("mymodel").is_some());
278 assert_eq!(reg.get("mymodel").unwrap().version, "2.0");
279 assert!(reg.get("nonexistent").is_none());
280 let _ = std::fs::remove_dir_all(&dir);
281 }
282
283 #[test]
284 fn test_remove_model() {
285 let dir = temp_registry_dir();
286 let mut reg = OfflineModelRegistry::new(dir.clone());
287 let entry = ModelEntry::new("removeme", "1.0", "hash", 512, ModelSource::local("/tmp"));
288 reg.add_model(entry);
289 assert_eq!(reg.list_all().len(), 1);
290
291 let removed = reg.remove("removeme");
292 assert!(removed.is_some());
293 assert_eq!(removed.unwrap().name, "removeme");
294 assert!(reg.list_all().is_empty());
295
296 assert!(reg.remove("nonexistent").is_none());
298 let _ = std::fs::remove_dir_all(&dir);
299 }
300
301 #[test]
302 fn test_save_manifest() {
303 let dir = temp_registry_dir();
304 let mut reg = OfflineModelRegistry::new(dir.clone());
305 let entry = ModelEntry::new("saved", "1.0", "sha256", 100, ModelSource::local("/tmp"));
306 reg.add_model(entry);
307 reg.save_manifest().unwrap();
308
309 let manifest_path = dir.join("manifest.json");
311 assert!(manifest_path.exists());
312
313 let reg2 = OfflineModelRegistry::new(dir.clone());
315 assert_eq!(reg2.list_all().len(), 1);
316 assert_eq!(reg2.list_all()[0].name, "saved");
317 let _ = std::fs::remove_dir_all(&dir);
318 }
319
320 #[test]
321 fn test_mirror_from_hub() {
322 let dir = temp_registry_dir();
323 let mut reg = OfflineModelRegistry::new(dir.clone());
324 let entry = reg.mirror_from_hub("org/my-model").unwrap();
325 assert_eq!(entry.name, "my-model");
326 assert_eq!(reg.list_all().len(), 1);
327 let _ = std::fs::remove_dir_all(&dir);
328 }
329
330 #[test]
331 fn test_mirror_from_hub_no_slash() {
332 let dir = temp_registry_dir();
333 let mut reg = OfflineModelRegistry::new(dir.clone());
334 let entry = reg.mirror_from_hub("simple-model").unwrap();
335 assert_eq!(entry.name, "simple-model");
336 let _ = std::fs::remove_dir_all(&dir);
337 }
338
339 #[test]
340 fn test_register_local_file() {
341 let dir = temp_registry_dir();
342 let model_file = dir.join("model.safetensors");
343 let mut f = std::fs::File::create(&model_file).unwrap();
344 f.write_all(b"fake model data for testing").unwrap();
345
346 let mut reg = OfflineModelRegistry::new(dir.clone());
347 let entry = reg.register_local("local-model", &model_file).unwrap();
348 assert_eq!(entry.name, "local-model");
349 assert_eq!(entry.version, "local");
350 assert!(!entry.sha256.is_empty());
351 assert!(entry.size_bytes > 0);
352 assert_eq!(entry.format, Some("safetensors".to_string()));
353 assert!(reg.list_all().len() == 1);
354 let _ = std::fs::remove_dir_all(&dir);
355 }
356
357 #[test]
358 fn test_register_local_file_not_found() {
359 let dir = temp_registry_dir();
360 let mut reg = OfflineModelRegistry::new(dir.clone());
361 let result = reg.register_local("missing", Path::new("/tmp/nonexistent_model_xyz"));
362 assert!(result.is_err());
363 let _ = std::fs::remove_dir_all(&dir);
364 }
365
366 #[test]
367 fn test_register_local_no_extension() {
368 let dir = temp_registry_dir();
369 let model_file = dir.join("model_no_ext");
370 std::fs::write(&model_file, b"data").unwrap();
371
372 let mut reg = OfflineModelRegistry::new(dir.clone());
373 let entry = reg.register_local("noext", &model_file).unwrap();
374 assert!(entry.format.is_none());
375 let _ = std::fs::remove_dir_all(&dir);
376 }
377
378 #[test]
379 fn test_load_model_found() {
380 let dir = temp_registry_dir();
381 let model_file = dir.join("loadable.bin");
382 std::fs::write(&model_file, b"model content").unwrap();
383
384 let mut reg = OfflineModelRegistry::new(dir.clone());
385 let entry = ModelEntry::new("loadable", "1.0", "", 100, ModelSource::local(&model_file))
386 .with_local_path(&model_file);
387 reg.add_model(entry);
388
389 let path = reg.load("loadable").unwrap();
390 assert_eq!(path, model_file);
391 let _ = std::fs::remove_dir_all(&dir);
392 }
393
394 #[test]
395 fn test_load_model_not_found() {
396 let dir = temp_registry_dir();
397 let reg = OfflineModelRegistry::new(dir.clone());
398 assert!(reg.load("nonexistent").is_err());
399 let _ = std::fs::remove_dir_all(&dir);
400 }
401
402 #[test]
403 fn test_load_model_no_local_path() {
404 let dir = temp_registry_dir();
405 let mut reg = OfflineModelRegistry::new(dir.clone());
406 let entry = ModelEntry::new("no-path", "1.0", "", 0, ModelSource::huggingface("org/model"));
407 reg.add_model(entry);
408 assert!(reg.load("no-path").is_err());
409 let _ = std::fs::remove_dir_all(&dir);
410 }
411
412 #[test]
413 fn test_load_model_file_missing() {
414 let dir = temp_registry_dir();
415 let mut reg = OfflineModelRegistry::new(dir.clone());
416 let entry = ModelEntry::new("gone", "1.0", "", 0, ModelSource::local("/tmp/gone_xyz"))
417 .with_local_path("/tmp/gone_xyz");
418 reg.add_model(entry);
419 assert!(reg.load("gone").is_err());
420 let _ = std::fs::remove_dir_all(&dir);
421 }
422
423 #[test]
424 fn test_verify_no_local_path() {
425 let dir = temp_registry_dir();
426 let reg = OfflineModelRegistry::new(dir.clone());
427 let entry = ModelEntry::new("no-path", "1.0", "sha", 0, ModelSource::huggingface("org/m"));
428 assert!(reg.verify(&entry).is_err());
429 let _ = std::fs::remove_dir_all(&dir);
430 }
431
432 #[test]
433 fn test_verify_file_missing() {
434 let dir = temp_registry_dir();
435 let reg = OfflineModelRegistry::new(dir.clone());
436 let entry = ModelEntry::new("missing", "1.0", "sha", 0, ModelSource::local("/tmp/nope"))
437 .with_local_path("/tmp/nope_xyz_verify");
438 let result = reg.verify(&entry).unwrap();
439 assert!(!result); let _ = std::fs::remove_dir_all(&dir);
441 }
442
443 #[test]
444 fn test_verify_empty_checksum() {
445 let dir = temp_registry_dir();
446 let model_file = dir.join("verify_empty.bin");
447 std::fs::write(&model_file, b"data").unwrap();
448
449 let reg = OfflineModelRegistry::new(dir.clone());
450 let entry = ModelEntry::new("verify-empty", "1.0", "", 0, ModelSource::local(&model_file))
451 .with_local_path(&model_file);
452 let result = reg.verify(&entry).unwrap();
453 assert!(result); let _ = std::fs::remove_dir_all(&dir);
455 }
456
457 #[test]
458 fn test_verify_checksum_match() {
459 let dir = temp_registry_dir();
460 let model_file = dir.join("verify_match.bin");
461 std::fs::write(&model_file, b"test content for sha256").unwrap();
462
463 let computed = OfflineModelRegistry::compute_file_sha256(&model_file).unwrap();
465
466 let reg = OfflineModelRegistry::new(dir.clone());
467 let entry =
468 ModelEntry::new("verify-match", "1.0", &computed, 0, ModelSource::local(&model_file))
469 .with_local_path(&model_file);
470 let result = reg.verify(&entry).unwrap();
471 assert!(result);
472 let _ = std::fs::remove_dir_all(&dir);
473 }
474
475 #[test]
476 fn test_verify_checksum_mismatch() {
477 let dir = temp_registry_dir();
478 let model_file = dir.join("verify_mismatch.bin");
479 std::fs::write(&model_file, b"some data").unwrap();
480
481 let reg = OfflineModelRegistry::new(dir.clone());
482 let entry =
483 ModelEntry::new("mismatch", "1.0", "wrong_hash", 0, ModelSource::local(&model_file))
484 .with_local_path(&model_file);
485 let result = reg.verify(&entry).unwrap();
486 assert!(!result);
487 let _ = std::fs::remove_dir_all(&dir);
488 }
489
490 #[test]
491 fn test_total_size() {
492 let dir = temp_registry_dir();
493 let mut reg = OfflineModelRegistry::new(dir.clone());
494 reg.add_model(ModelEntry::new("m1", "1.0", "", 100, ModelSource::local("/tmp")));
495 reg.add_model(ModelEntry::new("m2", "1.0", "", 200, ModelSource::local("/tmp")));
496 reg.add_model(ModelEntry::new("m3", "1.0", "", 300, ModelSource::local("/tmp")));
497 assert_eq!(reg.total_size(), 600);
498 let _ = std::fs::remove_dir_all(&dir);
499 }
500}