#[derive(Debug, Clone)]
pub struct SlidingWindowAttention {
head_dim: usize,
scale: f32,
window_size: usize,
}
impl SlidingWindowAttention {
pub fn new(head_dim: usize, window_size: usize) -> Result<Self> {
if head_dim == 0 {
return Err(RealizarError::InvalidShape {
reason: "head_dim must be > 0".to_string(),
});
}
if window_size == 0 {
return Err(RealizarError::InvalidShape {
reason: "window_size must be > 0".to_string(),
});
}
#[allow(clippy::cast_precision_loss)]
let scale = 1.0 / (head_dim as f32).sqrt();
Ok(Self {
head_dim,
scale,
window_size,
})
}
pub fn forward(
&self,
query: &Tensor<f32>,
key: &Tensor<f32>,
value: &Tensor<f32>,
) -> Result<Tensor<f32>> {
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.is_empty() || k_shape.is_empty() || v_shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Query, key, value tensors must have at least 1 dimension".to_string(),
});
}
let q_last = q_shape[q_shape.len() - 1];
let k_last = k_shape[k_shape.len() - 1];
let v_last = v_shape[v_shape.len() - 1];
if q_last != self.head_dim || k_last != self.head_dim || v_last != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Expected head_dim={}, got Q={}, K={}, V={}",
self.head_dim, q_last, k_last, v_last
),
});
}
let q_seq_len = if q_shape.len() > 1 { q_shape[0] } else { 1 };
let k_seq_len = if k_shape.len() > 1 { k_shape[0] } else { 1 };
let v_seq_len = if v_shape.len() > 1 { v_shape[0] } else { 1 };
if k_seq_len != v_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!("Key seq_len {k_seq_len} != Value seq_len {v_seq_len}"),
});
}
let q_data = query.data();
let k_data = key.data();
let v_data = value.data();
let mut output = Vec::with_capacity(q_seq_len * self.head_dim);
for i in 0..q_seq_len {
let window_end = (i + 1).min(k_seq_len);
let window_start = window_end.saturating_sub(self.window_size);
let window_len = window_end - window_start;
if window_len == 0 {
output.extend(std::iter::repeat_n(0.0, self.head_dim));
continue;
}
let mut scores = Vec::with_capacity(window_len);
for j in window_start..window_end {
let mut dot = 0.0;
for k in 0..self.head_dim {
dot += q_data[i * self.head_dim + k] * k_data[j * self.head_dim + k];
}
scores.push(dot * self.scale);
}
let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut exp_sum = 0.0;
for score in &mut scores {
let exp_val = (*score - max_score).exp();
*score = exp_val;
exp_sum += exp_val;
}
let inv_sum = 1.0 / exp_sum;
for score in &mut scores {
*score *= inv_sum;
}
for k in 0..self.head_dim {
let mut sum = 0.0;
for (idx, j) in (window_start..window_end).enumerate() {
sum += scores[idx] * v_data[j * self.head_dim + k];
}
output.push(sum);
}
}
Tensor::from_vec(vec![q_seq_len, self.head_dim], output)
}
pub fn forward_with_mask(
&self,
query: &Tensor<f32>,
key: &Tensor<f32>,
value: &Tensor<f32>,
causal: bool,
) -> Result<Tensor<f32>> {
if causal {
return self.forward(query, key, value);
}
let q_shape = query.shape();
let k_shape = key.shape();
let v_shape = value.shape();
if q_shape.is_empty() || k_shape.is_empty() || v_shape.is_empty() {
return Err(RealizarError::InvalidShape {
reason: "Query, key, value tensors must have at least 1 dimension".to_string(),
});
}
let q_last = q_shape[q_shape.len() - 1];
let k_last = k_shape[k_shape.len() - 1];
let v_last = v_shape[v_shape.len() - 1];
if q_last != self.head_dim || k_last != self.head_dim || v_last != self.head_dim {
return Err(RealizarError::InvalidShape {
reason: format!(
"Expected head_dim={}, got Q={}, K={}, V={}",
self.head_dim, q_last, k_last, v_last
),
});
}
let q_seq_len = if q_shape.len() > 1 { q_shape[0] } else { 1 };
let k_seq_len = if k_shape.len() > 1 { k_shape[0] } else { 1 };
let v_seq_len = if v_shape.len() > 1 { v_shape[0] } else { 1 };
if k_seq_len != v_seq_len {
return Err(RealizarError::InvalidShape {
reason: format!("Key seq_len {k_seq_len} != Value seq_len {v_seq_len}"),
});
}
let q_data = query.data();
let k_data = key.data();
let v_data = value.data();
let mut output = Vec::with_capacity(q_seq_len * self.head_dim);
let half_window = self.window_size / 2;
for i in 0..q_seq_len {
let window_start = i.saturating_sub(half_window);
let window_end = (i + half_window + 1).min(k_seq_len);
let window_len = window_end - window_start;
if window_len == 0 {
output.extend(std::iter::repeat_n(0.0, self.head_dim));
continue;
}
let mut scores = Vec::with_capacity(window_len);
for j in window_start..window_end {
let mut dot = 0.0;
for k in 0..self.head_dim {
dot += q_data[i * self.head_dim + k] * k_data[j * self.head_dim + k];
}
scores.push(dot * self.scale);
}
let max_score = scores.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mut exp_sum = 0.0;
for score in &mut scores {
let exp_val = (*score - max_score).exp();
*score = exp_val;
exp_sum += exp_val;
}
let inv_sum = 1.0 / exp_sum;
for score in &mut scores {
*score *= inv_sum;
}
for k in 0..self.head_dim {
let mut sum = 0.0;
for (idx, j) in (window_start..window_end).enumerate() {
sum += scores[idx] * v_data[j * self.head_dim + k];
}
output.push(sum);
}
}
Tensor::from_vec(vec![q_seq_len, self.head_dim], output)
}
#[must_use]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[must_use]
pub fn scale(&self) -> f32 {
self.scale
}
#[must_use]
pub fn window_size(&self) -> usize {
self.window_size
}
#[must_use]
pub fn effective_context(&self, position: usize, seq_len: usize) -> usize {
let window_end = (position + 1).min(seq_len);
let window_start = window_end.saturating_sub(self.window_size);
window_end - window_start
}
#[must_use]
pub fn memory_ratio(&self, seq_len: usize) -> f32 {
if seq_len == 0 {
return 1.0;
}
#[allow(clippy::cast_precision_loss)]
{
(self.window_size.min(seq_len) as f32) / (seq_len as f32)
}
}
}
#[derive(Debug, Clone)]
pub struct FusedQKVAttention {
head_dim: usize,
hidden_dim: usize,
num_heads: usize,
scale: f32,
w_q: Vec<f32>,
w_k: Vec<f32>,
w_v: Vec<f32>,
w_o: Vec<f32>,
}