#[must_use]
pub fn swa_mask_symmetric(seq_len: usize, w: usize) -> Vec<Vec<u8>> {
let mut m = vec![vec![0_u8; seq_len]; seq_len];
for i in 0..seq_len {
for j in 0..seq_len {
let d = i.abs_diff(j);
if d < w { m[i][j] = 1; }
}
}
m
}
#[must_use]
pub fn swa_mask_causal(seq_len: usize, w: usize) -> Vec<Vec<u8>> {
let mut m = vec![vec![0_u8; seq_len]; seq_len];
for i in 0..seq_len {
for j in 0..=i {
if i - j < w { m[i][j] = 1; }
}
}
m
}
#[must_use]
pub fn windowed_softmax(logits: &[Vec<f32>], mask: &[Vec<u8>]) -> Option<Vec<Vec<f64>>> {
if logits.is_empty() || logits.len() != mask.len() { return None; }
let n = logits.len();
let mut out = Vec::with_capacity(n);
for i in 0..n {
if logits[i].len() != n || mask[i].len() != n { return None; }
let mut max = f32::NEG_INFINITY;
for j in 0..n { if mask[i][j] == 1 && logits[i][j] > max { max = logits[i][j]; } }
if !max.is_finite() { return None; }
let mut sum = 0.0_f64;
let mut exps = vec![0.0_f64; n];
for j in 0..n {
if mask[i][j] == 1 {
let e = ((logits[i][j] - max) as f64).exp();
exps[j] = e;
sum += e;
}
}
if sum == 0.0 { return None; }
let row: Vec<f64> = exps.iter().map(|e| e / sum).collect();
out.push(row);
}
Some(out)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Swa001Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_window_symmetry(seq_len: usize, w: usize) -> Swa001Verdict {
if seq_len == 0 || w == 0 { return Swa001Verdict::Fail; }
let m = swa_mask_symmetric(seq_len, w);
for i in 0..seq_len {
for j in 0..seq_len {
if m[i][j] != m[j][i] { return Swa001Verdict::Fail; }
}
}
Swa001Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Swa002Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_causal_constraint(seq_len: usize, w: usize) -> Swa002Verdict {
if seq_len == 0 || w == 0 { return Swa002Verdict::Fail; }
let m = swa_mask_causal(seq_len, w);
for i in 0..seq_len {
for j in (i + 1)..seq_len {
if m[i][j] != 0 { return Swa002Verdict::Fail; }
}
}
Swa002Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Swa003Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_effective_context(seq_len: usize, w: usize) -> Swa003Verdict {
if seq_len == 0 || w == 0 { return Swa003Verdict::Fail; }
let m = swa_mask_causal(seq_len, w);
for i in 0..seq_len {
let count: usize = m[i].iter().map(|x| *x as usize).sum();
let expected = (i + 1).min(w);
if count != expected { return Swa003Verdict::Fail; }
}
Swa003Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Swa004Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_dense_degeneration(seq_len: usize) -> Swa004Verdict {
if seq_len == 0 { return Swa004Verdict::Fail; }
let m = swa_mask_causal(seq_len, seq_len);
for i in 0..seq_len {
for j in 0..seq_len {
let expected: u8 = u8::from(j <= i);
if m[i][j] != expected { return Swa004Verdict::Fail; }
}
}
Swa004Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Swa005Verdict { Pass, Fail }
#[must_use]
pub const fn receptive_field(num_layers: u32, w: u32) -> u32 {
if w == 0 || num_layers == 0 { return 0; }
1 + num_layers * (w - 1)
}
#[must_use]
pub const fn verdict_from_receptive_field(
num_layers: u32,
w: u32,
observed: u32,
) -> Swa005Verdict {
if num_layers == 0 || w == 0 { return Swa005Verdict::Fail; }
if receptive_field(num_layers, w) == observed { Swa005Verdict::Pass } else { Swa005Verdict::Fail }
}
pub const AC_SWA_006_TOLERANCE: f64 = 1e-9;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Swa006Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_windowed_softmax_normalized(
logits: &[Vec<f32>],
mask: &[Vec<u8>],
) -> Swa006Verdict {
let probs = match windowed_softmax(logits, mask) { Some(v) => v, None => return Swa006Verdict::Fail };
for row in probs {
let s: f64 = row.iter().sum();
if (s - 1.0).abs() > AC_SWA_006_TOLERANCE { return Swa006Verdict::Fail; }
}
Swa006Verdict::Pass
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Swa007Verdict { Pass, Fail }
#[must_use]
pub fn verdict_from_count_bound(seq_len: usize, w: usize) -> Swa007Verdict {
if seq_len == 0 || w == 0 { return Swa007Verdict::Fail; }
let m = swa_mask_causal(seq_len, w);
for row in m {
let count: usize = row.iter().map(|x| *x as usize).sum();
if count > w { return Swa007Verdict::Fail; }
}
Swa007Verdict::Pass
}
#[cfg(test)]
mod tests {
use super::*;
#[test] fn ref_symmetric_w3() {
let m = swa_mask_symmetric(5, 3);
for i in 0..5 {
for j in 0..5 {
let d = (i as i32 - j as i32).unsigned_abs() as usize;
let expected: u8 = u8::from(d < 3);
assert_eq!(m[i][j], expected);
}
}
}
#[test] fn ref_causal_w3_seq4() {
let m = swa_mask_causal(4, 3);
assert_eq!(m[0], vec![1, 0, 0, 0]);
assert_eq!(m[1], vec![1, 1, 0, 0]);
assert_eq!(m[2], vec![1, 1, 1, 0]);
assert_eq!(m[3], vec![0, 1, 1, 1]);
}
#[test] fn swa001_pass_w3() { assert_eq!(verdict_from_window_symmetry(8, 3), Swa001Verdict::Pass); }
#[test] fn swa001_pass_w_eq_seq() { assert_eq!(verdict_from_window_symmetry(8, 8), Swa001Verdict::Pass); }
#[test] fn swa001_fail_zero_w() { assert_eq!(verdict_from_window_symmetry(8, 0), Swa001Verdict::Fail); }
#[test] fn swa001_fail_zero_seq() { assert_eq!(verdict_from_window_symmetry(0, 3), Swa001Verdict::Fail); }
#[test] fn swa002_pass_w3() { assert_eq!(verdict_from_causal_constraint(8, 3), Swa002Verdict::Pass); }
#[test] fn swa002_pass_dense() { assert_eq!(verdict_from_causal_constraint(8, 8), Swa002Verdict::Pass); }
#[test] fn swa003_pass_canonical() { assert_eq!(verdict_from_effective_context(10, 4), Swa003Verdict::Pass); }
#[test] fn swa003_pass_w_gt_seq() { assert_eq!(verdict_from_effective_context(5, 100), Swa003Verdict::Pass); }
#[test] fn swa004_pass_seq8() { assert_eq!(verdict_from_dense_degeneration(8), Swa004Verdict::Pass); }
#[test] fn swa004_pass_seq1() { assert_eq!(verdict_from_dense_degeneration(1), Swa004Verdict::Pass); }
#[test] fn swa004_fail_zero() { assert_eq!(verdict_from_dense_degeneration(0), Swa004Verdict::Fail); }
#[test] fn swa005_pass_canonical() {
assert_eq!(verdict_from_receptive_field(4, 8, 29), Swa005Verdict::Pass);
}
#[test] fn swa005_pass_l1() {
assert_eq!(verdict_from_receptive_field(1, 8, 8), Swa005Verdict::Pass);
}
#[test] fn swa005_fail_off_by_one() {
assert_eq!(verdict_from_receptive_field(4, 8, 30), Swa005Verdict::Fail);
}
#[test] fn swa005_fail_zero_layers() {
assert_eq!(verdict_from_receptive_field(0, 8, 1), Swa005Verdict::Fail);
}
#[test] fn swa006_pass_uniform_logits() {
let n = 4;
let logits = vec![vec![1.0_f32; n]; n];
let mask = swa_mask_causal(n, 3);
assert_eq!(verdict_from_windowed_softmax_normalized(&logits, &mask), Swa006Verdict::Pass);
}
#[test] fn swa006_pass_random_like() {
let n = 6;
let logits: Vec<Vec<f32>> = (0..n).map(|i| {
(0..n).map(|j| ((i + j) as f32) * 0.3).collect()
}).collect();
let mask = swa_mask_causal(n, 3);
assert_eq!(verdict_from_windowed_softmax_normalized(&logits, &mask), Swa006Verdict::Pass);
}
#[test] fn swa006_fail_dim_mismatch() {
let logits = vec![vec![1.0_f32, 2.0]];
let mask = vec![vec![1_u8, 1, 1]];
assert_eq!(verdict_from_windowed_softmax_normalized(&logits, &mask), Swa006Verdict::Fail);
}
#[test] fn swa007_pass_w3_seq8() { assert_eq!(verdict_from_count_bound(8, 3), Swa007Verdict::Pass); }
#[test] fn swa007_pass_w_gt_seq() { assert_eq!(verdict_from_count_bound(8, 32), Swa007Verdict::Pass); }
#[test] fn swa007_fail_zero_w() { assert_eq!(verdict_from_count_bound(8, 0), Swa007Verdict::Fail); }
#[test] fn provenance_tolerance() {
assert!((AC_SWA_006_TOLERANCE - 1e-9).abs() < 1e-15);
}
#[test] fn receptive_field_canonical() {
assert_eq!(receptive_field(1, 8), 8);
assert_eq!(receptive_field(2, 8), 15);
assert_eq!(receptive_field(4, 8), 29);
assert_eq!(receptive_field(0, 8), 0);
assert_eq!(receptive_field(4, 0), 0);
}
}