1use std::fs;
2use std::fs::File;
3use std::io::{BufReader, Read};
4use std::path::{Component, Path, PathBuf};
5
6#[cfg(feature = "vec")]
7use std::time::Instant;
8
9use serde::Deserialize;
10
11use crate::error::{MemvidError, Result};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ModelVerificationStatus {
15 Ok,
16 Warn,
17 Fail,
18}
19
20impl ModelVerificationStatus {
21 fn elevate(&mut self, other: ModelVerificationStatus) {
22 use ModelVerificationStatus::{Fail, Ok, Warn};
23 match (*self, other) {
24 (Fail, _) | (_, Fail) => *self = Fail,
25 (Warn, _) | (_, Warn) => {
26 if matches!(*self, Ok) {
27 *self = Warn;
28 }
29 }
30 _ => {}
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
36pub struct ModelVerification {
37 pub digest: String,
38 pub dims: Option<u32>,
39 pub quant: Option<String>,
40 pub context_length: Option<u32>,
41 pub status: ModelVerificationStatus,
42 pub load_latency_ms: Option<u128>,
43 pub path: PathBuf,
44 pub warnings: Vec<String>,
45 pub errors: Vec<String>,
46}
47
48impl ModelVerification {
49 fn from_error(path: PathBuf, err: MemvidError) -> Self {
50 let digest = digest_from_dir_name(&path).unwrap_or_else(|| "sha256:unknown".to_string());
51 Self {
52 digest,
53 dims: None,
54 quant: None,
55 context_length: None,
56 status: ModelVerificationStatus::Fail,
57 load_latency_ms: None,
58 path,
59 warnings: Vec::new(),
60 errors: vec![err.to_string()],
61 }
62 }
63}
64
65#[derive(Debug, Clone, Default)]
66pub struct ModelVerifyOptions {
67 pub run_onnx_smoke: bool,
68}
69
70#[derive(Debug, Clone, Deserialize)]
71#[serde(default, rename_all = "kebab-case")]
72pub struct ModelManifest {
73 pub schema_version: u32,
74 pub digest: String,
75 pub dims: u32,
76 pub quant: Option<String>,
77 pub context_length: Option<u32>,
78 pub files: Vec<ModelManifestEntry>,
79 pub metadata: serde_json::Value,
80}
81
82impl Default for ModelManifest {
83 fn default() -> Self {
84 Self {
85 schema_version: 1,
86 digest: String::new(),
87 dims: 0,
88 quant: None,
89 context_length: None,
90 files: Vec::new(),
91 metadata: serde_json::Value::Null,
92 }
93 }
94}
95
96#[derive(Debug, Clone, Deserialize)]
97#[serde(default, rename_all = "kebab-case")]
98#[derive(Default)]
99pub struct ModelManifestEntry {
100 pub path: String,
101 pub sha256: String,
102 pub optional: bool,
103 pub roles: Vec<String>,
104 pub kind: Option<String>,
105}
106
107pub fn verify_models(root: &Path, options: &ModelVerifyOptions) -> Result<Vec<ModelVerification>> {
108 if !root.exists() {
109 return Ok(Vec::new());
110 }
111
112 let mut dirs: Vec<PathBuf> = fs::read_dir(root)?
113 .filter_map(std::result::Result::ok)
114 .filter_map(|entry| {
115 let path = entry.path();
116 entry
117 .file_type()
118 .ok()
119 .filter(std::fs::FileType::is_dir)
120 .and_then(|_| digest_from_dir_name(&path).map(|_| path))
121 })
122 .collect();
123 dirs.sort();
124
125 let mut reports = Vec::with_capacity(dirs.len());
126 for dir in dirs {
127 match verify_model_dir(&dir, options) {
128 Ok(report) => reports.push(report),
129 Err(err) => reports.push(ModelVerification::from_error(dir, err)),
130 }
131 }
132
133 reports.sort_by(|a, b| a.digest.cmp(&b.digest));
134 Ok(reports)
135}
136
137pub fn verify_model_dir(dir: &Path, options: &ModelVerifyOptions) -> Result<ModelVerification> {
138 let manifest_path = dir.join("manifest.json");
139 if !manifest_path.exists() {
140 return Err(MemvidError::ModelIntegrity {
141 reason: format!("missing manifest.json in {}", dir.display()).into_boxed_str(),
142 });
143 }
144
145 let manifest_data = fs::read_to_string(&manifest_path)?;
146 let manifest: ModelManifest =
147 serde_json::from_str(&manifest_data).map_err(|err| MemvidError::ModelManifestInvalid {
148 reason: format!(
149 "failed to parse manifest {}: {err}",
150 manifest_path.display()
151 )
152 .into_boxed_str(),
153 })?;
154
155 if manifest.digest.trim().is_empty() {
156 return Err(MemvidError::ModelManifestInvalid {
157 reason: "manifest digest is empty".into(),
158 });
159 }
160
161 if manifest.dims == 0 {
162 return Err(MemvidError::ModelManifestInvalid {
163 reason: "embedding dimensions must be > 0".into(),
164 });
165 }
166
167 let manifest_digest_hex = normalize_sha256(&manifest.digest, "manifest digest")?;
168 let dir_digest_hex = digest_from_dir_name(dir).ok_or_else(|| MemvidError::ModelIntegrity {
169 reason: format!(
170 "directory {} is not named as sha256-<digest>",
171 dir.display()
172 )
173 .into_boxed_str(),
174 })?;
175
176 let dir_digest_hex = normalize_sha256(&dir_digest_hex, "directory digest")?;
177 if manifest_digest_hex != dir_digest_hex {
178 return Err(MemvidError::ModelIntegrity {
179 reason: format!(
180 "manifest digest sha256:{manifest_digest_hex} does not match directory sha256:{dir_digest_hex}"
181 )
182 .into_boxed_str(),
183 });
184 }
185
186 let digest = format!("sha256:{manifest_digest_hex}");
187
188 let mut status = ModelVerificationStatus::Ok;
189 let mut warnings = Vec::new();
190 let mut errors = Vec::new();
191 let mut load_latency_ms = None;
192
193 for entry in &manifest.files {
194 validate_entry(entry)?;
195 let expected_hex = normalize_sha256(&entry.sha256, &entry.path)?;
196 let resolved_path = resolve_entry_path(dir, &entry.path)?;
197 if !resolved_path.exists() {
198 if entry.optional {
199 warnings.push(format!("optional file missing: {}", entry.path));
200 status.elevate(ModelVerificationStatus::Warn);
201 } else {
202 errors.push(format!("required file missing: {}", entry.path));
203 status.elevate(ModelVerificationStatus::Fail);
204 }
205 continue;
206 }
207
208 let actual_hex = compute_sha256_hex(&resolved_path)?;
209 if actual_hex != expected_hex {
210 errors.push(format!(
211 "checksum mismatch for {} (expected {}, got {})",
212 entry.path, expected_hex, actual_hex
213 ));
214 status.elevate(ModelVerificationStatus::Fail);
215 }
216 }
217
218 if status != ModelVerificationStatus::Fail && options.run_onnx_smoke {
219 if let Some(weights_entry) = select_weights_entry(&manifest) {
220 let weights_path = resolve_entry_path(dir, &weights_entry.path)?;
221 if weights_path.exists() {
222 match run_onnx_smoke_test(&weights_path) {
223 Ok(latency) => {
224 load_latency_ms = Some(latency.max(1));
225 }
226 Err(OnnxSmokeError::FeatureUnavailable(feature)) => {
227 warnings.push(format!(
228 "feature '{feature}' not enabled; skipping ONNX smoke test"
229 ));
230 status.elevate(ModelVerificationStatus::Warn);
231 }
232 Err(OnnxSmokeError::Engine(err)) => {
233 errors.push(format!("ONNX initialisation failed: {err}"));
234 status.elevate(ModelVerificationStatus::Fail);
235 }
236 }
237 }
238 } else {
239 warnings.push(
240 "manifest does not declare a model .onnx file; skipping ONNX smoke test".into(),
241 );
242 status.elevate(ModelVerificationStatus::Warn);
243 }
244 }
245
246 let resolved_dir = fs::canonicalize(dir).unwrap_or_else(|_| dir.to_path_buf());
247
248 Ok(ModelVerification {
249 digest,
250 dims: Some(manifest.dims),
251 quant: manifest.quant.clone(),
252 context_length: manifest.context_length,
253 status,
254 load_latency_ms,
255 path: resolved_dir,
256 warnings,
257 errors,
258 })
259}
260
261fn validate_entry(entry: &ModelManifestEntry) -> Result<()> {
262 if entry.path.trim().is_empty() {
263 return Err(MemvidError::ModelManifestInvalid {
264 reason: "file entry path is empty".into(),
265 });
266 }
267 if entry.path.contains('\\') {
268 return Err(MemvidError::ModelManifestInvalid {
269 reason: format!("file entry path must use forward slashes: {}", entry.path)
270 .into_boxed_str(),
271 });
272 }
273 if entry.sha256.trim().is_empty() {
274 return Err(MemvidError::ModelManifestInvalid {
275 reason: format!("file entry '{}' missing sha256", entry.path).into_boxed_str(),
276 });
277 }
278 Ok(())
279}
280
281fn resolve_entry_path(base: &Path, relative: &str) -> Result<PathBuf> {
282 let path = Path::new(relative);
283 if path.is_absolute() {
284 return Err(MemvidError::ModelManifestInvalid {
285 reason: format!("file entry '{relative}' must be relative").into_boxed_str(),
286 });
287 }
288
289 for component in path.components() {
290 if matches!(component, Component::ParentDir) {
291 return Err(MemvidError::ModelManifestInvalid {
292 reason: format!("file entry '{relative}' attempts directory traversal")
293 .into_boxed_str(),
294 });
295 }
296 }
297
298 Ok(base.join(path))
299}
300
301fn normalize_sha256(value: &str, context: &str) -> Result<String> {
302 let trimmed = value.trim();
303 let trimmed = trimmed
304 .strip_prefix("sha256:")
305 .or_else(|| trimmed.strip_prefix("sha256-"))
306 .unwrap_or(trimmed);
307 if trimmed.len() != 64 || !trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
308 return Err(MemvidError::ModelManifestInvalid {
309 reason: format!("invalid sha256 value for {context}").into_boxed_str(),
310 });
311 }
312 Ok(trimmed.to_ascii_lowercase())
313}
314
315fn digest_from_dir_name(path: &Path) -> Option<String> {
316 let name = path.file_name()?.to_str()?;
317 name.strip_prefix("sha256-")
318 .map(std::string::ToString::to_string)
319}
320
321fn compute_sha256_hex(path: &Path) -> Result<String> {
322 use sha2::{Digest, Sha256};
323
324 let file = File::open(path)?;
325 let mut reader = BufReader::new(file);
326 let mut hasher = Sha256::new();
327 let mut buffer = [0u8; 8192];
328 loop {
329 let read = reader.read(&mut buffer)?;
330 if read == 0 {
331 break;
332 }
333 hasher.update(&buffer[..read]);
334 }
335 Ok(hex::encode(hasher.finalize()))
336}
337
338fn select_weights_entry(manifest: &ModelManifest) -> Option<&ModelManifestEntry> {
339 if let Some(quant) = manifest.quant.as_deref() {
340 if let Some(entry) = manifest
341 .files
342 .iter()
343 .find(|entry| entry.path.ends_with(".onnx") && entry.path.contains(quant))
344 {
345 return Some(entry);
346 }
347 }
348
349 manifest
350 .files
351 .iter()
352 .find(|entry| entry.roles.iter().any(|role| role == "weights"))
353 .or_else(|| {
354 manifest
355 .files
356 .iter()
357 .find(|entry| entry.kind.as_deref() == Some("onnx"))
358 })
359 .or_else(|| {
360 manifest
361 .files
362 .iter()
363 .find(|entry| entry.path.ends_with(".onnx"))
364 })
365}
366
367#[allow(dead_code)]
368#[derive(Debug)]
369enum OnnxSmokeError {
370 FeatureUnavailable(&'static str),
371 Engine(String),
372}
373
374#[cfg(feature = "vec")]
375fn run_onnx_smoke_test(path: &Path) -> std::result::Result<u128, OnnxSmokeError> {
376 use ort::session::Session;
377
378 let builder = Session::builder().map_err(|err| OnnxSmokeError::Engine(err.to_string()))?;
379 let start = Instant::now();
380 let session = builder
381 .commit_from_file(path)
382 .map_err(|err| OnnxSmokeError::Engine(err.to_string()))?;
383 drop(session);
384 let elapsed = start.elapsed().as_millis();
385 Ok(elapsed.max(1))
386}
387
388#[cfg(not(feature = "vec"))]
389fn run_onnx_smoke_test(_path: &Path) -> std::result::Result<u128, OnnxSmokeError> {
390 Err(OnnxSmokeError::FeatureUnavailable("vec"))
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use sha2::{Digest, Sha256};
397 use tempfile::tempdir;
398
399 fn write_manifest(path: &Path, value: &serde_json::Value) -> Result<()> {
400 let bytes =
401 serde_json::to_vec_pretty(value).map_err(|err| MemvidError::ModelManifestInvalid {
402 reason: format!("failed to encode manifest: {err}").into_boxed_str(),
403 })?;
404 fs::write(path, bytes)?;
405 Ok(())
406 }
407
408 fn write_file(path: &Path, contents: &[u8]) -> Result<()> {
409 if let Some(parent) = path.parent() {
410 fs::create_dir_all(parent)?;
411 }
412 fs::write(path, contents)?;
413 Ok(())
414 }
415
416 fn checksum_hex(data: &[u8]) -> String {
417 let mut hasher = Sha256::new();
418 hasher.update(data);
419 hex::encode(hasher.finalize())
420 }
421
422 #[test]
423 fn verify_model_success() -> Result<()> {
424 let temp = tempdir()?;
425 let digest = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
426 let model_dir = temp.path().join(format!("sha256-{digest}"));
427 fs::create_dir_all(&model_dir)?;
428
429 let model_bytes = b"ONNX";
430 let tokenizer_bytes = b"{}";
431
432 write_file(
433 &model_dir.join("models/encoder/model_int8.onnx"),
434 model_bytes,
435 )?;
436 write_file(
437 &model_dir.join("models/encoder/tokenizer.json"),
438 tokenizer_bytes,
439 )?;
440
441 let manifest = serde_json::json!({
442 "digest": format!("sha256:{digest}"),
443 "dims": 384,
444 "quant": "int8",
445 "files": [
446 {
447 "path": "models/encoder/model_int8.onnx",
448 "sha256": checksum_hex(model_bytes),
449 "roles": ["weights"],
450 },
451 {
452 "path": "models/encoder/tokenizer.json",
453 "sha256": checksum_hex(tokenizer_bytes),
454 }
455 ]
456 });
457 write_manifest(&model_dir.join("manifest.json"), &manifest)?;
458
459 let options = ModelVerifyOptions {
460 run_onnx_smoke: false,
461 };
462 let report = verify_model_dir(&model_dir, &options)?;
463 assert_eq!(report.digest, format!("sha256:{digest}"));
464 assert_eq!(report.status, ModelVerificationStatus::Ok);
465 assert_eq!(report.dims, Some(384));
466 assert!(report.errors.is_empty());
467 Ok(())
468 }
469
470 #[test]
471 fn verify_model_missing_optional_warns() -> Result<()> {
472 let temp = tempdir()?;
473 let digest = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
474 let model_dir = temp.path().join(format!("sha256-{digest}"));
475 fs::create_dir_all(&model_dir)?;
476
477 let model_bytes = b"ONNX";
478
479 write_file(&model_dir.join("models/model.onnx"), model_bytes)?;
480
481 let manifest = serde_json::json!({
482 "digest": format!("sha256:{digest}"),
483 "dims": 256,
484 "files": [
485 {
486 "path": "models/model.onnx",
487 "sha256": checksum_hex(model_bytes),
488 "roles": ["weights"],
489 },
490 {
491 "path": "models/tokenizer.json",
492 "sha256": checksum_hex(b"missing"),
493 "optional": true
494 }
495 ]
496 });
497 write_manifest(&model_dir.join("manifest.json"), &manifest)?;
498
499 let options = ModelVerifyOptions {
500 run_onnx_smoke: false,
501 };
502 let report = verify_model_dir(&model_dir, &options)?;
503 assert_eq!(report.status, ModelVerificationStatus::Warn);
504 assert!(report.errors.is_empty());
505 assert_eq!(report.warnings.len(), 1);
506 Ok(())
507 }
508
509 #[test]
510 fn verify_models_directory_listing() -> Result<()> {
511 let temp = tempdir()?;
512 let digest = "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc";
513 let model_dir = temp.path().join(format!("sha256-{digest}"));
514 fs::create_dir_all(&model_dir)?;
515
516 let model_bytes = b"ONNX";
517 fs::write(model_dir.join("model.onnx"), model_bytes)?;
518
519 let manifest = serde_json::json!({
520 "digest": format!("sha256:{digest}"),
521 "dims": 128,
522 "files": [
523 {
524 "path": "model.onnx",
525 "sha256": checksum_hex(model_bytes),
526 "roles": ["weights"],
527 }
528 ]
529 });
530 write_manifest(&model_dir.join("manifest.json"), &manifest)?;
531
532 let options = ModelVerifyOptions {
533 run_onnx_smoke: false,
534 };
535 let reports = verify_models(temp.path(), &options)?;
536 assert_eq!(reports.len(), 1);
537 assert_eq!(reports[0].digest, format!("sha256:{digest}"));
538 Ok(())
539 }
540}