Skip to main content

memvid_core/
models.rs

1use std::fs;
2use std::fs::File;
3use std::io::{BufReader, Read};
4use std::path::{Component, Path, PathBuf};
5
6#[cfg(feature = "vec")]
7use std::time::Instant;
8
9use serde::Deserialize;
10
11use crate::error::{MemvidError, Result};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ModelVerificationStatus {
15    Ok,
16    Warn,
17    Fail,
18}
19
20impl ModelVerificationStatus {
21    fn elevate(&mut self, other: ModelVerificationStatus) {
22        use ModelVerificationStatus::{Fail, Ok, Warn};
23        match (*self, other) {
24            (Fail, _) | (_, Fail) => *self = Fail,
25            (Warn, _) | (_, Warn) => {
26                if matches!(*self, Ok) {
27                    *self = Warn;
28                }
29            }
30            _ => {}
31        }
32    }
33}
34
35#[derive(Debug, Clone)]
36pub struct ModelVerification {
37    pub digest: String,
38    pub dims: Option<u32>,
39    pub quant: Option<String>,
40    pub context_length: Option<u32>,
41    pub status: ModelVerificationStatus,
42    pub load_latency_ms: Option<u128>,
43    pub path: PathBuf,
44    pub warnings: Vec<String>,
45    pub errors: Vec<String>,
46}
47
48impl ModelVerification {
49    fn from_error(path: PathBuf, err: MemvidError) -> Self {
50        let digest = digest_from_dir_name(&path).unwrap_or_else(|| "sha256:unknown".to_string());
51        Self {
52            digest,
53            dims: None,
54            quant: None,
55            context_length: None,
56            status: ModelVerificationStatus::Fail,
57            load_latency_ms: None,
58            path,
59            warnings: Vec::new(),
60            errors: vec![err.to_string()],
61        }
62    }
63}
64
65#[derive(Debug, Clone, Default)]
66pub struct ModelVerifyOptions {
67    pub run_onnx_smoke: bool,
68}
69
70#[derive(Debug, Clone, Deserialize)]
71#[serde(default, rename_all = "kebab-case")]
72pub struct ModelManifest {
73    pub schema_version: u32,
74    pub digest: String,
75    pub dims: u32,
76    pub quant: Option<String>,
77    pub context_length: Option<u32>,
78    pub files: Vec<ModelManifestEntry>,
79    pub metadata: serde_json::Value,
80}
81
82impl Default for ModelManifest {
83    fn default() -> Self {
84        Self {
85            schema_version: 1,
86            digest: String::new(),
87            dims: 0,
88            quant: None,
89            context_length: None,
90            files: Vec::new(),
91            metadata: serde_json::Value::Null,
92        }
93    }
94}
95
96#[derive(Debug, Clone, Deserialize)]
97#[serde(default, rename_all = "kebab-case")]
98#[derive(Default)]
99pub struct ModelManifestEntry {
100    pub path: String,
101    pub sha256: String,
102    pub optional: bool,
103    pub roles: Vec<String>,
104    pub kind: Option<String>,
105}
106
107pub fn verify_models(root: &Path, options: &ModelVerifyOptions) -> Result<Vec<ModelVerification>> {
108    if !root.exists() {
109        return Ok(Vec::new());
110    }
111
112    let mut dirs: Vec<PathBuf> = fs::read_dir(root)?
113        .filter_map(std::result::Result::ok)
114        .filter_map(|entry| {
115            let path = entry.path();
116            entry
117                .file_type()
118                .ok()
119                .filter(std::fs::FileType::is_dir)
120                .and_then(|_| digest_from_dir_name(&path).map(|_| path))
121        })
122        .collect();
123    dirs.sort();
124
125    let mut reports = Vec::with_capacity(dirs.len());
126    for dir in dirs {
127        match verify_model_dir(&dir, options) {
128            Ok(report) => reports.push(report),
129            Err(err) => reports.push(ModelVerification::from_error(dir, err)),
130        }
131    }
132
133    reports.sort_by(|a, b| a.digest.cmp(&b.digest));
134    Ok(reports)
135}
136
137pub fn verify_model_dir(dir: &Path, options: &ModelVerifyOptions) -> Result<ModelVerification> {
138    let manifest_path = dir.join("manifest.json");
139    if !manifest_path.exists() {
140        return Err(MemvidError::ModelIntegrity {
141            reason: format!("missing manifest.json in {}", dir.display()).into_boxed_str(),
142        });
143    }
144
145    let manifest_data = fs::read_to_string(&manifest_path)?;
146    let manifest: ModelManifest =
147        serde_json::from_str(&manifest_data).map_err(|err| MemvidError::ModelManifestInvalid {
148            reason: format!(
149                "failed to parse manifest {}: {err}",
150                manifest_path.display()
151            )
152            .into_boxed_str(),
153        })?;
154
155    if manifest.digest.trim().is_empty() {
156        return Err(MemvidError::ModelManifestInvalid {
157            reason: "manifest digest is empty".into(),
158        });
159    }
160
161    if manifest.dims == 0 {
162        return Err(MemvidError::ModelManifestInvalid {
163            reason: "embedding dimensions must be > 0".into(),
164        });
165    }
166
167    let manifest_digest_hex = normalize_sha256(&manifest.digest, "manifest digest")?;
168    let dir_digest_hex = digest_from_dir_name(dir).ok_or_else(|| MemvidError::ModelIntegrity {
169        reason: format!(
170            "directory {} is not named as sha256-<digest>",
171            dir.display()
172        )
173        .into_boxed_str(),
174    })?;
175
176    let dir_digest_hex = normalize_sha256(&dir_digest_hex, "directory digest")?;
177    if manifest_digest_hex != dir_digest_hex {
178        return Err(MemvidError::ModelIntegrity {
179            reason: format!(
180                "manifest digest sha256:{manifest_digest_hex} does not match directory sha256:{dir_digest_hex}"
181            )
182            .into_boxed_str(),
183        });
184    }
185
186    let digest = format!("sha256:{manifest_digest_hex}");
187
188    let mut status = ModelVerificationStatus::Ok;
189    let mut warnings = Vec::new();
190    let mut errors = Vec::new();
191    let mut load_latency_ms = None;
192
193    for entry in &manifest.files {
194        validate_entry(entry)?;
195        let expected_hex = normalize_sha256(&entry.sha256, &entry.path)?;
196        let resolved_path = resolve_entry_path(dir, &entry.path)?;
197        if !resolved_path.exists() {
198            if entry.optional {
199                warnings.push(format!("optional file missing: {}", entry.path));
200                status.elevate(ModelVerificationStatus::Warn);
201            } else {
202                errors.push(format!("required file missing: {}", entry.path));
203                status.elevate(ModelVerificationStatus::Fail);
204            }
205            continue;
206        }
207
208        let actual_hex = compute_sha256_hex(&resolved_path)?;
209        if actual_hex != expected_hex {
210            errors.push(format!(
211                "checksum mismatch for {} (expected {}, got {})",
212                entry.path, expected_hex, actual_hex
213            ));
214            status.elevate(ModelVerificationStatus::Fail);
215        }
216    }
217
218    if status != ModelVerificationStatus::Fail && options.run_onnx_smoke {
219        if let Some(weights_entry) = select_weights_entry(&manifest) {
220            let weights_path = resolve_entry_path(dir, &weights_entry.path)?;
221            if weights_path.exists() {
222                match run_onnx_smoke_test(&weights_path) {
223                    Ok(latency) => {
224                        load_latency_ms = Some(latency.max(1));
225                    }
226                    Err(OnnxSmokeError::FeatureUnavailable(feature)) => {
227                        warnings.push(format!(
228                            "feature '{feature}' not enabled; skipping ONNX smoke test"
229                        ));
230                        status.elevate(ModelVerificationStatus::Warn);
231                    }
232                    Err(OnnxSmokeError::Engine(err)) => {
233                        errors.push(format!("ONNX initialisation failed: {err}"));
234                        status.elevate(ModelVerificationStatus::Fail);
235                    }
236                }
237            }
238        } else {
239            warnings.push(
240                "manifest does not declare a model .onnx file; skipping ONNX smoke test".into(),
241            );
242            status.elevate(ModelVerificationStatus::Warn);
243        }
244    }
245
246    let resolved_dir = fs::canonicalize(dir).unwrap_or_else(|_| dir.to_path_buf());
247
248    Ok(ModelVerification {
249        digest,
250        dims: Some(manifest.dims),
251        quant: manifest.quant.clone(),
252        context_length: manifest.context_length,
253        status,
254        load_latency_ms,
255        path: resolved_dir,
256        warnings,
257        errors,
258    })
259}
260
261fn validate_entry(entry: &ModelManifestEntry) -> Result<()> {
262    if entry.path.trim().is_empty() {
263        return Err(MemvidError::ModelManifestInvalid {
264            reason: "file entry path is empty".into(),
265        });
266    }
267    if entry.path.contains('\\') {
268        return Err(MemvidError::ModelManifestInvalid {
269            reason: format!("file entry path must use forward slashes: {}", entry.path)
270                .into_boxed_str(),
271        });
272    }
273    if entry.sha256.trim().is_empty() {
274        return Err(MemvidError::ModelManifestInvalid {
275            reason: format!("file entry '{}' missing sha256", entry.path).into_boxed_str(),
276        });
277    }
278    Ok(())
279}
280
281fn resolve_entry_path(base: &Path, relative: &str) -> Result<PathBuf> {
282    let path = Path::new(relative);
283    if path.is_absolute() {
284        return Err(MemvidError::ModelManifestInvalid {
285            reason: format!("file entry '{relative}' must be relative").into_boxed_str(),
286        });
287    }
288
289    for component in path.components() {
290        if matches!(component, Component::ParentDir) {
291            return Err(MemvidError::ModelManifestInvalid {
292                reason: format!("file entry '{relative}' attempts directory traversal")
293                    .into_boxed_str(),
294            });
295        }
296    }
297
298    Ok(base.join(path))
299}
300
301fn normalize_sha256(value: &str, context: &str) -> Result<String> {
302    let trimmed = value.trim();
303    let trimmed = trimmed
304        .strip_prefix("sha256:")
305        .or_else(|| trimmed.strip_prefix("sha256-"))
306        .unwrap_or(trimmed);
307    if trimmed.len() != 64 || !trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
308        return Err(MemvidError::ModelManifestInvalid {
309            reason: format!("invalid sha256 value for {context}").into_boxed_str(),
310        });
311    }
312    Ok(trimmed.to_ascii_lowercase())
313}
314
315fn digest_from_dir_name(path: &Path) -> Option<String> {
316    let name = path.file_name()?.to_str()?;
317    name.strip_prefix("sha256-")
318        .map(std::string::ToString::to_string)
319}
320
321fn compute_sha256_hex(path: &Path) -> Result<String> {
322    use sha2::{Digest, Sha256};
323
324    let file = File::open(path)?;
325    let mut reader = BufReader::new(file);
326    let mut hasher = Sha256::new();
327    let mut buffer = [0u8; 8192];
328    loop {
329        let read = reader.read(&mut buffer)?;
330        if read == 0 {
331            break;
332        }
333        hasher.update(&buffer[..read]);
334    }
335    Ok(hex::encode(hasher.finalize()))
336}
337
338fn select_weights_entry(manifest: &ModelManifest) -> Option<&ModelManifestEntry> {
339    if let Some(quant) = manifest.quant.as_deref() {
340        if let Some(entry) = manifest
341            .files
342            .iter()
343            .find(|entry| entry.path.ends_with(".onnx") && entry.path.contains(quant))
344        {
345            return Some(entry);
346        }
347    }
348
349    manifest
350        .files
351        .iter()
352        .find(|entry| entry.roles.iter().any(|role| role == "weights"))
353        .or_else(|| {
354            manifest
355                .files
356                .iter()
357                .find(|entry| entry.kind.as_deref() == Some("onnx"))
358        })
359        .or_else(|| {
360            manifest
361                .files
362                .iter()
363                .find(|entry| entry.path.ends_with(".onnx"))
364        })
365}
366
367#[allow(dead_code)]
368#[derive(Debug)]
369enum OnnxSmokeError {
370    FeatureUnavailable(&'static str),
371    Engine(String),
372}
373
374#[cfg(feature = "vec")]
375fn run_onnx_smoke_test(path: &Path) -> std::result::Result<u128, OnnxSmokeError> {
376    use ort::session::Session;
377
378    let builder = Session::builder().map_err(|err| OnnxSmokeError::Engine(err.to_string()))?;
379    let start = Instant::now();
380    let session = builder
381        .commit_from_file(path)
382        .map_err(|err| OnnxSmokeError::Engine(err.to_string()))?;
383    drop(session);
384    let elapsed = start.elapsed().as_millis();
385    Ok(elapsed.max(1))
386}
387
388#[cfg(not(feature = "vec"))]
389fn run_onnx_smoke_test(_path: &Path) -> std::result::Result<u128, OnnxSmokeError> {
390    Err(OnnxSmokeError::FeatureUnavailable("vec"))
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use sha2::{Digest, Sha256};
397    use tempfile::tempdir;
398
399    fn write_manifest(path: &Path, value: &serde_json::Value) -> Result<()> {
400        let bytes =
401            serde_json::to_vec_pretty(value).map_err(|err| MemvidError::ModelManifestInvalid {
402                reason: format!("failed to encode manifest: {err}").into_boxed_str(),
403            })?;
404        fs::write(path, bytes)?;
405        Ok(())
406    }
407
408    fn write_file(path: &Path, contents: &[u8]) -> Result<()> {
409        if let Some(parent) = path.parent() {
410            fs::create_dir_all(parent)?;
411        }
412        fs::write(path, contents)?;
413        Ok(())
414    }
415
416    fn checksum_hex(data: &[u8]) -> String {
417        let mut hasher = Sha256::new();
418        hasher.update(data);
419        hex::encode(hasher.finalize())
420    }
421
422    #[test]
423    fn verify_model_success() -> Result<()> {
424        let temp = tempdir()?;
425        let digest = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
426        let model_dir = temp.path().join(format!("sha256-{digest}"));
427        fs::create_dir_all(&model_dir)?;
428
429        let model_bytes = b"ONNX";
430        let tokenizer_bytes = b"{}";
431
432        write_file(
433            &model_dir.join("models/encoder/model_int8.onnx"),
434            model_bytes,
435        )?;
436        write_file(
437            &model_dir.join("models/encoder/tokenizer.json"),
438            tokenizer_bytes,
439        )?;
440
441        let manifest = serde_json::json!({
442            "digest": format!("sha256:{digest}"),
443            "dims": 384,
444            "quant": "int8",
445            "files": [
446                {
447                    "path": "models/encoder/model_int8.onnx",
448                    "sha256": checksum_hex(model_bytes),
449                    "roles": ["weights"],
450                },
451                {
452                    "path": "models/encoder/tokenizer.json",
453                    "sha256": checksum_hex(tokenizer_bytes),
454                }
455            ]
456        });
457        write_manifest(&model_dir.join("manifest.json"), &manifest)?;
458
459        let options = ModelVerifyOptions {
460            run_onnx_smoke: false,
461        };
462        let report = verify_model_dir(&model_dir, &options)?;
463        assert_eq!(report.digest, format!("sha256:{digest}"));
464        assert_eq!(report.status, ModelVerificationStatus::Ok);
465        assert_eq!(report.dims, Some(384));
466        assert!(report.errors.is_empty());
467        Ok(())
468    }
469
470    #[test]
471    fn verify_model_missing_optional_warns() -> Result<()> {
472        let temp = tempdir()?;
473        let digest = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
474        let model_dir = temp.path().join(format!("sha256-{digest}"));
475        fs::create_dir_all(&model_dir)?;
476
477        let model_bytes = b"ONNX";
478
479        write_file(&model_dir.join("models/model.onnx"), model_bytes)?;
480
481        let manifest = serde_json::json!({
482            "digest": format!("sha256:{digest}"),
483            "dims": 256,
484            "files": [
485                {
486                    "path": "models/model.onnx",
487                    "sha256": checksum_hex(model_bytes),
488                    "roles": ["weights"],
489                },
490                {
491                    "path": "models/tokenizer.json",
492                    "sha256": checksum_hex(b"missing"),
493                    "optional": true
494                }
495            ]
496        });
497        write_manifest(&model_dir.join("manifest.json"), &manifest)?;
498
499        let options = ModelVerifyOptions {
500            run_onnx_smoke: false,
501        };
502        let report = verify_model_dir(&model_dir, &options)?;
503        assert_eq!(report.status, ModelVerificationStatus::Warn);
504        assert!(report.errors.is_empty());
505        assert_eq!(report.warnings.len(), 1);
506        Ok(())
507    }
508
509    #[test]
510    fn verify_models_directory_listing() -> Result<()> {
511        let temp = tempdir()?;
512        let digest = "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc";
513        let model_dir = temp.path().join(format!("sha256-{digest}"));
514        fs::create_dir_all(&model_dir)?;
515
516        let model_bytes = b"ONNX";
517        fs::write(model_dir.join("model.onnx"), model_bytes)?;
518
519        let manifest = serde_json::json!({
520            "digest": format!("sha256:{digest}"),
521            "dims": 128,
522            "files": [
523                {
524                    "path": "model.onnx",
525                    "sha256": checksum_hex(model_bytes),
526                    "roles": ["weights"],
527                }
528            ]
529        });
530        write_manifest(&model_dir.join("manifest.json"), &manifest)?;
531
532        let options = ModelVerifyOptions {
533            run_onnx_smoke: false,
534        };
535        let reports = verify_models(temp.path(), &options)?;
536        assert_eq!(reports.len(), 1);
537        assert_eq!(reports[0].digest, format!("sha256:{digest}"));
538        Ok(())
539    }
540}