#[derive(Debug, Clone)]
pub struct TensorCanary {
pub name: String,
pub shape: Vec<usize>,
pub dtype: String,
pub mean: f32,
pub std: f32,
pub min: f32,
pub max: f32,
pub checksum: u32,
}
#[derive(Debug, Clone)]
pub enum Regression {
ShapeMismatch {
expected: Vec<usize>,
actual: Vec<usize>,
},
MeanDrift {
expected: f32,
actual: f32,
error: f32,
},
StdDrift {
expected: f32,
actual: f32,
error: f32,
},
RangeDrift {
expected_min: f32,
expected_max: f32,
actual_min: f32,
actual_max: f32,
},
ChecksumMismatch { expected: u32, actual: u32 },
}
impl TensorCanary {
#[must_use]
pub fn from_data(
name: impl Into<String>,
shape: Vec<usize>,
dtype: impl Into<String>,
data: &[f32],
) -> Self {
let features = TensorFeatures::from_data(data);
let bytes: Vec<u8> = data
.iter()
.take(256) .flat_map(|f| f.to_le_bytes())
.collect();
let checksum = crc32_simple(&bytes);
Self {
name: name.into(),
shape,
dtype: dtype.into(),
mean: features.mean,
std: features.std,
min: features.min,
max: features.max,
checksum,
}
}
#[must_use]
pub fn detect_regression(&self, current: &TensorCanary) -> Option<Regression> {
if self.shape != current.shape {
return Some(Regression::ShapeMismatch {
expected: self.shape.clone(),
actual: current.shape.clone(),
});
}
let mean_base = self.mean.abs().max(1e-7);
let mean_error = (current.mean - self.mean).abs() / mean_base;
if mean_error > 0.01 {
return Some(Regression::MeanDrift {
expected: self.mean,
actual: current.mean,
error: mean_error,
});
}
let std_base = self.std.abs().max(1e-7);
let std_error = (current.std - self.std).abs() / std_base;
if std_error > 0.05 {
return Some(Regression::StdDrift {
expected: self.std,
actual: current.std,
error: std_error,
});
}
let range_tolerance = (self.max - self.min).abs() * 0.1;
if current.min < self.min - range_tolerance || current.max > self.max + range_tolerance {
return Some(Regression::RangeDrift {
expected_min: self.min,
expected_max: self.max,
actual_min: current.min,
actual_max: current.max,
});
}
if self.checksum != current.checksum {
return Some(Regression::ChecksumMismatch {
expected: self.checksum,
actual: current.checksum,
});
}
None
}
}
#[derive(Debug, Clone)]
pub struct CanaryFile {
pub model_name: String,
pub created_at: String,
pub tensors: Vec<TensorCanary>,
}
impl CanaryFile {
#[must_use]
pub fn new(model_name: impl Into<String>) -> Self {
Self {
model_name: model_name.into(),
created_at: chrono_now(),
tensors: Vec::new(),
}
}
pub fn add_tensor(&mut self, canary: TensorCanary) {
self.tensors.push(canary);
}
#[must_use]
pub fn verify(&self, current_tensors: &[TensorCanary]) -> Vec<(String, Regression)> {
let mut regressions = Vec::new();
for canary in &self.tensors {
if let Some(current) = current_tensors.iter().find(|t| t.name == canary.name) {
if let Some(regression) = canary.detect_regression(current) {
regressions.push((canary.name.clone(), regression));
}
} else {
regressions.push((
canary.name.clone(),
Regression::ShapeMismatch {
expected: canary.shape.clone(),
actual: vec![],
},
));
}
}
regressions
}
}
fn crc32_simple(data: &[u8]) -> u32 {
let mut crc = 0xFFFF_FFFFu32;
for &byte in data {
crc ^= u32::from(byte);
for _ in 0..8 {
if crc & 1 != 0 {
crc = (crc >> 1) ^ 0xEDB8_8320;
} else {
crc >>= 1;
}
}
}
!crc
}
fn chrono_now() -> String {
use std::time::{SystemTime, UNIX_EPOCH};
let duration = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default();
format!("{}", duration.as_secs())
}
#[cfg(test)]
mod tests;