#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SlidingWindowConfig {
pub window_size: usize,
pub sink_tokens: usize,
}
impl Default for SlidingWindowConfig {
fn default() -> Self {
Self {
window_size: 4096,
sink_tokens: 4,
}
}
}
impl SlidingWindowConfig {
pub fn new(window_size: usize, sink_tokens: usize) -> Self {
Self {
window_size,
sink_tokens,
}
}
pub fn is_disabled(&self) -> bool {
self.window_size == 0
}
}
pub fn attention_range(
pos: usize,
seq_len: usize,
config: &SlidingWindowConfig,
) -> (Vec<usize>, usize) {
if config.window_size == 0 || seq_len == 0 {
return (Vec::new(), 0);
}
let effective_seq_len = seq_len.min(pos + 1);
let sink_count = config.sink_tokens.min(effective_seq_len);
if config.window_size >= effective_seq_len {
let positions: Vec<usize> = (0..effective_seq_len).collect();
let count = positions.len();
return (positions, count);
}
let recent_budget = config.window_size.saturating_sub(sink_count);
let recent_start = if pos + 1 > recent_budget {
(pos + 1 - recent_budget).max(sink_count)
} else {
sink_count
};
let recent_end = pos + 1;
let mut positions: Vec<usize> = Vec::with_capacity(config.window_size);
for i in 0..sink_count {
positions.push(i);
}
for i in recent_start..recent_end {
positions.push(i);
}
let count = positions.len();
(positions, count)
}
pub fn apply_sliding_window_mask(
scores: &mut [f32],
query_pos: usize,
key_positions: &[usize],
config: &SlidingWindowConfig,
) {
debug_assert_eq!(scores.len(), key_positions.len());
if config.window_size == 0 {
for score in scores.iter_mut() {
*score = f32::NEG_INFINITY;
}
return;
}
let (valid_positions, _) = attention_range(
query_pos,
query_pos + 1, config,
);
for (score, &key_pos) in scores.iter_mut().zip(key_positions.iter()) {
if !valid_positions.contains(&key_pos) {
*score = f32::NEG_INFINITY;
}
}
}
pub fn evict_outside_window<T: Default + Clone>(
cache: &mut [T],
current_pos: usize,
config: &SlidingWindowConfig,
) -> usize {
if config.window_size == 0 || cache.is_empty() {
return 0;
}
let seq_len = cache.len().min(current_pos + 1);
let (valid_positions, _) = attention_range(current_pos, seq_len, config);
let mut evicted = 0;
for (pos, entry) in cache.iter_mut().enumerate().take(seq_len) {
if !valid_positions.contains(&pos) {
*entry = T::default();
evicted += 1;
}
}
evicted
}
#[inline]
pub fn is_in_window(key_pos: usize, query_pos: usize, config: &SlidingWindowConfig) -> bool {
if config.window_size == 0 {
return false;
}
if key_pos < config.sink_tokens {
return true;
}
if query_pos >= key_pos {
let distance = query_pos - key_pos;
let recent_budget = config.window_size.saturating_sub(config.sink_tokens);
distance < recent_budget
} else {
false
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_config() {
let config = SlidingWindowConfig::default();
assert_eq!(config.window_size, 4096);
assert_eq!(config.sink_tokens, 4);
assert!(!config.is_disabled());
}
#[test]
fn disabled_config() {
let config = SlidingWindowConfig::new(0, 0);
assert!(config.is_disabled());
}
#[test]
fn small_sequence_within_window() {
let config = SlidingWindowConfig::new(8, 2);
let (positions, count) = attention_range(4, 5, &config);
assert_eq!(count, 5);
assert_eq!(positions, vec![0, 1, 2, 3, 4]);
}
#[test]
fn window_slides_past_beginning() {
let config = SlidingWindowConfig::new(4, 2);
let (positions, count) = attention_range(10, 11, &config);
assert_eq!(count, 4);
assert_eq!(positions, vec![0, 1, 9, 10]);
}
#[test]
fn sink_tokens_always_included() {
let config = SlidingWindowConfig::new(4, 2);
let (positions, _) = attention_range(100, 101, &config);
assert!(positions.contains(&0));
assert!(positions.contains(&1));
}
#[test]
fn window_with_no_sinks() {
let config = SlidingWindowConfig::new(3, 0);
let (positions, count) = attention_range(10, 11, &config);
assert_eq!(count, 3);
assert_eq!(positions, vec![8, 9, 10]);
}
#[test]
fn empty_sequence() {
let config = SlidingWindowConfig::default();
let (positions, count) = attention_range(0, 0, &config);
assert_eq!(count, 0);
assert!(positions.is_empty());
}
#[test]
fn mask_application() {
let config = SlidingWindowConfig::new(3, 1);
let key_positions: Vec<usize> = (0..6).collect();
let mut scores = vec![1.0; 6];
apply_sliding_window_mask(&mut scores, 5, &key_positions, &config);
assert!(scores[0].is_finite(), "sink token should be kept");
assert!(scores[1] == f32::NEG_INFINITY, "pos 1 should be masked");
assert!(scores[2] == f32::NEG_INFINITY, "pos 2 should be masked");
assert!(scores[3] == f32::NEG_INFINITY, "pos 3 should be masked");
assert!(scores[4].is_finite(), "recent pos 4 should be kept");
assert!(scores[5].is_finite(), "recent pos 5 should be kept");
}
#[test]
fn eviction() {
let config = SlidingWindowConfig::new(3, 1);
let mut cache: Vec<i32> = vec![10, 20, 30, 40, 50, 60];
let evicted = evict_outside_window(&mut cache, 5, &config);
assert_eq!(evicted, 3);
assert_eq!(cache[0], 10, "sink should be preserved");
assert_eq!(cache[1], 0, "pos 1 should be evicted");
assert_eq!(cache[2], 0, "pos 2 should be evicted");
assert_eq!(cache[3], 0, "pos 3 should be evicted");
assert_eq!(cache[4], 50, "pos 4 should be preserved");
assert_eq!(cache[5], 60, "pos 5 should be preserved");
}
#[test]
fn is_in_window_basic() {
let config = SlidingWindowConfig::new(4, 2);
assert!(is_in_window(0, 100, &config));
assert!(is_in_window(1, 100, &config));
assert!(is_in_window(100, 100, &config)); assert!(is_in_window(99, 100, &config));
assert!(!is_in_window(98, 100, &config)); }
#[test]
fn zero_window_masks_everything() {
let config = SlidingWindowConfig::new(0, 0);
let mut scores = vec![1.0, 2.0, 3.0];
let key_positions = vec![0, 1, 2];
apply_sliding_window_mask(&mut scores, 2, &key_positions, &config);
for score in &scores {
assert_eq!(*score, f32::NEG_INFINITY);
}
}
#[test]
fn window_total_count_bounded() {
let config = SlidingWindowConfig::new(4, 2);
let (positions, count) = attention_range(1000, 1001, &config);
assert!(count <= config.window_size);
assert_eq!(positions.len(), count);
}
}