1use std::fs;
2use std::fs::File;
3use std::io::{BufReader, Read};
4use std::path::{Component, Path, PathBuf};
5
6#[cfg(feature = "vec")]
7use std::time::Instant;
8
9use serde::Deserialize;
10
11use crate::error::{MemvidError, Result};
12
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14pub enum ModelVerificationStatus {
15 Ok,
16 Warn,
17 Fail,
18}
19
20impl ModelVerificationStatus {
21 fn elevate(&mut self, other: ModelVerificationStatus) {
22 use ModelVerificationStatus::{Fail, Ok, Warn};
23 match (*self, other) {
24 (Fail, _) | (_, Fail) => *self = Fail,
25 (Warn, _) | (_, Warn) => {
26 if matches!(*self, Ok) {
27 *self = Warn;
28 }
29 }
30 _ => {}
31 }
32 }
33}
34
35#[derive(Debug, Clone)]
36pub struct ModelVerification {
37 pub digest: String,
38 pub dims: Option<u32>,
39 pub quant: Option<String>,
40 pub context_length: Option<u32>,
41 pub status: ModelVerificationStatus,
42 pub load_latency_ms: Option<u128>,
43 pub path: PathBuf,
44 pub warnings: Vec<String>,
45 pub errors: Vec<String>,
46}
47
48impl ModelVerification {
49 fn from_error(path: PathBuf, err: MemvidError) -> Self {
50 let digest = digest_from_dir_name(&path).unwrap_or_else(|| "sha256:unknown".to_string());
51 Self {
52 digest,
53 dims: None,
54 quant: None,
55 context_length: None,
56 status: ModelVerificationStatus::Fail,
57 load_latency_ms: None,
58 path,
59 warnings: Vec::new(),
60 errors: vec![err.to_string()],
61 }
62 }
63}
64
65#[derive(Debug, Clone)]
66pub struct ModelVerifyOptions {
67 pub run_onnx_smoke: bool,
68}
69
70impl Default for ModelVerifyOptions {
71 fn default() -> Self {
72 Self {
73 run_onnx_smoke: cfg!(feature = "vec"),
74 }
75 }
76}
77
78#[derive(Debug, Clone, Deserialize)]
79#[serde(default, rename_all = "kebab-case")]
80pub struct ModelManifest {
81 pub schema_version: u32,
82 pub digest: String,
83 pub dims: u32,
84 pub quant: Option<String>,
85 pub context_length: Option<u32>,
86 pub files: Vec<ModelManifestEntry>,
87 pub metadata: serde_json::Value,
88}
89
90impl Default for ModelManifest {
91 fn default() -> Self {
92 Self {
93 schema_version: 1,
94 digest: String::new(),
95 dims: 0,
96 quant: None,
97 context_length: None,
98 files: Vec::new(),
99 metadata: serde_json::Value::Null,
100 }
101 }
102}
103
104#[derive(Debug, Clone, Deserialize)]
105#[serde(default, rename_all = "kebab-case")]
106pub struct ModelManifestEntry {
107 pub path: String,
108 pub sha256: String,
109 pub optional: bool,
110 pub roles: Vec<String>,
111 pub kind: Option<String>,
112}
113
114impl Default for ModelManifestEntry {
115 fn default() -> Self {
116 Self {
117 path: String::new(),
118 sha256: String::new(),
119 optional: false,
120 roles: Vec::new(),
121 kind: None,
122 }
123 }
124}
125
126pub fn verify_models(root: &Path, options: &ModelVerifyOptions) -> Result<Vec<ModelVerification>> {
127 if !root.exists() {
128 return Ok(Vec::new());
129 }
130
131 let mut dirs: Vec<PathBuf> = fs::read_dir(root)?
132 .filter_map(|entry| entry.ok())
133 .filter_map(|entry| {
134 let path = entry.path();
135 entry
136 .file_type()
137 .ok()
138 .filter(|ft| ft.is_dir())
139 .and_then(|_| digest_from_dir_name(&path).map(|_| path))
140 })
141 .collect();
142 dirs.sort();
143
144 let mut reports = Vec::with_capacity(dirs.len());
145 for dir in dirs {
146 match verify_model_dir(&dir, options) {
147 Ok(report) => reports.push(report),
148 Err(err) => reports.push(ModelVerification::from_error(dir, err)),
149 }
150 }
151
152 reports.sort_by(|a, b| a.digest.cmp(&b.digest));
153 Ok(reports)
154}
155
156pub fn verify_model_dir(dir: &Path, options: &ModelVerifyOptions) -> Result<ModelVerification> {
157 let manifest_path = dir.join("manifest.json");
158 if !manifest_path.exists() {
159 return Err(MemvidError::ModelIntegrity {
160 reason: format!("missing manifest.json in {}", dir.display()).into_boxed_str(),
161 });
162 }
163
164 let manifest_data = fs::read_to_string(&manifest_path)?;
165 let manifest: ModelManifest =
166 serde_json::from_str(&manifest_data).map_err(|err| MemvidError::ModelManifestInvalid {
167 reason: format!(
168 "failed to parse manifest {}: {err}",
169 manifest_path.display()
170 )
171 .into_boxed_str(),
172 })?;
173
174 if manifest.digest.trim().is_empty() {
175 return Err(MemvidError::ModelManifestInvalid {
176 reason: "manifest digest is empty".into(),
177 });
178 }
179
180 if manifest.dims == 0 {
181 return Err(MemvidError::ModelManifestInvalid {
182 reason: "embedding dimensions must be > 0".into(),
183 });
184 }
185
186 let manifest_digest_hex = normalize_sha256(&manifest.digest, "manifest digest")?;
187 let dir_digest_hex = digest_from_dir_name(dir).ok_or_else(|| MemvidError::ModelIntegrity {
188 reason: format!(
189 "directory {} is not named as sha256-<digest>",
190 dir.display()
191 )
192 .into_boxed_str(),
193 })?;
194
195 let dir_digest_hex = normalize_sha256(&dir_digest_hex, "directory digest")?;
196 if manifest_digest_hex != dir_digest_hex {
197 return Err(MemvidError::ModelIntegrity {
198 reason: format!(
199 "manifest digest sha256:{manifest_digest_hex} does not match directory sha256:{dir_digest_hex}"
200 )
201 .into_boxed_str(),
202 });
203 }
204
205 let digest = format!("sha256:{manifest_digest_hex}");
206
207 let mut status = ModelVerificationStatus::Ok;
208 let mut warnings = Vec::new();
209 let mut errors = Vec::new();
210 let mut load_latency_ms = None;
211
212 for entry in &manifest.files {
213 validate_entry(entry)?;
214 let expected_hex = normalize_sha256(&entry.sha256, &entry.path)?;
215 let resolved_path = resolve_entry_path(dir, &entry.path)?;
216 if !resolved_path.exists() {
217 if entry.optional {
218 warnings.push(format!("optional file missing: {}", entry.path));
219 status.elevate(ModelVerificationStatus::Warn);
220 } else {
221 errors.push(format!("required file missing: {}", entry.path));
222 status.elevate(ModelVerificationStatus::Fail);
223 }
224 continue;
225 }
226
227 let actual_hex = compute_sha256_hex(&resolved_path)?;
228 if actual_hex != expected_hex {
229 errors.push(format!(
230 "checksum mismatch for {} (expected {}, got {})",
231 entry.path, expected_hex, actual_hex
232 ));
233 status.elevate(ModelVerificationStatus::Fail);
234 }
235 }
236
237 if status != ModelVerificationStatus::Fail && options.run_onnx_smoke {
238 if let Some(weights_entry) = select_weights_entry(&manifest) {
239 let weights_path = resolve_entry_path(dir, &weights_entry.path)?;
240 if weights_path.exists() {
241 match run_onnx_smoke_test(&weights_path) {
242 Ok(latency) => {
243 load_latency_ms = Some(latency.max(1));
244 }
245 Err(OnnxSmokeError::FeatureUnavailable(feature)) => {
246 warnings.push(format!(
247 "feature '{feature}' not enabled; skipping ONNX smoke test"
248 ));
249 status.elevate(ModelVerificationStatus::Warn);
250 }
251 Err(OnnxSmokeError::Engine(err)) => {
252 errors.push(format!("ONNX initialisation failed: {err}"));
253 status.elevate(ModelVerificationStatus::Fail);
254 }
255 }
256 }
257 } else {
258 warnings.push(
259 "manifest does not declare a model .onnx file; skipping ONNX smoke test".into(),
260 );
261 status.elevate(ModelVerificationStatus::Warn);
262 }
263 }
264
265 let resolved_dir = fs::canonicalize(dir).unwrap_or_else(|_| dir.to_path_buf());
266
267 Ok(ModelVerification {
268 digest,
269 dims: Some(manifest.dims),
270 quant: manifest.quant.clone(),
271 context_length: manifest.context_length,
272 status,
273 load_latency_ms,
274 path: resolved_dir,
275 warnings,
276 errors,
277 })
278}
279
280fn validate_entry(entry: &ModelManifestEntry) -> Result<()> {
281 if entry.path.trim().is_empty() {
282 return Err(MemvidError::ModelManifestInvalid {
283 reason: "file entry path is empty".into(),
284 });
285 }
286 if entry.path.contains("\\") {
287 return Err(MemvidError::ModelManifestInvalid {
288 reason: format!("file entry path must use forward slashes: {}", entry.path)
289 .into_boxed_str(),
290 });
291 }
292 if entry.sha256.trim().is_empty() {
293 return Err(MemvidError::ModelManifestInvalid {
294 reason: format!("file entry '{}' missing sha256", entry.path).into_boxed_str(),
295 });
296 }
297 Ok(())
298}
299
300fn resolve_entry_path(base: &Path, relative: &str) -> Result<PathBuf> {
301 let path = Path::new(relative);
302 if path.is_absolute() {
303 return Err(MemvidError::ModelManifestInvalid {
304 reason: format!("file entry '{}' must be relative", relative).into_boxed_str(),
305 });
306 }
307
308 for component in path.components() {
309 if matches!(component, Component::ParentDir) {
310 return Err(MemvidError::ModelManifestInvalid {
311 reason: format!("file entry '{}' attempts directory traversal", relative)
312 .into_boxed_str(),
313 });
314 }
315 }
316
317 Ok(base.join(path))
318}
319
320fn normalize_sha256(value: &str, context: &str) -> Result<String> {
321 let trimmed = value.trim();
322 let trimmed = trimmed
323 .strip_prefix("sha256:")
324 .or_else(|| trimmed.strip_prefix("sha256-"))
325 .unwrap_or(trimmed);
326 if trimmed.len() != 64 || !trimmed.chars().all(|c| c.is_ascii_hexdigit()) {
327 return Err(MemvidError::ModelManifestInvalid {
328 reason: format!("invalid sha256 value for {context}").into_boxed_str(),
329 });
330 }
331 Ok(trimmed.to_ascii_lowercase())
332}
333
334fn digest_from_dir_name(path: &Path) -> Option<String> {
335 let name = path.file_name()?.to_str()?;
336 name.strip_prefix("sha256-").map(|rest| rest.to_string())
337}
338
339fn compute_sha256_hex(path: &Path) -> Result<String> {
340 use sha2::{Digest, Sha256};
341
342 let file = File::open(path)?;
343 let mut reader = BufReader::new(file);
344 let mut hasher = Sha256::new();
345 let mut buffer = [0u8; 8192];
346 loop {
347 let read = reader.read(&mut buffer)?;
348 if read == 0 {
349 break;
350 }
351 hasher.update(&buffer[..read]);
352 }
353 Ok(hex::encode(hasher.finalize()))
354}
355
356fn select_weights_entry<'a>(manifest: &'a ModelManifest) -> Option<&'a ModelManifestEntry> {
357 if let Some(quant) = manifest.quant.as_deref() {
358 if let Some(entry) = manifest
359 .files
360 .iter()
361 .find(|entry| entry.path.ends_with(".onnx") && entry.path.contains(quant))
362 {
363 return Some(entry);
364 }
365 }
366
367 manifest
368 .files
369 .iter()
370 .find(|entry| entry.roles.iter().any(|role| role == "weights"))
371 .or_else(|| {
372 manifest
373 .files
374 .iter()
375 .find(|entry| entry.kind.as_deref() == Some("onnx"))
376 })
377 .or_else(|| {
378 manifest
379 .files
380 .iter()
381 .find(|entry| entry.path.ends_with(".onnx"))
382 })
383}
384
385#[allow(dead_code)]
386#[derive(Debug)]
387enum OnnxSmokeError {
388 FeatureUnavailable(&'static str),
389 Engine(String),
390}
391
392#[cfg(feature = "vec")]
393fn run_onnx_smoke_test(path: &Path) -> std::result::Result<u128, OnnxSmokeError> {
394 use ort::session::Session;
395
396 let builder = Session::builder().map_err(|err| OnnxSmokeError::Engine(err.to_string()))?;
397 let start = Instant::now();
398 let session = builder
399 .commit_from_file(path)
400 .map_err(|err| OnnxSmokeError::Engine(err.to_string()))?;
401 drop(session);
402 let elapsed = start.elapsed().as_millis();
403 Ok(elapsed.max(1))
404}
405
406#[cfg(not(feature = "vec"))]
407fn run_onnx_smoke_test(_path: &Path) -> std::result::Result<u128, OnnxSmokeError> {
408 Err(OnnxSmokeError::FeatureUnavailable("vec"))
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414 use sha2::{Digest, Sha256};
415 use tempfile::tempdir;
416
417 fn write_manifest(path: &Path, value: &serde_json::Value) -> Result<()> {
418 let bytes =
419 serde_json::to_vec_pretty(value).map_err(|err| MemvidError::ModelManifestInvalid {
420 reason: format!("failed to encode manifest: {err}").into_boxed_str(),
421 })?;
422 fs::write(path, bytes)?;
423 Ok(())
424 }
425
426 fn write_file(path: &Path, contents: &[u8]) -> Result<()> {
427 if let Some(parent) = path.parent() {
428 fs::create_dir_all(parent)?;
429 }
430 fs::write(path, contents)?;
431 Ok(())
432 }
433
434 fn checksum_hex(data: &[u8]) -> String {
435 let mut hasher = Sha256::new();
436 hasher.update(data);
437 hex::encode(hasher.finalize())
438 }
439
440 #[test]
441 fn verify_model_success() -> Result<()> {
442 let temp = tempdir()?;
443 let digest = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa";
444 let model_dir = temp.path().join(format!("sha256-{digest}"));
445 fs::create_dir_all(&model_dir)?;
446
447 let model_bytes = b"ONNX";
448 let tokenizer_bytes = b"{}";
449
450 write_file(
451 &model_dir.join("models/encoder/model_int8.onnx"),
452 model_bytes,
453 )?;
454 write_file(
455 &model_dir.join("models/encoder/tokenizer.json"),
456 tokenizer_bytes,
457 )?;
458
459 let manifest = serde_json::json!({
460 "digest": format!("sha256:{digest}"),
461 "dims": 384,
462 "quant": "int8",
463 "files": [
464 {
465 "path": "models/encoder/model_int8.onnx",
466 "sha256": checksum_hex(model_bytes),
467 "roles": ["weights"],
468 },
469 {
470 "path": "models/encoder/tokenizer.json",
471 "sha256": checksum_hex(tokenizer_bytes),
472 }
473 ]
474 });
475 write_manifest(&model_dir.join("manifest.json"), &manifest)?;
476
477 let options = ModelVerifyOptions {
478 run_onnx_smoke: false,
479 };
480 let report = verify_model_dir(&model_dir, &options)?;
481 assert_eq!(report.digest, format!("sha256:{digest}"));
482 assert_eq!(report.status, ModelVerificationStatus::Ok);
483 assert_eq!(report.dims, Some(384));
484 assert!(report.errors.is_empty());
485 Ok(())
486 }
487
488 #[test]
489 fn verify_model_missing_optional_warns() -> Result<()> {
490 let temp = tempdir()?;
491 let digest = "bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb";
492 let model_dir = temp.path().join(format!("sha256-{digest}"));
493 fs::create_dir_all(&model_dir)?;
494
495 let model_bytes = b"ONNX";
496
497 write_file(&model_dir.join("models/model.onnx"), model_bytes)?;
498
499 let manifest = serde_json::json!({
500 "digest": format!("sha256:{digest}"),
501 "dims": 256,
502 "files": [
503 {
504 "path": "models/model.onnx",
505 "sha256": checksum_hex(model_bytes),
506 "roles": ["weights"],
507 },
508 {
509 "path": "models/tokenizer.json",
510 "sha256": checksum_hex(b"missing"),
511 "optional": true
512 }
513 ]
514 });
515 write_manifest(&model_dir.join("manifest.json"), &manifest)?;
516
517 let options = ModelVerifyOptions {
518 run_onnx_smoke: false,
519 };
520 let report = verify_model_dir(&model_dir, &options)?;
521 assert_eq!(report.status, ModelVerificationStatus::Warn);
522 assert!(report.errors.is_empty());
523 assert_eq!(report.warnings.len(), 1);
524 Ok(())
525 }
526
527 #[test]
528 fn verify_models_directory_listing() -> Result<()> {
529 let temp = tempdir()?;
530 let digest = "cccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccccc";
531 let model_dir = temp.path().join(format!("sha256-{digest}"));
532 fs::create_dir_all(&model_dir)?;
533
534 let model_bytes = b"ONNX";
535 fs::write(model_dir.join("model.onnx"), model_bytes)?;
536
537 let manifest = serde_json::json!({
538 "digest": format!("sha256:{digest}"),
539 "dims": 128,
540 "files": [
541 {
542 "path": "model.onnx",
543 "sha256": checksum_hex(model_bytes),
544 "roles": ["weights"],
545 }
546 ]
547 });
548 write_manifest(&model_dir.join("manifest.json"), &manifest)?;
549
550 let options = ModelVerifyOptions {
551 run_onnx_smoke: false,
552 };
553 let reports = verify_models(temp.path(), &options)?;
554 assert_eq!(reports.len(), 1);
555 assert_eq!(reports[0].digest, format!("sha256:{digest}"));
556 Ok(())
557 }
558}