use crate::error::{RealizarError, Result};
use std::fmt;
const MAX_REASONABLE_WEIGHT: f32 = 1e6;
#[derive(Debug, Clone)]
pub struct TensorStats {
pub len: usize,
pub zero_count: usize,
pub nan_count: usize,
pub inf_count: usize,
pub min: f32,
pub max: f32,
pub mean: f32,
pub l2_norm: f32,
}
impl TensorStats {
pub fn compute(data: &[f32]) -> Self {
let len = data.len();
if len == 0 {
return Self {
len: 0,
zero_count: 0,
nan_count: 0,
inf_count: 0,
min: 0.0,
max: 0.0,
mean: 0.0,
l2_norm: 0.0,
};
}
let mut zero_count = 0;
let mut nan_count = 0;
let mut inf_count = 0;
let mut min = f32::INFINITY;
let mut max = f32::NEG_INFINITY;
let mut sum = 0.0f64;
let mut sum_sq = 0.0f64;
for &v in data {
if v.is_nan() {
nan_count += 1;
} else if v.is_infinite() {
inf_count += 1;
} else {
if v.abs() < 1e-10 {
zero_count += 1;
}
if v < min {
min = v;
}
if v > max {
max = v;
}
sum += v as f64;
sum_sq += (v as f64) * (v as f64);
}
}
Self {
len,
zero_count,
nan_count,
inf_count,
min: if min == f32::INFINITY { 0.0 } else { min },
max: if max == f32::NEG_INFINITY { 0.0 } else { max },
mean: (sum / len as f64) as f32,
l2_norm: (sum_sq.sqrt()) as f32,
}
}
pub fn zero_pct(&self) -> f32 {
if self.len == 0 {
return 0.0;
}
100.0 * self.zero_count as f32 / self.len as f32
}
}
#[derive(Debug)]
pub struct ValidationResult {
pub passed: bool,
pub stats: TensorStats,
pub failures: Vec<String>,
}
pub fn validate_embedding(
name: &str,
data: &[f32],
vocab_size: usize,
hidden_dim: usize,
) -> ValidationResult {
let stats = TensorStats::compute(data);
let mut failures = Vec::new();
let expected_len = vocab_size * hidden_dim;
if data.len() != expected_len {
failures.push(format!(
"Shape mismatch: got {} elements, expected {} ({}x{})",
data.len(),
expected_len,
vocab_size,
hidden_dim
));
}
let zero_pct = stats.zero_pct();
if zero_pct > 50.0 {
failures.push(format!(
"DENSITY FAILURE: {:.1}% zeros (max 50%). Data likely loaded from wrong offset!",
zero_pct
));
}
if stats.nan_count > 0 {
failures.push(format!("Contains {} NaN values", stats.nan_count));
}
if stats.inf_count > 0 {
failures.push(format!("Contains {} Inf values", stats.inf_count));
}
if stats.l2_norm < 1e-6 {
failures.push("L2 norm ~0: tensor is effectively empty".to_string());
}
if (stats.max - stats.min).abs() < 1e-10 {
failures.push("All values identical: tensor is constant".to_string());
}
for pct in [10, 50, 90] {
let token_id = vocab_size * pct / 100;
let start = token_id * hidden_dim;
let end = start + hidden_dim;
if end <= data.len() {
let token_l2: f32 = data[start..end].iter().map(|x| x * x).sum::<f32>().sqrt();
if token_l2 < 1e-6 {
failures.push(format!(
"Token {} ({}% of vocab) has L2=0: embedding data likely corrupted",
token_id, pct
));
}
}
}
let passed = failures.is_empty();
if !passed {
eprintln!("[VALIDATION FAILED] {}: {:?}", name, failures);
}
ValidationResult {
passed,
stats,
failures,
}
}
pub fn validate_weight(
name: &str,
data: &[f32],
out_dim: usize,
in_dim: usize,
) -> ValidationResult {
let stats = TensorStats::compute(data);
let mut failures = Vec::new();
let expected_len = out_dim * in_dim;
if data.len() != expected_len {
failures.push(format!(
"Shape mismatch: got {} elements, expected {} ({}x{})",
data.len(),
expected_len,
out_dim,
in_dim
));
}
let zero_pct = stats.zero_pct();
if zero_pct > 80.0 {
failures.push(format!("DENSITY FAILURE: {:.1}% zeros (max 80%)", zero_pct));
}
if stats.nan_count > 0 {
failures.push(format!("Contains {} NaN values", stats.nan_count));
}
if stats.inf_count > 0 {
failures.push(format!("Contains {} Inf values", stats.inf_count));
}
if stats.l2_norm < 1e-6 {
failures.push("L2 norm ~0".to_string());
}
let max_abs = stats.max.abs().max(stats.min.abs());
if max_abs > MAX_REASONABLE_WEIGHT {
failures.push(format!(
"Extreme magnitude: max|w|={max_abs:.3e} exceeds {MAX_REASONABLE_WEIGHT:.0e} \
(corruption — overflows inference)"
));
}
let passed = failures.is_empty();
if !passed {
eprintln!("[VALIDATION FAILED] {}: {:?}", name, failures);
}
ValidationResult {
passed,
stats,
failures,
}
}
pub fn validate_vector(_name: &str, data: &[f32], expected_len: usize) -> ValidationResult {
let stats = TensorStats::compute(data);
let mut failures = Vec::new();
if data.len() != expected_len {
failures.push(format!(
"Length mismatch: got {}, expected {}",
data.len(),
expected_len
));
}
if stats.nan_count > 0 {
failures.push(format!("Contains {} NaN values", stats.nan_count));
}
if stats.inf_count > 0 {
failures.push(format!("Contains {} Inf values", stats.inf_count));
}
let passed = failures.is_empty();
ValidationResult {
passed,
stats,
failures,
}
}
pub fn enforce_embedding_validation(
name: &str,
data: &[f32],
vocab_size: usize,
hidden_dim: usize,
) -> Result<()> {
let result = validate_embedding(name, data, vocab_size, hidden_dim);
if !result.passed {
return Err(RealizarError::FormatError {
reason: format!(
"Tensor '{}' failed validation: {}",
name,
result.failures.join("; ")
),
});
}
Ok(())
}
pub fn enforce_weight_validation(
name: &str,
data: &[f32],
out_dim: usize,
in_dim: usize,
) -> Result<()> {
let result = validate_weight(name, data, out_dim, in_dim);
if !result.passed {
return Err(RealizarError::FormatError {
reason: format!(
"Tensor '{}' failed validation: {}",
name,
result.failures.join("; ")
),
});
}
Ok(())
}
#[derive(Debug, Clone, Copy)]
struct Shape2D {
rows: usize,
cols: usize,
}
fn as_2d(shape: &[usize]) -> Option<Shape2D> {
if shape.len() == 2 {
Some(Shape2D {
rows: shape[0],
cols: shape[1],
})
} else {
None
}
}
fn find_shape<'a, I>(tensors: I, needles: &[&str]) -> Option<(&'a str, Shape2D)>
where
I: IntoIterator<Item = (&'a str, &'a [usize])>,
{
for (name, shape) in tensors {
for needle in needles {
if name == *needle || name.ends_with(needle) {
if let Some(s) = as_2d(shape) {
return Some((name, s));
}
}
}
}
None
}
pub fn validate_cross_tensor_structure<'a, I>(tensors: I) -> Result<()>
where
I: IntoIterator<Item = (&'a str, &'a [usize])> + Clone,
{
let embed = find_shape(
tensors.clone(),
&["embed_tokens.weight", "tok_embeddings.weight"],
);
let lm_head = find_shape(tensors.clone(), &["lm_head.weight", "output.weight"]);
let q_proj = find_shape(
tensors.clone(),
&[
"self_attn.q_proj.weight",
"attention.wq.weight",
"self_attn.qkv_proj.weight",
"attn.c_attn.weight",
],
);
if let (Some((emb_name, emb)), Some((lm_name, lm))) = (embed, lm_head) {
if emb.rows != lm.rows {
return Err(RealizarError::FormatError {
reason: format!(
"[F-STRUCT-001] Vocab-size mismatch: '{emb_name}' has {} rows (vocab) but \
'{lm_name}' has {} rows. The embedding table and the output head MUST index \
the same vocabulary; this model would emit out-of-range token ids and produce \
garbage. (safetensors/Transformers/Ollama load this silently.)",
emb.rows, lm.rows
),
});
}
}
if let (Some((emb_name, emb)), Some((q_name, q))) = (embed, q_proj) {
if emb.cols != q.cols {
return Err(RealizarError::FormatError {
reason: format!(
"[F-STRUCT-001] Hidden-dim mismatch: '{emb_name}' has hidden_dim {} but \
attention input '{q_name}' expects hidden_dim {}. The first attention matmul \
would be dimension-mismatched (OOB / garbage). \
(safetensors/Transformers/Ollama load this silently.)",
emb.cols, q.cols
),
});
}
}
Ok(())
}
#[derive(Debug, Clone)]
pub struct ContractValidationError {
pub tensor_name: String,
pub rule_id: String,
pub message: String,
}
impl fmt::Display for ContractValidationError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"[{}] Tensor '{}': {}",
self.rule_id, self.tensor_name, self.message
)
}
}
impl std::error::Error for ContractValidationError {}
impl From<ContractValidationError> for RealizarError {
fn from(e: ContractValidationError) -> Self {
RealizarError::FormatError {
reason: e.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct ValidatedEmbedding {
data: Vec<f32>,
vocab_size: usize,
hidden_dim: usize,
stats: TensorStats,
}
include!("validation_embedding.rs");
include!("inner.rs");
include!("falsification_tests.rs");