use crate::managers::paged::{PagedKvCacheHandle, PagedKvCacheManager};
use ferrum_types::{FerrumError, Result};
pub fn paged_attention(
query: &[f32],
q_tokens: usize,
num_heads: usize,
num_kv_heads: usize,
head_dim: usize,
manager: &PagedKvCacheManager,
handle: &PagedKvCacheHandle,
layer: usize,
kv_len: usize,
) -> Result<Vec<f32>> {
if query.len() != q_tokens * num_heads * head_dim {
return Err(FerrumError::invalid_parameter(format!(
"Query length mismatch: expected {}, got {}",
q_tokens * num_heads * head_dim,
query.len()
)));
}
if kv_len == 0 {
return Err(FerrumError::invalid_parameter("kv_len must be positive"));
}
let heads_per_kv = num_heads / num_kv_heads;
let (all_keys, all_values) = manager.read_kv(handle, layer, 0, kv_len)?;
let kv_head_stride = head_dim;
let scale = 1.0 / (head_dim as f32).sqrt();
let mut output = vec![0.0f32; q_tokens * num_heads * head_dim];
for qt in 0..q_tokens {
for h in 0..num_heads {
let kv_h = h / heads_per_kv;
let q_offset = (qt * num_heads + h) * head_dim;
let q = &query[q_offset..q_offset + head_dim];
let mut scores = Vec::with_capacity(kv_len);
for kv_pos in 0..kv_len {
let k_offset = (kv_pos * num_kv_heads + kv_h) * kv_head_stride;
let k = &all_keys[k_offset..k_offset + head_dim];
let dot: f32 = q.iter().zip(k.iter()).map(|(a, b)| a * b).sum();
scores.push(dot * scale);
}
let max_visible = kv_len - q_tokens + qt;
for kv_pos in (max_visible + 1)..kv_len {
scores[kv_pos] = f32::NEG_INFINITY;
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
for s in &mut scores {
*s = (*s - max_score).exp();
sum += *s;
}
if sum > 0.0 {
for s in &mut scores {
*s /= sum;
}
}
let out_offset = (qt * num_heads + h) * head_dim;
for kv_pos in 0..kv_len {
let v_offset = (kv_pos * num_kv_heads + kv_h) * kv_head_stride;
let v = &all_values[v_offset..v_offset + head_dim];
let w = scores[kv_pos];
for d in 0..head_dim {
output[out_offset + d] += w * v[d];
}
}
}
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::managers::paged::PagedKvCacheConfig;
use ferrum_interfaces::kv_cache::AllocationRequest;
use ferrum_interfaces::KvCacheManager;
use ferrum_types::{DataType, Device, RequestId};
async fn setup(
num_layers: usize,
num_heads: usize,
head_dim: usize,
block_size: usize,
initial_tokens: usize,
) -> (PagedKvCacheManager, RequestId) {
let config = PagedKvCacheConfig {
block_size,
max_gpu_blocks: 64,
max_cpu_blocks: 0,
enable_cow: false,
enable_swapping: false,
num_layers,
num_heads,
head_dim,
enable_prefix_cache: false,
..Default::default()
};
let manager = PagedKvCacheManager::new(Device::CPU, config).unwrap();
let request = AllocationRequest {
request_id: RequestId::new(),
initial_tokens,
max_sequence_length: 1024,
num_layers,
num_heads,
head_dim,
device: Device::CPU,
dtype: DataType::FP16,
priority: ferrum_types::Priority::Normal,
};
let rid = request.request_id.clone();
let _ = manager.allocate(&request).await.unwrap();
(manager, rid)
}
#[tokio::test]
async fn single_token_decode_attention() {
let num_heads = 2;
let head_dim = 4;
let (manager, rid) = setup(1, num_heads, head_dim, 16, 3).await;
let handle_dyn = manager.get_handle(rid.clone()).unwrap();
let handle = handle_dyn
.as_any()
.downcast_ref::<PagedKvCacheHandle>()
.unwrap();
let kv_size = num_heads * head_dim; for pos in 0..3 {
let key = vec![1.0f32; kv_size];
let val = vec![(pos + 1) as f32; kv_size];
manager.write_kv(handle, 0, pos, &key, &val).unwrap();
}
let query = vec![1.0f32; num_heads * head_dim];
let output = paged_attention(
&query, 1, num_heads, num_heads, head_dim, &manager, handle, 0, 3,
)
.unwrap();
assert_eq!(output.len(), num_heads * head_dim);
for &v in &output {
assert!((v - 2.0).abs() < 1e-5, "Expected ~2.0, got {}", v);
}
}
#[tokio::test]
async fn prefill_causal_masking() {
let num_heads = 1;
let head_dim = 2;
let (manager, rid) = setup(1, num_heads, head_dim, 16, 3).await;
let handle_dyn = manager.get_handle(rid.clone()).unwrap();
let handle = handle_dyn
.as_any()
.downcast_ref::<PagedKvCacheHandle>()
.unwrap();
let kv_size = num_heads * head_dim;
let keys_data = [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
let vals_data = [[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]];
for pos in 0..3 {
manager
.write_kv(handle, 0, pos, &keys_data[pos], &vals_data[pos])
.unwrap();
}
let mut query = Vec::with_capacity(3 * kv_size);
for pos in 0..3 {
query.extend_from_slice(&keys_data[pos]);
}
let output = paged_attention(
&query, 3, num_heads, num_heads, head_dim, &manager, handle, 0, 3,
)
.unwrap();
assert_eq!(output.len(), 3 * kv_size);
assert!((output[0] - 1.0).abs() < 1e-5);
assert!((output[1] - 0.0).abs() < 1e-5);
assert!(
output[2] < 0.5,
"Expected output[2] < 0.5, got {}",
output[2]
);
assert!(
output[3] > 0.5,
"Expected output[3] > 0.5, got {}",
output[3]
);
}
#[tokio::test]
async fn attention_across_blocks() {
let num_heads = 1;
let head_dim = 2;
let (manager, rid) = setup(1, num_heads, head_dim, 2, 4).await;
let handle_dyn = manager.get_handle(rid.clone()).unwrap();
let handle = handle_dyn
.as_any()
.downcast_ref::<PagedKvCacheHandle>()
.unwrap();
for pos in 0..4 {
let key = vec![(pos + 1) as f32; head_dim];
let val = vec![(pos + 1) as f32 * 10.0; head_dim];
manager.write_kv(handle, 0, pos, &key, &val).unwrap();
}
let query = vec![1.0f32; head_dim];
let output = paged_attention(
&query, 1, num_heads, num_heads, head_dim, &manager, handle, 0, 4,
)
.unwrap();
assert_eq!(output.len(), head_dim);
assert!(output[0] > 10.0 && output[0] < 40.0);
}
}