#[non_exhaustive]
#[derive(Debug, Clone)]
pub struct SlidingWindowConfig {
pub window_size: usize,
pub global_tokens: usize,
}
impl Default for SlidingWindowConfig {
fn default() -> Self {
Self {
window_size: 512,
global_tokens: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct SlidingWindowAttention {
config: SlidingWindowConfig,
}
impl SlidingWindowAttention {
pub fn new(config: SlidingWindowConfig) -> Self {
Self { config }
}
pub fn forward(
&self,
q: &[f64],
k: &[f64],
v: &[f64],
seq_len: usize,
head_dim: usize,
) -> Vec<f64> {
let expected = seq_len * head_dim;
if seq_len == 0 || head_dim == 0 || q.len() < expected {
return vec![0.0; expected];
}
let scale = 1.0 / (head_dim as f64).sqrt();
let mut out = vec![0.0f64; expected];
let window = self.config.window_size;
let global = self.config.global_tokens;
for qi in 0..seq_len {
let (k_start, k_end) = if qi < global {
(0, seq_len)
} else {
(qi.saturating_sub(window), (qi + window + 1).min(seq_len))
};
let v_out = self.compute_local_attention(
qi, k_start, k_end, q, k, v, head_dim, scale, global, seq_len,
);
let out_offset = qi * head_dim;
out[out_offset..out_offset + head_dim].copy_from_slice(&v_out);
}
out
}
#[allow(clippy::too_many_arguments)]
pub fn compute_local_attention(
&self,
q_pos: usize,
k_start: usize,
k_end: usize,
q: &[f64],
k: &[f64],
v: &[f64],
head_dim: usize,
scale: f64,
global: usize,
seq_len: usize,
) -> Vec<f64> {
let q_row = &q[q_pos * head_dim..(q_pos + 1) * head_dim];
let mut key_positions: Vec<usize> = (k_start..k_end).collect();
for g in 0..global.min(seq_len) {
if g < k_start || g >= k_end {
key_positions.push(g);
}
}
key_positions.sort_unstable();
key_positions.dedup();
if key_positions.is_empty() {
return vec![0.0f64; head_dim];
}
let mut scores: Vec<f64> = key_positions
.iter()
.map(|&kj| {
let k_row = &k[kj * head_dim..(kj + 1) * head_dim];
let dot: f64 = q_row
.iter()
.zip(k_row.iter())
.map(|(&qi, &ki)| qi * ki)
.sum();
dot * scale
})
.collect();
let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
for s in &mut scores {
*s = (*s - max_s).exp();
}
let sum_s: f64 = scores.iter().sum();
if sum_s > 0.0 {
for s in &mut scores {
*s /= sum_s;
}
} else {
let u = 1.0 / key_positions.len() as f64;
scores.iter_mut().for_each(|s| *s = u);
}
let mut out = vec![0.0f64; head_dim];
for (&weight, &kj) in scores.iter().zip(key_positions.iter()) {
let v_row = &v[kj * head_dim..(kj + 1) * head_dim];
for (o, &vv) in out.iter_mut().zip(v_row.iter()) {
*o += weight * vv;
}
}
out
}
pub fn config(&self) -> &SlidingWindowConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
fn arithmetic_qkv(seq_len: usize, head_dim: usize) -> Vec<f64> {
let n = seq_len * head_dim;
(0..n).map(|i| (i as f64 + 1.0) * 0.05).collect()
}
fn dense_attention(
q: &[f64],
k: &[f64],
v: &[f64],
seq_len: usize,
head_dim: usize,
) -> Vec<f64> {
let scale = 1.0 / (head_dim as f64).sqrt();
let mut out = vec![0.0f64; seq_len * head_dim];
for qi in 0..seq_len {
let q_row = &q[qi * head_dim..(qi + 1) * head_dim];
let mut scores: Vec<f64> = (0..seq_len)
.map(|kj| {
let k_row = &k[kj * head_dim..(kj + 1) * head_dim];
let dot: f64 = q_row.iter().zip(k_row.iter()).map(|(a, b)| a * b).sum();
dot * scale
})
.collect();
let max_s = scores.iter().copied().fold(f64::NEG_INFINITY, f64::max);
for s in &mut scores {
*s = (*s - max_s).exp();
}
let sum_s: f64 = scores.iter().sum();
for s in &mut scores {
*s /= sum_s;
}
let out_row = &mut out[qi * head_dim..(qi + 1) * head_dim];
for kj in 0..seq_len {
let v_row = &v[kj * head_dim..(kj + 1) * head_dim];
for (o, &vv) in out_row.iter_mut().zip(v_row.iter()) {
*o += scores[kj] * vv;
}
}
}
out
}
#[test]
fn test_sliding_window_vs_full_for_short_seq() {
let seq_len = 6;
let head_dim = 4;
let window_size = 8;
let q = arithmetic_qkv(seq_len, head_dim);
let k = arithmetic_qkv(seq_len, head_dim);
let v = arithmetic_qkv(seq_len, head_dim);
let cfg = SlidingWindowConfig {
window_size,
global_tokens: 0,
};
let swa = SlidingWindowAttention::new(cfg);
let swa_out = swa.forward(&q, &k, &v, seq_len, head_dim);
let dense_out = dense_attention(&q, &k, &v, seq_len, head_dim);
for (a, b) in swa_out.iter().zip(dense_out.iter()) {
assert!(
(a - b).abs() < 1e-9,
"sliding-window with large window should match dense: {a:.8} vs {b:.8}"
);
}
}
#[test]
fn test_sliding_window_attends_only_to_neighbors() {
let seq_len = 8;
let head_dim = 2;
let window_size = 1;
let q = vec![1.0f64; seq_len * head_dim];
let k = q.clone();
let v: Vec<f64> = (0..seq_len)
.flat_map(|i| vec![(i + 1) as f64; head_dim])
.collect();
let cfg = SlidingWindowConfig {
window_size,
global_tokens: 0,
};
let swa = SlidingWindowAttention::new(cfg);
let out = swa.forward(&q, &k, &v, seq_len, head_dim);
let pos4_val = out[4 * head_dim];
let expected = (4.0 + 5.0 + 6.0) / 3.0;
assert!(
(pos4_val - expected).abs() < 1e-9,
"pos4 should attend to neighbors: got {pos4_val:.6}, expected {expected:.6}"
);
}
#[test]
fn test_sliding_window_boundary_positions() {
let seq_len = 4;
let head_dim = 2;
let window_size = 1;
let q = vec![1.0f64; seq_len * head_dim];
let k = q.clone();
let v: Vec<f64> = (0..seq_len)
.flat_map(|i| vec![(i + 1) as f64; head_dim])
.collect();
let cfg = SlidingWindowConfig {
window_size,
global_tokens: 0,
};
let swa = SlidingWindowAttention::new(cfg);
let out = swa.forward(&q, &k, &v, seq_len, head_dim);
let pos0_val = out[0];
assert!(
(pos0_val - 1.5).abs() < 1e-9,
"boundary pos0: got {pos0_val:.6}"
);
let pos3_val = out[3 * head_dim];
assert!(
(pos3_val - 3.5).abs() < 1e-9,
"boundary pos3: got {pos3_val:.6}"
);
}
#[test]
fn test_sliding_window_global_tokens() {
let seq_len = 6;
let head_dim = 2;
let window_size = 1;
let q = vec![1.0f64; seq_len * head_dim];
let k = q.clone();
let v: Vec<f64> = (0..seq_len)
.flat_map(|i| vec![(i + 1) as f64; head_dim])
.collect();
let cfg = SlidingWindowConfig {
window_size,
global_tokens: 1,
};
let swa = SlidingWindowAttention::new(cfg);
let out = swa.forward(&q, &k, &v, seq_len, head_dim);
let pos0_val = out[0];
let expected = (1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0) / 6.0;
assert!(
(pos0_val - expected).abs() < 1e-9,
"global pos0 should attend all: got {pos0_val:.6}, expected {expected:.6}"
);
}
#[test]
fn test_sliding_window_global_token_included_by_all() {
let seq_len = 8;
let head_dim = 2;
let window_size = 1; let global_tokens = 1;
let q = vec![1.0f64; seq_len * head_dim];
let k = q.clone();
let mut v = vec![1.0f64; seq_len * head_dim];
v[0] = 100.0;
v[1] = 100.0;
let cfg = SlidingWindowConfig {
window_size,
global_tokens,
};
let swa = SlidingWindowAttention::new(cfg);
let out = swa.forward(&q, &k, &v, seq_len, head_dim);
let pos5 = out[5 * head_dim];
assert!(
pos5 > 1.0,
"pos5 should be above 1.0 due to global token: got {pos5:.4}"
);
}
#[test]
fn test_sliding_window_output_shape() {
let seq_len = 10;
let head_dim = 8;
let q = arithmetic_qkv(seq_len, head_dim);
let cfg = SlidingWindowConfig::default();
let swa = SlidingWindowAttention::new(cfg);
let out = swa.forward(&q, &q, &q, seq_len, head_dim);
assert_eq!(out.len(), seq_len * head_dim);
}
#[test]
fn test_sliding_window_empty_sequence() {
let cfg = SlidingWindowConfig::default();
let swa = SlidingWindowAttention::new(cfg);
let out = swa.forward(&[], &[], &[], 0, 4);
assert!(out.is_empty() || out.iter().all(|&x| x == 0.0));
}
#[test]
fn test_sliding_window_config_default() {
let cfg = SlidingWindowConfig::default();
assert_eq!(cfg.window_size, 512);
assert_eq!(cfg.global_tokens, 0);
}
}