use serde_json::Value;
#[derive(Debug, Clone, PartialEq)]
pub enum AttnRowsOutcome {
Ok {
shape: (usize, usize, usize, usize),
},
NotA4DArray {
message: String,
},
RowOutOfNormalization {
layer: usize,
head: usize,
row: usize,
sum: f64,
tolerance: f64,
},
}
#[derive(Debug, Clone, PartialEq)]
pub enum AttnCausalMaskOutcome {
Ok,
NonZeroFuturePosition {
layer: usize,
head: usize,
row: usize,
col: usize,
value: f64,
epsilon: f64,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AttnHtmlOutcome {
Ok { count: usize },
TooFewHeatmaps { got: usize, expected: usize },
}
fn parse_4d_array(
body: &Value,
) -> Result<(Vec<Vec<Vec<Vec<f64>>>>, (usize, usize, usize, usize)), String> {
let layers = body
.as_array()
.ok_or_else(|| "root is not an array".to_string())?;
if layers.is_empty() {
return Err("layers dimension is empty".to_string());
}
let mut out: Vec<Vec<Vec<Vec<f64>>>> = Vec::with_capacity(layers.len());
let mut shape = (layers.len(), 0usize, 0usize, 0usize);
for (li, layer) in layers.iter().enumerate() {
let heads = layer
.as_array()
.ok_or_else(|| format!("layer[{li}] not an array"))?;
let mut layer_vec: Vec<Vec<Vec<f64>>> = Vec::with_capacity(heads.len());
for (hi, head) in heads.iter().enumerate() {
let rows = head
.as_array()
.ok_or_else(|| format!("layer[{li}].head[{hi}] not an array"))?;
let mut head_vec: Vec<Vec<f64>> = Vec::with_capacity(rows.len());
for (ri, row) in rows.iter().enumerate() {
let cols = row
.as_array()
.ok_or_else(|| format!("layer[{li}].head[{hi}].row[{ri}] not an array"))?;
let mut row_vec: Vec<f64> = Vec::with_capacity(cols.len());
for (ci, c) in cols.iter().enumerate() {
let v = c.as_f64().ok_or_else(|| {
format!("layer[{li}].head[{hi}].row[{ri}].col[{ci}] not a number")
})?;
row_vec.push(v);
}
head_vec.push(row_vec);
}
layer_vec.push(head_vec);
}
if li == 0 {
shape.1 = layer_vec.len();
shape.2 = layer_vec.first().map(Vec::len).unwrap_or(0);
shape.3 = layer_vec
.first()
.and_then(|h| h.first())
.map(Vec::len)
.unwrap_or(0);
}
out.push(layer_vec);
}
Ok((out, shape))
}
pub fn classify_row_softmax_normalization(body: &Value, tolerance: f64) -> AttnRowsOutcome {
let (arr, shape) = match parse_4d_array(body) {
Ok(v) => v,
Err(e) => return AttnRowsOutcome::NotA4DArray { message: e },
};
for (li, layer) in arr.iter().enumerate() {
for (hi, head) in layer.iter().enumerate() {
for (ri, row) in head.iter().enumerate() {
let s: f64 = row.iter().sum();
if (s - 1.0).abs() > tolerance {
return AttnRowsOutcome::RowOutOfNormalization {
layer: li,
head: hi,
row: ri,
sum: s,
tolerance,
};
}
}
}
}
AttnRowsOutcome::Ok { shape }
}
pub fn classify_causal_mask(body: &Value, epsilon: f64) -> AttnCausalMaskOutcome {
let arr = match parse_4d_array(body) {
Ok((v, _)) => v,
Err(_) => return AttnCausalMaskOutcome::Ok,
};
for (li, layer) in arr.iter().enumerate() {
for (hi, head) in layer.iter().enumerate() {
for (ri, row) in head.iter().enumerate() {
for (ci, &v) in row.iter().enumerate() {
if ci > ri && v.abs() > epsilon {
return AttnCausalMaskOutcome::NonZeroFuturePosition {
layer: li,
head: hi,
row: ri,
col: ci,
value: v,
epsilon,
};
}
}
}
}
}
AttnCausalMaskOutcome::Ok
}
pub fn classify_html_heatmap_count(html: &str, expected: usize) -> AttnHtmlOutcome {
let count = count_tag_opens(html, "<svg") + count_tag_opens(html, "<canvas");
if count >= expected {
AttnHtmlOutcome::Ok { count }
} else {
AttnHtmlOutcome::TooFewHeatmaps {
got: count,
expected,
}
}
}
fn count_tag_opens(haystack: &str, needle: &str) -> usize {
let mut n = 0usize;
let mut start = 0usize;
while let Some(idx) = haystack[start..].find(needle) {
n += 1;
start += idx + needle.len();
}
n
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
fn good_array() -> Value {
let row0 = json!([1.0, 0.0, 0.0]); let row1 = json!([0.4, 0.6, 0.0]); let row2 = json!([0.2, 0.3, 0.5]); let head = json!([row0.clone(), row1.clone(), row2.clone()]);
let layer = json!([head.clone(), head.clone()]);
json!([layer.clone(), layer.clone()])
}
#[test]
fn rows_softmax_ok_on_good_array() {
let out = classify_row_softmax_normalization(&good_array(), 1e-5);
assert!(matches!(out, AttnRowsOutcome::Ok { .. }), "{out:?}");
}
#[test]
fn rows_softmax_rejects_non_4d_input() {
let out = classify_row_softmax_normalization(&json!([1, 2, 3]), 1e-5);
assert!(
matches!(out, AttnRowsOutcome::NotA4DArray { .. }),
"{out:?}"
);
}
#[test]
fn rows_softmax_reports_unnormalized_row() {
let bad = json!([[[[0.6, 0.6, 0.0], [0.4, 0.6, 0.0], [0.2, 0.3, 0.5]]]]);
match classify_row_softmax_normalization(&bad, 1e-5) {
AttnRowsOutcome::RowOutOfNormalization {
layer: 0,
head: 0,
row: 0,
sum,
..
} => {
assert!((sum - 1.2).abs() < 1e-9, "got sum {sum}");
}
other => panic!("expected RowOutOfNormalization, got {other:?}"),
}
}
#[test]
fn rows_softmax_tolerance_can_be_relaxed() {
let body = json!([[[[0.51, 0.50, 0.0], [0.4, 0.6, 0.0], [0.2, 0.3, 0.5]]]]);
assert!(matches!(
classify_row_softmax_normalization(&body, 1e-5),
AttnRowsOutcome::RowOutOfNormalization { .. }
));
assert!(matches!(
classify_row_softmax_normalization(&body, 0.05),
AttnRowsOutcome::Ok { .. }
));
}
#[test]
fn causal_mask_ok_on_good_array() {
assert_eq!(
classify_causal_mask(&good_array(), 1e-9),
AttnCausalMaskOutcome::Ok
);
}
#[test]
fn causal_mask_reports_nonzero_future() {
let body = json!([[[[0.5, 0.5, 0.0], [0.4, 0.6, 0.0], [0.2, 0.3, 0.5]]]]);
match classify_causal_mask(&body, 1e-9) {
AttnCausalMaskOutcome::NonZeroFuturePosition {
layer: 0,
head: 0,
row: 0,
col: 1,
value,
..
} => {
assert!((value - 0.5).abs() < 1e-9, "got value {value}");
}
other => panic!("expected NonZeroFuturePosition, got {other:?}"),
}
}
#[test]
fn html_heatmap_count_ok_when_threshold_met() {
let html = "<html><svg></svg><svg></svg><canvas></canvas><svg></svg></html>";
assert_eq!(
classify_html_heatmap_count(html, 4),
AttnHtmlOutcome::Ok { count: 4 }
);
}
#[test]
fn html_heatmap_count_reports_too_few() {
let html = "<html><svg></svg></html>";
assert_eq!(
classify_html_heatmap_count(html, 4),
AttnHtmlOutcome::TooFewHeatmaps {
got: 1,
expected: 4
}
);
}
#[test]
fn html_heatmap_count_handles_attributes_in_open_tag() {
let html = r#"<svg class="heatmap" width="200"></svg><svg id="b"></svg>"#;
assert_eq!(
classify_html_heatmap_count(html, 2),
AttnHtmlOutcome::Ok { count: 2 }
);
}
}