use std::collections::HashMap;
use std::path::{Path, PathBuf};
use serde::{Deserialize, Serialize};
use crate::model::{Component, ComponentType, HashAlgorithm, NormalizedSbom};
use crate::verification::verify_file_hash;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ModelVerifyResult {
Verified,
Mismatch,
Missing,
NoHash,
}
impl ModelVerifyResult {
#[must_use]
pub const fn label(&self) -> &'static str {
match self {
Self::Verified => "VERIFIED",
Self::Mismatch => "MISMATCH",
Self::Missing => "MISSING",
Self::NoHash => "NO-HASH",
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ComponentModelVerification {
pub name: String,
pub version: Option<String>,
pub result: ModelVerifyResult,
pub hash: Option<String>,
pub file: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelVerifyReport {
pub model_dir: String,
pub total_models: usize,
pub verified_count: usize,
pub mismatch_count: usize,
pub missing_count: usize,
pub no_hash_count: usize,
pub components: Vec<ComponentModelVerification>,
}
impl ModelVerifyReport {
#[must_use]
pub const fn has_failures(&self) -> bool {
self.mismatch_count > 0 || self.missing_count > 0
}
}
const fn is_verifiable(alg: &HashAlgorithm) -> bool {
matches!(alg, HashAlgorithm::Sha256 | HashAlgorithm::Sha512)
}
#[must_use]
pub fn verify_model_dir(sbom: &NormalizedSbom, model_dir: &Path) -> ModelVerifyReport {
let root = std::fs::canonicalize(model_dir).unwrap_or_else(|_| model_dir.to_path_buf());
let index = FileIndex::build(&root);
let mut report = ModelVerifyReport {
model_dir: model_dir.display().to_string(),
total_models: 0,
verified_count: 0,
mismatch_count: 0,
missing_count: 0,
no_hash_count: 0,
components: Vec::new(),
};
for component in sbom.components.values() {
if !is_model_like(component) {
continue;
}
report.total_models += 1;
let record = verify_component(component, &root, &index);
match record.result {
ModelVerifyResult::Verified => report.verified_count += 1,
ModelVerifyResult::Mismatch => report.mismatch_count += 1,
ModelVerifyResult::Missing => report.missing_count += 1,
ModelVerifyResult::NoHash => report.no_hash_count += 1,
}
report.components.push(record);
}
report
}
fn is_model_like(component: &Component) -> bool {
matches!(
component.component_type,
ComponentType::MachineLearningModel | ComponentType::Data
)
}
fn verify_component(
component: &Component,
model_dir: &Path,
index: &FileIndex,
) -> ComponentModelVerification {
let make = |result, hash: Option<String>, file: Option<String>| ComponentModelVerification {
name: component.name.clone(),
version: component.version.clone(),
result,
hash,
file,
};
let verifiable: Vec<_> = component
.hashes
.iter()
.filter(|h| is_verifiable(&h.algorithm))
.collect();
if verifiable.is_empty() {
return make(ModelVerifyResult::NoHash, None, None);
}
let name_candidates = filename_candidates(component);
let mut last_missing_hash: Option<String> = None;
for hash in verifiable {
let hash_hex = hash.value.to_lowercase();
last_missing_hash = Some(hash_hex.clone());
if let Some(path) = index.by_basename(&hash_hex) {
return verify_against(component, &hash_hex, path, model_dir);
}
for candidate in &name_candidates {
if let Some(path) = index.by_basename(candidate) {
return verify_against(component, &hash_hex, path, model_dir);
}
}
}
make(ModelVerifyResult::Missing, last_missing_hash, None)
}
fn verify_against(
component: &Component,
hash_hex: &str,
path: &Path,
model_dir: &Path,
) -> ComponentModelVerification {
let rel = path
.strip_prefix(model_dir)
.unwrap_or(path)
.display()
.to_string();
let make = |result| ComponentModelVerification {
name: component.name.clone(),
version: component.version.clone(),
result,
hash: Some(hash_hex.to_string()),
file: Some(rel.clone()),
};
match verify_file_hash(path, hash_hex) {
Ok(r) if r.verified => make(ModelVerifyResult::Verified),
Ok(_) => make(ModelVerifyResult::Mismatch),
Err(_) => make(ModelVerifyResult::Mismatch),
}
}
fn filename_candidates(component: &Component) -> Vec<String> {
let exts = [
"safetensors",
"bin",
"pt",
"pth",
"onnx",
"gguf",
"ggml",
"h5",
"pb",
"tflite",
];
let stems = ["model", "pytorch_model", component.name.as_str()];
let mut out = Vec::new();
for stem in stems {
if stem.is_empty() {
continue;
}
for ext in exts {
out.push(format!("{stem}.{ext}"));
}
}
out
}
struct FileIndex {
by_name: HashMap<String, PathBuf>,
}
impl FileIndex {
fn build(root: &Path) -> Self {
let mut by_name = HashMap::new();
let mut stack = vec![root.to_path_buf()];
let mut visited: std::collections::HashSet<PathBuf> = std::collections::HashSet::new();
while let Some(dir) = stack.pop() {
if !visited.insert(dir.clone()) {
continue;
}
let Ok(entries) = std::fs::read_dir(&dir) else {
continue;
};
for entry in entries.flatten() {
let path = entry.path();
let Ok(resolved) = std::fs::canonicalize(&path) else {
continue;
};
if !resolved.starts_with(root) {
continue;
}
let meta = match std::fs::metadata(&resolved) {
Ok(m) => m,
Err(_) => continue,
};
if meta.is_dir() {
stack.push(resolved);
} else if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
by_name
.entry(name.to_lowercase())
.or_insert_with(|| resolved.clone());
}
}
}
Self { by_name }
}
fn by_basename(&self, name: &str) -> Option<&Path> {
self.by_name.get(&name.to_lowercase()).map(PathBuf::as_path)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::model::{DocumentMetadata, Hash};
use sha2::{Digest, Sha256};
use std::fs;
fn sha256_hex(bytes: &[u8]) -> String {
let mut h = Sha256::new();
h.update(bytes);
h.finalize().iter().map(|b| format!("{b:02x}")).collect()
}
fn model_component(name: &str, hash_hex: &str) -> Component {
let mut c = Component::new(name.to_string(), format!("{name}-ref"))
.with_version("1.0.0".to_string());
c.component_type = ComponentType::MachineLearningModel;
c.hashes
.push(Hash::new(HashAlgorithm::Sha256, hash_hex.to_string()));
c
}
#[test]
fn verifies_against_hf_blob_named_by_sha256() {
let dir = tempfile::tempdir().unwrap();
let weights = b"fake model weights";
let hex = sha256_hex(weights);
let blobs = dir.path().join("blobs");
fs::create_dir_all(&blobs).unwrap();
fs::write(blobs.join(&hex), weights).unwrap();
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(model_component("bert", &hex));
let report = verify_model_dir(&sbom, dir.path());
assert_eq!(report.total_models, 1);
assert_eq!(report.verified_count, 1);
assert_eq!(report.components[0].result, ModelVerifyResult::Verified);
assert!(!report.has_failures());
}
#[test]
fn verifies_against_direct_filename() {
let dir = tempfile::tempdir().unwrap();
let weights = b"safetensors bytes";
let hex = sha256_hex(weights);
fs::write(dir.path().join("model.safetensors"), weights).unwrap();
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(model_component("bert", &hex));
let report = verify_model_dir(&sbom, dir.path());
assert_eq!(report.verified_count, 1);
assert_eq!(
report.components[0].file.as_deref(),
Some("model.safetensors")
);
}
#[test]
fn detects_tampering_as_mismatch() {
let dir = tempfile::tempdir().unwrap();
fs::write(dir.path().join("model.safetensors"), b"tampered bytes").unwrap();
let claimed = sha256_hex(b"original bytes");
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(model_component("bert", &claimed));
let report = verify_model_dir(&sbom, dir.path());
assert_eq!(report.mismatch_count, 1);
assert_eq!(report.components[0].result, ModelVerifyResult::Mismatch);
assert!(report.has_failures());
}
#[test]
fn reports_missing_when_no_file_found() {
let dir = tempfile::tempdir().unwrap();
let hex = sha256_hex(b"weights that are not on disk");
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(model_component("bert", &hex));
let report = verify_model_dir(&sbom, dir.path());
assert_eq!(report.missing_count, 1);
assert_eq!(report.components[0].result, ModelVerifyResult::Missing);
}
#[test]
fn reports_no_hash_when_only_weak_hash_present() {
let dir = tempfile::tempdir().unwrap();
let mut c = Component::new("bert".to_string(), "bert-ref".to_string());
c.component_type = ComponentType::MachineLearningModel;
c.hashes
.push(Hash::new(HashAlgorithm::Md5, "deadbeef".to_string()));
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(c);
let report = verify_model_dir(&sbom, dir.path());
assert_eq!(report.no_hash_count, 1);
assert_eq!(report.components[0].result, ModelVerifyResult::NoHash);
}
#[cfg(unix)]
#[test]
fn does_not_follow_symlink_escaping_model_dir() {
use std::os::unix::fs::symlink;
let outside = tempfile::tempdir().unwrap();
let weights = b"weights that live outside the model dir";
let hex = sha256_hex(weights);
let secret = outside.path().join("model.safetensors");
fs::write(&secret, weights).unwrap();
let model_dir = tempfile::tempdir().unwrap();
symlink(&secret, model_dir.path().join("model.safetensors")).unwrap();
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(model_component("escape", &hex));
let report = verify_model_dir(&sbom, model_dir.path());
assert_eq!(report.total_models, 1);
assert_eq!(
report.verified_count, 0,
"a symlink escaping the model dir must not be followed/verified"
);
assert_eq!(
report.components[0].result,
ModelVerifyResult::Missing,
"out-of-tree symlink target is treated as no in-tree file found"
);
}
#[cfg(unix)]
#[test]
fn follows_intra_tree_symlink_like_hf_cache() {
use std::os::unix::fs::symlink;
let dir = tempfile::tempdir().unwrap();
let weights = b"in-tree hf blob bytes";
let hex = sha256_hex(weights);
let blobs = dir.path().join("blobs");
let snapshots = dir.path().join("snapshots").join("main");
fs::create_dir_all(&blobs).unwrap();
fs::create_dir_all(&snapshots).unwrap();
let blob = blobs.join(&hex);
fs::write(&blob, weights).unwrap();
symlink(&blob, snapshots.join("model.safetensors")).unwrap();
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(model_component("bert", &hex));
let report = verify_model_dir(&sbom, dir.path());
assert_eq!(
report.verified_count, 1,
"intra-tree HF snapshot→blob symlink must still verify"
);
}
#[test]
fn ignores_non_model_components() {
let dir = tempfile::tempdir().unwrap();
let mut c = Component::new("lib".to_string(), "lib-ref".to_string());
c.component_type = ComponentType::Library;
c.hashes
.push(Hash::new(HashAlgorithm::Sha256, "a".repeat(64)));
let mut sbom = NormalizedSbom::new(DocumentMetadata::default());
sbom.add_component(c);
let report = verify_model_dir(&sbom, dir.path());
assert_eq!(report.total_models, 0, "library components are not models");
}
}