use serde_json::Value;
pub const F11_REQUIRED_OP_PREFIXES: &[&str] = &[
"attention_q",
"attention_k",
"attention_v",
"attention_out",
"ffn_gate",
"ffn_up",
"ffn_down",
"layernorm",
];
pub const F11_ERROR_JSON_KEYS: &[&str] =
&["error", "layer", "shape", "first_bad_index", "value", "op"];
pub const F11_ALLOWED_VALUES: &[&str] = &["nan", "+inf", "-inf"];
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CheckFiniteErrorOutcome {
Ok,
NotAnObject,
MissingKey { key: &'static str },
ErrorTagWrong { got: String },
ValueOutOfSet { got: String },
ShapeNotIntArray,
FirstBadIndexNotNonNegative { got: i64 },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CheckFiniteCoverageOutcome {
Ok { count: usize },
NotAnObject,
LayersNotArray,
TooFewLayers { got: usize, min: usize },
MissingOpPrefixes { missing: Vec<String> },
}
pub fn classify_error_json(body: &Value) -> CheckFiniteErrorOutcome {
let Some(obj) = body.as_object() else {
return CheckFiniteErrorOutcome::NotAnObject;
};
for k in F11_ERROR_JSON_KEYS {
if !obj.contains_key(*k) {
return CheckFiniteErrorOutcome::MissingKey { key: k };
}
}
let err = obj.get("error").and_then(Value::as_str).unwrap_or("");
if err != "non_finite" {
return CheckFiniteErrorOutcome::ErrorTagWrong {
got: err.to_string(),
};
}
let value = obj.get("value").and_then(Value::as_str).unwrap_or("");
if !F11_ALLOWED_VALUES.contains(&value) {
return CheckFiniteErrorOutcome::ValueOutOfSet {
got: value.to_string(),
};
}
let Some(shape) = obj.get("shape").and_then(Value::as_array) else {
return CheckFiniteErrorOutcome::ShapeNotIntArray;
};
for dim in shape {
if dim.as_i64().is_none() {
return CheckFiniteErrorOutcome::ShapeNotIntArray;
}
}
let idx = obj
.get("first_bad_index")
.and_then(Value::as_i64)
.unwrap_or(-1);
if idx < 0 {
return CheckFiniteErrorOutcome::FirstBadIndexNotNonNegative { got: idx };
}
CheckFiniteErrorOutcome::Ok
}
pub fn classify_layer_coverage(
body: &Value,
min_layers: usize,
required_op_prefixes: &[&str],
) -> CheckFiniteCoverageOutcome {
let Some(obj) = body.as_object() else {
return CheckFiniteCoverageOutcome::NotAnObject;
};
let Some(layers) = obj.get("layers").and_then(Value::as_array) else {
return CheckFiniteCoverageOutcome::LayersNotArray;
};
let count = layers.len();
if count < min_layers {
return CheckFiniteCoverageOutcome::TooFewLayers {
got: count,
min: min_layers,
};
}
let mut names: Vec<String> = Vec::new();
for l in layers {
if let Some(n) = l.get("name").and_then(Value::as_str) {
names.push(n.to_string());
} else if let Some(n) = l.as_str() {
names.push(n.to_string());
}
}
let mut missing: Vec<String> = required_op_prefixes
.iter()
.filter(|p| !names.iter().any(|n| n.contains(**p)))
.map(|s| (*s).to_string())
.collect();
if !missing.is_empty() {
missing.sort();
return CheckFiniteCoverageOutcome::MissingOpPrefixes { missing };
}
CheckFiniteCoverageOutcome::Ok { count }
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn good_error_json() -> Value {
json!({
"error": "non_finite",
"layer": "blk.0.ffn_up",
"shape": [1, 4096],
"first_bad_index": 0,
"value": "nan",
"op": "ffn_up"
})
}
fn good_layer_list(n_blocks: usize) -> Value {
let mut layers: Vec<Value> = Vec::new();
for b in 0..n_blocks {
for op in &[
"attention_q",
"attention_k",
"attention_v",
"attention_out",
"ffn_gate",
"ffn_up",
"ffn_down",
"layernorm_in",
"layernorm_post",
] {
layers.push(json!({"name": format!("blk.{b}.{op}"), "shape": [1, 4096]}));
}
}
layers.push(json!({"name": "embed_tokens", "shape": [1, 4096]}));
layers.push(json!({"name": "output_norm", "shape": [1, 4096]}));
layers.push(json!({"name": "lm_head", "shape": [1, 32000]}));
json!({"layers": layers})
}
#[test]
fn error_json_ok_on_well_formed() {
assert_eq!(
classify_error_json(&good_error_json()),
CheckFiniteErrorOutcome::Ok
);
}
#[test]
fn error_json_rejects_not_an_object() {
assert_eq!(
classify_error_json(&json!([1, 2])),
CheckFiniteErrorOutcome::NotAnObject
);
}
#[test]
fn error_json_reports_missing_key() {
let body = json!({"error": "non_finite", "shape": [1], "first_bad_index": 0, "value": "nan", "op": "x"});
assert_eq!(
classify_error_json(&body),
CheckFiniteErrorOutcome::MissingKey { key: "layer" }
);
}
#[test]
fn error_json_rejects_wrong_error_tag() {
let body = json!({"error": "weird_other", "layer": "L", "shape": [1], "first_bad_index": 0, "value": "nan", "op": "x"});
assert!(matches!(
classify_error_json(&body),
CheckFiniteErrorOutcome::ErrorTagWrong { .. }
));
}
#[test]
fn error_json_rejects_value_outside_allowed_set() {
let body = json!({"error": "non_finite", "layer": "L", "shape": [1], "first_bad_index": 0, "value": "banana", "op": "x"});
assert!(matches!(
classify_error_json(&body),
CheckFiniteErrorOutcome::ValueOutOfSet { .. }
));
}
#[test]
fn error_json_rejects_non_int_shape() {
let body = json!({"error": "non_finite", "layer": "L", "shape": ["wide"], "first_bad_index": 0, "value": "nan", "op": "x"});
assert_eq!(
classify_error_json(&body),
CheckFiniteErrorOutcome::ShapeNotIntArray
);
}
#[test]
fn error_json_rejects_negative_first_bad_index() {
let body = json!({"error": "non_finite", "layer": "L", "shape": [1], "first_bad_index": -1, "value": "nan", "op": "x"});
assert!(matches!(
classify_error_json(&body),
CheckFiniteErrorOutcome::FirstBadIndexNotNonNegative { got: -1 }
));
}
#[test]
fn layer_coverage_ok_on_well_formed_list() {
let body = good_layer_list(28);
match classify_layer_coverage(&body, 100, F11_REQUIRED_OP_PREFIXES) {
CheckFiniteCoverageOutcome::Ok { count } => assert_eq!(count, 28 * 9 + 3),
other => panic!("expected Ok, got {other:?}"),
}
}
#[test]
fn layer_coverage_rejects_not_an_object() {
assert_eq!(
classify_layer_coverage(&json!([1, 2]), 100, F11_REQUIRED_OP_PREFIXES),
CheckFiniteCoverageOutcome::NotAnObject
);
}
#[test]
fn layer_coverage_rejects_missing_layers_array() {
assert_eq!(
classify_layer_coverage(&json!({"otherKey": []}), 100, F11_REQUIRED_OP_PREFIXES),
CheckFiniteCoverageOutcome::LayersNotArray
);
}
#[test]
fn layer_coverage_reports_too_few_layers() {
let body = good_layer_list(2); assert!(matches!(
classify_layer_coverage(&body, 100, F11_REQUIRED_OP_PREFIXES),
CheckFiniteCoverageOutcome::TooFewLayers { got: 21, min: 100 }
));
}
#[test]
fn layer_coverage_reports_missing_op_prefixes() {
let mut layers: Vec<Value> = Vec::new();
for b in 0..20 {
for op in &["attention_q", "attention_k", "attention_v", "attention_out"] {
layers.push(json!({"name": format!("blk.{b}.{op}")}));
}
for k in 0..5 {
layers.push(json!({"name": format!("blk.{b}.extra_{k}")}));
}
}
let body = json!({"layers": layers});
match classify_layer_coverage(&body, 100, F11_REQUIRED_OP_PREFIXES) {
CheckFiniteCoverageOutcome::MissingOpPrefixes { missing } => {
assert!(missing.contains(&"ffn_gate".to_string()));
assert!(missing.contains(&"layernorm".to_string()));
assert!(!missing.contains(&"attention_q".to_string()));
}
other => panic!("expected MissingOpPrefixes, got {other:?}"),
}
}
}