memvid_core/
models.rs

1use std::fs;
2use std::fs::File;
3use std::io::{BufReader, Read};
4use std::path::{Component, Path, PathBuf};
5
6#[cfg(feature = "vec")]
7use std::time::Instant;
8
9use serde::Deserialize;
10
11use crate::error::{MemvidError, Result};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ModelVerificationStatus {
15    Ok,
16    Warn,
17    Fail,
18}
19
20impl ModelVerificationStatus {
21    fn elevate(&mut self, other: ModelVerificationStatus) {
22        use ModelVerificationStatus::{Fail, Ok, Warn};
23        match (*self, other) {
24            (Fail, _) | (_, Fail) => *self = Fail,
25            (Warn, _) | (_, Warn) => {
26                if matches!(*self, Ok) {
27                    *self = Warn;
28                }
29            }
30            _ => {}
31        }
32    }
33}
34
35#[derive(Debug, Clone)]
36pub struct ModelVerification {
37    pub digest: String,
38    pub dims: Option<u32>,
39    pub quant: Option<String>,
40    pub context_length: Option<u32>,
41    pub status: ModelVerificationStatus,
42    pub load_latency_ms: Option<u128>,
43    pub path: PathBuf,
44    pub warnings: Vec<String>,
45    pub errors: Vec<String>,
46}
47
48impl ModelVerification {
49    fn from_error(path: PathBuf, err: MemvidError) -> Self {
50        let digest = digest_from_dir_name(&path).unwrap_or_else(|| "sha256:unknown".to_string());
51        Self {
52            digest,
53            dims: None,
54            quant: None,
55            context_length: None,
56            status: ModelVerificationStatus::Fail,
57            load_latency_ms: None,
58            path,
59            warnings: Vec::new(),
60            errors: vec![err.to_string()],
61        }
62    }
63}
64
65#[derive(Debug, Clone)]
66pub struct ModelVerifyOptions {
67    pub run_onnx_smoke: bool,
68}
69
70impl Default for ModelVerifyOptions {
71    fn default() -> Self {
72        Self {
73            run_onnx_smoke: cfg!(feature = "vec"),
74        }
75    }
76}
77
78#[derive(Debug, Clone, Deserialize)]
79#[serde(default, rename_all = "kebab-case")]
80pub struct ModelManifest {
81    pub schema_version: u32,
82    pub digest: String,
83    pub dims: u32,
84    pub quant: Option<String>,
85    pub context_length: Option<u32>,
86    pub files: Vec<ModelManifestEntry>,
87    pub metadata: serde_json::Value,
88}
89
90impl Default for ModelManifest {
91    fn default() -> Self {
92        Self {
93            schema_version: 1,
94            digest: String::new(),
95            dims: 0,
96            quant: None,
97            context_length: None,
98            files: Vec::new(),
99            metadata: serde_json::Value::Null,
100        }
101    }
102}
103
104#[derive(Debug, Clone, Deserialize)]
105#[serde(default, rename_all = "kebab-case")]
106pub struct ModelManifestEntry {
107    pub path: String,
108    pub sha256: String,
109    pub optional: bool,
110    pub roles: Vec<String>,
111    pub kind: Option<String>,
112}
113
114impl Default for ModelManifestEntry {
115    fn default() -> Self {
116        Self {
117            path: String::new(),
118            sha256: String::new(),
119            optional: false,
120            roles: Vec::new(),
121            kind: None,
122        }
123    }
124}
125
126pub fn verify_models(root: &Path, options: &ModelVerifyOptions) -> Result<Vec<ModelVerification>> {
127    if !root.exists() {
128        return Ok(Vec::new());
129    }
130
131    let mut dirs: Vec<PathBuf> = fs::read_dir(root)?
132        .filter_map(|entry| entry.ok())
133        .filter_map(|entry| {
134            let path = entry.path();
135            entry
136                .file_type()
137                .ok()
138                .filter(|ft| ft.is_dir())
139                .and_then(|_| digest_from_dir_name(&path).map(|_| path))
140        })
141        .collect();
142    dirs.sort();
143
144    let mut reports = Vec::with_capacity(dirs.len());
145    for dir in dirs {
146        match verify_model_dir(&dir, options) {
147            Ok(report) => reports.push(report),
148            Err(err) => reports.push(ModelVerification::from_error(dir, err)),
149        }
150    }
151
152    reports.sort_by(|a, b| a.digest.cmp(&b.digest));
153    Ok(reports)
154}
155
156pub fn verify_model_dir(dir: &Path, options: &ModelVerifyOptions) -> Result<ModelVerification> {
157    let manifest_path = dir.join("manifest.json");
158    if !manifest_path.exists() {
159        return Err(MemvidError::ModelIntegrity {
160            reason: format!("missing manifest.json in {}", dir.display()).into_boxed_str(),
161        });
162    }
163
164    let manifest_data = fs::read_to_string(&manifest_path)?;
165    let manifest: ModelManifest =
166        serde_json::from_str(&manifest_data).map_err(|err| MemvidError::ModelManifestInvalid {
167            reason: format!(
168                "failed to parse manifest {}: {err}",
169                manifest_path.display()
170            )
171            .into_boxed_str(),
172        })?;
173
174    if manifest.digest.trim().is_empty() {
175        return Err(MemvidError::ModelManifestInvalid {
176            reason: "manifest digest is empty".into(),
177        });
178    }
179
180    if manifest.dims == 0 {
181        return Err(MemvidError::ModelManifestInvalid {
182            reason: "embedding dimensions must be > 0".into(),
183        });
184    }
185
186    let manifest_digest_hex = normalize_sha256(&manifest.digest, "manifest digest")?;
187    let dir_digest_hex = digest_from_dir_name(dir).ok_or_else(|| MemvidError::ModelIntegrity {
188        reason: format!(
189            "directory {} is not named as sha256-<digest>",
190            dir.display()
191        )
192        .into_boxed_str(),
193    })?;
194
195    let dir_digest_hex = normalize_sha256(&dir_digest_hex, "directory digest")?;
196    if manifest_digest_hex != dir_digest_hex {
197        return Err(MemvidError::ModelIntegrity {
198            reason: format!(
199                "manifest digest sha256:{manifest_digest_hex} does not match directory sha256:{dir_digest_hex}"
200            )
201            .into_boxed_str(),
202        });
203    }
204
205    let digest = format!("sha256:{manifest_digest_hex}");
206
207    let mut status = ModelVerificationStatus::Ok;
208    let mut warnings = Vec::new();
209    let mut errors = Vec::new();
210    let mut load_latency_ms = None;
211
212    for entry in &manifest.files {
213        validate_entry(entry)?;
214        let expected_hex = normalize_sha256(&entry.sha256, &entry.path)?;
215        let resolved_path = resolve_entry_path(dir, &entry.path)?;
216        if !resolved_path.exists() {
217            if entry.optional {
218                warnings.push(format!("optional file missing: {}", entry.path));
219                status.elevate(ModelVerificationStatus::Warn);
220            } else {
221                errors.push(format!("required file missing: {}", entry.path));
222                status.elevate(ModelVerificationStatus::Fail);
223            }
224            continue;
225        }
226
227        let actual_hex = compute_sha256_hex(&resolved_path)?;
228        if actual_hex != expected_hex {
229            errors.push(format!(
230                "checksum mismatch for {} (expected {}, got {})",
231                entry.path, expected_hex, actual_hex
232            ));
233            status.elevate(ModelVerificationStatus::Fail);
234        }
235    }
236
237    if status != ModelVerificationStatus::Fail && options.run_onnx_smoke {
238        if let Some(weights_entry) = select_weights_entry(&manifest) {
239            let weights_path = resolve_entry_path(dir, &weights_entry.path)?;
240            if weights_path.exists() {
241                match run_onnx_smoke_test(&weights_path) {
242                    Ok(latency) => {
243                        load_latency_ms = Some(latency.max(1));
244                    }
245                    Err(OnnxSmokeError::FeatureUnavailable(feature)) => {
246                        warnings.push(format!(
247                            "feature '{feature}' not enabled; skipping ONNX smoke test"
248                        ));
249                        status.elevate(ModelVerificationStatus::Warn);
250                    }
251                    Err(OnnxSmokeError::Engine(err)) => {
252                        errors.push(format!("ONNX initialisation failed: {err}"));
253                        status.elevate(ModelVerificationStatus::Fail);
254                    }
255                }
256            }
257        } else {
258            warnings.push(
259                "manifest does not declare a model .onnx file; skipping ONNX smoke test".into(),
260            );
261            status.elevate(ModelVerificationStatus::Warn);
262        }
263    }
264
265    let resolved_dir = fs::canonicalize(dir).unwrap_or_else(|_| dir.to_path_buf());
266
267    Ok(ModelVerification {
268        digest,
269        dims: Some(manifest.dims),
270        quant: manifest.quant.clone(),
271        context_length: manifest.context_length,
272        status,
273        load_latency_ms,
274        path: resolved_dir,
275        warnings,
276        errors,
277    })
278}
279
280fn validate_entry(entry: &ModelManifestEntry) -> Result<()> {
281    if entry.path.trim().is_empty() {
282        return Err(MemvidError::ModelManifestInvalid {
283            reason: "file entry path is empty".into(),
284        });
285    }
286    if entry.path.contains("\\") {
287        return Err(MemvidError::ModelManifestInvalid {
288            reason: format!("file entry path must use forward slashes: {}", entry.path)
289                .into_boxed_str(),
290        });
291    }
292    if entry.sha256.trim().is_empty() {
293        return Err(MemvidError::ModelManifestInvalid {
294            reason: format!("file entry '{}' missing sha256", entry.path).into_boxed_str(),
295        });
296    }
297    Ok(())
298}
299
300fn resolve_entry_path(base: &Path, relative: &str) -> Result<PathBuf> {
301    let path = Path::new(relative);
302    if path.is_absolute() {
303        return Err(MemvidError::ModelManifestInvalid {
304            reason: format!("file entry '{}' must be relative", relative).into_boxed_str(),
305        });
306    }
307
308    for component in path.components() {
309        if matches!(component, Component::ParentDir) {
310            return Err(MemvidError::ModelManifestInvalid {
311                reason: format!("file entry '{}' attempts directory traversal", relative)
312                    .into_boxed_str(),
313            });
314        }
315    }
316
317    Ok(base.join(path))
318}
319
320fn normalize_sha256(value: &str, context: &str) -> Result<String> {
321    let trimmed = value.trim();
322    let trimmed = trimmed
323        .strip_prefix("sha256:")
324        .or_else(|| trimmed.strip_prefix("sha256-"))
325        .unwrap_or(trimmed);
326    if trimmed.len() != 64 || !trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
327        return Err(MemvidError::ModelManifestInvalid {
328            reason: format!("invalid sha256 value for {context}").into_boxed_str(),
329        });
330    }
331    Ok(trimmed.to_ascii_lowercase())
332}
333
334fn digest_from_dir_name(path: &Path) -> Option<String> {
335    let name = path.file_name()?.to_str()?;
336    name.strip_prefix("sha256-").map(|rest| rest.to_string())
337}
338
339fn compute_sha256_hex(path: &Path) -> Result<String> {
340    use sha2::{Digest, Sha256};
341
342    let file = File::open(path)?;
343    let mut reader = BufReader::new(file);
344    let mut hasher = Sha256::new();
345    let mut buffer = [0u8; 8192];
346    loop {
347        let read = reader.read(&mut buffer)?;
348        if read == 0 {
349            break;
350        }
351        hasher.update(&buffer[..read]);
352    }
353    Ok(hex::encode(hasher.finalize()))
354}
355
356fn select_weights_entry<'a>(manifest: &'a ModelManifest) -> Option<&'a ModelManifestEntry> {
357    if let Some(quant) = manifest.quant.as_deref() {
358        if let Some(entry) = manifest
359            .files
360            .iter()
361            .find(|entry| entry.path.ends_with(".onnx") && entry.path.contains(quant))
362        {
363            return Some(entry);
364        }
365    }
366
367    manifest
368        .files
369        .iter()
370        .find(|entry| entry.roles.iter().any(|role| role == "weights"))
371        .or_else(|| {
372            manifest
373                .files
374                .iter()
375                .find(|entry| entry.kind.as_deref() == Some("onnx"))
376        })
377        .or_else(|| {
378            manifest
379                .files
380                .iter()
381                .find(|entry| entry.path.ends_with(".onnx"))
382        })
383}
384
385#[allow(dead_code)]
386#[derive(Debug)]
387enum OnnxSmokeError {
388    FeatureUnavailable(&'static str),
389    Engine(String),
390}
391
392#[cfg(feature = "vec")]
393fn run_onnx_smoke_test(path: &Path) -> std::result::Result<u128, OnnxSmokeError> {
394    use ort::session::Session;
395
396    let builder = Session::builder().map_err(|err| OnnxSmokeError::Engine(err.to_string()))?;
397    let start = Instant::now();
398    let session = builder
399        .commit_from_file(path)
400        .map_err(|err| OnnxSmokeError::Engine(err.to_string()))?;
401    drop(session);
402    let elapsed = start.elapsed().as_millis();
403    Ok(elapsed.max(1))
404}
405
406#[cfg(not(feature = "vec"))]
407fn run_onnx_smoke_test(_path: &Path) -> std::result::Result<u128, OnnxSmokeError> {
408    Err(OnnxSmokeError::FeatureUnavailable("vec"))
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414    use sha2::{Digest, Sha256};
415    use tempfile::tempdir;
416
417    fn write_manifest(path: &Path, value: &serde_json::Value) -> Result<()> {
418        let bytes =
419            serde_json::to_vec_pretty(value).map_err(|err| MemvidError::ModelManifestInvalid {
420                reason: format!("failed to encode manifest: {err}").into_boxed_str(),
421            })?;
422        fs::write(path, bytes)?;
423        Ok(())
424    }
425
426    fn write_file(path: &Path, contents: &[u8]) -> Result<()> {
427        if let Some(parent) = path.parent() {
428            fs::create_dir_all(parent)?;
429        }
430        fs::write(path, contents)?;
431        Ok(())
432    }
433
434    fn checksum_hex(data: &[u8]) -> String {
435        let mut hasher = Sha256::new();
436        hasher.update(data);
437        hex::encode(hasher.finalize())
438    }
439
440    #[test]
441    fn verify_model_success() -> Result<()> {
442        let temp = tempdir()?;
443        let digest = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
444        let model_dir = temp.path().join(format!("sha256-{digest}"));
445        fs::create_dir_all(&model_dir)?;
446
447        let model_bytes = b"ONNX";
448        let tokenizer_bytes = b"{}";
449
450        write_file(
451            &model_dir.join("models/encoder/model_int8.onnx"),
452            model_bytes,
453        )?;
454        write_file(
455            &model_dir.join("models/encoder/tokenizer.json"),
456            tokenizer_bytes,
457        )?;
458
459        let manifest = serde_json::json!({
460            "digest": format!("sha256:{digest}"),
461            "dims": 384,
462            "quant": "int8",
463            "files": [
464                {
465                    "path": "models/encoder/model_int8.onnx",
466                    "sha256": checksum_hex(model_bytes),
467                    "roles": ["weights"],
468                },
469                {
470                    "path": "models/encoder/tokenizer.json",
471                    "sha256": checksum_hex(tokenizer_bytes),
472                }
473            ]
474        });
475        write_manifest(&model_dir.join("manifest.json"), &manifest)?;
476
477        let options = ModelVerifyOptions {
478            run_onnx_smoke: false,
479        };
480        let report = verify_model_dir(&model_dir, &options)?;
481        assert_eq!(report.digest, format!("sha256:{digest}"));
482        assert_eq!(report.status, ModelVerificationStatus::Ok);
483        assert_eq!(report.dims, Some(384));
484        assert!(report.errors.is_empty());
485        Ok(())
486    }
487
488    #[test]
489    fn verify_model_missing_optional_warns() -> Result<()> {
490        let temp = tempdir()?;
491        let digest = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
492        let model_dir = temp.path().join(format!("sha256-{digest}"));
493        fs::create_dir_all(&model_dir)?;
494
495        let model_bytes = b"ONNX";
496
497        write_file(&model_dir.join("models/model.onnx"), model_bytes)?;
498
499        let manifest = serde_json::json!({
500            "digest": format!("sha256:{digest}"),
501            "dims": 256,
502            "files": [
503                {
504                    "path": "models/model.onnx",
505                    "sha256": checksum_hex(model_bytes),
506                    "roles": ["weights"],
507                },
508                {
509                    "path": "models/tokenizer.json",
510                    "sha256": checksum_hex(b"missing"),
511                    "optional": true
512                }
513            ]
514        });
515        write_manifest(&model_dir.join("manifest.json"), &manifest)?;
516
517        let options = ModelVerifyOptions {
518            run_onnx_smoke: false,
519        };
520        let report = verify_model_dir(&model_dir, &options)?;
521        assert_eq!(report.status, ModelVerificationStatus::Warn);
522        assert!(report.errors.is_empty());
523        assert_eq!(report.warnings.len(), 1);
524        Ok(())
525    }
526
527    #[test]
528    fn verify_models_directory_listing() -> Result<()> {
529        let temp = tempdir()?;
530        let digest = "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc";
531        let model_dir = temp.path().join(format!("sha256-{digest}"));
532        fs::create_dir_all(&model_dir)?;
533
534        let model_bytes = b"ONNX";
535        fs::write(model_dir.join("model.onnx"), model_bytes)?;
536
537        let manifest = serde_json::json!({
538            "digest": format!("sha256:{digest}"),
539            "dims": 128,
540            "files": [
541                {
542                    "path": "model.onnx",
543                    "sha256": checksum_hex(model_bytes),
544                    "roles": ["weights"],
545                }
546            ]
547        });
548        write_manifest(&model_dir.join("manifest.json"), &manifest)?;
549
550        let options = ModelVerifyOptions {
551            run_onnx_smoke: false,
552        };
553        let reports = verify_models(temp.path(), &options)?;
554        assert_eq!(reports.len(), 1);
555        assert_eq!(reports[0].digest, format!("sha256:{digest}"));
556        Ok(())
557    }
558}