use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[allow(clippy::struct_excessive_bools)]
pub struct IntegrityResult {
pub passed: bool,
pub config_found: bool,
pub layer_count_match: bool,
pub hidden_size_match: bool,
pub vocab_size_match: bool,
pub errors: Vec<String>,
pub config_values: Option<ConfigValues>,
pub tensor_values: Option<TensorDerivedValues>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ConfigValues {
pub num_hidden_layers: Option<usize>,
pub hidden_size: Option<usize>,
pub vocab_size: Option<usize>,
pub num_attention_heads: Option<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorDerivedValues {
pub layer_count: Option<usize>,
pub hidden_size: Option<usize>,
pub vocab_size: Option<usize>,
}
#[derive(Debug, Deserialize)]
struct HfConfig {
num_hidden_layers: Option<usize>,
hidden_size: Option<usize>,
vocab_size: Option<usize>,
num_attention_heads: Option<usize>,
}
#[must_use]
pub fn check_safetensors_integrity(model_dir: &Path) -> IntegrityResult {
let mut result = IntegrityResult {
passed: true,
config_found: false,
layer_count_match: true,
hidden_size_match: true,
vocab_size_match: true,
errors: Vec::new(),
config_values: None,
tensor_values: None,
};
let config_path = model_dir.join("config.json");
let config_path = if config_path.exists() {
config_path
} else {
find_config_json(model_dir).unwrap_or(config_path)
};
let config = match read_config(&config_path) {
Ok(cfg) => {
result.config_found = true;
result.config_values = Some(ConfigValues {
num_hidden_layers: cfg.num_hidden_layers,
hidden_size: cfg.hidden_size,
vocab_size: cfg.vocab_size,
num_attention_heads: cfg.num_attention_heads,
});
cfg
}
Err(e) => {
result.config_found = false;
result.passed = false;
result.errors.push(format!("G0-INTEGRITY-CONFIG: {e}"));
return result;
}
};
let safetensors_files = find_safetensors_files(model_dir);
if safetensors_files.is_empty() {
result.passed = false;
result
.errors
.push("G0-INTEGRITY-CONFIG: No .safetensors files found".to_string());
return result;
}
let mut all_tensors: HashMap<String, Vec<usize>> = HashMap::new();
for st_path in &safetensors_files {
match read_safetensors_metadata(st_path) {
Ok(tensors) => {
all_tensors.extend(tensors);
}
Err(e) => {
result.passed = false;
result.errors.push(format!(
"G0-INTEGRITY-CONFIG: Failed to read {}: {e}",
st_path.display()
));
return result;
}
}
}
let tensor_values = derive_values_from_tensors(&all_tensors);
result.tensor_values = Some(tensor_values.clone());
if let (Some(config_layers), Some(tensor_layers)) =
(config.num_hidden_layers, tensor_values.layer_count)
{
if config_layers != tensor_layers {
result.layer_count_match = false;
result.passed = false;
result.errors.push(format!(
"G0-INTEGRITY-LAYERS: config says {config_layers} layers but tensors have {tensor_layers}"
));
}
}
if let (Some(config_hidden), Some(tensor_hidden)) =
(config.hidden_size, tensor_values.hidden_size)
{
if config_hidden != tensor_hidden {
result.hidden_size_match = false;
result.passed = false;
result.errors.push(format!(
"G0-INTEGRITY-HIDDEN: config says hidden_size={config_hidden} but embedding has {tensor_hidden}"
));
}
}
if let (Some(config_vocab), Some(tensor_vocab)) = (config.vocab_size, tensor_values.vocab_size)
{
if config_vocab != tensor_vocab {
result.vocab_size_match = false;
result.passed = false;
result.errors.push(format!(
"G0-INTEGRITY-VOCAB: config says vocab_size={config_vocab} but embedding has {tensor_vocab}"
));
}
}
result
}
#[must_use]
pub fn check_safetensors_file_integrity(model_file: &Path) -> IntegrityResult {
let mut result = IntegrityResult {
passed: true,
config_found: false,
layer_count_match: true,
hidden_size_match: true,
vocab_size_match: true,
errors: Vec::new(),
config_values: None,
tensor_values: None,
};
let config_path = find_config_for_model_file(model_file);
let Some(config_path) = config_path else {
result.config_found = false;
result.passed = false;
result.errors.push(format!(
"G0-INTEGRITY-CONFIG: No config.json found for {}",
model_file.display()
));
return result;
};
let config = match read_config(&config_path) {
Ok(cfg) => {
result.config_found = true;
result.config_values = Some(ConfigValues {
num_hidden_layers: cfg.num_hidden_layers,
hidden_size: cfg.hidden_size,
vocab_size: cfg.vocab_size,
num_attention_heads: cfg.num_attention_heads,
});
cfg
}
Err(e) => {
result.config_found = false;
result.passed = false;
result.errors.push(format!("G0-INTEGRITY-CONFIG: {e}"));
return result;
}
};
let all_tensors = match read_safetensors_metadata(model_file) {
Ok(tensors) => tensors,
Err(e) => {
result.passed = false;
result.errors.push(format!(
"G0-INTEGRITY-CONFIG: Failed to read {}: {e}",
model_file.display()
));
return result;
}
};
let tensor_values = derive_values_from_tensors(&all_tensors);
result.tensor_values = Some(tensor_values.clone());
if let (Some(config_layers), Some(tensor_layers)) =
(config.num_hidden_layers, tensor_values.layer_count)
{
if config_layers != tensor_layers {
result.layer_count_match = false;
result.passed = false;
result.errors.push(format!(
"G0-INTEGRITY-LAYERS: config says {config_layers} layers but tensors have {tensor_layers}"
));
}
}
if let (Some(config_hidden), Some(tensor_hidden)) =
(config.hidden_size, tensor_values.hidden_size)
{
if config_hidden != tensor_hidden {
result.hidden_size_match = false;
result.passed = false;
result.errors.push(format!(
"G0-INTEGRITY-HIDDEN: config says hidden_size={config_hidden} but embedding has {tensor_hidden}"
));
}
}
if let (Some(config_vocab), Some(tensor_vocab)) = (config.vocab_size, tensor_values.vocab_size)
{
if config_vocab != tensor_vocab {
result.vocab_size_match = false;
result.passed = false;
result.errors.push(format!(
"G0-INTEGRITY-VOCAB: config says vocab_size={config_vocab} but embedding has {tensor_vocab}"
));
}
}
result
}
fn find_config_for_model_file(model_file: &Path) -> Option<std::path::PathBuf> {
let parent = model_file.parent()?;
let stem = model_file.file_name()?.to_str()?;
if let Some(hash_prefix) = stem.strip_suffix(".safetensors") {
let config_name = format!("{hash_prefix}.config.json");
let config_path = parent.join(&config_name);
if config_path.exists() {
return Some(config_path);
}
}
let config_path = parent.join("config.json");
if config_path.exists() {
return Some(config_path);
}
None
}
fn read_config(path: &Path) -> Result<HfConfig, String> {
let file = File::open(path).map_err(|e| format!("config.json not found or unreadable: {e}"))?;
let reader = BufReader::new(file);
serde_json::from_reader(reader).map_err(|e| format!("config.json parse error: {e}"))
}
fn find_config_json(dir: &Path) -> Option<std::path::PathBuf> {
let entries = std::fs::read_dir(dir).ok()?;
for entry in entries.flatten() {
let path = entry.path();
if let Some(name) = path.file_name().and_then(|n| n.to_str()) {
if name.ends_with(".config.json") {
return Some(path);
}
}
}
None
}
fn find_safetensors_files(dir: &Path) -> Vec<std::path::PathBuf> {
let mut files = Vec::new();
if let Ok(entries) = std::fs::read_dir(dir) {
for entry in entries.flatten() {
let path = entry.path();
if path.extension().is_some_and(|ext| ext == "safetensors") {
files.push(path);
}
}
}
files.sort(); files
}
const MAX_HEADER_SIZE: usize = 100 * 1024 * 1024;
fn read_safetensors_metadata(path: &Path) -> Result<HashMap<String, Vec<usize>>, String> {
let mut file = File::open(path).map_err(|e| format!("Failed to open safetensors file: {e}"))?;
let mut header_len_bytes = [0u8; 8];
file.read_exact(&mut header_len_bytes)
.map_err(|e| format!("Failed to read header length: {e}"))?;
let header_len = u64::from_le_bytes(header_len_bytes) as usize;
if header_len > MAX_HEADER_SIZE {
return Err(format!(
"Header size {header_len} exceeds maximum {MAX_HEADER_SIZE}"
));
}
let mut header_bytes = vec![0u8; header_len];
file.read_exact(&mut header_bytes)
.map_err(|e| format!("Failed to read header: {e}"))?;
let header_str = std::str::from_utf8(&header_bytes)
.map_err(|e| format!("Header is not valid UTF-8: {e}"))?;
let header: serde_json::Value =
serde_json::from_str(header_str).map_err(|e| format!("Header JSON parse error: {e}"))?;
let obj = header.as_object().ok_or("Header is not a JSON object")?;
let tensors = obj
.iter()
.filter(|(name, _)| *name != "__metadata__")
.filter_map(|(name, value)| {
let shape = value.as_object()?.get("shape")?.as_array()?;
let dims: Vec<usize> = shape
.iter()
.filter_map(|v| v.as_u64().map(|n| n as usize))
.collect();
Some((name.clone(), dims))
})
.collect();
Ok(tensors)
}
fn derive_values_from_tensors(tensors: &HashMap<String, Vec<usize>>) -> TensorDerivedValues {
let layer_count = derive_layer_count(tensors);
let (vocab_size, hidden_size) = find_embedding_shape(tensors);
TensorDerivedValues {
layer_count,
hidden_size,
vocab_size,
}
}
fn derive_layer_count(tensors: &HashMap<String, Vec<usize>>) -> Option<usize> {
tensors
.keys()
.filter_map(|name| extract_layer_number(name))
.max()
.map(|n| n + 1)
}
fn find_embedding_shape(tensors: &HashMap<String, Vec<usize>>) -> (Option<usize>, Option<usize>) {
let candidates = [
"model.embed_tokens.weight",
"embed_tokens.weight",
"transformer.wte.weight",
"wte.weight",
"lm_head.weight",
"model.lm_head.weight",
];
for name in candidates {
if let Some(shape) = tensors.get(name) {
if shape.len() >= 2 {
return (Some(shape[0]), Some(shape[1]));
}
}
}
(None, None)
}
fn extract_layer_number(name: &str) -> Option<usize> {
let patterns = ["layers.", "h.", "transformer.h."];
for pattern in patterns {
if let Some(idx) = name.find(pattern) {
let rest = &name[idx + pattern.len()..];
let num_str: String = rest.chars().take_while(char::is_ascii_digit).collect();
if let Ok(num) = num_str.parse::<usize>() {
return Some(num);
}
}
}
None
}
pub mod gate_ids {
pub const CONFIG: &str = "G0-INTEGRITY-CONFIG";
pub const LAYERS: &str = "G0-INTEGRITY-LAYERS";
pub const HIDDEN: &str = "G0-INTEGRITY-HIDDEN";
pub const VOCAB: &str = "G0-INTEGRITY-VOCAB";
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn create_test_config(dir: &Path, layers: usize, hidden: usize, vocab: usize) {
let config = format!(
r#"{{
"num_hidden_layers": {layers},
"hidden_size": {hidden},
"vocab_size": {vocab},
"num_attention_heads": 12
}}"#
);
let path = dir.join("config.json");
std::fs::write(path, config).expect("write config");
}
fn create_mock_safetensors(dir: &Path, layers: usize, hidden: usize, vocab: usize) {
let mut header_obj = serde_json::Map::new();
let mut embed_info = serde_json::Map::new();
embed_info.insert("shape".to_string(), serde_json::json!([vocab, hidden]));
embed_info.insert(
"dtype".to_string(),
serde_json::Value::String("F32".to_string()),
);
embed_info.insert(
"data_offsets".to_string(),
serde_json::json!([0, vocab * hidden * 4]),
);
header_obj.insert(
"model.embed_tokens.weight".to_string(),
serde_json::Value::Object(embed_info),
);
for i in 0..layers {
let mut layer_info = serde_json::Map::new();
layer_info.insert("shape".to_string(), serde_json::json!([hidden, hidden]));
layer_info.insert(
"dtype".to_string(),
serde_json::Value::String("F32".to_string()),
);
layer_info.insert("data_offsets".to_string(), serde_json::json!([0, 0]));
header_obj.insert(
format!("model.layers.{i}.self_attn.q_proj.weight"),
serde_json::Value::Object(layer_info),
);
}
let header_json = serde_json::to_string(&header_obj).expect("serialize header");
let header_bytes = header_json.as_bytes();
let header_len = header_bytes.len() as u64;
let path = dir.join("model.safetensors");
let mut file = File::create(path).expect("create safetensors");
file.write_all(&header_len.to_le_bytes())
.expect("write len");
file.write_all(header_bytes).expect("write header");
file.write_all(&[0u8; 1024]).expect("write data");
}
#[test]
fn test_integrity_check_all_match() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 24, 896, 151_936);
create_mock_safetensors(dir.path(), 24, 896, 151_936);
let result = check_safetensors_integrity(dir.path());
assert!(
result.passed,
"Should pass when all values match: {:?}",
result.errors
);
assert!(result.config_found);
assert!(result.layer_count_match);
assert!(result.hidden_size_match);
assert!(result.vocab_size_match);
assert!(result.errors.is_empty());
}
#[test]
fn test_integrity_check_layer_mismatch() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 14, 896, 151_936);
create_mock_safetensors(dir.path(), 24, 896, 151_936);
let result = check_safetensors_integrity(dir.path());
assert!(!result.passed, "Should fail on layer mismatch");
assert!(!result.layer_count_match);
assert!(
result
.errors
.iter()
.any(|e| e.contains("G0-INTEGRITY-LAYERS"))
);
}
#[test]
fn test_integrity_check_hidden_size_mismatch() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 24, 4096, 151_936);
create_mock_safetensors(dir.path(), 24, 896, 151_936);
let result = check_safetensors_integrity(dir.path());
assert!(!result.passed, "Should fail on hidden_size mismatch");
assert!(!result.hidden_size_match);
assert!(
result
.errors
.iter()
.any(|e| e.contains("G0-INTEGRITY-HIDDEN"))
);
}
#[test]
fn test_integrity_check_vocab_size_mismatch() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 24, 896, 896);
create_mock_safetensors(dir.path(), 24, 896, 151_936);
let result = check_safetensors_integrity(dir.path());
assert!(!result.passed, "Should fail on vocab_size mismatch");
assert!(!result.vocab_size_match);
assert!(
result
.errors
.iter()
.any(|e| e.contains("G0-INTEGRITY-VOCAB"))
);
}
#[test]
fn test_integrity_check_missing_config() {
let dir = TempDir::new().expect("create temp dir");
create_mock_safetensors(dir.path(), 24, 896, 151_936);
let result = check_safetensors_integrity(dir.path());
assert!(!result.passed, "Should fail when config.json missing");
assert!(!result.config_found);
assert!(
result
.errors
.iter()
.any(|e| e.contains("G0-INTEGRITY-CONFIG"))
);
}
#[test]
fn test_integrity_check_no_safetensors() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 24, 896, 151_936);
let result = check_safetensors_integrity(dir.path());
assert!(!result.passed, "Should fail when no .safetensors files");
assert!(
result
.errors
.iter()
.any(|e| e.contains("No .safetensors files"))
);
}
#[test]
fn test_integrity_check_multiple_mismatches() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 14, 4096, 896);
create_mock_safetensors(dir.path(), 24, 896, 151_936);
let result = check_safetensors_integrity(dir.path());
assert!(!result.passed, "Should fail on multiple mismatches");
assert!(!result.layer_count_match);
assert!(!result.hidden_size_match);
assert!(!result.vocab_size_match);
assert_eq!(result.errors.len(), 3, "Should have 3 error messages");
}
#[test]
fn test_extract_layer_number() {
assert_eq!(
extract_layer_number("model.layers.23.self_attn.q_proj.weight"),
Some(23)
);
assert_eq!(
extract_layer_number("layers.0.mlp.gate_proj.weight"),
Some(0)
);
assert_eq!(extract_layer_number("h.15.attn.c_attn.weight"), Some(15));
assert_eq!(extract_layer_number("transformer.h.7.mlp.weight"), Some(7));
assert_eq!(extract_layer_number("model.embed_tokens.weight"), None);
assert_eq!(extract_layer_number("lm_head.weight"), None);
}
#[test]
fn test_config_values_serialization() {
let values = ConfigValues {
num_hidden_layers: Some(24),
hidden_size: Some(896),
vocab_size: Some(151_936),
num_attention_heads: Some(14),
};
let json = serde_json::to_string(&values).expect("serialize");
assert!(json.contains("24"));
assert!(json.contains("896"));
}
#[test]
fn test_tensor_derived_values_serialization() {
let values = TensorDerivedValues {
layer_count: Some(24),
hidden_size: Some(896),
vocab_size: Some(151_936),
};
let json = serde_json::to_string(&values).expect("serialize");
assert!(json.contains("24"));
assert!(json.contains("151936"));
}
#[test]
fn test_integrity_result_serialization() {
let result = IntegrityResult {
passed: false,
config_found: true,
layer_count_match: false,
hidden_size_match: true,
vocab_size_match: true,
errors: vec!["G0-INTEGRITY-LAYERS: mismatch".to_string()],
config_values: None,
tensor_values: None,
};
let json = serde_json::to_string(&result).expect("serialize");
assert!(json.contains("G0-INTEGRITY-LAYERS"));
}
#[test]
fn test_gate_ids() {
assert_eq!(gate_ids::CONFIG, "G0-INTEGRITY-CONFIG");
assert_eq!(gate_ids::LAYERS, "G0-INTEGRITY-LAYERS");
assert_eq!(gate_ids::HIDDEN, "G0-INTEGRITY-HIDDEN");
assert_eq!(gate_ids::VOCAB, "G0-INTEGRITY-VOCAB");
}
#[test]
fn test_integrity_result_debug() {
let result = IntegrityResult {
passed: true,
config_found: true,
layer_count_match: true,
hidden_size_match: true,
vocab_size_match: true,
errors: vec![],
config_values: None,
tensor_values: None,
};
let debug_str = format!("{result:?}");
assert!(debug_str.contains("IntegrityResult"));
}
#[test]
fn test_config_values_debug() {
let values = ConfigValues {
num_hidden_layers: Some(24),
hidden_size: Some(896),
vocab_size: Some(151_936),
num_attention_heads: Some(14),
};
let debug_str = format!("{values:?}");
assert!(debug_str.contains("ConfigValues"));
}
#[test]
fn test_tensor_derived_values_debug() {
let values = TensorDerivedValues {
layer_count: Some(24),
hidden_size: Some(896),
vocab_size: Some(151_936),
};
let debug_str = format!("{values:?}");
assert!(debug_str.contains("TensorDerivedValues"));
}
#[test]
fn test_integrity_result_clone() {
let result = IntegrityResult {
passed: true,
config_found: true,
layer_count_match: true,
hidden_size_match: true,
vocab_size_match: true,
errors: vec!["test".to_string()],
config_values: Some(ConfigValues {
num_hidden_layers: Some(24),
hidden_size: Some(896),
vocab_size: Some(151_936),
num_attention_heads: Some(14),
}),
tensor_values: Some(TensorDerivedValues {
layer_count: Some(24),
hidden_size: Some(896),
vocab_size: Some(151_936),
}),
};
let cloned = result.clone();
assert_eq!(cloned.passed, result.passed);
assert_eq!(cloned.errors.len(), result.errors.len());
}
#[test]
fn test_read_safetensors_corrupted_file() {
let dir = TempDir::new().expect("create temp dir");
let path = dir.path().join("corrupt.safetensors");
std::fs::write(&path, b"short").expect("write corrupt");
let result = read_safetensors_metadata(&path);
assert!(result.is_err());
}
#[test]
fn test_read_safetensors_oversized_header() {
let dir = TempDir::new().expect("create temp dir");
let path = dir.path().join("oversize.safetensors");
let mut file = std::fs::File::create(&path).expect("create file");
let huge: u64 = 200_000_000;
file.write_all(&huge.to_le_bytes()).expect("write len");
drop(file);
let result = read_safetensors_metadata(&path);
assert!(result.is_err());
assert!(result.unwrap_err().contains("exceeds maximum"));
}
#[test]
fn test_read_safetensors_with_metadata_key() {
let dir = TempDir::new().expect("create temp dir");
let mut header_obj = serde_json::Map::new();
header_obj.insert(
"__metadata__".to_string(),
serde_json::json!({"format": "pt"}),
);
let mut tensor_info = serde_json::Map::new();
tensor_info.insert("shape".to_string(), serde_json::json!([100, 50]));
tensor_info.insert(
"dtype".to_string(),
serde_json::Value::String("F32".to_string()),
);
tensor_info.insert("data_offsets".to_string(), serde_json::json!([0, 20000]));
header_obj.insert(
"model.weight".to_string(),
serde_json::Value::Object(tensor_info),
);
let header_json = serde_json::to_string(&header_obj).expect("serialize header");
let header_bytes = header_json.as_bytes();
let header_len = header_bytes.len() as u64;
let path = dir.path().join("model.safetensors");
let mut file = std::fs::File::create(&path).expect("create file");
file.write_all(&header_len.to_le_bytes())
.expect("write len");
file.write_all(header_bytes).expect("write header");
file.write_all(&[0u8; 128]).expect("write data padding");
drop(file);
let tensors = read_safetensors_metadata(&path).expect("should parse");
assert!(!tensors.contains_key("__metadata__"));
assert!(tensors.contains_key("model.weight"));
assert_eq!(tensors["model.weight"], vec![100, 50]);
}
#[test]
fn test_derive_values_from_lm_head_fallback() {
let mut tensors = HashMap::new();
tensors.insert("lm_head.weight".to_string(), vec![32000, 4096]);
tensors.insert(
"model.layers.0.self_attn.q_proj.weight".to_string(),
vec![4096, 4096],
);
tensors.insert(
"model.layers.1.self_attn.q_proj.weight".to_string(),
vec![4096, 4096],
);
let values = derive_values_from_tensors(&tensors);
assert_eq!(values.vocab_size, Some(32000));
assert_eq!(values.hidden_size, Some(4096));
assert_eq!(values.layer_count, Some(2));
}
#[test]
fn test_derive_values_model_lm_head_fallback() {
let mut tensors = HashMap::new();
tensors.insert("model.lm_head.weight".to_string(), vec![50_000, 768]);
let values = derive_values_from_tensors(&tensors);
assert_eq!(values.vocab_size, Some(50_000));
assert_eq!(values.hidden_size, Some(768));
}
#[test]
fn test_check_safetensors_integrity_read_error() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 12, 768, 30_000);
let path = dir.path().join("model.safetensors");
std::fs::write(&path, b"bad").expect("write corrupt");
let result = check_safetensors_integrity(dir.path());
assert!(!result.passed);
assert!(
result
.errors
.iter()
.any(|e| e.contains("G0-INTEGRITY-CONFIG"))
);
}
#[test]
fn test_check_safetensors_integrity_hidden_size_mismatch() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 2, 1024, 30_000);
create_mock_safetensors(dir.path(), 2, 768, 30_000);
let result = check_safetensors_integrity(dir.path());
assert!(!result.passed);
assert!(result.errors.iter().any(|e| e.contains("HIDDEN")));
}
#[test]
fn test_check_safetensors_integrity_vocab_size_mismatch() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 2, 768, 50_000);
create_mock_safetensors(dir.path(), 2, 768, 30_000);
let result = check_safetensors_integrity(dir.path());
assert!(!result.passed);
assert!(result.errors.iter().any(|e| e.contains("VOCAB")));
}
fn create_named_config(dir: &Path, name: &str, layers: usize, hidden: usize, vocab: usize) {
let config = format!(
r#"{{
"num_hidden_layers": {layers},
"hidden_size": {hidden},
"vocab_size": {vocab},
"num_attention_heads": 12
}}"#
);
std::fs::write(dir.join(name), config).expect("write config");
}
fn create_named_safetensors(
dir: &Path,
name: &str,
layers: usize,
hidden: usize,
vocab: usize,
) {
let mut header_obj = serde_json::Map::new();
let mut embed_info = serde_json::Map::new();
embed_info.insert("shape".to_string(), serde_json::json!([vocab, hidden]));
embed_info.insert(
"dtype".to_string(),
serde_json::Value::String("F32".to_string()),
);
embed_info.insert(
"data_offsets".to_string(),
serde_json::json!([0, vocab * hidden * 4]),
);
header_obj.insert(
"model.embed_tokens.weight".to_string(),
serde_json::Value::Object(embed_info),
);
for i in 0..layers {
let mut layer_info = serde_json::Map::new();
layer_info.insert("shape".to_string(), serde_json::json!([hidden, hidden]));
layer_info.insert(
"dtype".to_string(),
serde_json::Value::String("F32".to_string()),
);
layer_info.insert("data_offsets".to_string(), serde_json::json!([0, 0]));
header_obj.insert(
format!("model.layers.{i}.self_attn.q_proj.weight"),
serde_json::Value::Object(layer_info),
);
}
let header_json = serde_json::to_string(&header_obj).expect("serialize header");
let header_bytes = header_json.as_bytes();
let header_len = header_bytes.len() as u64;
let path = dir.join(name);
let mut file = File::create(path).expect("create safetensors");
file.write_all(&header_len.to_le_bytes())
.expect("write len");
file.write_all(header_bytes).expect("write header");
file.write_all(&[0u8; 1024]).expect("write data");
}
#[test]
fn test_file_integrity_with_hash_prefix_config() {
let dir = TempDir::new().expect("create temp dir");
create_named_config(dir.path(), "abc123.config.json", 24, 896, 151_936);
create_named_safetensors(dir.path(), "abc123.safetensors", 24, 896, 151_936);
let model_file = dir.path().join("abc123.safetensors");
let result = check_safetensors_file_integrity(&model_file);
assert!(
result.passed,
"Should pass with hash-prefixed config: {:?}",
result.errors
);
assert!(result.config_found);
assert!(result.layer_count_match);
}
#[test]
fn test_file_integrity_ignores_other_models_in_shared_dir() {
let dir = TempDir::new().expect("create temp dir");
create_named_config(dir.path(), "aaa111.config.json", 24, 896, 151_936);
create_named_safetensors(dir.path(), "aaa111.safetensors", 24, 896, 151_936);
create_named_config(dir.path(), "bbb222.config.json", 28, 3584, 151_936);
create_named_safetensors(dir.path(), "bbb222.safetensors", 28, 3584, 151_936);
let model_file = dir.path().join("aaa111.safetensors");
let result = check_safetensors_file_integrity(&model_file);
assert!(
result.passed,
"Must use only aaa111's config and tensors, not bbb222's: {:?}",
result.errors
);
assert_eq!(
result.tensor_values.as_ref().unwrap().layer_count,
Some(24),
"Should see 24 layers from aaa111, not 28 from bbb222"
);
}
#[test]
fn test_file_integrity_no_config_found() {
let dir = TempDir::new().expect("create temp dir");
create_named_safetensors(dir.path(), "orphan.safetensors", 12, 768, 30_000);
let model_file = dir.path().join("orphan.safetensors");
let result = check_safetensors_file_integrity(&model_file);
assert!(!result.passed);
assert!(!result.config_found);
assert!(
result
.errors
.iter()
.any(|e| e.contains("G0-INTEGRITY-CONFIG"))
);
}
#[test]
fn test_file_integrity_falls_back_to_plain_config() {
let dir = TempDir::new().expect("create temp dir");
create_test_config(dir.path(), 24, 896, 151_936);
create_named_safetensors(dir.path(), "model.safetensors", 24, 896, 151_936);
let model_file = dir.path().join("model.safetensors");
let result = check_safetensors_file_integrity(&model_file);
assert!(
result.passed,
"Should fall back to config.json: {:?}",
result.errors
);
}
#[test]
fn test_file_integrity_layer_mismatch() {
let dir = TempDir::new().expect("create temp dir");
create_named_config(dir.path(), "bad.config.json", 14, 896, 151_936);
create_named_safetensors(dir.path(), "bad.safetensors", 24, 896, 151_936);
let model_file = dir.path().join("bad.safetensors");
let result = check_safetensors_file_integrity(&model_file);
assert!(!result.passed);
assert!(!result.layer_count_match);
assert!(result.errors.iter().any(|e| e.contains("LAYERS")));
}
#[test]
fn test_find_config_for_model_file_hash_prefix() {
let dir = TempDir::new().expect("create temp dir");
create_named_config(dir.path(), "d71534cb.config.json", 24, 896, 151_936);
create_named_safetensors(dir.path(), "d71534cb.safetensors", 24, 896, 151_936);
let result = find_config_for_model_file(&dir.path().join("d71534cb.safetensors"));
assert!(result.is_some());
assert!(
result
.unwrap()
.file_name()
.unwrap()
.to_str()
.unwrap()
.contains("d71534cb.config.json")
);
}
#[test]
fn test_find_config_for_model_file_no_match() {
let dir = TempDir::new().expect("create temp dir");
create_named_safetensors(dir.path(), "noconf.safetensors", 2, 768, 30_000);
let result = find_config_for_model_file(&dir.path().join("noconf.safetensors"));
assert!(result.is_none());
}
}