use crate::commands::nf4_classifier::{
expected_nf4_storage_bytes, nearest_codebook_index, nf4_dequantize_block, nf4_quantize_block,
rel_l2_error, NF4_CODEBOOK, NF4_DEFAULT_BLOCK_SIZE, NF4_MAX_REL_L2_ERROR_SYNTHETIC,
};
use serde_json::Value;
use std::fs;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct Nf4LintArgs {
pub observation_file: String,
pub json: bool,
}
#[derive(Debug, Clone, serde::Serialize)]
struct GateReport {
gate: &'static str,
falsify_id: &'static str,
outcome: String,
passed: bool,
}
pub fn run(args: Nf4LintArgs) -> Result<(), String> {
let path = Path::new(&args.observation_file);
if !path.exists() {
return Err(format!(
"FALSIFY-CRUX-B-10: observation file not found: {}",
args.observation_file
));
}
let raw = fs::read_to_string(path)
.map_err(|e| format!("FALSIFY-CRUX-B-10: failed to read observation: {e}"))?;
if raw.trim().is_empty() {
return Err("FALSIFY-CRUX-B-10: observation file is empty".to_string());
}
let obs: Value = serde_json::from_str(&raw)
.map_err(|e| format!("FALSIFY-CRUX-B-10: observation is not valid JSON: {e}"))?;
let mut reports: Vec<GateReport> = Vec::new();
let mut failures: Vec<String> = Vec::new();
if let Some(cb) = obs.get("codebook") {
let (report, err) = run_codebook_gate(cb);
reports.push(report);
if let Some(e) = err {
failures.push(e);
}
}
if let Some(rt) = obs.get("roundtrip") {
let (report, err) = run_roundtrip_gate(rt);
reports.push(report);
if let Some(e) = err {
failures.push(e);
}
}
if let Some(st) = obs.get("storage") {
let (report, err) = run_storage_gate(st);
reports.push(report);
if let Some(e) = err {
failures.push(e);
}
}
if let Some(p) = obs.get("parity") {
let (report, err) = run_parity_gate(p);
reports.push(report);
if let Some(e) = err {
failures.push(e);
}
}
if reports.is_empty() {
return Err(
"FALSIFY-CRUX-B-10: observation has none of codebook/roundtrip/storage/parity".into(),
);
}
if args.json {
let payload = serde_json::json!({
"contract": "CRUX-B-10",
"gates": reports,
});
println!("{}", serde_json::to_string_pretty(&payload).unwrap());
} else {
for r in &reports {
let tag = if r.passed { "PASS" } else { "FAIL" };
println!("[{tag}] {} ({}): {}", r.gate, r.falsify_id, r.outcome);
}
}
if !failures.is_empty() {
return Err(failures.join("\n"));
}
Ok(())
}
fn read_f32_array(v: &Value) -> Vec<f32> {
v.as_array()
.map(|a| {
a.iter()
.filter_map(|n| n.as_f64().map(|f| f as f32))
.collect()
})
.unwrap_or_default()
}
fn run_codebook_gate(v: &Value) -> (GateReport, Option<String>) {
let expected = v.get("expected").map(read_f32_array).unwrap_or_default();
let passed = if expected.is_empty() {
NF4_CODEBOOK.len() == 16
} else if expected.len() != NF4_CODEBOOK.len() {
false
} else {
expected
.iter()
.zip(NF4_CODEBOOK.iter())
.all(|(e, c)| (*e - *c).abs() < 1e-6)
};
let desc = if passed {
format!("codebook matches ({} entries)", NF4_CODEBOOK.len())
} else {
format!(
"codebook divergence (expected_len={}, got_len={})",
expected.len(),
NF4_CODEBOOK.len()
)
};
let err = if passed {
None
} else {
Some(format!(
"FALSIFY-CRUX-B-10-001 codebook gate failed: {desc}"
))
};
(
GateReport {
gate: "codebook",
falsify_id: "FALSIFY-CRUX-B-10-001",
outcome: desc,
passed,
},
err,
)
}
fn run_roundtrip_gate(v: &Value) -> (GateReport, Option<String>) {
let weights = v.get("weights").map(read_f32_array).unwrap_or_default();
let max_rel_l2 = v
.get("max_rel_l2")
.and_then(|x| x.as_f64())
.unwrap_or(NF4_MAX_REL_L2_ERROR_SYNTHETIC);
if weights.is_empty() {
let desc = "weights array is empty".to_string();
return (
GateReport {
gate: "roundtrip",
falsify_id: "FALSIFY-CRUX-B-10-003",
outcome: desc.clone(),
passed: false,
},
Some(format!(
"FALSIFY-CRUX-B-10-003 roundtrip gate failed: {desc}"
)),
);
}
let (idx, scale) = nf4_quantize_block(&weights);
let recon = nf4_dequantize_block(&idx, scale);
let err = rel_l2_error(&weights, &recon);
let passed = err.is_finite() && err <= max_rel_l2;
let desc = format!("rel_l2={err:.6} (max={max_rel_l2})");
let fail = if passed {
None
} else {
Some(format!(
"FALSIFY-CRUX-B-10-003 roundtrip gate failed: {desc}"
))
};
(
GateReport {
gate: "roundtrip",
falsify_id: "FALSIFY-CRUX-B-10-003",
outcome: desc,
passed,
},
fail,
)
}
fn run_storage_gate(v: &Value) -> (GateReport, Option<String>) {
let n_weights = v.get("n_weights").and_then(|x| x.as_u64()).unwrap_or(0);
let block_size = v
.get("block_size")
.and_then(|x| x.as_u64())
.unwrap_or(NF4_DEFAULT_BLOCK_SIZE as u64);
let dq = v
.get("double_quant")
.and_then(|x| x.as_bool())
.unwrap_or(false);
let min_bpw = v
.get("expected_min_bytes_per_weight")
.and_then(|x| x.as_f64())
.unwrap_or(0.50);
let max_bpw = v
.get("expected_max_bytes_per_weight")
.and_then(|x| x.as_f64())
.unwrap_or(0.65);
if n_weights == 0 || block_size == 0 {
let desc = format!("invalid n_weights={n_weights} block_size={block_size}");
return (
GateReport {
gate: "storage",
falsify_id: "FALSIFY-CRUX-B-10-002",
outcome: desc.clone(),
passed: false,
},
Some(format!("FALSIFY-CRUX-B-10-002 storage gate failed: {desc}")),
);
}
let bytes = expected_nf4_storage_bytes(n_weights, block_size, dq);
let bpw = (bytes as f64) / (n_weights as f64);
let passed = bpw >= min_bpw && bpw <= max_bpw;
let desc = format!("bytes={bytes} bpw={bpw:.4} (envelope [{min_bpw},{max_bpw}], dq={dq})");
let err = if passed {
None
} else {
Some(format!("FALSIFY-CRUX-B-10-002 storage gate failed: {desc}"))
};
(
GateReport {
gate: "storage",
falsify_id: "FALSIFY-CRUX-B-10-002",
outcome: desc,
passed,
},
err,
)
}
fn run_parity_gate(v: &Value) -> (GateReport, Option<String>) {
let target = v
.get("target")
.and_then(|x| x.as_f64())
.map(|f| f as f32)
.unwrap_or(0.0);
let expected = v
.get("expected_index")
.and_then(|x| x.as_u64())
.unwrap_or(0) as u8;
let got = nearest_codebook_index(target);
let passed = got == expected;
let desc = format!("target={target} expected_index={expected} got={got}");
let err = if passed {
None
} else {
Some(format!("FALSIFY-CRUX-B-10-004 parity gate failed: {desc}"))
};
(
GateReport {
gate: "parity",
falsify_id: "FALSIFY-CRUX-B-10-004",
outcome: desc,
passed,
},
err,
)
}