use crate::commands::embeddings_classifier::{
classify_determinism, classify_embeddings_response_shape, classify_usage_tokens,
parse_embeddings_flag, DeterminismOutcome, EmbeddingRow, EmbeddingsFlagOutcome,
EmbeddingsShapeOutcome, UsageOutcome, EMBEDDINGS_COSINE_TOLERANCE,
};
use serde_json::Value;
use std::fs;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct EmbeddingsLintArgs {
pub observation_file: String,
pub json: bool,
}
#[derive(Debug, Clone, serde::Serialize)]
struct GateReport {
gate: &'static str,
falsify_id: &'static str,
outcome: String,
passed: bool,
}
pub fn run(args: EmbeddingsLintArgs) -> Result<(), String> {
let path = Path::new(&args.observation_file);
if !path.exists() {
return Err(format!(
"FALSIFY-CRUX-C-13: observation file not found: {}",
args.observation_file
));
}
let raw = fs::read_to_string(path)
.map_err(|e| format!("FALSIFY-CRUX-C-13: failed to read observation: {e}"))?;
if raw.trim().is_empty() {
return Err("FALSIFY-CRUX-C-13: observation file is empty".to_string());
}
let obs: Value = serde_json::from_str(&raw)
.map_err(|e| format!("FALSIFY-CRUX-C-13: observation is not valid JSON: {e}"))?;
let mut reports: Vec<GateReport> = Vec::new();
let mut failures: Vec<String> = Vec::new();
if let Some(shape) = obs.get("shape") {
let (report, err) = run_shape_gate(shape);
reports.push(report);
if let Some(e) = err {
failures.push(e);
}
}
if let Some(det) = obs.get("determinism") {
let (report, err) = run_determinism_gate(det);
reports.push(report);
if let Some(e) = err {
failures.push(e);
}
}
if let Some(usage) = obs.get("usage") {
let (report, err) = run_usage_gate(usage);
reports.push(report);
if let Some(e) = err {
failures.push(e);
}
}
if let Some(flag) = obs.get("flag") {
let (report, err) = run_flag_gate(flag);
reports.push(report);
if let Some(e) = err {
failures.push(e);
}
}
if reports.is_empty() {
return Err(
"FALSIFY-CRUX-C-13: observation has none of shape/determinism/usage/flag".to_string(),
);
}
if args.json {
let payload = serde_json::json!({
"contract": "CRUX-C-13",
"gates": reports,
});
println!("{}", serde_json::to_string_pretty(&payload).unwrap());
} else {
for r in &reports {
let tag = if r.passed { "PASS" } else { "FAIL" };
println!("[{tag}] {} ({}): {}", r.gate, r.falsify_id, r.outcome);
}
}
if !failures.is_empty() {
return Err(failures.join("\n"));
}
Ok(())
}
fn run_shape_gate(v: &Value) -> (GateReport, Option<String>) {
let input_len = v.get("input_len").and_then(|x| x.as_u64()).unwrap_or(0) as usize;
let hidden_size = v.get("hidden_size").and_then(|x| x.as_u64()).unwrap_or(0) as usize;
let rows_raw: Vec<(u64, Vec<f32>)> = v
.get("data")
.and_then(|x| x.as_array())
.map(|arr| {
arr.iter()
.map(|row| {
let index = row.get("index").and_then(|x| x.as_u64()).unwrap_or(0);
let embedding: Vec<f32> = row
.get("embedding")
.and_then(|x| x.as_array())
.map(|a| {
a.iter()
.filter_map(|n| n.as_f64().map(|f| f as f32))
.collect()
})
.unwrap_or_default();
(index, embedding)
})
.collect()
})
.unwrap_or_default();
let rows: Vec<EmbeddingRow<'_>> = rows_raw
.iter()
.map(|(i, e)| EmbeddingRow {
index: *i,
embedding: e.as_slice(),
})
.collect();
let outcome = classify_embeddings_response_shape(input_len, &rows, hidden_size);
let passed = matches!(outcome, EmbeddingsShapeOutcome::Ok { .. });
let desc = format!("{outcome:?}");
let err = if passed {
None
} else {
Some(format!("FALSIFY-CRUX-C-13-001 shape gate failed: {desc}"))
};
(
GateReport {
gate: "shape",
falsify_id: "FALSIFY-CRUX-C-13-001",
outcome: desc,
passed,
},
err,
)
}
fn run_determinism_gate(v: &Value) -> (GateReport, Option<String>) {
let v1: Vec<f32> = v
.get("v1")
.and_then(|x| x.as_array())
.map(|a| {
a.iter()
.filter_map(|n| n.as_f64().map(|f| f as f32))
.collect()
})
.unwrap_or_default();
let v2: Vec<f32> = v
.get("v2")
.and_then(|x| x.as_array())
.map(|a| {
a.iter()
.filter_map(|n| n.as_f64().map(|f| f as f32))
.collect()
})
.unwrap_or_default();
let outcome = classify_determinism(&v1, &v2, EMBEDDINGS_COSINE_TOLERANCE);
let passed = matches!(outcome, DeterminismOutcome::Deterministic { .. });
let desc = format!("{outcome:?}");
let err = if passed {
None
} else {
Some(format!(
"FALSIFY-CRUX-C-13-002 determinism gate failed: {desc}"
))
};
(
GateReport {
gate: "determinism",
falsify_id: "FALSIFY-CRUX-C-13-002",
outcome: desc,
passed,
},
err,
)
}
fn run_usage_gate(v: &Value) -> (GateReport, Option<String>) {
let prompt = v.get("prompt").and_then(|x| x.as_u64()).unwrap_or(0);
let total = v.get("total").and_then(|x| x.as_u64()).unwrap_or(0);
let outcome = classify_usage_tokens(prompt, total);
let passed = matches!(outcome, UsageOutcome::Ok { .. });
let desc = format!("{outcome:?}");
let err = if passed {
None
} else {
Some(format!("FALSIFY-CRUX-C-13-003 usage gate failed: {desc}"))
};
(
GateReport {
gate: "usage",
falsify_id: "FALSIFY-CRUX-C-13-003",
outcome: desc,
passed,
},
err,
)
}
fn run_flag_gate(v: &Value) -> (GateReport, Option<String>) {
let argv: Vec<String> = v
.get("argv")
.and_then(|x| x.as_array())
.map(|a| {
a.iter()
.filter_map(|n| n.as_str().map(|s| s.to_string()))
.collect()
})
.unwrap_or_default();
let argv_refs: Vec<&str> = argv.iter().map(String::as_str).collect();
let expected = v
.get("expected")
.and_then(|x| x.as_str())
.unwrap_or("enabled");
let outcome = parse_embeddings_flag(&argv_refs);
let observed = match &outcome {
EmbeddingsFlagOutcome::Enabled => "enabled",
EmbeddingsFlagOutcome::Disabled => "disabled",
EmbeddingsFlagOutcome::MalformedFlag { .. } => "malformed",
};
let passed = observed == expected;
let desc = format!("{outcome:?} (expected={expected}, observed={observed})");
let err = if passed {
None
} else {
Some(format!("FALSIFY-CRUX-C-13-004 flag gate failed: {desc}"))
};
(
GateReport {
gate: "flag",
falsify_id: "FALSIFY-CRUX-C-13-004",
outcome: desc,
passed,
},
err,
)
}