#![allow(clippy::disallowed_methods)]
use std::collections::HashMap;
const LOGIT_TOLERANCE: f32 = 1e-4;
#[derive(Debug, Clone)]
struct MockLogits {
format: &'static str,
logits: Vec<f32>,
}
impl MockLogits {
fn new(format: &'static str, prompt: &str, vocab_size: usize) -> Self {
let seed = prompt
.bytes()
.fold(0u64, |acc, b| acc.wrapping_add(b as u64));
let logits: Vec<f32> = (0..vocab_size)
.map(|i| {
let base = ((seed.wrapping_mul(i as u64 + 1)) % 1000) as f32 / 1000.0;
let format_offset = match format {
"gguf" => 0.0,
"safetensors" => 1e-6, "apr" => 2e-6, _ => 0.0,
};
base + format_offset
})
.collect();
Self { format, logits }
}
fn with_logits(format: &'static str, _prompt: &str, logits: Vec<f32>) -> Self {
Self { format, logits }
}
}
fn compare_logits(a: &MockLogits, b: &MockLogits) -> Result<f32, String> {
if a.logits.len() != b.logits.len() {
return Err(format!(
"Shape mismatch: {} has {} logits, {} has {}",
a.format,
a.logits.len(),
b.format,
b.logits.len()
));
}
let max_diff = a
.logits
.iter()
.zip(b.logits.iter())
.map(|(x, y)| (x - y).abs())
.fold(0.0f32, f32::max);
Ok(max_diff)
}
#[derive(Debug)]
struct ParityResult {
passed: bool,
}
impl ParityResult {
fn new(_format_a: &'static str, _format_b: &'static str, max_diff: f32) -> Self {
Self {
passed: max_diff <= LOGIT_TOLERANCE,
}
}
}
#[test]
fn p1_gguf_safetensors_logit_parity() {
let prompt = "Hello";
let vocab_size = 32000;
let gguf_logits = MockLogits::new("gguf", prompt, vocab_size);
let st_logits = MockLogits::new("safetensors", prompt, vocab_size);
let max_diff = compare_logits(&gguf_logits, &st_logits).expect("Comparison failed");
let result = ParityResult::new("gguf", "safetensors", max_diff);
assert!(
result.passed,
"P1 FALSIFIED: GGUF vs SafeTensors max_diff={:.2e} > {:.2e}",
max_diff, LOGIT_TOLERANCE
);
println!(
"P1 PASS: GGUF vs SafeTensors max_diff={:.2e} <= {:.2e}",
max_diff, LOGIT_TOLERANCE
);
}
#[test]
fn p2_gguf_apr_logit_parity() {
let prompt = "Hello";
let vocab_size = 32000;
let gguf_logits = MockLogits::new("gguf", prompt, vocab_size);
let apr_logits = MockLogits::new("apr", prompt, vocab_size);
let max_diff = compare_logits(&gguf_logits, &apr_logits).expect("Comparison failed");
let result = ParityResult::new("gguf", "apr", max_diff);
assert!(
result.passed,
"P2 FALSIFIED: GGUF vs APR max_diff={:.2e} > {:.2e}",
max_diff, LOGIT_TOLERANCE
);
println!(
"P2 PASS: GGUF vs APR max_diff={:.2e} <= {:.2e}",
max_diff, LOGIT_TOLERANCE
);
}
#[test]
fn p3_safetensors_apr_logit_parity() {
let prompt = "Hello";
let vocab_size = 32000;
let st_logits = MockLogits::new("safetensors", prompt, vocab_size);
let apr_logits = MockLogits::new("apr", prompt, vocab_size);
let max_diff = compare_logits(&st_logits, &apr_logits).expect("Comparison failed");
let result = ParityResult::new("safetensors", "apr", max_diff);
assert!(
result.passed,
"P3 FALSIFIED: SafeTensors vs APR max_diff={:.2e} > {:.2e}",
max_diff, LOGIT_TOLERANCE
);
println!(
"P3 PASS: SafeTensors vs APR max_diff={:.2e} <= {:.2e}",
max_diff, LOGIT_TOLERANCE
);
}
#[test]
fn p4_all_formats_logit_parity() {
let prompt = "Hello";
let vocab_size = 32000;
let gguf = MockLogits::new("gguf", prompt, vocab_size);
let st = MockLogits::new("safetensors", prompt, vocab_size);
let apr = MockLogits::new("apr", prompt, vocab_size);
let results = vec![
("gguf", "safetensors", compare_logits(&gguf, &st)),
("gguf", "apr", compare_logits(&gguf, &apr)),
("safetensors", "apr", compare_logits(&st, &apr)),
];
let mut all_passed = true;
let mut max_overall = 0.0f32;
for (fmt_a, fmt_b, result) in &results {
match result {
Ok(diff) => {
max_overall = max_overall.max(*diff);
if *diff > LOGIT_TOLERANCE {
all_passed = false;
eprintln!(
"P4 FAIL: {} vs {} max_diff={:.2e} > {:.2e}",
fmt_a, fmt_b, diff, LOGIT_TOLERANCE
);
}
}
Err(e) => {
all_passed = false;
eprintln!("P4 FAIL: {} vs {} error: {}", fmt_a, fmt_b, e);
}
}
}
assert!(
all_passed,
"P4 FALSIFIED: Cross-format parity check failed (max_diff={:.2e})",
max_overall
);
println!(
"P4 PASS: All formats match within {:.2e} (max_diff={:.2e})",
LOGIT_TOLERANCE, max_overall
);
}
#[test]
fn p5_detect_shape_mismatch() {
let gguf = MockLogits::with_logits("gguf", "Hello", vec![0.1, 0.2, 0.3]);
let apr = MockLogits::with_logits("apr", "Hello", vec![0.1, 0.2]);
let result = compare_logits(&gguf, &apr);
assert!(result.is_err(), "P5 FALSIFIED: Shape mismatch not detected");
let err = result.unwrap_err();
assert!(
err.contains("Shape mismatch"),
"P5 FALSIFIED: Error message doesn't mention shape mismatch: {}",
err
);
println!("P5 PASS: Shape mismatch correctly detected");
}
#[test]
fn p6_detect_logit_divergence() {
let gguf = MockLogits::with_logits("gguf", "Hello", vec![0.1, 0.2, 0.3, 0.4, 0.5]);
let poisoned = MockLogits::with_logits(
"poisoned",
"Hello",
vec![0.1, 0.2, 0.5, 0.4, 0.5], );
let max_diff = compare_logits(&gguf, &poisoned).expect("Comparison failed");
assert!(
max_diff > LOGIT_TOLERANCE,
"P6 FALSIFIED: Large divergence not flagged (diff={:.2e})",
max_diff
);
println!(
"P6 PASS: Logit divergence correctly flagged (diff={:.2e} > {:.2e})",
max_diff, LOGIT_TOLERANCE
);
}
#[test]
fn p7_tolerance_boundary() {
let base: Vec<f32> = vec![0.0, 0.1, 0.2, 0.3, 0.4];
let below_boundary: Vec<f32> = base.iter().map(|x| x + LOGIT_TOLERANCE * 0.5).collect();
let above_boundary: Vec<f32> = base.iter().map(|x| x + LOGIT_TOLERANCE * 2.0).collect();
let gguf = MockLogits::with_logits("gguf", "Hello", base);
let below = MockLogits::with_logits("below", "Hello", below_boundary);
let above = MockLogits::with_logits("above", "Hello", above_boundary);
let diff_below = compare_logits(&gguf, &below).expect("Comparison failed");
assert!(
diff_below <= LOGIT_TOLERANCE,
"P7 FALSIFIED: Below-boundary case should pass (diff={:.2e})",
diff_below
);
let diff_above = compare_logits(&gguf, &above).expect("Comparison failed");
assert!(
diff_above > LOGIT_TOLERANCE,
"P7 FALSIFIED: Above-boundary case should fail (diff={:.2e})",
diff_above
);
println!(
"P7 PASS: Boundary cases handled correctly (below={:.2e}, above={:.2e})",
diff_below, diff_above
);
}
#[test]
#[ignore = "Requires Qwen2.5-0.5B model files"]
fn p8_real_model_logit_parity() {
todo!("Implement real model comparison when files available");
}
#[test]
#[ignore = "Requires Qwen2.5-0.5B model files"]
fn p9_real_safetensors_apr_parity() {
todo!("Implement real SafeTensors/APR comparison");
}
struct ParityChecker {
tolerance: f32,
results: HashMap<(&'static str, &'static str), f32>,
}
impl ParityChecker {
fn new(tolerance: f32) -> Self {
Self {
tolerance,
results: HashMap::new(),
}
}
fn compare(&mut self, a: &MockLogits, b: &MockLogits) -> bool {
let key = (a.format, b.format);
match compare_logits(a, b) {
Ok(diff) => {
self.results.insert(key, diff);
diff <= self.tolerance
}
Err(_) => false,
}
}
fn report(&self) -> String {
let mut report = String::from("Cross-Format Parity Report\n");
report.push_str(&format!("Tolerance: {:.2e}\n\n", self.tolerance));
for ((fmt_a, fmt_b), diff) in &self.results {
let status = if *diff <= self.tolerance {
"PASS"
} else {
"FAIL"
};
report.push_str(&format!(
"{} vs {}: max_diff={:.2e} [{}]\n",
fmt_a, fmt_b, diff, status
));
}
report
}
}
#[test]
fn p10_parity_checker_report() {
let mut checker = ParityChecker::new(LOGIT_TOLERANCE);
let gguf = MockLogits::new("gguf", "Hello", 1000);
let st = MockLogits::new("safetensors", "Hello", 1000);
let apr = MockLogits::new("apr", "Hello", 1000);
checker.compare(&gguf, &st);
checker.compare(&gguf, &apr);
checker.compare(&st, &apr);
let report = checker.report();
assert!(report.contains("PASS"), "P10: Report should show PASS");
assert!(
report.contains("Tolerance"),
"P10: Report should show tolerance"
);
assert!(
report.contains("gguf"),
"P10: Report should mention formats"
);
println!("\n{}", report);
}