use std::fmt;
use std::marker::PhantomData;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RowMajor;
#[derive(Debug, Clone)]
pub struct ContractValidationError {
pub tensor_name: String,
pub rule_id: String,
pub message: String,
}
impl fmt::Display for ContractValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"[{}] Tensor '{}': {}",
self.rule_id, self.tensor_name, self.message
)
}
}
impl std::error::Error for ContractValidationError {}
#[derive(Debug, Clone)]
pub struct TensorStats {
pub len: usize,
pub zero_count: usize,
pub nan_count: usize,
pub inf_count: usize,
pub min: f32,
pub max: f32,
pub mean: f32,
pub l2_norm: f32,
}
impl TensorStats {
#[must_use]
pub fn compute(data: &[f32]) -> Self {
let len = data.len();
if len == 0 {
return Self {
len: 0,
zero_count: 0,
nan_count: 0,
inf_count: 0,
min: 0.0,
max: 0.0,
mean: 0.0,
l2_norm: 0.0,
};
}
let mut zero_count = 0;
let mut nan_count = 0;
let mut inf_count = 0;
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut sum = 0.0f64;
let mut sum_sq = 0.0f64;
for &v in data {
if v.is_nan() {
nan_count += 1;
} else if v.is_infinite() {
inf_count += 1;
} else {
if v.abs() < 1e-10 {
zero_count += 1;
}
if v < min {
min = v;
}
if v > max {
max = v;
}
sum += v as f64;
sum_sq += (v as f64) * (v as f64);
}
}
Self {
len,
zero_count,
nan_count,
inf_count,
min: if min == f32::INFINITY { 0.0 } else { min },
max: if max == f32::NEG_INFINITY { 0.0 } else { max },
mean: (sum / len as f64) as f32,
l2_norm: sum_sq.sqrt() as f32,
}
}
#[must_use]
pub fn zero_pct(&self) -> f32 {
if self.len == 0 {
return 0.0;
}
100.0 * self.zero_count as f32 / self.len as f32
}
}
#[derive(Debug, Clone)]
pub struct ValidatedEmbedding {
data: Vec<f32>,
vocab_size: usize,
hidden_dim: usize,
stats: TensorStats,
}
impl ValidatedEmbedding {
const MAX_ZERO_PCT: f32 = 50.0;
const MIN_L2_NORM: f32 = 1e-6;
const MIN_TOKEN_L2: f32 = 1e-6;
const SPOT_CHECK_PCTS: [usize; 3] = [10, 50, 90];
const MAX_DEAD_ROW_PCT: f32 = 25.0;
pub fn new(
data: Vec<f32>,
vocab_size: usize,
hidden_dim: usize,
) -> Result<Self, ContractValidationError> {
let expected_len = vocab_size * hidden_dim;
if data.len() != expected_len {
return Err(ContractValidationError {
tensor_name: "embedding".to_string(),
rule_id: "F-LAYOUT-CONTRACT-001".to_string(),
message: format!(
"Shape mismatch: got {} elements, expected {} ({}x{})",
data.len(),
expected_len,
vocab_size,
hidden_dim
),
});
}
let stats = TensorStats::compute(&data);
Self::validate_stats(&stats)?;
Self::validate_spot_checks(&data, vocab_size, hidden_dim)?;
Self::validate_dead_rows(&data, vocab_size, hidden_dim)?;
Ok(Self {
data,
vocab_size,
hidden_dim,
stats,
})
}
fn validate_stats(stats: &TensorStats) -> Result<(), ContractValidationError> {
if stats.zero_pct() > Self::MAX_ZERO_PCT {
return Err(ContractValidationError {
tensor_name: "embedding".to_string(),
rule_id: "F-DATA-QUALITY-001".to_string(),
message: format!(
"DENSITY FAILURE: {:.1}% zeros (max {}%). Data likely loaded from wrong offset!",
stats.zero_pct(),
Self::MAX_ZERO_PCT
),
});
}
if stats.nan_count > 0 {
return Err(ContractValidationError {
tensor_name: "embedding".to_string(),
rule_id: "F-DATA-QUALITY-002".to_string(),
message: format!("Contains {} NaN values", stats.nan_count),
});
}
if stats.inf_count > 0 {
return Err(ContractValidationError {
tensor_name: "embedding".to_string(),
rule_id: "F-DATA-QUALITY-002".to_string(),
message: format!("Contains {} Inf values", stats.inf_count),
});
}
if stats.l2_norm < Self::MIN_L2_NORM {
return Err(ContractValidationError {
tensor_name: "embedding".to_string(),
rule_id: "F-DATA-QUALITY-003".to_string(),
message: "L2 norm ~0: tensor is effectively empty".to_string(),
});
}
if (stats.max - stats.min).abs() < 1e-10 {
return Err(ContractValidationError {
tensor_name: "embedding".to_string(),
rule_id: "F-DATA-QUALITY-003".to_string(),
message: "All values identical: tensor is constant".to_string(),
});
}
Ok(())
}
fn validate_spot_checks(
data: &[f32],
vocab_size: usize,
hidden_dim: usize,
) -> Result<(), ContractValidationError> {
for pct in Self::SPOT_CHECK_PCTS {
let token_id = vocab_size * pct / 100;
let start = token_id * hidden_dim;
let end = start + hidden_dim;
if end <= data.len() {
let token_l2: f32 = data[start..end].iter().map(|x| x * x).sum::<f32>().sqrt();
if token_l2 < Self::MIN_TOKEN_L2 {
return Err(ContractValidationError {
tensor_name: "embedding".to_string(),
rule_id: "F-DATA-QUALITY-004".to_string(),
message: format!(
"Token {} ({}% of vocab) has L2={:.2e}: embedding data likely corrupted or offset",
token_id, pct, token_l2
),
});
}
}
}
Ok(())
}
fn validate_dead_rows(
data: &[f32],
vocab_size: usize,
hidden_dim: usize,
) -> Result<(), ContractValidationError> {
let mut dead_count = 0usize;
for token_id in 0..vocab_size {
let start = token_id * hidden_dim;
let end = start + hidden_dim;
if end <= data.len() {
let row_l2: f32 = data[start..end].iter().map(|x| x * x).sum::<f32>().sqrt();
if row_l2 < Self::MIN_TOKEN_L2 {
dead_count += 1;
}
}
}
let dead_pct = 100.0 * dead_count as f32 / vocab_size as f32;
if dead_pct > Self::MAX_DEAD_ROW_PCT {
return Err(ContractValidationError {
tensor_name: "embedding".to_string(),
rule_id: "F-DATA-QUALITY-005".to_string(),
message: format!(
"DEAD ROW FAILURE: {dead_count}/{vocab_size} rows ({dead_pct:.1}%) are dead (max {}%)",
Self::MAX_DEAD_ROW_PCT
),
});
}
Ok(())
}
#[must_use]
pub fn data(&self) -> &[f32] {
&self.data
}
#[must_use]
pub fn into_inner(self) -> Vec<f32> {
self.data
}
#[must_use]
pub fn vocab_size(&self) -> usize {
self.vocab_size
}
#[must_use]
pub fn hidden_dim(&self) -> usize {
self.hidden_dim
}
#[must_use]
pub fn stats(&self) -> &TensorStats {
&self.stats
}
}
#[derive(Debug, Clone)]
pub struct ValidatedWeight<L = RowMajor> {
data: Vec<f32>,
out_dim: usize,
in_dim: usize,
name: String,
stats: TensorStats,
_layout: PhantomData<L>,
}
include!("validated_vector.rs");
include!("validated_tensors_tests.rs");