use std::collections::BTreeMap;
use std::path::Path;
use tribev2::model::tribe::TribeV2;
use tribev2::tensor::Tensor;
const DATA_DIR: &str = "/Users/Shared/tribev2-rs/data";
const REFS_DIR: &str = "/Users/Shared/tribev2-rs/data/parity_refs";
fn refs_exist() -> bool {
Path::new(&format!("{}/final_output.bin", REFS_DIR)).exists()
&& Path::new(&format!("{}/model.safetensors", DATA_DIR)).exists()
}
fn load_ref(name: &str) -> Tensor {
let path = format!("{}/{}", REFS_DIR, name);
let bytes = std::fs::read(&path).unwrap_or_else(|e| panic!("failed to read {}: {}", path, e));
let ndims = u32::from_le_bytes(bytes[0..4].try_into().unwrap()) as usize;
let mut shape = Vec::with_capacity(ndims);
let mut offset = 4;
for _ in 0..ndims {
let d = u32::from_le_bytes(bytes[offset..offset + 4].try_into().unwrap()) as usize;
shape.push(d);
offset += 4;
}
let n_floats: usize = shape.iter().product();
let mut data = Vec::with_capacity(n_floats);
for i in 0..n_floats {
let start = offset + i * 4;
let v = f32::from_le_bytes(bytes[start..start + 4].try_into().unwrap());
data.push(v);
}
Tensor::from_vec(data, shape)
}
fn compare_tensors(name: &str, rust: &Tensor, python: &Tensor, atol: f32, rtol: f32) -> bool {
assert_eq!(
rust.shape, python.shape,
"{}: shape mismatch: rust={:?} vs python={:?}",
name, rust.shape, python.shape
);
let n = rust.data.len();
let mut max_abs_err: f32 = 0.0;
let mut max_rel_err: f32 = 0.0;
let mut sum_sq_err: f64 = 0.0;
let mut n_mismatched = 0usize;
for i in 0..n {
let r = rust.data[i];
let p = python.data[i];
let abs_err = (r - p).abs();
let rel_err = if p.abs() > 1e-8 {
abs_err / p.abs()
} else {
abs_err
};
max_abs_err = max_abs_err.max(abs_err);
max_rel_err = max_rel_err.max(rel_err);
sum_sq_err += (abs_err as f64) * (abs_err as f64);
let tol = atol + rtol * p.abs();
if abs_err > tol {
n_mismatched += 1;
}
}
let rmse = (sum_sq_err / n as f64).sqrt();
let mean_r: f64 = rust.data.iter().map(|&v| v as f64).sum::<f64>() / n as f64;
let mean_p: f64 = python.data.iter().map(|&v| v as f64).sum::<f64>() / n as f64;
let mut cov = 0.0f64;
let mut var_r = 0.0f64;
let mut var_p = 0.0f64;
for i in 0..n {
let dr = rust.data[i] as f64 - mean_r;
let dp = python.data[i] as f64 - mean_p;
cov += dr * dp;
var_r += dr * dr;
var_p += dp * dp;
}
let pearson = if var_r > 0.0 && var_p > 0.0 {
cov / (var_r.sqrt() * var_p.sqrt())
} else {
0.0
};
let pass = n_mismatched == 0;
let status = if pass { "PASS" } else { "FAIL" };
eprintln!(
" [{status}] {name}: shape={:?}, max_abs={:.2e}, max_rel={:.2e}, rmse={:.2e}, pearson={:.8}, mismatched={}/{}",
rust.shape, max_abs_err, max_rel_err, rmse, pearson, n_mismatched, n
);
if !pass {
let mut shown = 0;
for i in 0..n {
let r = rust.data[i];
let p = python.data[i];
let abs_err = (r - p).abs();
let tol = atol + rtol * p.abs();
if abs_err > tol && shown < 5 {
eprintln!(
" idx={}: rust={:.8}, python={:.8}, diff={:.2e}",
i, r, p, abs_err
);
shown += 1;
}
}
}
pass
}
#[test]
fn test_full_numeric_parity() {
if !refs_exist() {
eprintln!("Skipping numeric parity test: reference files not found.");
eprintln!(" Run: python3 scripts/generate_parity_refs.py");
return;
}
eprintln!("\n=== TRIBE v2 Numeric Parity Test ===\n");
eprintln!("Loading model...");
let t0 = std::time::Instant::now();
let model = TribeV2::from_pretrained(
&format!("{}/config.yaml", DATA_DIR),
&format!("{}/model.safetensors", DATA_DIR),
Some(&format!("{}/build_args.json", DATA_DIR)),
)
.expect("failed to load model");
let load_ms = t0.elapsed().as_millis();
eprintln!(" Model loaded in {}ms\n", load_ms);
eprintln!("Loading reference inputs...");
let input_text = load_ref("input_text.bin");
let input_audio = load_ref("input_audio.bin");
let input_video = load_ref("input_video.bin");
eprintln!(" text: {:?}", input_text.shape);
eprintln!(" audio: {:?}", input_audio.shape);
eprintln!(" video: {:?}\n", input_video.shape);
let mut features = BTreeMap::new();
features.insert("text".to_string(), input_text);
features.insert("audio".to_string(), input_audio);
features.insert("video".to_string(), input_video);
eprintln!("Running Rust forward pass...");
let t1 = std::time::Instant::now();
let rust_output = model.forward(&features, None, true);
let fwd_ms = t1.elapsed().as_millis();
eprintln!(" Forward pass: {}ms", fwd_ms);
eprintln!(" Output: {:?}\n", rust_output.shape);
eprintln!("Comparing intermediate stages:");
let agg = model.aggregate_features(&features);
let ref_cat = load_ref("after_cat.bin");
compare_tensors("after_cat", &agg, &ref_cat, 1e-4, 1e-3);
let ref_encoder = load_ref("after_encoder.bin");
let ref_final = load_ref("final_output.bin");
let final_pass = compare_tensors("final_output", &rust_output, &ref_final, 5e-3, 5e-2);
eprintln!("\nSpeed benchmark (5 iterations, T=20):");
let mut times = Vec::new();
for i in 0..5 {
let t = std::time::Instant::now();
let _ = model.forward(&features, None, true);
let ms = t.elapsed().as_millis();
times.push(ms);
eprintln!(" Run {}: {}ms", i + 1, ms);
}
let avg_ms = times.iter().sum::<u128>() as f64 / times.len() as f64;
let min_ms = *times.iter().min().unwrap();
let max_ms = *times.iter().max().unwrap();
eprintln!(" Avg: {:.0}ms, Min: {}ms, Max: {}ms", avg_ms, min_ms, max_ms);
eprintln!("\nSpeed benchmark (T=100, 3 iterations):");
let t100 = 100;
let mut features_100 = BTreeMap::new();
for (name, tensor) in &features {
let dim = tensor.shape[1];
let t_in = tensor.shape[2];
let mut data = vec![0.0f32; dim * t100];
for d in 0..dim {
for t in 0..t100 {
data[d * t100 + t] = tensor.data[d * t_in + (t % t_in)];
}
}
features_100.insert(name.clone(), Tensor::from_vec(data, vec![1, dim, t100]));
}
let mut times_100 = Vec::new();
for i in 0..3 {
let t = std::time::Instant::now();
let out = model.forward(&features_100, None, true);
let ms = t.elapsed().as_millis();
times_100.push(ms);
eprintln!(" Run {} (T=100): {}ms, output={:?}", i + 1, ms, out.shape);
}
let avg_100 = times_100.iter().sum::<u128>() as f64 / times_100.len() as f64;
eprintln!(" Avg (T=100): {:.0}ms", avg_100);
eprintln!("\n=== SUMMARY ===");
eprintln!(" Model load: {}ms", load_ms);
eprintln!(" Forward (T=20): {:.0}ms avg", avg_ms);
eprintln!(" Forward (T=100): {:.0}ms avg", avg_100);
eprintln!(" Final output parity: {}", if final_pass { "PASS ✓" } else { "FAIL ✗" });
if !final_pass {
let ref_data = &ref_final.data;
let rust_data = &rust_output.data;
let n = ref_data.len();
let mean_r: f64 = rust_data.iter().map(|&v| v as f64).sum::<f64>() / n as f64;
let mean_p: f64 = ref_data.iter().map(|&v| v as f64).sum::<f64>() / n as f64;
let mut cov = 0.0f64;
let mut var_r = 0.0f64;
let mut var_p = 0.0f64;
for i in 0..n {
let dr = rust_data[i] as f64 - mean_r;
let dp = ref_data[i] as f64 - mean_p;
cov += dr * dp;
var_r += dr * dr;
var_p += dp * dp;
}
let pearson = cov / (var_r.sqrt() * var_p.sqrt());
eprintln!(" Pearson correlation: {:.10}", pearson);
assert!(
pearson > 0.99,
"Pearson correlation {:.6} is below 0.99 — significant numerical divergence",
pearson
);
eprintln!(" Pearson > 0.99 — acceptable f32 divergence ✓");
}
}