use std::collections::VecDeque;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum SinkError {
#[error("head {head} out of range (num_heads = {num_heads})")]
HeadOutOfRange { head: usize, num_heads: usize },
#[error("layer {layer} out of range (num_layers = {num_layers})")]
LayerOutOfRange { layer: usize, num_layers: usize },
#[error("shape mismatch: expected {expected} elements, got {actual}")]
ShapeMismatch { expected: usize, actual: usize },
#[error("sink slots not yet filled (only {filled}/{total} sink tokens pushed)")]
SinkNotFilled { filled: usize, total: usize },
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AttentionSinkConfig {
pub num_sink_tokens: usize,
pub window_size: usize,
}
impl AttentionSinkConfig {
pub fn new(num_sink_tokens: usize, window_size: usize) -> Self {
Self {
num_sink_tokens,
window_size,
}
}
#[inline]
pub fn capacity(&self) -> usize {
self.num_sink_tokens + self.window_size
}
#[inline]
pub fn max_seq_len(&self) -> usize {
self.capacity()
}
}
impl Default for AttentionSinkConfig {
fn default() -> Self {
Self::new(4, 512)
}
}
#[derive(Debug, Clone)]
pub struct SinkSlot {
pub original_position: usize,
pub key: Vec<f32>,
pub value: Vec<f32>,
}
impl SinkSlot {
fn new(original_position: usize, key: Vec<f32>, value: Vec<f32>) -> Self {
Self {
original_position,
key,
value,
}
}
}
pub struct AttentionSinkLayer {
config: AttentionSinkConfig,
head_dim: usize,
num_heads: usize,
sinks: Vec<Vec<SinkSlot>>,
recent: Vec<VecDeque<SinkSlot>>,
pub total_tokens: usize,
evicted: usize,
}
impl AttentionSinkLayer {
pub fn new(config: AttentionSinkConfig, num_heads: usize, head_dim: usize) -> Self {
let sinks = (0..num_heads).map(|_| Vec::new()).collect();
let recent = (0..num_heads)
.map(|_| VecDeque::with_capacity(config.window_size))
.collect();
Self {
config,
head_dim,
num_heads,
sinks,
recent,
total_tokens: 0,
evicted: 0,
}
}
#[inline]
fn head_key_slice(keys: &[f32], h: usize, head_dim: usize) -> &[f32] {
let start = h * head_dim;
&keys[start..start + head_dim]
}
#[inline]
fn head_value_slice(values: &[f32], h: usize, head_dim: usize) -> &[f32] {
let start = h * head_dim;
&values[start..start + head_dim]
}
pub fn push(&mut self, keys: &[f32], values: &[f32]) -> Result<(), SinkError> {
let expected = self.num_heads * self.head_dim;
if keys.len() != expected {
return Err(SinkError::ShapeMismatch {
expected,
actual: keys.len(),
});
}
if values.len() != expected {
return Err(SinkError::ShapeMismatch {
expected,
actual: values.len(),
});
}
let pos = self.total_tokens;
let is_sink = pos < self.config.num_sink_tokens;
for h in 0..self.num_heads {
let k = Self::head_key_slice(keys, h, self.head_dim).to_vec();
let v = Self::head_value_slice(values, h, self.head_dim).to_vec();
let slot = SinkSlot::new(pos, k, v);
if is_sink {
self.sinks[h].push(slot);
} else {
if self.recent[h].len() >= self.config.window_size {
if h == 0 {
self.evicted += 1;
}
self.recent[h].pop_front();
}
self.recent[h].push_back(slot);
}
}
self.total_tokens += 1;
Ok(())
}
pub fn get_remapped_positions(&self) -> Vec<usize> {
let sink_count = self.sinks.first().map(|s| s.len()).unwrap_or(0);
let recent_count = self.recent.first().map(|r| r.len()).unwrap_or(0);
let total = sink_count + recent_count;
let mut positions = Vec::with_capacity(total);
for i in 0..sink_count {
positions.push(i);
}
for j in 0..recent_count {
positions.push(sink_count + j);
}
positions
}
#[inline]
pub fn cache_len(&self) -> usize {
let sink_count = self.sinks.first().map(|s| s.len()).unwrap_or(0);
let recent_count = self.recent.first().map(|r| r.len()).unwrap_or(0);
sink_count + recent_count
}
#[inline]
pub fn recent_len(&self) -> usize {
self.recent.first().map(|r| r.len()).unwrap_or(0)
}
#[inline]
pub fn is_streaming(&self) -> bool {
self.evicted > 0
}
pub fn get_keys_for_head(&self, head: usize) -> Result<Vec<f32>, SinkError> {
if head >= self.num_heads {
return Err(SinkError::HeadOutOfRange {
head,
num_heads: self.num_heads,
});
}
let cap = self.cache_len() * self.head_dim;
let mut out = Vec::with_capacity(cap);
for slot in &self.sinks[head] {
out.extend_from_slice(&slot.key);
}
for slot in &self.recent[head] {
out.extend_from_slice(&slot.key);
}
Ok(out)
}
pub fn get_values_for_head(&self, head: usize) -> Result<Vec<f32>, SinkError> {
if head >= self.num_heads {
return Err(SinkError::HeadOutOfRange {
head,
num_heads: self.num_heads,
});
}
let cap = self.cache_len() * self.head_dim;
let mut out = Vec::with_capacity(cap);
for slot in &self.sinks[head] {
out.extend_from_slice(&slot.value);
}
for slot in &self.recent[head] {
out.extend_from_slice(&slot.value);
}
Ok(out)
}
#[inline]
pub fn evicted_count(&self) -> usize {
self.evicted
}
pub fn memory_bytes(&self) -> usize {
let bytes_per_slot = self.head_dim * std::mem::size_of::<f32>() * 2; let sink_slots: usize = self.sinks.iter().map(|s| s.len()).sum();
let recent_slots: usize = self.recent.iter().map(|r| r.len()).sum();
(sink_slots + recent_slots) * bytes_per_slot
}
}
pub struct AttentionSinkCache {
layers: Vec<AttentionSinkLayer>,
config: AttentionSinkConfig,
pub num_layers: usize,
}
impl AttentionSinkCache {
pub fn new(
num_layers: usize,
num_heads: usize,
head_dim: usize,
config: AttentionSinkConfig,
) -> Self {
let layers = (0..num_layers)
.map(|_| AttentionSinkLayer::new(config.clone(), num_heads, head_dim))
.collect();
Self {
layers,
config,
num_layers,
}
}
pub fn push_step(
&mut self,
all_keys: &[Vec<f32>],
all_values: &[Vec<f32>],
) -> Result<(), SinkError> {
if all_keys.len() != self.num_layers {
return Err(SinkError::ShapeMismatch {
expected: self.num_layers,
actual: all_keys.len(),
});
}
if all_values.len() != self.num_layers {
return Err(SinkError::ShapeMismatch {
expected: self.num_layers,
actual: all_values.len(),
});
}
for (layer_idx, layer) in self.layers.iter_mut().enumerate() {
layer.push(&all_keys[layer_idx], &all_values[layer_idx])?;
}
Ok(())
}
pub fn get_keys_for_head(&self, layer: usize, head: usize) -> Result<Vec<f32>, SinkError> {
self.layer(layer)?.get_keys_for_head(head)
}
pub fn get_values_for_head(&self, layer: usize, head: usize) -> Result<Vec<f32>, SinkError> {
self.layer(layer)?.get_values_for_head(head)
}
pub fn get_remapped_positions(&self, layer: usize) -> Result<Vec<usize>, SinkError> {
Ok(self.layer(layer)?.get_remapped_positions())
}
pub fn cache_len(&self) -> usize {
self.layers.first().map(|l| l.cache_len()).unwrap_or(0)
}
pub fn is_streaming(&self) -> bool {
self.layers
.first()
.map(|l| l.is_streaming())
.unwrap_or(false)
}
pub fn total_evicted(&self) -> usize {
self.layers.iter().map(|l| l.evicted_count()).sum()
}
pub fn config(&self) -> &AttentionSinkConfig {
&self.config
}
#[inline]
fn layer(&self, layer: usize) -> Result<&AttentionSinkLayer, SinkError> {
self.layers.get(layer).ok_or(SinkError::LayerOutOfRange {
layer,
num_layers: self.num_layers,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_kv(num_heads: usize, head_dim: usize, val: f32) -> Vec<f32> {
vec![val; num_heads * head_dim]
}
#[test]
fn config_default_values() {
let cfg = AttentionSinkConfig::default();
assert_eq!(cfg.num_sink_tokens, 4);
assert_eq!(cfg.window_size, 512);
assert_eq!(cfg.capacity(), 516);
assert_eq!(cfg.max_seq_len(), 516);
}
#[test]
fn push_sink_and_recent() {
let cfg = AttentionSinkConfig::new(2, 3);
let mut layer = AttentionSinkLayer::new(cfg, 1, 4);
layer
.push(&make_kv(1, 4, 1.0), &make_kv(1, 4, 1.0))
.expect("push sink 0");
layer
.push(&make_kv(1, 4, 2.0), &make_kv(1, 4, 2.0))
.expect("push sink 1");
assert_eq!(layer.cache_len(), 2);
assert_eq!(layer.recent_len(), 0);
assert!(!layer.is_streaming());
layer
.push(&make_kv(1, 4, 3.0), &make_kv(1, 4, 3.0))
.expect("push recent 0");
assert_eq!(layer.cache_len(), 3);
assert_eq!(layer.recent_len(), 1);
}
#[test]
fn eviction_and_streaming_flag() {
let cfg = AttentionSinkConfig::new(1, 2);
let mut layer = AttentionSinkLayer::new(cfg, 1, 2);
for i in 0..3u32 {
layer
.push(&[i as f32, i as f32], &[i as f32, i as f32])
.expect("push");
}
assert!(!layer.is_streaming());
assert_eq!(layer.cache_len(), 3);
layer.push(&[9.0, 9.0], &[9.0, 9.0]).expect("evicting push");
assert!(layer.is_streaming());
assert_eq!(layer.evicted_count(), 1);
assert_eq!(layer.cache_len(), 3);
}
#[test]
fn remapped_positions_contiguous() {
let cfg = AttentionSinkConfig::new(2, 3);
let mut layer = AttentionSinkLayer::new(cfg, 1, 2);
for i in 0..4u32 {
layer
.push(&[i as f32, i as f32], &[i as f32, i as f32])
.expect("push");
}
let positions = layer.get_remapped_positions();
assert_eq!(positions, vec![0, 1, 2, 3]);
}
#[test]
fn multi_layer_cache_push_step() {
let cfg = AttentionSinkConfig::new(2, 4);
let mut cache = AttentionSinkCache::new(3, 2, 8, cfg);
let keys: Vec<Vec<f32>> = (0..3).map(|_| vec![1.0f32; 16]).collect();
let values: Vec<Vec<f32>> = (0..3).map(|_| vec![2.0f32; 16]).collect();
cache.push_step(&keys, &values).expect("push step");
assert_eq!(cache.cache_len(), 1);
assert!(!cache.is_streaming());
}
}