pub const AC_GEMM_BW_RELATIVE_TOLERANCE: f32 = 1e-4;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum GoGembwVerdict {
Pass,
Fail,
}
#[must_use]
pub fn classify_garbage(text: &str) -> bool {
if text.trim().is_empty() {
return true;
}
for ch in text.chars() {
if ch.is_control() && ch != '\n' && ch != '\t' && ch != '\r' {
return true;
}
}
if text.contains('\u{FFFD}') {
return true;
}
false
}
#[must_use]
pub fn verdict_from_no_false_positive(text: &str) -> GoGembwVerdict {
if classify_garbage(text) {
GoGembwVerdict::Fail } else {
GoGembwVerdict::Pass
}
}
#[must_use]
pub fn verdict_from_layout002_detection(
is_layout_garbage_input: bool,
detector_flagged: bool,
) -> GoGembwVerdict {
if is_layout_garbage_input == detector_flagged {
GoGembwVerdict::Pass
} else {
GoGembwVerdict::Fail
}
}
#[must_use]
pub fn verdict_from_control_char_detection(text: &str) -> GoGembwVerdict {
if classify_garbage(text) {
GoGembwVerdict::Pass
} else {
GoGembwVerdict::Fail
}
}
#[must_use]
pub fn verdict_from_empty_is_garbage(text: &str) -> GoGembwVerdict {
let is_empty_or_ws = text.trim().is_empty();
let detected = classify_garbage(text);
if is_empty_or_ws && detected {
GoGembwVerdict::Pass
} else {
GoGembwVerdict::Fail
}
}
#[must_use]
pub fn frobenius_norm(matrix: &[f32]) -> f32 {
if matrix.is_empty() {
return 0.0;
}
let mut sum = 0.0_f32;
for &x in matrix {
if !x.is_finite() {
return f32::NAN;
}
sum += x * x;
}
sum.sqrt()
}
#[must_use]
pub fn verdict_from_grad_correctness(
dw_tiled: &[f32],
dw_naive: &[f32],
) -> GoGembwVerdict {
if dw_tiled.is_empty() || dw_tiled.len() != dw_naive.len() {
return GoGembwVerdict::Fail;
}
let diff: Vec<f32> = dw_tiled
.iter()
.zip(dw_naive.iter())
.map(|(a, b)| a - b)
.collect();
let diff_norm = frobenius_norm(&diff);
let naive_norm = frobenius_norm(dw_naive);
if !diff_norm.is_finite() || !naive_norm.is_finite() {
return GoGembwVerdict::Fail;
}
if naive_norm == 0.0 {
return if diff_norm == 0.0 {
GoGembwVerdict::Pass
} else {
GoGembwVerdict::Fail
};
}
if diff_norm < AC_GEMM_BW_RELATIVE_TOLERANCE * naive_norm {
GoGembwVerdict::Pass
} else {
GoGembwVerdict::Fail
}
}
#[must_use]
pub fn verdict_from_transpose_involution(
original: &[f32],
double_transposed: &[f32],
) -> GoGembwVerdict {
if original.is_empty() || original.len() != double_transposed.len() {
return GoGembwVerdict::Fail;
}
for (a, b) in original.iter().zip(double_transposed.iter()) {
if a.to_bits() != b.to_bits() {
return GoGembwVerdict::Fail;
}
}
GoGembwVerdict::Pass
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn provenance_constants() {
assert_eq!(AC_GEMM_BW_RELATIVE_TOLERANCE, 1e-4);
}
#[test]
fn classify_empty_is_garbage() {
assert!(classify_garbage(""));
}
#[test]
fn classify_whitespace_only_is_garbage() {
assert!(classify_garbage(" \t "));
}
#[test]
fn classify_normal_text_not_garbage() {
assert!(!classify_garbage("Hello, world!"));
}
#[test]
fn classify_with_newlines_not_garbage() {
assert!(!classify_garbage("Line 1\nLine 2\n"));
}
#[test]
fn classify_with_null_byte_is_garbage() {
assert!(classify_garbage("hello\x00world"));
}
#[test]
fn classify_with_replacement_char_is_garbage() {
assert!(classify_garbage("garbage\u{FFFD}text"));
}
#[test]
fn fgo001_pass_valid_english() {
let v = verdict_from_no_false_positive("The quick brown fox jumps.");
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgo001_pass_valid_code() {
let v = verdict_from_no_false_positive("fn main() {\n println!(\"hi\");\n}");
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgo002_pass_layout_garbage_detected() {
let v = verdict_from_layout002_detection(true, true);
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgo002_pass_clean_input_not_flagged() {
let v = verdict_from_layout002_detection(false, false);
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgo002_fail_garbage_not_detected() {
let v = verdict_from_layout002_detection(true, false);
assert_eq!(v, GoGembwVerdict::Fail);
}
#[test]
fn fgo002_fail_false_positive() {
let v = verdict_from_layout002_detection(false, true);
assert_eq!(v, GoGembwVerdict::Fail);
}
#[test]
fn fgo003_pass_control_char_detected() {
let v = verdict_from_control_char_detection("text\x01with\x02ctrl");
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgo003_fail_normal_text_classified_garbage() {
let v = verdict_from_control_char_detection("normal text");
assert_eq!(v, GoGembwVerdict::Fail);
}
#[test]
fn fgo004_pass_empty_string() {
let v = verdict_from_empty_is_garbage("");
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgo004_pass_whitespace_only() {
let v = verdict_from_empty_is_garbage(" ");
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgo004_fail_non_empty() {
let v = verdict_from_empty_is_garbage("hello");
assert_eq!(v, GoGembwVerdict::Fail);
}
#[test]
fn fgembw001_pass_within_tolerance() {
let naive = vec![1.0_f32, 2.0, 3.0, 4.0];
let tiled = vec![1.00001_f32, 2.00001, 3.00001, 4.00001];
let v = verdict_from_grad_correctness(&tiled, &naive);
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgembw001_fail_far_drift() {
let naive = vec![1.0_f32];
let tiled = vec![2.0_f32];
let v = verdict_from_grad_correctness(&tiled, &naive);
assert_eq!(v, GoGembwVerdict::Fail);
}
#[test]
fn fgembw001_pass_both_zero() {
let naive = vec![0.0_f32, 0.0];
let tiled = vec![0.0_f32, 0.0];
let v = verdict_from_grad_correctness(&tiled, &naive);
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgembw001_fail_length_mismatch() {
let v = verdict_from_grad_correctness(&[1.0], &[1.0, 2.0]);
assert_eq!(v, GoGembwVerdict::Fail);
}
#[test]
fn fgembw002_pass_involution() {
let orig = vec![1.0_f32, 2.0, 3.0];
let dt = orig.clone();
let v = verdict_from_transpose_involution(&orig, &dt);
assert_eq!(v, GoGembwVerdict::Pass);
}
#[test]
fn fgembw002_fail_one_ulp_drift() {
let orig = vec![1.0_f32];
let bumped = f32::from_bits(1.0_f32.to_bits() + 1);
let dt = vec![bumped];
let v = verdict_from_transpose_involution(&orig, &dt);
assert_eq!(v, GoGembwVerdict::Fail);
}
#[test]
fn mutation_survey_garbage_strings() {
for bad in &[
"",
" ",
"\x00\x01\x02",
"olumbia+lsi nunca/localENTS\u{FFFD}",
"\u{FFFD}",
] {
assert!(classify_garbage(bad), "should be garbage: {bad:?}");
}
for good in &[
"Hello, world!",
"x = 42",
"fn main() { println!(\"hi\"); }",
"Line 1\nLine 2",
"tab\there",
] {
assert!(!classify_garbage(good), "should not be garbage: {good:?}");
}
}
#[test]
fn realistic_healthy_passes_all_6() {
let v1 = verdict_from_no_false_positive("Hello world");
let v2 = verdict_from_layout002_detection(true, true);
let v3 = verdict_from_control_char_detection("\x01ctrl");
let v4 = verdict_from_empty_is_garbage("");
let naive = vec![1.0_f32, 2.0];
let tiled = vec![1.0_f32, 2.0];
let v5 = verdict_from_grad_correctness(&tiled, &naive);
let v6 = verdict_from_transpose_involution(&[1.0, 2.0], &[1.0, 2.0]);
for v in [v1, v2, v3, v4, v5, v6] {
assert_eq!(v, GoGembwVerdict::Pass);
}
}
#[test]
fn realistic_pre_fix_all_6_failures() {
let v1 = verdict_from_no_false_positive("\x00\x01");
let v2 = verdict_from_layout002_detection(true, false); let v3 = verdict_from_control_char_detection("clean text");
let v4 = verdict_from_empty_is_garbage("non-empty");
let v5 = verdict_from_grad_correctness(&[10.0], &[1.0]); let bumped = f32::from_bits(1.0_f32.to_bits() + 1);
let v6 = verdict_from_transpose_involution(&[1.0], &[bumped]);
for v in [v1, v2, v3, v4, v5, v6] {
assert_eq!(v, GoGembwVerdict::Fail);
}
}
}