use crate::command::CommandRunner;
use crate::evidence::Evidence;
use apr_qa_gen::{Backend, Format, Modality, ModelId, QaScenario};
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
use std::path::{Path, PathBuf};
use std::sync::Arc;
const CONTRACT_YAML: &str = include_str!("apr_format_contract.yaml");
#[derive(Debug, Clone, Deserialize)]
pub struct FormatContract {
pub version: String,
pub tensor_naming: TensorNamingContract,
pub dtype_bytes: DtypeByteSection,
pub tolerances: Vec<ToleranceEntry>,
pub invariants: Vec<InvariantDef>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TensorNamingContract {
pub convention: String,
pub description: String,
pub examples: Vec<NamingExample>,
pub pattern: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct NamingExample {
pub canonical: String,
pub forbidden: String,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DtypeByteSection {
pub description: String,
pub mappings: Vec<DtypeByteEntry>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct DtypeByteEntry {
pub dtype: String,
pub byte: u8,
}
#[derive(Debug, Clone, Deserialize)]
pub struct ToleranceEntry {
pub dtype: String,
pub atol: f64,
pub rtol: f64,
}
#[derive(Debug, Clone, Deserialize)]
pub struct InvariantDef {
pub id: String,
pub name: String,
pub description: String,
pub catches: Vec<String>,
pub gate_id: String,
#[serde(default)]
pub test: Option<String>,
#[serde(default)]
pub implemented: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContractTestConfig {
#[serde(default = "default_invariants")]
pub invariants: Vec<String>,
}
fn default_invariants() -> Vec<String> {
vec![
"I-2".to_string(),
"I-3".to_string(),
"I-4".to_string(),
"I-5".to_string(),
]
}
impl Default for ContractTestConfig {
fn default() -> Self {
Self {
invariants: default_invariants(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum InvariantId {
I1,
I2,
I3,
I4,
I5,
}
impl InvariantId {
#[must_use]
pub fn from_label(label: &str) -> Option<Self> {
match label {
"I-1" => Some(Self::I1),
"I-2" => Some(Self::I2),
"I-3" => Some(Self::I3),
"I-4" => Some(Self::I4),
"I-5" => Some(Self::I5),
_ => None,
}
}
#[must_use]
pub fn gate_id(self) -> &'static str {
match self {
Self::I1 => "F-CONTRACT-I1-001",
Self::I2 => "F-CONTRACT-I2-001",
Self::I3 => "F-CONTRACT-I3-001",
Self::I4 => "F-CONTRACT-I4-001",
Self::I5 => "F-CONTRACT-I5-001",
}
}
}
pub fn load_format_contract() -> crate::error::Result<FormatContract> {
serde_yaml::from_str(CONTRACT_YAML).map_err(crate::error::Error::from)
}
pub fn validate_dtype_bytes(contract: &FormatContract) -> crate::error::Result<()> {
let mut seen = HashSet::new();
for entry in &contract.dtype_bytes.mappings {
if !seen.insert(entry.byte) {
return Err(crate::error::Error::Execution(format!(
"Duplicate GGML byte value {} for dtype {}",
entry.byte, entry.dtype
)));
}
}
Ok(())
}
#[must_use]
pub fn validate_tensor_name(name: &str, contract: &FormatContract) -> bool {
is_valid_tensor_name(name, &contract.tensor_naming.pattern)
}
fn is_valid_tensor_name(name: &str, _pattern: &str) -> bool {
let parts: Vec<&str> = name.split('.').collect();
match parts.len() {
2 => {
matches!(parts[0], "token_embd" | "output_norm" | "output") && is_word(parts[1])
}
3 => {
parts[0].chars().all(|c| c.is_ascii_digit())
&& !parts[0].is_empty()
&& is_word(parts[1])
&& is_word(parts[2])
}
_ => false,
}
}
fn is_word(s: &str) -> bool {
!s.is_empty() && s.chars().all(|c| c.is_alphanumeric() || c == '_')
}
#[must_use]
pub fn lookup_tolerance(dtype: &str, contract: &FormatContract) -> Option<(f64, f64)> {
contract
.tolerances
.iter()
.find(|t| t.dtype == dtype)
.map(|t| (t.atol, t.rtol))
}
pub fn run_contract_tests(
runner: &Arc<dyn CommandRunner>,
model_path: &Path,
model_id: &ModelId,
config: &ContractTestConfig,
) -> Vec<Evidence> {
let mut evidence = Vec::new();
let contract = match load_format_contract() {
Ok(c) => c,
Err(e) => {
evidence.push(Evidence::falsified(
"F-CONTRACT-LOAD-001",
contract_scenario(model_id),
format!("Failed to load format contract: {e}"),
"N/A",
0,
));
return evidence;
}
};
for label in &config.invariants {
let Some(inv_id) = InvariantId::from_label(label) else {
continue;
};
if inv_id == InvariantId::I1 {
continue;
}
let inv_def = contract.invariants.iter().find(|i| i.id == *label);
let gate_id = inv_def.map_or_else(|| inv_id.gate_id(), |d| d.gate_id.as_str());
let ev = match inv_id {
InvariantId::I1 => unreachable!(),
InvariantId::I2 => run_i2_tensor_bijection(runner, model_path, model_id, gate_id),
InvariantId::I3 => run_i3_no_silent_fallbacks(runner, model_path, model_id, gate_id),
InvariantId::I4 => {
run_i4_statistical_preservation(runner, model_path, model_id, gate_id)
}
InvariantId::I5 => run_i5_tokenizer_roundtrip(runner, model_path, model_id, gate_id),
};
evidence.push(ev);
}
evidence
}
fn resolve_apr_path(model_path: &Path) -> PathBuf {
model_path.join("apr").join("model.apr")
}
fn resolve_safetensors_path(model_path: &Path) -> PathBuf {
model_path.join("safetensors").join("model.safetensors")
}
fn run_i2_tensor_bijection(
runner: &Arc<dyn CommandRunner>,
model_path: &Path,
model_id: &ModelId,
gate_id: &str,
) -> Evidence {
let st_path = resolve_safetensors_path(model_path);
let apr_path = resolve_apr_path(model_path);
let st_inspect = runner.inspect_model_json(&st_path);
let apr_inspect = runner.inspect_model_json(&apr_path);
if !st_inspect.success || !apr_inspect.success {
let err = if st_inspect.success {
&apr_inspect.stderr
} else {
&st_inspect.stderr
};
return Evidence::falsified(
gate_id,
contract_scenario(model_id),
format!("I-2 Tensor Name Bijection: inspect failed: {err}"),
&format!("st: {}, apr: {}", st_inspect.stdout, apr_inspect.stdout),
0,
);
}
let st_names = parse_tensor_names(&st_inspect.stdout);
let apr_names = parse_tensor_names(&apr_inspect.stdout);
let missing: Vec<&str> = st_names
.iter()
.filter(|n| !apr_names.contains(n.as_str()))
.map(String::as_str)
.collect();
if !missing.is_empty() {
return Evidence::falsified(
gate_id,
contract_scenario(model_id),
format!(
"I-2 Tensor Name Bijection: {} source tensors missing in APR: {}",
missing.len(),
missing.join(", ")
),
&format!("source={}, apr={}", st_names.len(), apr_names.len()),
0,
);
}
let extra: Vec<&str> = apr_names
.iter()
.filter(|n| !st_names.contains(n.as_str()))
.map(String::as_str)
.collect();
let allowed_extras: HashSet<&str> = HashSet::from(["lm_head.weight", "lm_head.bias"]);
let unexpected: Vec<&str> = extra
.iter()
.filter(|n| !allowed_extras.contains(*n))
.copied()
.collect();
if !unexpected.is_empty() {
return Evidence::falsified(
gate_id,
contract_scenario(model_id),
format!(
"I-2 Tensor Name Bijection: {} unexpected extra tensors in APR: {}",
unexpected.len(),
unexpected.join(", ")
),
&format!(
"source={}, apr={}, extra={:?}",
st_names.len(),
apr_names.len(),
extra
),
0,
);
}
let tied = if extra.is_empty() {
""
} else {
" (tied embedding materialized)"
};
let mut ev = Evidence::corroborated(
gate_id,
contract_scenario(model_id),
&format!("source={}, apr={}", st_names.len(), apr_names.len()),
0,
);
ev.reason = format!(
"I-2 Tensor Name Bijection: all {} source tensors present in APR ({} total){}",
st_names.len(),
apr_names.len(),
tied,
);
ev
}
fn parse_tensor_names(json_output: &str) -> HashSet<String> {
if let Some(start) = json_output.find("\"tensor_names\":[") {
let after = &json_output[start + 16..];
if let Some(end) = after.find(']') {
let array_str = &after[..end];
return array_str
.split(',')
.filter_map(|s| {
let trimmed = s.trim().trim_matches('"');
if trimmed.is_empty() {
None
} else {
Some(trimmed.to_string())
}
})
.collect();
}
}
HashSet::new()
}
fn run_i3_no_silent_fallbacks(
runner: &Arc<dyn CommandRunner>,
model_path: &Path,
model_id: &ModelId,
gate_id: &str,
) -> Evidence {
let apr_path = resolve_apr_path(model_path);
let result = runner.check_model(&apr_path);
if !result.success {
return Evidence::falsified(
gate_id,
contract_scenario(model_id),
format!("I-3 No Silent Fallbacks: check failed: {}", result.stderr),
&result.stdout,
0,
);
}
if contains_f32_fallback(&result.stdout) || contains_f32_fallback(&result.stderr) {
Evidence::falsified(
gate_id,
contract_scenario(model_id),
"I-3 No Silent Fallbacks: detected F32 fallback in check output",
&result.stdout,
0,
)
} else {
let mut ev =
Evidence::corroborated(gate_id, contract_scenario(model_id), &result.stdout, 0);
ev.reason = "I-3 No Silent Fallbacks: no F32 fallbacks detected".to_string();
ev
}
}
fn contains_f32_fallback(output: &str) -> bool {
let lower = output.to_lowercase();
lower.contains("fallback") && lower.contains("f32")
|| lower.contains("defaulting to f32")
|| lower.contains("unknown dtype")
}
fn run_i4_statistical_preservation(
runner: &Arc<dyn CommandRunner>,
model_path: &Path,
model_id: &ModelId,
gate_id: &str,
) -> Evidence {
let st_path = resolve_safetensors_path(model_path);
let apr_path = resolve_apr_path(model_path);
let result = runner.validate_stats(&st_path, &apr_path);
if !result.success {
return Evidence::falsified(
gate_id,
contract_scenario(model_id),
format!(
"I-4 Statistical Preservation: validate-stats failed: {}",
result.stderr
),
&result.stdout,
0,
);
}
if result.stdout.contains("\"passed\":true") || result.stdout.contains("passed") {
let mut ev =
Evidence::corroborated(gate_id, contract_scenario(model_id), &result.stdout, 0);
ev.reason = "I-4 Statistical Preservation: tensor statistics preserved within tolerance"
.to_string();
ev
} else {
Evidence::falsified(
gate_id,
contract_scenario(model_id),
format!(
"I-4 Statistical Preservation: statistics diverged: {}",
result.stdout
),
&result.stdout,
0,
)
}
}
fn run_i5_tokenizer_roundtrip(
runner: &Arc<dyn CommandRunner>,
model_path: &Path,
model_id: &ModelId,
gate_id: &str,
) -> Evidence {
let st_path = resolve_safetensors_path(model_path);
let apr_path = resolve_apr_path(model_path);
let result = runner.compare_inference(&st_path, &apr_path, "Hello", 1, 0.0);
if !result.success {
return Evidence::falsified(
gate_id,
contract_scenario(model_id),
format!(
"I-5 Tokenizer Roundtrip: compare-inference failed: {}",
result.stderr
),
&result.stdout,
0,
);
}
if result.stdout.contains("\"passed\":true") {
let mut ev =
Evidence::corroborated(gate_id, contract_scenario(model_id), &result.stdout, 0);
ev.reason = "I-5 Tokenizer Roundtrip: tokenizer roundtrip verified".to_string();
ev
} else {
Evidence::falsified(
gate_id,
contract_scenario(model_id),
format!(
"I-5 Tokenizer Roundtrip: inference output mismatch: {}",
result.stdout
),
&result.stdout,
0,
)
}
}
fn contract_scenario(model_id: &ModelId) -> QaScenario {
QaScenario::new(
model_id.clone(),
Modality::Run,
Backend::Cpu,
Format::Apr,
"Format contract invariant".to_string(),
0,
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::evidence::Outcome;
#[test]
fn test_load_format_contract() {
let contract = load_format_contract().expect("Failed to load contract");
assert!(!contract.invariants.is_empty());
assert!(!contract.dtype_bytes.mappings.is_empty());
assert!(!contract.tolerances.is_empty());
}
#[test]
fn test_contract_version() {
let contract = load_format_contract().expect("Failed to load contract");
assert_eq!(contract.version, "1.0");
}
#[test]
fn test_dtype_byte_mappings_complete() {
let contract = load_format_contract().expect("Failed to load contract");
let dtypes: Vec<&str> = contract
.dtype_bytes
.mappings
.iter()
.map(|m| m.dtype.as_str())
.collect();
assert!(dtypes.contains(&"F32"));
assert!(dtypes.contains(&"F16"));
assert!(dtypes.contains(&"Q4_K"));
assert!(dtypes.contains(&"Q6_K"));
assert!(dtypes.contains(&"BF16"));
assert!(dtypes.contains(&"Q8_0"));
assert!(dtypes.contains(&"Q2_K"));
assert!(dtypes.contains(&"Q3_K"));
assert!(dtypes.contains(&"Q5_K"));
assert!(dtypes.contains(&"Q4_0"));
assert!(dtypes.contains(&"Q5_0"));
}
#[test]
fn test_dtype_byte_no_duplicates() {
let contract = load_format_contract().expect("Failed to load contract");
validate_dtype_bytes(&contract).expect("No duplicates expected");
}
#[test]
fn test_dtype_byte_ggml_values() {
let contract = load_format_contract().expect("Failed to load contract");
let find_byte = |dtype: &str| -> u8 {
contract
.dtype_bytes
.mappings
.iter()
.find(|m| m.dtype == dtype)
.expect("dtype not found")
.byte
};
assert_eq!(find_byte("F32"), 0);
assert_eq!(find_byte("F16"), 1);
assert_eq!(find_byte("Q4_K"), 12);
assert_eq!(find_byte("Q6_K"), 14);
assert_eq!(find_byte("BF16"), 30);
}
#[test]
fn test_tensor_naming_pattern() {
let contract = load_format_contract().expect("Failed to load contract");
assert!(validate_tensor_name("0.q_proj.weight", &contract));
assert!(validate_tensor_name("31.down_proj.weight", &contract));
assert!(validate_tensor_name("token_embd.weight", &contract));
assert!(validate_tensor_name("output_norm.weight", &contract));
assert!(validate_tensor_name("output.weight", &contract));
assert!(!validate_tensor_name(
"model.layers.0.self_attn.q_proj.weight",
&contract
));
assert!(!validate_tensor_name(
"model.embed_tokens.weight",
&contract
));
assert!(!validate_tensor_name("", &contract));
}
#[test]
fn test_invariant_definitions_complete() {
let contract = load_format_contract().expect("Failed to load contract");
assert_eq!(contract.invariants.len(), 5);
let ids: Vec<&str> = contract.invariants.iter().map(|i| i.id.as_str()).collect();
assert!(ids.contains(&"I-1"));
assert!(ids.contains(&"I-2"));
assert!(ids.contains(&"I-3"));
assert!(ids.contains(&"I-4"));
assert!(ids.contains(&"I-5"));
}
#[test]
fn test_tolerance_lookup() {
let contract = load_format_contract().expect("Failed to load contract");
let (atol, rtol) = lookup_tolerance("F32", &contract).expect("F32 tolerance");
assert!((atol - 0.0).abs() < f64::EPSILON);
assert!((rtol - 0.0).abs() < f64::EPSILON);
let (atol, rtol) = lookup_tolerance("Q4_K", &contract).expect("Q4_K tolerance");
assert!((atol - 0.05).abs() < f64::EPSILON);
assert!((rtol - 0.05).abs() < f64::EPSILON);
let (atol, rtol) = lookup_tolerance("Q6_K", &contract).expect("Q6_K tolerance");
assert!((atol - 0.02).abs() < f64::EPSILON);
assert!((rtol - 0.02).abs() < f64::EPSILON);
assert!(lookup_tolerance("UNKNOWN", &contract).is_none());
}
#[test]
fn test_validate_tensor_name_valid() {
let contract = load_format_contract().expect("Failed to load contract");
for example in &contract.tensor_naming.examples {
assert!(
validate_tensor_name(&example.canonical, &contract),
"Expected '{}' to be valid",
example.canonical
);
}
}
#[test]
fn test_validate_tensor_name_invalid() {
let contract = load_format_contract().expect("Failed to load contract");
for example in &contract.tensor_naming.examples {
assert!(
!validate_tensor_name(&example.forbidden, &contract),
"Expected '{}' to be invalid",
example.forbidden
);
}
}
#[test]
fn test_contract_test_config_default() {
let config = ContractTestConfig::default();
assert_eq!(config.invariants.len(), 4);
assert!(config.invariants.contains(&"I-2".to_string()));
assert!(config.invariants.contains(&"I-3".to_string()));
assert!(config.invariants.contains(&"I-4".to_string()));
assert!(config.invariants.contains(&"I-5".to_string()));
}
#[test]
fn test_invariant_id_from_label() {
assert_eq!(InvariantId::from_label("I-1"), Some(InvariantId::I1));
assert_eq!(InvariantId::from_label("I-2"), Some(InvariantId::I2));
assert_eq!(InvariantId::from_label("I-3"), Some(InvariantId::I3));
assert_eq!(InvariantId::from_label("I-4"), Some(InvariantId::I4));
assert_eq!(InvariantId::from_label("I-5"), Some(InvariantId::I5));
assert_eq!(InvariantId::from_label("I-99"), None);
}
#[test]
fn test_invariant_id_gate_id() {
assert_eq!(InvariantId::I1.gate_id(), "F-CONTRACT-I1-001");
assert_eq!(InvariantId::I2.gate_id(), "F-CONTRACT-I2-001");
assert_eq!(InvariantId::I3.gate_id(), "F-CONTRACT-I3-001");
assert_eq!(InvariantId::I4.gate_id(), "F-CONTRACT-I4-001");
assert_eq!(InvariantId::I5.gate_id(), "F-CONTRACT-I5-001");
}
#[test]
fn test_contains_f32_fallback_positive() {
assert!(contains_f32_fallback(
"Warning: fallback to F32 for unknown type"
));
assert!(contains_f32_fallback("defaulting to f32"));
assert!(contains_f32_fallback("unknown dtype detected"));
}
#[test]
fn test_contains_f32_fallback_negative() {
assert!(!contains_f32_fallback("All checks passed"));
assert!(!contains_f32_fallback("Using Q4_K quantization"));
assert!(!contains_f32_fallback("F32 tensors loaded normally"));
}
#[test]
fn test_contract_i2_tensor_name_bijection_pass() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> = Arc::new(MockCommandRunner::new());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig::default();
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
let i2 = evidence.iter().find(|e| e.gate_id == "F-CONTRACT-I2-001");
assert!(i2.is_some(), "I-2 evidence should exist");
assert_eq!(i2.unwrap().outcome, Outcome::Corroborated);
}
#[test]
fn test_contract_i2_tensor_name_bijection_fail() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> =
Arc::new(MockCommandRunner::new().with_inspect_json_failure());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig {
invariants: vec!["I-2".to_string()],
};
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
let i2 = evidence.iter().find(|e| e.gate_id == "F-CONTRACT-I2-001");
assert!(i2.is_some());
assert_eq!(i2.unwrap().outcome, Outcome::Falsified);
}
#[test]
fn test_parse_tensor_names_valid() {
let json = r#"{"format":"SafeTensors","tensor_count":3,"tensor_names":["embed.weight","lm_head.weight","0.q_proj.weight"],"parameters":"1.5B"}"#;
let names = parse_tensor_names(json);
assert_eq!(names.len(), 3);
assert!(names.contains("embed.weight"));
assert!(names.contains("lm_head.weight"));
assert!(names.contains("0.q_proj.weight"));
}
#[test]
fn test_parse_tensor_names_empty() {
let json = r#"{"tensor_names":[]}"#;
let names = parse_tensor_names(json);
assert!(names.is_empty());
}
#[test]
fn test_parse_tensor_names_missing_field() {
let json = r#"{"format":"SafeTensors","tensor_count":3}"#;
let names = parse_tensor_names(json);
assert!(names.is_empty());
}
#[test]
fn test_parse_tensor_names_malformed() {
let names = parse_tensor_names("not json at all");
assert!(names.is_empty());
}
#[test]
fn test_i2_tied_embedding_allowed_extras() {
let allowed: HashSet<&str> = HashSet::from(["lm_head.weight", "lm_head.bias"]);
assert!(allowed.contains("lm_head.weight"));
assert!(allowed.contains("lm_head.bias"));
assert!(!allowed.contains("unexpected_tensor.weight"));
}
#[test]
fn test_contract_i3_no_silent_fallbacks_pass() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> = Arc::new(MockCommandRunner::new());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig {
invariants: vec!["I-3".to_string()],
};
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
let i3 = evidence.iter().find(|e| e.gate_id == "F-CONTRACT-I3-001");
assert!(i3.is_some());
assert_eq!(i3.unwrap().outcome, Outcome::Corroborated);
}
#[test]
fn test_contract_i3_no_silent_fallbacks_fail() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> =
Arc::new(MockCommandRunner::new().with_check_failure());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig {
invariants: vec!["I-3".to_string()],
};
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
let i3 = evidence.iter().find(|e| e.gate_id == "F-CONTRACT-I3-001");
assert!(i3.is_some());
assert_eq!(i3.unwrap().outcome, Outcome::Falsified);
}
#[test]
fn test_contract_i4_statistical_preservation_pass() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> = Arc::new(MockCommandRunner::new());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig {
invariants: vec!["I-4".to_string()],
};
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
let i4 = evidence.iter().find(|e| e.gate_id == "F-CONTRACT-I4-001");
assert!(i4.is_some());
assert_eq!(i4.unwrap().outcome, Outcome::Corroborated);
}
#[test]
fn test_contract_i4_statistical_preservation_fail() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> =
Arc::new(MockCommandRunner::new().with_validate_stats_failure());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig {
invariants: vec!["I-4".to_string()],
};
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
let i4 = evidence.iter().find(|e| e.gate_id == "F-CONTRACT-I4-001");
assert!(i4.is_some());
assert_eq!(i4.unwrap().outcome, Outcome::Falsified);
}
#[test]
fn test_contract_i5_tokenizer_roundtrip_pass() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> = Arc::new(MockCommandRunner::new());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig {
invariants: vec!["I-5".to_string()],
};
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
let i5 = evidence.iter().find(|e| e.gate_id == "F-CONTRACT-I5-001");
assert!(i5.is_some());
assert_eq!(i5.unwrap().outcome, Outcome::Corroborated);
}
#[test]
fn test_contract_i5_tokenizer_roundtrip_fail() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> =
Arc::new(MockCommandRunner::new().with_compare_inference_failure());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig {
invariants: vec!["I-5".to_string()],
};
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
let i5 = evidence.iter().find(|e| e.gate_id == "F-CONTRACT-I5-001");
assert!(i5.is_some());
assert_eq!(i5.unwrap().outcome, Outcome::Falsified);
}
#[test]
fn test_contract_all_invariants_pass() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> = Arc::new(MockCommandRunner::new());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig::default();
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
assert_eq!(evidence.len(), 4);
for ev in &evidence {
assert_eq!(
ev.outcome,
Outcome::Corroborated,
"Gate {} should pass",
ev.gate_id
);
}
}
#[test]
fn test_contract_skips_i1() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> = Arc::new(MockCommandRunner::new());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig {
invariants: vec!["I-1".to_string(), "I-2".to_string()],
};
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
assert_eq!(evidence.len(), 1);
assert_eq!(evidence[0].gate_id, "F-CONTRACT-I2-001");
}
#[test]
fn test_contract_unknown_invariant_skipped() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> = Arc::new(MockCommandRunner::new());
let model_id = ModelId::new("test", "model");
let config = ContractTestConfig {
invariants: vec!["I-99".to_string()],
};
let evidence = run_contract_tests(
&runner,
Path::new("/test/workspace/org/model"),
&model_id,
&config,
);
assert!(evidence.is_empty());
}
#[test]
fn test_resolve_paths_with_dots_in_name() {
let workspace = Path::new("/output/workspace/Qwen/Qwen2.5-Coder-0.5B-Instruct");
let apr = resolve_apr_path(workspace);
let st = resolve_safetensors_path(workspace);
assert_eq!(
apr,
PathBuf::from("/output/workspace/Qwen/Qwen2.5-Coder-0.5B-Instruct/apr/model.apr")
);
assert_eq!(
st,
PathBuf::from(
"/output/workspace/Qwen/Qwen2.5-Coder-0.5B-Instruct/safetensors/model.safetensors"
)
);
let broken = workspace.with_extension("apr");
assert_eq!(
broken,
PathBuf::from("/output/workspace/Qwen/Qwen2.5-Coder-0.apr")
);
}
#[test]
fn test_contract_tests_with_dotted_workspace_path() {
use crate::command::MockCommandRunner;
let runner: Arc<dyn CommandRunner> = Arc::new(MockCommandRunner::new());
let model_id = ModelId::new("Qwen", "Qwen2.5-Coder-0.5B-Instruct");
let config = ContractTestConfig::default();
let evidence = run_contract_tests(
&runner,
Path::new("/workspace/Qwen/Qwen2.5-Coder-0.5B-Instruct"),
&model_id,
&config,
);
assert_eq!(evidence.len(), 4, "Expected 4 invariant results");
for ev in &evidence {
assert!(
!ev.reason.contains("Coder-0.apr"),
"Path was truncated by with_extension: {}",
ev.reason
);
}
}
#[test]
fn test_is_valid_tensor_name_edge_cases() {
let contract = load_format_contract().expect("Failed to load contract");
assert!(validate_tensor_name("0.attn.weight", &contract));
assert!(validate_tensor_name("99.mlp.bias", &contract));
assert!(!validate_tensor_name("weight", &contract));
assert!(!validate_tensor_name(".q_proj.weight", &contract));
assert!(!validate_tensor_name("a.q_proj.weight", &contract));
assert!(!validate_tensor_name("0.q_proj.weight.extra", &contract));
}
#[test]
fn test_naming_convention() {
let contract = load_format_contract().expect("Failed to load contract");
assert_eq!(contract.tensor_naming.convention, "gguf-short");
}
#[test]
fn test_invariant_catches_fields() {
let contract = load_format_contract().expect("Failed to load contract");
let i1 = contract.invariants.iter().find(|i| i.id == "I-1").unwrap();
assert!(i1.catches.contains(&"GH-190".to_string()));
assert!(i1.implemented);
let i2 = contract.invariants.iter().find(|i| i.id == "I-2").unwrap();
assert!(i2.catches.contains(&"GH-190".to_string()));
assert!(!i2.implemented);
}
#[test]
fn test_tolerance_entries_ordered_by_precision() {
let contract = load_format_contract().expect("Failed to load contract");
let f32_tol = lookup_tolerance("F32", &contract).unwrap();
assert!(f32_tol.0.abs() < f64::EPSILON);
let q2k_tol = lookup_tolerance("Q2_K", &contract).unwrap();
assert!(q2k_tol.0 > 0.1);
}
#[test]
fn test_is_word() {
assert!(is_word("weight"));
assert!(is_word("q_proj"));
assert!(is_word("down_proj"));
assert!(is_word("a"));
assert!(!is_word(""));
assert!(!is_word("has.dot"));
assert!(!is_word("has space"));
}
}