#[inline]
pub fn allowed(q: usize, k: usize, prompt_len: usize, n_swa: usize, swa_layer: bool) -> bool {
let q_is_canvas = q >= prompt_len;
let k_is_canvas = k >= prompt_len;
if q_is_canvas {
if swa_layer {
k_is_canvas || k + n_swa > prompt_len } else {
true
}
} else {
if k_is_canvas || k > q {
return false;
}
if swa_layer {
q - k < n_swa
} else {
true
}
}
}
pub fn build_unified_mask(
n_tokens: usize,
prompt_len: usize,
n_swa: usize,
swa_layer: bool,
) -> Vec<f32> {
let mut m = vec![f32::NEG_INFINITY; n_tokens * n_tokens];
for q in 0..n_tokens {
for k in 0..n_tokens {
if allowed(q, k, prompt_len, n_swa, swa_layer) {
m[q * n_tokens + k] = 0.0;
}
}
}
m
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prompt_rows_are_causal_and_never_see_canvas() {
let (p, c, n_swa) = (4usize, 3usize, 1024usize);
let n = p + c;
for q in 0..p {
for k in 0..n {
let a = allowed(q, k, p, n_swa, false);
if k >= p {
assert!(!a, "prompt q{q} must not see canvas k{k}");
} else {
assert_eq!(a, k <= q, "prompt q{q}→k{k} must be causal");
}
}
}
}
#[test]
fn canvas_rows_are_bidirectional_global() {
let (p, c) = (4usize, 3usize);
let n = p + c;
for q in p..n {
for k in 0..n {
assert!(
allowed(q, k, p, 1024, false),
"global canvas q{q} must attend everywhere (k{k})"
);
}
}
}
#[test]
fn canvas_rows_swa_window_the_prompt_but_keep_full_canvas() {
let (p, c, n_swa) = (5usize, 4usize, 3usize);
let n = p + c;
for q in p..n {
for k in 0..p {
let a = allowed(q, k, p, n_swa, true);
assert_eq!(a, k >= p - (n_swa - 1), "swa canvas q{q}→prompt k{k}");
}
for k in p..n {
assert!(allowed(q, k, p, n_swa, true), "swa canvas keeps all canvas");
}
}
}
#[test]
fn prompt_rows_swa_are_causal_windowed() {
let (p, n_swa) = (6usize, 2usize);
for q in 0..p {
for k in 0..p {
let a = allowed(q, k, p, n_swa, true);
let want = k <= q && q - k < n_swa;
assert_eq!(a, want, "swa prompt q{q}→k{k}");
}
}
}
#[test]
fn full_mask_matches_allowed() {
let m = build_unified_mask(7, 4, 3, true);
for q in 0..7 {
for k in 0..7 {
let v = m[q * 7 + k];
if allowed(q, k, 4, 3, true) {
assert_eq!(v, 0.0);
} else {
assert_eq!(v, f32::NEG_INFINITY);
}
}
}
}
}