Skip to main content

oxirs_embed/model_zoo/
loader.rs

1//! SHA-256-verified checkpoint loading via the existing `ModelRepository`.
2//!
3//! # Load flow (Approach A — materialize-then-load)
4//!
5//! 1. Look up the manifest in the registry.
6//! 2. Optionally enforce the license gate.
7//! 3. Validate the `model_type` string against the set that
8//!    `persistence::ModelRepository` can dispatch.
9//! 4. Resolve the checkpoint source path (`file:///...` only in default
10//!    features).
11//! 5. Read the raw bytes from disk.
12//! 6. Verify SHA-256 (unless the manifest carries the sentinel `"PLACEHOLDER"`
13//!    which marks catalog entries whose checkpoint does not ship with the crate).
14//! 7. Write the materialised checkpoint into a temp subdirectory that looks
15//!    exactly like what `ModelRepository::new + scan_models` expects:
16//!    ```text
17//!    <base_dir>/<name>/
18//!        model.bin          ← the checkpoint bytes
19//!        model_type.json    ← JSON-encoded model_type string
20//!        metadata.json      ← minimal ModelMetadata
21//!    ```
22//! 8. Construct a `ModelRepository` rooted at `<base_dir>` and call
23//!    `load_model(&manifest.name)`.
24
25use std::io::{self, Write as _};
26use std::path::{Path, PathBuf};
27
28use sha2::{Digest, Sha256};
29use thiserror::Error;
30
31use crate::model_zoo::registry::ModelZoo;
32use crate::EmbeddingModel;
33
34// ---------------------------------------------------------------------------
35// Error type
36// ---------------------------------------------------------------------------
37
38/// Errors that can occur during model zoo operations.
39#[derive(Debug, Error)]
40pub enum ModelZooError {
41    /// The requested model name is not present in the registry.
42    #[error("model '{0}' not found in registry")]
43    NotFound(String),
44
45    /// The model's license requires explicit acceptance.
46    #[error("license '{license}' requires acceptance — set accept_license=true")]
47    LicenseNotAccepted { license: String },
48
49    /// The downloaded/read file's SHA-256 does not match the manifest.
50    #[error("SHA256 mismatch: expected {expected}, got {actual}")]
51    ChecksumMismatch { expected: String, actual: String },
52
53    /// The manifest declares a model type that the persistence layer cannot
54    /// handle.
55    #[error(
56        "unsupported model type '{0}' — supported: TransE, DistMult, ComplEx, RotatE, HoLE, GNNEmbedding"
57    )]
58    UnsupportedModelType(String),
59
60    /// Failed to parse a TOML manifest.
61    #[error("manifest parse error: {0}")]
62    ManifestParse(String),
63
64    /// I/O error (file read, temp dir creation, etc.).
65    #[error(transparent)]
66    Io(#[from] io::Error),
67
68    /// Error propagated from the underlying persistence layer.
69    #[error(transparent)]
70    Persistence(#[from] anyhow::Error),
71}
72
73// ---------------------------------------------------------------------------
74// Supported model types (mirrors persistence.rs dispatch table)
75// ---------------------------------------------------------------------------
76
77/// Model type strings that `persistence::ModelRepository::load_model` accepts.
78const SUPPORTED_MODEL_TYPES: &[&str] = &[
79    "TransE",
80    "DistMult",
81    "ComplEx",
82    "RotatE",
83    "HoLE",
84    "GNN",
85    "GNNEmbedding",
86];
87
88fn is_supported_model_type(model_type: &str) -> bool {
89    SUPPORTED_MODEL_TYPES.contains(&model_type)
90}
91
92// ---------------------------------------------------------------------------
93// Permissive license set
94// ---------------------------------------------------------------------------
95
96/// Licenses that are considered permissive (no acceptance gate required).
97const PERMISSIVE_LICENSES: &[&str] = &[
98    "Apache-2.0",
99    "MIT",
100    "MIT OR Apache-2.0",
101    "BSD-2-Clause",
102    "BSD-3-Clause",
103    "ISC",
104    "CC0-1.0",
105    "Unlicense",
106    "WTFPL",
107];
108
109fn is_permissive_license(license: &str) -> bool {
110    PERMISSIVE_LICENSES
111        .iter()
112        .any(|&l| license.eq_ignore_ascii_case(l))
113}
114
115// ---------------------------------------------------------------------------
116// ModelZooLoader
117// ---------------------------------------------------------------------------
118
119/// Loads pretrained (or synthetic-seed) models from the [`ModelZoo`] registry.
120///
121/// # Example
122///
123/// ```rust,no_run
124/// use oxirs_embed::model_zoo::{ModelZoo, ModelZooLoader};
125///
126/// let loader = ModelZooLoader::new(std::env::temp_dir()).accept_license();
127/// // Would load a real checkpoint when source resolves to a real file:
128/// // let model = loader.load("transe-fb15k237")?;
129/// ```
130pub struct ModelZooLoader {
131    zoo: &'static ModelZoo,
132    base_dir: PathBuf,
133    accept_license: bool,
134}
135
136impl ModelZooLoader {
137    /// Create a loader that stores materialised checkpoints under `base_dir`.
138    pub fn new(base_dir: impl Into<PathBuf>) -> Self {
139        Self {
140            zoo: ModelZoo::registry(),
141            base_dir: base_dir.into(),
142            accept_license: false,
143        }
144    }
145
146    /// Create a loader backed by a custom (non-global) `ModelZoo`.
147    pub fn with_zoo(zoo: &'static ModelZoo, base_dir: impl Into<PathBuf>) -> Self {
148        Self {
149            zoo,
150            base_dir: base_dir.into(),
151            accept_license: false,
152        }
153    }
154
155    /// Signal that the caller accepts all licenses (including restrictive ones
156    /// such as `CC-BY-NC-4.0`).
157    pub fn accept_license(mut self) -> Self {
158        self.accept_license = true;
159        self
160    }
161
162    /// Load the model identified by `name` from the registry.
163    ///
164    /// Returns a heap-allocated, type-erased [`EmbeddingModel`].
165    pub fn load(&self, name: &str) -> Result<Box<dyn EmbeddingModel>, ModelZooError> {
166        // 1. Look up manifest
167        let manifest = self
168            .zoo
169            .get(name)
170            .ok_or_else(|| ModelZooError::NotFound(name.to_string()))?;
171
172        // 2. License check
173        if !self.accept_license && !is_permissive_license(&manifest.license) {
174            return Err(ModelZooError::LicenseNotAccepted {
175                license: manifest.license.clone(),
176            });
177        }
178
179        // 3. Model type validation
180        if !is_supported_model_type(&manifest.model_type) {
181            return Err(ModelZooError::UnsupportedModelType(
182                manifest.model_type.clone(),
183            ));
184        }
185
186        // 4. Resolve source path
187        let source_path = resolve_source_path(&manifest.source, &self.base_dir)?;
188
189        // 5. Read bytes
190        let bytes = std::fs::read(&source_path)?;
191
192        // 6. SHA-256 verification (skip sentinel "PLACEHOLDER")
193        if manifest.sha256 != "PLACEHOLDER" {
194            Self::verify_sha256(&bytes, &manifest.sha256)?;
195        }
196
197        // 7. Materialise the repository structure in <base_dir>/<name>/
198        let model_dir = self.base_dir.join(&manifest.name);
199        materialise_checkpoint(&model_dir, &bytes, &manifest.model_type)?;
200
201        // 8. Construct ModelRepository and load
202        let repo = crate::persistence::ModelRepository::new(&self.base_dir)?;
203        let model = repo.load_model(&manifest.name)?;
204        Ok(model)
205    }
206
207    /// Verify that `data` hashes to `expected` (hex-encoded SHA-256).
208    fn verify_sha256(data: &[u8], expected: &str) -> Result<(), ModelZooError> {
209        let mut hasher = Sha256::new();
210        hasher.update(data);
211        let digest = hasher.finalize();
212        let actual = hex::encode(digest);
213        if actual != expected.to_lowercase() {
214            return Err(ModelZooError::ChecksumMismatch {
215                expected: expected.to_string(),
216                actual,
217            });
218        }
219        Ok(())
220    }
221}
222
223// ---------------------------------------------------------------------------
224// Helpers
225// ---------------------------------------------------------------------------
226
227/// Compute the hex-encoded SHA-256 digest of `data`.
228pub fn sha256_hex(data: &[u8]) -> String {
229    let mut hasher = Sha256::new();
230    hasher.update(data);
231    hex::encode(hasher.finalize())
232}
233
234/// Resolve a `file:///` URL into a [`PathBuf`].
235///
236/// When the literal path under `file:///` does not exist on the filesystem the
237/// function tries to resolve it relative to `base_dir` (stripping the
238/// `file:///` prefix).
239fn resolve_source_path(source: &str, base_dir: &Path) -> Result<PathBuf, ModelZooError> {
240    if let Some(rest) = source.strip_prefix("file:///") {
241        let absolute = Path::new("/").join(rest);
242        if absolute.exists() {
243            return Ok(absolute);
244        }
245        // Try relative to base_dir (e.g. "seeds/transe-fb15k237.ckpt")
246        let relative = base_dir.join(rest);
247        if relative.exists() {
248            return Ok(relative);
249        }
250        // Return the absolute path even if it doesn't exist yet — callers that
251        // want to test the path-not-found code path can catch the IO error.
252        return Ok(absolute);
253    }
254
255    if source.starts_with("https://") || source.starts_with("http://") {
256        return Err(ModelZooError::Io(io::Error::new(
257            io::ErrorKind::Unsupported,
258            "HTTP download requires the 'download' feature (not enabled in default build). \
259             Use a file:/// source or enable the feature.",
260        )));
261    }
262
263    Err(ModelZooError::Io(io::Error::new(
264        io::ErrorKind::InvalidInput,
265        format!("unrecognised source scheme: {source}"),
266    )))
267}
268
269/// Write the three files that `ModelRepository::scan_models` expects.
270///
271/// ```text
272/// <model_dir>/
273///     model.bin          ← checkpoint bytes
274///     model_type.json    ← JSON-encoded type string
275///     metadata.json      ← minimal ModelMetadata
276/// ```
277fn materialise_checkpoint(
278    model_dir: &Path,
279    bytes: &[u8],
280    model_type: &str,
281) -> Result<(), ModelZooError> {
282    std::fs::create_dir_all(model_dir)?;
283
284    // model.bin
285    let mut f = std::fs::File::create(model_dir.join("model.bin"))?;
286    f.write_all(bytes)?;
287
288    // model_type.json
289    let type_json = serde_json::to_string(model_type)
290        .map_err(|e| ModelZooError::Io(io::Error::new(io::ErrorKind::Other, e.to_string())))?;
291    std::fs::write(model_dir.join("model_type.json"), &type_json)?;
292
293    // metadata.json — use the same struct that ModelRepository expects
294    let metadata = crate::persistence::ModelMetadata::default();
295    let meta_json = serde_json::to_string_pretty(&metadata)
296        .map_err(|e| ModelZooError::Io(io::Error::new(io::ErrorKind::Other, e.to_string())))?;
297    std::fs::write(model_dir.join("metadata.json"), &meta_json)?;
298
299    Ok(())
300}
301
302// ---------------------------------------------------------------------------
303// Unit tests
304// ---------------------------------------------------------------------------
305
306#[cfg(test)]
307mod tests {
308    use super::*;
309
310    #[test]
311    fn test_sha256_hex_deterministic() {
312        let data = b"hello world";
313        let h1 = sha256_hex(data);
314        let h2 = sha256_hex(data);
315        assert_eq!(h1, h2);
316        assert_eq!(h1.len(), 64); // 32 bytes → 64 hex chars
317    }
318
319    #[test]
320    fn test_verify_sha256_ok() {
321        let data = b"test data for hashing";
322        let expected = sha256_hex(data);
323        // Should not error
324        ModelZooLoader::verify_sha256(data, &expected).expect("verification should pass");
325    }
326
327    #[test]
328    fn test_verify_sha256_mismatch() {
329        let data = b"test data for hashing";
330        let wrong_hash = "0".repeat(64);
331        let result = ModelZooLoader::verify_sha256(data, &wrong_hash);
332        assert!(result.is_err());
333        match result {
334            Err(ModelZooError::ChecksumMismatch { expected, actual }) => {
335                assert_eq!(expected, wrong_hash);
336                assert_ne!(actual, wrong_hash);
337            }
338            other => panic!("Expected ChecksumMismatch, got {other:?}"),
339        }
340    }
341
342    #[test]
343    fn test_is_supported_model_type() {
344        for ty in SUPPORTED_MODEL_TYPES {
345            assert!(is_supported_model_type(ty), "{ty} should be supported");
346        }
347        assert!(!is_supported_model_type("Bogus"));
348        assert!(!is_supported_model_type("TransE2"));
349    }
350
351    #[test]
352    fn test_is_permissive_license() {
353        assert!(is_permissive_license("Apache-2.0"));
354        assert!(is_permissive_license("MIT"));
355        assert!(is_permissive_license("MIT OR Apache-2.0"));
356        assert!(!is_permissive_license("CC-BY-NC-4.0"));
357        assert!(!is_permissive_license("Proprietary"));
358    }
359
360    #[test]
361    fn test_resolve_source_path_file_scheme() {
362        // A path that definitely exists
363        let base = std::env::temp_dir();
364        let existing = base.to_string_lossy().to_string();
365        let source = format!("file://{existing}");
366        let result = resolve_source_path(&source, &base);
367        // Should not error (path exists)
368        assert!(result.is_ok());
369    }
370
371    #[test]
372    fn test_resolve_source_path_http_error() {
373        let base = std::env::temp_dir();
374        let result = resolve_source_path("https://example.com/model.ckpt", &base);
375        assert!(result.is_err());
376        let msg = result.err().map(|e| e.to_string()).unwrap_or_default();
377        assert!(msg.contains("download") || msg.contains("HTTP"));
378    }
379
380    #[test]
381    fn test_resolve_source_path_unknown_scheme() {
382        let base = std::env::temp_dir();
383        let result = resolve_source_path("s3://bucket/model.ckpt", &base);
384        assert!(result.is_err());
385    }
386
387    #[test]
388    fn test_materialise_checkpoint_creates_files() {
389        let tmp = std::env::temp_dir().join("oxirs_materialise_test");
390        let model_dir = tmp.join("test_model");
391        let bytes = b"fake checkpoint bytes";
392
393        materialise_checkpoint(&model_dir, bytes, "TransE").expect("materialise ok");
394
395        assert!(model_dir.join("model.bin").exists());
396        assert!(model_dir.join("model_type.json").exists());
397        assert!(model_dir.join("metadata.json").exists());
398
399        // Verify model_type.json content
400        let raw = std::fs::read_to_string(model_dir.join("model_type.json")).expect("read");
401        let ty: String = serde_json::from_str(&raw).expect("parse");
402        assert_eq!(ty, "TransE");
403
404        std::fs::remove_dir_all(&tmp).ok();
405    }
406
407    #[test]
408    fn test_load_not_found() {
409        let loader = ModelZooLoader::new(std::env::temp_dir()).accept_license();
410        let result = loader.load("definitely-does-not-exist");
411        assert!(matches!(result, Err(ModelZooError::NotFound(_))));
412    }
413
414    #[test]
415    fn test_loader_license_refused() {
416        use crate::model_zoo::manifest::ModelManifest;
417        use crate::model_zoo::registry::ModelZoo;
418
419        // Build a custom zoo with a non-permissive license entry
420        let tmp_dir = std::env::temp_dir().join("oxirs_zoo_license_test");
421        std::fs::create_dir_all(&tmp_dir).expect("create temp dir");
422
423        let manifest = ModelManifest {
424            name: "restricted-model".to_string(),
425            model_type: "TransE".to_string(),
426            dataset: "TestDS".to_string(),
427            dimensions: 10,
428            entities: 5,
429            relations: 2,
430            sha256: "PLACEHOLDER".to_string(),
431            source: "file:///nonexistent.ckpt".to_string(),
432            license: "CC-BY-NC-4.0".to_string(),
433            citation: "Test".to_string(),
434            version: "1.0.0".to_string(),
435            created: "2026-05-01".to_string(),
436            notes: None,
437        };
438
439        let toml_str = toml::to_string(&manifest).expect("serialize");
440        std::fs::write(tmp_dir.join("restricted-model.toml"), toml_str).expect("write");
441
442        let zoo = ModelZoo::with_manifest_dir(&tmp_dir).expect("build zoo");
443        // Leak the zoo to get a 'static reference for testing
444        let zoo_ref: &'static ModelZoo = Box::leak(Box::new(zoo));
445
446        let loader = ModelZooLoader::with_zoo(zoo_ref, std::env::temp_dir());
447        let result = loader.load("restricted-model");
448
449        assert!(matches!(
450            result,
451            Err(ModelZooError::LicenseNotAccepted { .. })
452        ));
453
454        std::fs::remove_dir_all(&tmp_dir).ok();
455    }
456}