use crate::error::{RealizarError, Result};
use std::fmt;
#[derive(Debug, Clone)]
pub struct TensorStats {
pub len: usize,
pub zero_count: usize,
pub nan_count: usize,
pub inf_count: usize,
pub min: f32,
pub max: f32,
pub mean: f32,
pub l2_norm: f32,
}
impl TensorStats {
pub fn compute(data: &[f32]) -> Self {
let len = data.len();
if len == 0 {
return Self {
len: 0,
zero_count: 0,
nan_count: 0,
inf_count: 0,
min: 0.0,
max: 0.0,
mean: 0.0,
l2_norm: 0.0,
};
}
let mut zero_count = 0;
let mut nan_count = 0;
let mut inf_count = 0;
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut sum = 0.0f64;
let mut sum_sq = 0.0f64;
for &v in data {
if v.is_nan() {
nan_count += 1;
} else if v.is_infinite() {
inf_count += 1;
} else {
if v.abs() < 1e-10 {
zero_count += 1;
}
if v < min {
min = v;
}
if v > max {
max = v;
}
sum += v as f64;
sum_sq += (v as f64) * (v as f64);
}
}
Self {
len,
zero_count,
nan_count,
inf_count,
min: if min == f32::INFINITY { 0.0 } else { min },
max: if max == f32::NEG_INFINITY { 0.0 } else { max },
mean: (sum / len as f64) as f32,
l2_norm: (sum_sq.sqrt()) as f32,
}
}
pub fn zero_pct(&self) -> f32 {
if self.len == 0 {
return 0.0;
}
100.0 * self.zero_count as f32 / self.len as f32
}
}
#[derive(Debug)]
pub struct ValidationResult {
pub passed: bool,
pub stats: TensorStats,
pub failures: Vec<String>,
}
pub fn validate_embedding(
name: &str,
data: &[f32],
vocab_size: usize,
hidden_dim: usize,
) -> ValidationResult {
let stats = TensorStats::compute(data);
let mut failures = Vec::new();
let expected_len = vocab_size * hidden_dim;
if data.len() != expected_len {
failures.push(format!(
"Shape mismatch: got {} elements, expected {} ({}x{})",
data.len(),
expected_len,
vocab_size,
hidden_dim
));
}
let zero_pct = stats.zero_pct();
if zero_pct > 50.0 {
failures.push(format!(
"DENSITY FAILURE: {:.1}% zeros (max 50%). Data likely loaded from wrong offset!",
zero_pct
));
}
if stats.nan_count > 0 {
failures.push(format!("Contains {} NaN values", stats.nan_count));
}
if stats.inf_count > 0 {
failures.push(format!("Contains {} Inf values", stats.inf_count));
}
if stats.l2_norm < 1e-6 {
failures.push("L2 norm ~0: tensor is effectively empty".to_string());
}
if (stats.max - stats.min).abs() < 1e-10 {
failures.push("All values identical: tensor is constant".to_string());
}
for pct in [10, 50, 90] {
let token_id = vocab_size * pct / 100;
let start = token_id * hidden_dim;
let end = start + hidden_dim;
if end <= data.len() {
let token_l2: f32 = data[start..end].iter().map(|x| x * x).sum::<f32>().sqrt();
if token_l2 < 1e-6 {
failures.push(format!(
"Token {} ({}% of vocab) has L2=0: embedding data likely corrupted",
token_id, pct
));
}
}
}
let passed = failures.is_empty();
if !passed {
eprintln!("[VALIDATION FAILED] {}: {:?}", name, failures);
}
ValidationResult {
passed,
stats,
failures,
}
}
pub fn validate_weight(
name: &str,
data: &[f32],
out_dim: usize,
in_dim: usize,
) -> ValidationResult {
let stats = TensorStats::compute(data);
let mut failures = Vec::new();
let expected_len = out_dim * in_dim;
if data.len() != expected_len {
failures.push(format!(
"Shape mismatch: got {} elements, expected {} ({}x{})",
data.len(),
expected_len,
out_dim,
in_dim
));
}
let zero_pct = stats.zero_pct();
if zero_pct > 80.0 {
failures.push(format!("DENSITY FAILURE: {:.1}% zeros (max 80%)", zero_pct));
}
if stats.nan_count > 0 {
failures.push(format!("Contains {} NaN values", stats.nan_count));
}
if stats.inf_count > 0 {
failures.push(format!("Contains {} Inf values", stats.inf_count));
}
if stats.l2_norm < 1e-6 {
failures.push("L2 norm ~0".to_string());
}
let passed = failures.is_empty();
if !passed {
eprintln!("[VALIDATION FAILED] {}: {:?}", name, failures);
}
ValidationResult {
passed,
stats,
failures,
}
}
pub fn validate_vector(_name: &str, data: &[f32], expected_len: usize) -> ValidationResult {
let stats = TensorStats::compute(data);
let mut failures = Vec::new();
if data.len() != expected_len {
failures.push(format!(
"Length mismatch: got {}, expected {}",
data.len(),
expected_len
));
}
if stats.nan_count > 0 {
failures.push(format!("Contains {} NaN values", stats.nan_count));
}
if stats.inf_count > 0 {
failures.push(format!("Contains {} Inf values", stats.inf_count));
}
let passed = failures.is_empty();
ValidationResult {
passed,
stats,
failures,
}
}
pub fn enforce_embedding_validation(
name: &str,
data: &[f32],
vocab_size: usize,
hidden_dim: usize,
) -> Result<()> {
let result = validate_embedding(name, data, vocab_size, hidden_dim);
if !result.passed {
return Err(RealizarError::FormatError {
reason: format!(
"Tensor '{}' failed validation: {}",
name,
result.failures.join("; ")
),
});
}
Ok(())
}
pub fn enforce_weight_validation(
name: &str,
data: &[f32],
out_dim: usize,
in_dim: usize,
) -> Result<()> {
let result = validate_weight(name, data, out_dim, in_dim);
if !result.passed {
return Err(RealizarError::FormatError {
reason: format!(
"Tensor '{}' failed validation: {}",
name,
result.failures.join("; ")
),
});
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct ContractValidationError {
pub tensor_name: String,
pub rule_id: String,
pub message: String,
}
impl fmt::Display for ContractValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"[{}] Tensor '{}': {}",
self.rule_id, self.tensor_name, self.message
)
}
}
impl std::error::Error for ContractValidationError {}
impl From<ContractValidationError> for RealizarError {
fn from(e: ContractValidationError) -> Self {
RealizarError::FormatError {
reason: e.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct ValidatedEmbedding {
data: Vec<f32>,
vocab_size: usize,
hidden_dim: usize,
stats: TensorStats,
}
include!("validation_embedding.rs");
include!("inner.rs");
include!("falsification_tests.rs");