use scirs2_core::ndarray::{s, Array2, Array3, ArrayView2, Axis};
use scirs2_core::numeric::Float;
use super::{
kv_page::{KvPagePool, PageId},
InferenceError, InferenceResult,
};
#[derive(Debug, Clone)]
pub struct PagedAttentionConfig {
pub num_heads: usize,
pub head_dim: usize,
pub scale: Option<f64>,
}
impl Default for PagedAttentionConfig {
fn default() -> Self {
Self {
num_heads: 8,
head_dim: 64,
scale: None,
}
}
}
pub struct PagedAttentionForward {
config: PagedAttentionConfig,
}
impl PagedAttentionForward {
pub fn new(config: PagedAttentionConfig) -> Self {
Self { config }
}
pub fn forward<F: Float + Default + Clone>(
&self,
query: &Array2<F>,
page_chain: &[PageId],
pool: &KvPagePool<F>,
) -> InferenceResult<Array2<F>> {
let num_heads = self.config.num_heads;
let head_dim = self.config.head_dim;
if query.shape() != [num_heads, head_dim] {
return Err(InferenceError::KvShapeMismatch {
expected_heads: num_heads,
expected_dim: head_dim,
got_heads: query.shape()[0],
got_dim: query.shape()[1],
});
}
if page_chain.is_empty() {
return Ok(Array2::default((num_heads, head_dim)));
}
let block_size = pool.config().block_size;
let total_slots = self.count_live_slots(page_chain, pool, block_size)?;
if total_slots == 0 {
return Ok(Array2::default((num_heads, head_dim)));
}
let mut keys_buf: Array3<F> = Array3::default((total_slots, num_heads, head_dim));
let mut vals_buf: Array3<F> = Array3::default((total_slots, num_heads, head_dim));
self.gather_kv(page_chain, pool, block_size, &mut keys_buf, &mut vals_buf)?;
let scale = self.effective_scale::<F>();
let output = self.sdp_attention(query, &keys_buf, &vals_buf, scale)?;
Ok(output)
}
fn count_live_slots<F: Float + Default + Clone>(
&self,
page_chain: &[PageId],
pool: &KvPagePool<F>,
block_size: usize,
) -> InferenceResult<usize> {
let mut total = 0usize;
let chain_len = page_chain.len();
for (idx, &pid) in page_chain.iter().enumerate() {
let page = pool.get_page(pid)?;
let live = if idx < chain_len.saturating_sub(1) {
block_size.min(page.len())
} else {
page.len()
};
total += live;
}
Ok(total)
}
fn gather_kv<F: Float + Default + Clone>(
&self,
page_chain: &[PageId],
pool: &KvPagePool<F>,
block_size: usize,
keys_buf: &mut Array3<F>,
vals_buf: &mut Array3<F>,
) -> InferenceResult<()> {
let chain_len = page_chain.len();
let mut dst_slot = 0usize;
for (idx, &pid) in page_chain.iter().enumerate() {
let page = pool.get_page(pid)?;
let live = if idx < chain_len.saturating_sub(1) {
block_size.min(page.len())
} else {
page.len()
};
for slot in 0..live {
let (k_view, v_view) = page.read_kv(slot)?;
keys_buf.slice_mut(s![dst_slot, .., ..]).assign(&k_view);
vals_buf.slice_mut(s![dst_slot, .., ..]).assign(&v_view);
dst_slot += 1;
}
}
Ok(())
}
fn sdp_attention<F: Float + Default + Clone>(
&self,
query: &Array2<F>, keys: &Array3<F>, values: &Array3<F>, scale: F,
) -> InferenceResult<Array2<F>> {
let num_heads = self.config.num_heads;
let head_dim = self.config.head_dim;
let seq_len = keys.shape()[0];
let mut output = Array2::<F>::default((num_heads, head_dim));
for h in 0..num_heads {
let q_head = query.slice(s![h, ..]); let k_heads = keys.slice(s![.., h, ..]); let v_heads = values.slice(s![.., h, ..]);
let mut scores = Vec::with_capacity(seq_len);
for n in 0..seq_len {
let k_tok = k_heads.slice(s![n, ..]); let dot: F = q_head
.iter()
.zip(k_tok.iter())
.map(|(&qi, &ki)| qi * ki)
.fold(F::zero(), |acc, x| acc + x);
scores.push(dot * scale);
}
let max_score = scores.iter().copied().fold(F::neg_infinity(), F::max);
let exp_scores: Vec<F> = scores.iter().map(|&s| (s - max_score).exp()).collect();
let sum_exp: F = exp_scores.iter().copied().fold(F::zero(), |a, b| a + b);
let safe_sum = if sum_exp == F::zero() {
F::one()
} else {
sum_exp
};
let weights: Vec<F> = exp_scores.iter().map(|&e| e / safe_sum).collect();
let mut out_head = output.slice_mut(s![h, ..]);
for (n, &w) in weights.iter().enumerate().take(seq_len) {
let v_tok = v_heads.slice(s![n, ..]); for (out_el, &v_el) in out_head.iter_mut().zip(v_tok.iter()) {
*out_el = *out_el + w * v_el;
}
}
}
Ok(output)
}
fn effective_scale<F: Float>(&self) -> F {
let s = self
.config
.scale
.unwrap_or_else(|| 1.0 / (self.config.head_dim as f64).sqrt());
F::from(s).unwrap_or_else(F::one)
}
pub fn config(&self) -> &PagedAttentionConfig {
&self.config
}
}
const _: () = {
let _ = Axis(0);
};
#[cfg(test)]
mod tests {
use super::*;
use crate::inference::kv_page::KvPageConfig;
fn make_pool(
num_pages: usize,
block_size: usize,
num_heads: usize,
head_dim: usize,
) -> KvPagePool<f32> {
let cfg = KvPageConfig {
block_size,
num_heads,
head_dim,
dtype_bytes: 4,
};
KvPagePool::<f32>::new(num_pages, cfg)
}
fn write_constant_kv(
pool: &mut KvPagePool<f32>,
pid: PageId,
num_slots: usize,
k_val: f32,
v_val: f32,
) {
let (num_heads, head_dim) = {
let cfg = pool.config();
(cfg.num_heads, cfg.head_dim)
};
let k = Array2::<f32>::from_elem((num_heads, head_dim), k_val);
let v = Array2::<f32>::from_elem((num_heads, head_dim), v_val);
for slot in 0..num_slots {
pool.get_page_mut(pid)
.expect("page")
.write_kv(slot, k.view(), v.view())
.expect("write");
}
}
#[test]
fn test_forward_output_shape_correct() {
let num_heads = 4;
let head_dim = 8;
let mut pool = make_pool(8, 4, num_heads, head_dim);
let pid = pool.alloc_page().expect("alloc");
write_constant_kv(&mut pool, pid, 2, 1.0, 1.0);
let cfg = PagedAttentionConfig {
num_heads,
head_dim,
scale: None,
};
let attn = PagedAttentionForward::new(cfg);
let query = Array2::<f32>::from_elem((num_heads, head_dim), 1.0);
let output = attn.forward(&query, &[pid], &pool).expect("forward");
assert_eq!(output.shape(), &[num_heads, head_dim]);
}
#[test]
fn test_forward_over_two_pages() {
let num_heads = 2;
let head_dim = 4;
let block_size = 3;
let mut pool = make_pool(8, block_size, num_heads, head_dim);
let pid0 = pool.alloc_page().expect("alloc p0");
let pid1 = pool.alloc_page().expect("alloc p1");
write_constant_kv(&mut pool, pid0, block_size, 0.1, 0.2);
write_constant_kv(&mut pool, pid1, 2, 0.5, 0.6);
let cfg = PagedAttentionConfig {
num_heads,
head_dim,
scale: Some(1.0),
};
let attn = PagedAttentionForward::new(cfg);
let query = Array2::<f32>::from_elem((num_heads, head_dim), 1.0);
let output = attn.forward(&query, &[pid0, pid1], &pool).expect("forward");
assert_eq!(output.shape(), &[num_heads, head_dim]);
assert!(output.iter().all(|x| x.is_finite()));
}
#[test]
fn test_scale_factor_affects_output() {
let num_heads = 1;
let head_dim = 4;
let mut pool1 = make_pool(4, 4, num_heads, head_dim);
let mut pool2 = make_pool(4, 4, num_heads, head_dim);
let pid1 = pool1.alloc_page().expect("alloc");
write_constant_kv(&mut pool1, pid1, 2, 1.0, 1.0);
let pid2 = pool2.alloc_page().expect("alloc");
write_constant_kv(&mut pool2, pid2, 2, 1.0, 1.0);
let query = Array2::<f32>::from_elem((num_heads, head_dim), 1.0);
let attn_small = PagedAttentionForward::new(PagedAttentionConfig {
num_heads,
head_dim,
scale: Some(0.01),
});
let attn_large = PagedAttentionForward::new(PagedAttentionConfig {
num_heads,
head_dim,
scale: Some(100.0),
});
let out_small = attn_small.forward(&query, &[pid1], &pool1).expect("fwd");
let out_large = attn_large.forward(&query, &[pid2], &pool2).expect("fwd");
assert!(out_small.iter().all(|x| x.is_finite()));
assert!(out_large.iter().all(|x| x.is_finite()));
}
#[test]
fn test_single_token_query_and_single_kv() {
let num_heads = 2;
let head_dim = 3;
let mut pool = make_pool(4, 4, num_heads, head_dim);
let pid = pool.alloc_page().expect("alloc");
{
let k = Array2::<f32>::from_elem((num_heads, head_dim), 0.5);
let v = Array2::<f32>::from_elem((num_heads, head_dim), 0.5);
pool.get_page_mut(pid)
.expect("page")
.write_kv(0, k.view(), v.view())
.expect("write");
}
let cfg = PagedAttentionConfig {
num_heads,
head_dim,
scale: Some(1.0),
};
let attn = PagedAttentionForward::new(cfg);
let query = Array2::<f32>::from_elem((num_heads, head_dim), 1.0);
let output = attn.forward(&query, &[pid], &pool).expect("forward");
for &x in output.iter() {
assert!((x - 0.5_f32).abs() < 1e-5, "expected 0.5, got {x}");
}
}
#[test]
fn test_empty_page_chain_returns_zeros() {
let pool = make_pool(4, 4, 2, 4);
let cfg = PagedAttentionConfig {
num_heads: 2,
head_dim: 4,
scale: None,
};
let attn = PagedAttentionForward::new(cfg);
let query = Array2::<f32>::zeros((2, 4));
let output = attn
.forward(&query, &[], &pool)
.expect("forward empty chain");
assert_eq!(output.shape(), &[2, 4]);
assert!(output.iter().all(|&x| x == 0.0));
}
#[test]
fn test_query_shape_mismatch_errors() {
let pool = make_pool(4, 4, 2, 4);
let cfg = PagedAttentionConfig {
num_heads: 2,
head_dim: 4,
scale: None,
};
let attn = PagedAttentionForward::new(cfg);
let query = Array2::<f32>::zeros((3, 4));
let err = attn.forward(&query, &[], &pool).expect_err("should error");
assert!(matches!(err, InferenceError::KvShapeMismatch { .. }));
}
#[test]
fn test_uniform_kv_produces_uniform_output() {
let num_heads = 3;
let head_dim = 5;
let block_size = 4;
let kv_val = 0.7_f32;
let mut pool = make_pool(4, block_size, num_heads, head_dim);
let pid = pool.alloc_page().expect("alloc");
write_constant_kv(&mut pool, pid, block_size, kv_val, kv_val);
let cfg = PagedAttentionConfig {
num_heads,
head_dim,
scale: Some(1.0),
};
let attn = PagedAttentionForward::new(cfg);
let query = Array2::<f32>::from_elem((num_heads, head_dim), 1.0);
let output = attn.forward(&query, &[pid], &pool).expect("forward");
for &x in output.iter() {
assert!((x - kv_val).abs() < 1e-5, "expected {kv_val}, got {x}");
}
}
}