1use std::io::{self, Write as _};
26use std::path::{Path, PathBuf};
27
28use sha2::{Digest, Sha256};
29use thiserror::Error;
30
31use crate::model_zoo::registry::ModelZoo;
32use crate::EmbeddingModel;
33
34#[derive(Debug, Error)]
40pub enum ModelZooError {
41 #[error("model '{0}' not found in registry")]
43 NotFound(String),
44
45 #[error("license '{license}' requires acceptance — set accept_license=true")]
47 LicenseNotAccepted { license: String },
48
49 #[error("SHA256 mismatch: expected {expected}, got {actual}")]
51 ChecksumMismatch { expected: String, actual: String },
52
53 #[error(
56 "unsupported model type '{0}' — supported: TransE, DistMult, ComplEx, RotatE, HoLE, GNNEmbedding"
57 )]
58 UnsupportedModelType(String),
59
60 #[error("manifest parse error: {0}")]
62 ManifestParse(String),
63
64 #[error(transparent)]
66 Io(#[from] io::Error),
67
68 #[error(transparent)]
70 Persistence(#[from] anyhow::Error),
71}
72
73const SUPPORTED_MODEL_TYPES: &[&str] = &[
79 "TransE",
80 "DistMult",
81 "ComplEx",
82 "RotatE",
83 "HoLE",
84 "GNN",
85 "GNNEmbedding",
86];
87
88fn is_supported_model_type(model_type: &str) -> bool {
89 SUPPORTED_MODEL_TYPES.contains(&model_type)
90}
91
92const PERMISSIVE_LICENSES: &[&str] = &[
98 "Apache-2.0",
99 "MIT",
100 "MIT OR Apache-2.0",
101 "BSD-2-Clause",
102 "BSD-3-Clause",
103 "ISC",
104 "CC0-1.0",
105 "Unlicense",
106 "WTFPL",
107];
108
109fn is_permissive_license(license: &str) -> bool {
110 PERMISSIVE_LICENSES
111 .iter()
112 .any(|&l| license.eq_ignore_ascii_case(l))
113}
114
115pub struct ModelZooLoader {
131 zoo: &'static ModelZoo,
132 base_dir: PathBuf,
133 accept_license: bool,
134}
135
136impl ModelZooLoader {
137 pub fn new(base_dir: impl Into<PathBuf>) -> Self {
139 Self {
140 zoo: ModelZoo::registry(),
141 base_dir: base_dir.into(),
142 accept_license: false,
143 }
144 }
145
146 pub fn with_zoo(zoo: &'static ModelZoo, base_dir: impl Into<PathBuf>) -> Self {
148 Self {
149 zoo,
150 base_dir: base_dir.into(),
151 accept_license: false,
152 }
153 }
154
155 pub fn accept_license(mut self) -> Self {
158 self.accept_license = true;
159 self
160 }
161
162 pub fn load(&self, name: &str) -> Result<Box<dyn EmbeddingModel>, ModelZooError> {
166 let manifest = self
168 .zoo
169 .get(name)
170 .ok_or_else(|| ModelZooError::NotFound(name.to_string()))?;
171
172 if !self.accept_license && !is_permissive_license(&manifest.license) {
174 return Err(ModelZooError::LicenseNotAccepted {
175 license: manifest.license.clone(),
176 });
177 }
178
179 if !is_supported_model_type(&manifest.model_type) {
181 return Err(ModelZooError::UnsupportedModelType(
182 manifest.model_type.clone(),
183 ));
184 }
185
186 let source_path = resolve_source_path(&manifest.source, &self.base_dir)?;
188
189 let bytes = std::fs::read(&source_path)?;
191
192 if manifest.sha256 != "PLACEHOLDER" {
194 Self::verify_sha256(&bytes, &manifest.sha256)?;
195 }
196
197 let model_dir = self.base_dir.join(&manifest.name);
199 materialise_checkpoint(&model_dir, &bytes, &manifest.model_type)?;
200
201 let repo = crate::persistence::ModelRepository::new(&self.base_dir)?;
203 let model = repo.load_model(&manifest.name)?;
204 Ok(model)
205 }
206
207 fn verify_sha256(data: &[u8], expected: &str) -> Result<(), ModelZooError> {
209 let mut hasher = Sha256::new();
210 hasher.update(data);
211 let digest = hasher.finalize();
212 let actual = hex::encode(digest);
213 if actual != expected.to_lowercase() {
214 return Err(ModelZooError::ChecksumMismatch {
215 expected: expected.to_string(),
216 actual,
217 });
218 }
219 Ok(())
220 }
221}
222
223pub fn sha256_hex(data: &[u8]) -> String {
229 let mut hasher = Sha256::new();
230 hasher.update(data);
231 hex::encode(hasher.finalize())
232}
233
234fn resolve_source_path(source: &str, base_dir: &Path) -> Result<PathBuf, ModelZooError> {
240 if let Some(rest) = source.strip_prefix("file:///") {
241 let absolute = Path::new("/").join(rest);
242 if absolute.exists() {
243 return Ok(absolute);
244 }
245 let relative = base_dir.join(rest);
247 if relative.exists() {
248 return Ok(relative);
249 }
250 return Ok(absolute);
253 }
254
255 if source.starts_with("https://") || source.starts_with("http://") {
256 return Err(ModelZooError::Io(io::Error::new(
257 io::ErrorKind::Unsupported,
258 "HTTP download requires the 'download' feature (not enabled in default build). \
259 Use a file:/// source or enable the feature.",
260 )));
261 }
262
263 Err(ModelZooError::Io(io::Error::new(
264 io::ErrorKind::InvalidInput,
265 format!("unrecognised source scheme: {source}"),
266 )))
267}
268
269fn materialise_checkpoint(
278 model_dir: &Path,
279 bytes: &[u8],
280 model_type: &str,
281) -> Result<(), ModelZooError> {
282 std::fs::create_dir_all(model_dir)?;
283
284 let mut f = std::fs::File::create(model_dir.join("model.bin"))?;
286 f.write_all(bytes)?;
287
288 let type_json = serde_json::to_string(model_type)
290 .map_err(|e| ModelZooError::Io(io::Error::new(io::ErrorKind::Other, e.to_string())))?;
291 std::fs::write(model_dir.join("model_type.json"), &type_json)?;
292
293 let metadata = crate::persistence::ModelMetadata::default();
295 let meta_json = serde_json::to_string_pretty(&metadata)
296 .map_err(|e| ModelZooError::Io(io::Error::new(io::ErrorKind::Other, e.to_string())))?;
297 std::fs::write(model_dir.join("metadata.json"), &meta_json)?;
298
299 Ok(())
300}
301
302#[cfg(test)]
307mod tests {
308 use super::*;
309
310 #[test]
311 fn test_sha256_hex_deterministic() {
312 let data = b"hello world";
313 let h1 = sha256_hex(data);
314 let h2 = sha256_hex(data);
315 assert_eq!(h1, h2);
316 assert_eq!(h1.len(), 64); }
318
319 #[test]
320 fn test_verify_sha256_ok() {
321 let data = b"test data for hashing";
322 let expected = sha256_hex(data);
323 ModelZooLoader::verify_sha256(data, &expected).expect("verification should pass");
325 }
326
327 #[test]
328 fn test_verify_sha256_mismatch() {
329 let data = b"test data for hashing";
330 let wrong_hash = "0".repeat(64);
331 let result = ModelZooLoader::verify_sha256(data, &wrong_hash);
332 assert!(result.is_err());
333 match result {
334 Err(ModelZooError::ChecksumMismatch { expected, actual }) => {
335 assert_eq!(expected, wrong_hash);
336 assert_ne!(actual, wrong_hash);
337 }
338 other => panic!("Expected ChecksumMismatch, got {other:?}"),
339 }
340 }
341
342 #[test]
343 fn test_is_supported_model_type() {
344 for ty in SUPPORTED_MODEL_TYPES {
345 assert!(is_supported_model_type(ty), "{ty} should be supported");
346 }
347 assert!(!is_supported_model_type("Bogus"));
348 assert!(!is_supported_model_type("TransE2"));
349 }
350
351 #[test]
352 fn test_is_permissive_license() {
353 assert!(is_permissive_license("Apache-2.0"));
354 assert!(is_permissive_license("MIT"));
355 assert!(is_permissive_license("MIT OR Apache-2.0"));
356 assert!(!is_permissive_license("CC-BY-NC-4.0"));
357 assert!(!is_permissive_license("Proprietary"));
358 }
359
360 #[test]
361 fn test_resolve_source_path_file_scheme() {
362 let base = std::env::temp_dir();
364 let existing = base.to_string_lossy().to_string();
365 let source = format!("file://{existing}");
366 let result = resolve_source_path(&source, &base);
367 assert!(result.is_ok());
369 }
370
371 #[test]
372 fn test_resolve_source_path_http_error() {
373 let base = std::env::temp_dir();
374 let result = resolve_source_path("https://example.com/model.ckpt", &base);
375 assert!(result.is_err());
376 let msg = result.err().map(|e| e.to_string()).unwrap_or_default();
377 assert!(msg.contains("download") || msg.contains("HTTP"));
378 }
379
380 #[test]
381 fn test_resolve_source_path_unknown_scheme() {
382 let base = std::env::temp_dir();
383 let result = resolve_source_path("s3://bucket/model.ckpt", &base);
384 assert!(result.is_err());
385 }
386
387 #[test]
388 fn test_materialise_checkpoint_creates_files() {
389 let tmp = std::env::temp_dir().join("oxirs_materialise_test");
390 let model_dir = tmp.join("test_model");
391 let bytes = b"fake checkpoint bytes";
392
393 materialise_checkpoint(&model_dir, bytes, "TransE").expect("materialise ok");
394
395 assert!(model_dir.join("model.bin").exists());
396 assert!(model_dir.join("model_type.json").exists());
397 assert!(model_dir.join("metadata.json").exists());
398
399 let raw = std::fs::read_to_string(model_dir.join("model_type.json")).expect("read");
401 let ty: String = serde_json::from_str(&raw).expect("parse");
402 assert_eq!(ty, "TransE");
403
404 std::fs::remove_dir_all(&tmp).ok();
405 }
406
407 #[test]
408 fn test_load_not_found() {
409 let loader = ModelZooLoader::new(std::env::temp_dir()).accept_license();
410 let result = loader.load("definitely-does-not-exist");
411 assert!(matches!(result, Err(ModelZooError::NotFound(_))));
412 }
413
414 #[test]
415 fn test_loader_license_refused() {
416 use crate::model_zoo::manifest::ModelManifest;
417 use crate::model_zoo::registry::ModelZoo;
418
419 let tmp_dir = std::env::temp_dir().join("oxirs_zoo_license_test");
421 std::fs::create_dir_all(&tmp_dir).expect("create temp dir");
422
423 let manifest = ModelManifest {
424 name: "restricted-model".to_string(),
425 model_type: "TransE".to_string(),
426 dataset: "TestDS".to_string(),
427 dimensions: 10,
428 entities: 5,
429 relations: 2,
430 sha256: "PLACEHOLDER".to_string(),
431 source: "file:///nonexistent.ckpt".to_string(),
432 license: "CC-BY-NC-4.0".to_string(),
433 citation: "Test".to_string(),
434 version: "1.0.0".to_string(),
435 created: "2026-05-01".to_string(),
436 notes: None,
437 };
438
439 let toml_str = toml::to_string(&manifest).expect("serialize");
440 std::fs::write(tmp_dir.join("restricted-model.toml"), toml_str).expect("write");
441
442 let zoo = ModelZoo::with_manifest_dir(&tmp_dir).expect("build zoo");
443 let zoo_ref: &'static ModelZoo = Box::leak(Box::new(zoo));
445
446 let loader = ModelZooLoader::with_zoo(zoo_ref, std::env::temp_dir());
447 let result = loader.load("restricted-model");
448
449 assert!(matches!(
450 result,
451 Err(ModelZooError::LicenseNotAccepted { .. })
452 ));
453
454 std::fs::remove_dir_all(&tmp_dir).ok();
455 }
456}