use std::fs;
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub struct GoldenTrace {
pub name: String,
pub input_ids: Vec<u32>,
pub expected_logits: Vec<f32>,
pub tolerance: f32,
}
impl GoldenTrace {
pub fn new(name: impl Into<String>, input_ids: Vec<u32>, expected_logits: Vec<f32>) -> Self {
Self {
name: name.into(),
input_ids,
expected_logits,
tolerance: 1e-4, }
}
#[must_use]
pub fn with_tolerance(mut self, tolerance: f32) -> Self {
self.tolerance = tolerance;
self
}
}
#[derive(Debug, Clone, Default)]
pub struct GoldenTraceSet {
pub architecture: String,
pub model_name: String,
pub traces: Vec<GoldenTrace>,
pub created_at: String,
pub reference: String,
}
impl GoldenTraceSet {
pub fn new(architecture: impl Into<String>, model_name: impl Into<String>) -> Self {
Self {
architecture: architecture.into(),
model_name: model_name.into(),
traces: Vec::new(),
created_at: timestamp_now(),
reference: "PyTorch/HuggingFace".to_string(),
}
}
pub fn add_trace(&mut self, trace: GoldenTrace) {
self.traces.push(trace);
}
pub fn load(path: &Path) -> Result<Self, String> {
let json = fs::read_to_string(path)
.map_err(|e| format!("Failed to read golden trace file: {e}"))?;
Self::from_json(&json)
}
pub fn save(&self, path: &Path) -> Result<(), String> {
let json = self.to_json()?;
fs::write(path, json).map_err(|e| format!("Failed to write golden trace file: {e}"))
}
pub fn to_json(&self) -> Result<String, String> {
use std::fmt::Write;
let mut json = String::new();
json.push_str("{\n");
let _ = writeln!(json, " \"architecture\": \"{}\",", self.architecture);
let _ = writeln!(json, " \"model_name\": \"{}\",", self.model_name);
let _ = writeln!(json, " \"created_at\": \"{}\",", self.created_at);
let _ = writeln!(json, " \"reference\": \"{}\",", self.reference);
json.push_str(" \"traces\": [\n");
for (i, trace) in self.traces.iter().enumerate() {
json.push_str(" {\n");
let _ = writeln!(json, " \"name\": \"{}\",", trace.name);
let _ = writeln!(json, " \"input_ids\": {:?},", trace.input_ids);
let _ = writeln!(json, " \"tolerance\": {},", trace.tolerance);
let _ = writeln!(
json,
" \"expected_logits_len\": {}",
trace.expected_logits.len()
);
if i < self.traces.len() - 1 {
json.push_str(" },\n");
} else {
json.push_str(" }\n");
}
}
json.push_str(" ]\n");
json.push_str("}\n");
Ok(json)
}
pub fn from_json(json: &str) -> Result<Self, String> {
let mut set = Self::default();
if let Some(arch) = extract_json_string(json, "architecture") {
set.architecture = arch;
}
if let Some(name) = extract_json_string(json, "model_name") {
set.model_name = name;
}
if let Some(created) = extract_json_string(json, "created_at") {
set.created_at = created;
}
if let Some(reference) = extract_json_string(json, "reference") {
set.reference = reference;
}
Ok(set)
}
}
#[derive(Debug, Clone)]
pub struct TraceVerifyResult {
pub name: String,
pub passed: bool,
pub max_deviation: f32,
pub mean_deviation: f32,
pub logits_compared: usize,
pub tolerance: f32,
pub error: Option<String>,
}
impl TraceVerifyResult {
#[must_use]
pub fn pass(name: &str, max_dev: f32, mean_dev: f32, count: usize, tol: f32) -> Self {
Self {
name: name.to_string(),
passed: true,
max_deviation: max_dev,
mean_deviation: mean_dev,
logits_compared: count,
tolerance: tol,
error: None,
}
}
pub fn fail(name: &str, error: impl Into<String>) -> Self {
Self {
name: name.to_string(),
passed: false,
max_deviation: f32::MAX,
mean_deviation: f32::MAX,
logits_compared: 0,
tolerance: 0.0,
error: Some(error.into()),
}
}
}
#[derive(Debug, Clone)]
pub struct GoldenVerifyReport {
pub results: Vec<TraceVerifyResult>,
pub passed: bool,
pub passed_count: usize,
pub total_count: usize,
}
impl GoldenVerifyReport {
#[must_use]
pub fn from_results(results: Vec<TraceVerifyResult>) -> Self {
let passed_count = results.iter().filter(|r| r.passed).count();
let total_count = results.len();
let passed = passed_count == total_count && total_count > 0;
Self {
results,
passed,
passed_count,
total_count,
}
}
}
#[must_use]
pub fn verify_logits(
name: &str,
actual: &[f32],
expected: &[f32],
tolerance: f32,
) -> TraceVerifyResult {
if actual.len() != expected.len() {
return TraceVerifyResult::fail(
name,
format!(
"Logit count mismatch: expected {}, got {}",
expected.len(),
actual.len()
),
);
}
let mut max_dev = 0.0f32;
let mut sum_dev = 0.0f32;
for (a, e) in actual.iter().zip(expected.iter()) {
let dev = (a - e).abs();
max_dev = max_dev.max(dev);
sum_dev += dev;
}
let mean_dev = sum_dev / actual.len() as f32;
if max_dev > tolerance {
TraceVerifyResult {
name: name.to_string(),
passed: false,
max_deviation: max_dev,
mean_deviation: mean_dev,
logits_compared: actual.len(),
tolerance,
error: Some(format!(
"Max deviation {max_dev:.6} exceeds tolerance {tolerance:.6}"
)),
}
} else {
TraceVerifyResult::pass(name, max_dev, mean_dev, actual.len(), tolerance)
}
}
#[derive(Debug, Clone)]
pub struct LogitStats {
pub mean: f32,
pub std: f32,
pub min: f32,
pub max: f32,
pub argmax: usize,
pub top5: Vec<usize>,
}
impl LogitStats {
#[must_use]
pub fn compute(logits: &[f32]) -> Self {
if logits.is_empty() {
return Self {
mean: 0.0,
std: 0.0,
min: 0.0,
max: 0.0,
argmax: 0,
top5: vec![],
};
}
let n = logits.len() as f32;
let sum: f32 = logits.iter().sum();
let mean = sum / n;
let var_sum: f32 = logits.iter().map(|x| (x - mean).powi(2)).sum();
let std = (var_sum / n).sqrt();
let min = logits.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = logits.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let argmax = logits
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i);
let mut indexed: Vec<(usize, f32)> = logits.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
let top5: Vec<usize> = indexed.iter().take(5).map(|(i, _)| *i).collect();
Self {
mean,
std,
min,
max,
argmax,
top5,
}
}
}
fn timestamp_now() -> String {
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
format!("{}", duration.as_secs())
}
fn extract_json_string(json: &str, key: &str) -> Option<String> {
let pattern = format!("\"{key}\":");
let start = json.find(&pattern)?;
let rest = &json[start + pattern.len()..];
let rest = rest.trim_start();
if !rest.starts_with('"') {
return None;
}
let rest = rest.get(1..)?;
let end = rest.find('"')?;
Some(rest.get(..end)?.to_string())
}
#[cfg(test)]
#[path = "golden_tests.rs"]
mod tests;