#[cfg(target_arch = "aarch64")]
use std::arch::aarch64::*;
use smallvec::SmallVec;
use super::{AttentionConfig, NEON_LANE_WIDTH, UNROLL_FACTOR};
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
use rayon::prelude::*;
pub const BLOCK_SIZE_SMALL: usize = 32;
pub const BLOCK_SIZE_MEDIUM: usize = 64;
pub const BLOCK_SIZE_LARGE: usize = 128;
const ATTENTION_BLOCK_SIZE: usize = BLOCK_SIZE_MEDIUM;
const UNROLL_8X: usize = 8;
const PARALLEL_THRESHOLD: usize = 256;
const SMALLVEC_BLOCK_SIZE: usize = 128;
#[derive(Debug)]
pub struct AttentionScratch {
output: Vec<f32>,
block_scores: Vec<f32>,
kv_buffer: Vec<f32>,
head_outputs: Vec<f32>,
head_dim: usize,
max_block_size: usize,
max_num_heads: usize,
max_kv_len: usize,
}
impl AttentionScratch {
pub fn new(head_dim: usize, max_block_size: usize, max_num_heads: usize) -> Self {
Self::with_kv_capacity(head_dim, max_block_size, max_num_heads, 4096)
}
pub fn with_kv_capacity(
head_dim: usize,
max_block_size: usize,
max_num_heads: usize,
max_kv_len: usize,
) -> Self {
Self {
output: vec![0.0; head_dim],
block_scores: vec![0.0; max_block_size],
kv_buffer: vec![0.0; max_kv_len * head_dim * 2], head_outputs: vec![0.0; max_num_heads * head_dim],
head_dim,
max_block_size,
max_num_heads,
max_kv_len,
}
}
#[inline]
pub fn reset(&mut self) {
self.output.fill(0.0);
self.block_scores.fill(0.0);
}
#[inline]
pub fn output_buffer(&mut self) -> &mut [f32] {
&mut self.output
}
#[inline]
pub fn block_scores_buffer(&mut self, block_size: usize) -> &mut [f32] {
debug_assert!(
block_size <= self.max_block_size,
"block_size {} exceeds max_block_size {}",
block_size,
self.max_block_size
);
&mut self.block_scores[..block_size]
}
#[inline]
pub fn kv_buffer(&mut self, kv_len: usize) -> (&mut [f32], &mut [f32]) {
let size = kv_len * self.head_dim;
debug_assert!(
kv_len <= self.max_kv_len,
"kv_len {} exceeds max_kv_len {}",
kv_len,
self.max_kv_len
);
let (keys, values) = self.kv_buffer.split_at_mut(size);
(&mut keys[..size], &mut values[..size])
}
#[inline]
pub fn head_outputs_buffer(&mut self, num_heads: usize) -> &mut [f32] {
let size = num_heads * self.head_dim;
debug_assert!(
num_heads <= self.max_num_heads,
"num_heads {} exceeds max_num_heads {}",
num_heads,
self.max_num_heads
);
&mut self.head_outputs[..size]
}
#[inline]
pub fn head_dim(&self) -> usize {
self.head_dim
}
#[inline]
pub fn max_block_size(&self) -> usize {
self.max_block_size
}
#[inline]
pub fn is_compatible(&self, head_dim: usize, block_size: usize, num_heads: usize) -> bool {
self.head_dim >= head_dim
&& self.max_block_size >= block_size
&& self.max_num_heads >= num_heads
}
}
impl Clone for AttentionScratch {
fn clone(&self) -> Self {
Self {
output: vec![0.0; self.head_dim],
block_scores: vec![0.0; self.max_block_size],
kv_buffer: vec![0.0; self.max_kv_len * self.head_dim * 2],
head_outputs: vec![0.0; self.max_num_heads * self.head_dim],
head_dim: self.head_dim,
max_block_size: self.max_block_size,
max_num_heads: self.max_num_heads,
max_kv_len: self.max_kv_len,
}
}
}
#[cfg(not(target_arch = "wasm32"))]
thread_local! {
pub static THREAD_LOCAL_SCRATCH: std::cell::RefCell<AttentionScratch> =
std::cell::RefCell::new(AttentionScratch::with_kv_capacity(128, 128, 32, 4096));
}
#[derive(Debug, Clone)]
pub struct PagedKvCache {
pub key_blocks: Vec<Vec<f32>>,
pub value_blocks: Vec<Vec<f32>>,
pub block_size: usize,
pub num_kv_heads: usize,
pub head_dim: usize,
pub num_tokens: usize,
preallocated_blocks: usize,
}
impl PagedKvCache {
pub fn new(block_size: usize, num_kv_heads: usize, head_dim: usize) -> Self {
Self {
key_blocks: Vec::new(),
value_blocks: Vec::new(),
block_size,
num_kv_heads,
head_dim,
num_tokens: 0,
preallocated_blocks: 0,
}
}
pub fn with_capacity(
block_size: usize,
num_kv_heads: usize,
head_dim: usize,
num_blocks: usize,
) -> Self {
let block_capacity = block_size * num_kv_heads * head_dim;
let mut key_blocks = Vec::with_capacity(num_blocks);
let mut value_blocks = Vec::with_capacity(num_blocks);
for _ in 0..num_blocks {
key_blocks.push(vec![0.0; block_capacity]);
value_blocks.push(vec![0.0; block_capacity]);
}
Self {
key_blocks,
value_blocks,
block_size,
num_kv_heads,
head_dim,
num_tokens: 0,
preallocated_blocks: num_blocks,
}
}
pub fn with_max_tokens(
block_size: usize,
num_kv_heads: usize,
head_dim: usize,
max_tokens: usize,
) -> Self {
let num_blocks = (max_tokens + block_size - 1) / block_size;
Self::with_capacity(block_size, num_kv_heads, head_dim, num_blocks)
}
#[inline]
pub fn reset(&mut self) {
self.num_tokens = 0;
}
#[inline]
pub fn capacity(&self) -> usize {
self.key_blocks.len() * self.block_size
}
#[inline]
pub fn has_capacity(&self, additional_tokens: usize) -> bool {
self.num_tokens + additional_tokens <= self.capacity()
}
pub fn append(&mut self, keys: &[f32], values: &[f32]) {
let stride = self.num_kv_heads * self.head_dim;
let num_tokens = keys.len() / stride;
for i in 0..num_tokens {
let offset = i * stride;
if self.num_tokens % self.block_size == 0 {
let block_idx = self.num_tokens / self.block_size;
if block_idx >= self.key_blocks.len() {
let block_capacity = self.block_size * stride;
self.key_blocks.push(vec![0.0; block_capacity]);
self.value_blocks.push(vec![0.0; block_capacity]);
}
}
let block_idx = self.num_tokens / self.block_size;
let pos_in_block = (self.num_tokens % self.block_size) * stride;
unsafe {
let key_block = self.key_blocks.get_unchecked_mut(block_idx);
let value_block = self.value_blocks.get_unchecked_mut(block_idx);
key_block[pos_in_block..pos_in_block + stride]
.copy_from_slice(&keys[offset..offset + stride]);
value_block[pos_in_block..pos_in_block + stride]
.copy_from_slice(&values[offset..offset + stride]);
}
self.num_tokens += 1;
}
}
#[inline]
pub unsafe fn append_unchecked(&mut self, keys: &[f32], values: &[f32]) {
let stride = self.num_kv_heads * self.head_dim;
let num_tokens = keys.len() / stride;
for i in 0..num_tokens {
let offset = i * stride;
let block_idx = self.num_tokens / self.block_size;
let pos_in_block = (self.num_tokens % self.block_size) * stride;
let key_block = self.key_blocks.get_unchecked_mut(block_idx);
let value_block = self.value_blocks.get_unchecked_mut(block_idx);
std::ptr::copy_nonoverlapping(
keys.as_ptr().add(offset),
key_block.as_mut_ptr().add(pos_in_block),
stride,
);
std::ptr::copy_nonoverlapping(
values.as_ptr().add(offset),
value_block.as_mut_ptr().add(pos_in_block),
stride,
);
self.num_tokens += 1;
}
}
pub fn get_keys(&self) -> Vec<f32> {
let stride = self.num_kv_heads * self.head_dim;
let mut result = Vec::with_capacity(self.num_tokens * stride);
for (block_idx, block) in self.key_blocks.iter().enumerate() {
let tokens_in_block = self.tokens_in_block(block_idx);
if tokens_in_block > 0 {
result.extend_from_slice(&block[..tokens_in_block * stride]);
}
}
result
}
#[inline]
pub fn copy_keys_into(&self, output: &mut [f32]) -> usize {
let stride = self.num_kv_heads * self.head_dim;
let total_size = self.num_tokens * stride;
debug_assert!(output.len() >= total_size, "Output buffer too small");
let mut write_pos = 0;
for (block_idx, block) in self.key_blocks.iter().enumerate() {
let tokens_in_block = self.tokens_in_block(block_idx);
if tokens_in_block > 0 {
let slice_len = tokens_in_block * stride;
output[write_pos..write_pos + slice_len].copy_from_slice(&block[..slice_len]);
write_pos += slice_len;
}
}
write_pos
}
pub fn get_values(&self) -> Vec<f32> {
let stride = self.num_kv_heads * self.head_dim;
let mut result = Vec::with_capacity(self.num_tokens * stride);
for (block_idx, block) in self.value_blocks.iter().enumerate() {
let tokens_in_block = self.tokens_in_block(block_idx);
if tokens_in_block > 0 {
result.extend_from_slice(&block[..tokens_in_block * stride]);
}
}
result
}
#[inline]
pub fn copy_values_into(&self, output: &mut [f32]) -> usize {
let stride = self.num_kv_heads * self.head_dim;
let total_size = self.num_tokens * stride;
debug_assert!(output.len() >= total_size, "Output buffer too small");
let mut write_pos = 0;
for (block_idx, block) in self.value_blocks.iter().enumerate() {
let tokens_in_block = self.tokens_in_block(block_idx);
if tokens_in_block > 0 {
let slice_len = tokens_in_block * stride;
output[write_pos..write_pos + slice_len].copy_from_slice(&block[..slice_len]);
write_pos += slice_len;
}
}
write_pos
}
#[inline]
fn tokens_in_block(&self, block_idx: usize) -> usize {
let tokens_before_this_block = block_idx * self.block_size;
if tokens_before_this_block >= self.num_tokens {
return 0;
}
let remaining_tokens = self.num_tokens - tokens_before_this_block;
remaining_tokens.min(self.block_size)
}
}
#[inline(always)]
pub fn select_block_size(kv_len: usize, head_dim: usize) -> usize {
let l1_budget = 128 * 1024; let bytes_per_token = head_dim * 4 * 2;
if kv_len <= 64 {
return BLOCK_SIZE_SMALL;
}
if kv_len <= 512 {
return BLOCK_SIZE_MEDIUM;
}
if bytes_per_token * BLOCK_SIZE_LARGE > l1_budget {
return BLOCK_SIZE_MEDIUM;
}
BLOCK_SIZE_LARGE
}
#[inline(always)]
pub fn flash_attention_neon(
query: &[f32],
key: &[f32],
value: &[f32],
scale: f32,
causal: bool,
) -> Vec<f32> {
flash_attention_v2(query, key, value, scale, causal, ATTENTION_BLOCK_SIZE)
}
#[inline(always)]
pub fn flash_attention_v2(
query: &[f32],
key: &[f32],
value: &[f32],
scale: f32,
causal: bool,
block_size: usize,
) -> Vec<f32> {
let head_dim = if !query.is_empty() && !key.is_empty() {
query.len()
} else {
return vec![];
};
let kv_len = key.len() / head_dim;
if kv_len == 0 {
return vec![0.0; head_dim];
}
#[cfg(target_arch = "aarch64")]
unsafe {
flash_attention_v2_neon_impl(
query, key, value, head_dim, kv_len, scale, causal, block_size,
)
}
#[cfg(not(target_arch = "aarch64"))]
{
flash_attention_scalar(query, key, value, head_dim, kv_len, scale, causal)
}
}
#[inline(always)]
pub fn flash_attention_auto(
query: &[f32],
key: &[f32],
value: &[f32],
scale: f32,
causal: bool,
) -> Vec<f32> {
let head_dim = if !query.is_empty() {
query.len()
} else {
return vec![];
};
let kv_len = key.len() / head_dim;
let block_size = select_block_size(kv_len, head_dim);
flash_attention_v2(query, key, value, scale, causal, block_size)
}
#[inline(always)]
pub fn flash_attention_into(
query: &[f32],
key: &[f32],
value: &[f32],
scale: f32,
causal: bool,
output: &mut [f32],
) {
let head_dim = query.len();
if head_dim == 0 || key.is_empty() {
return;
}
let kv_len = key.len() / head_dim;
if kv_len == 0 {
output[..head_dim].fill(0.0);
return;
}
let block_size = select_block_size(kv_len, head_dim);
#[cfg(target_arch = "aarch64")]
{
unsafe {
flash_attention_v2_neon_into(
query, key, value, head_dim, kv_len, scale, causal, block_size, output,
);
}
return;
}
#[cfg(not(target_arch = "aarch64"))]
{
flash_attention_scalar_into(query, key, value, head_dim, kv_len, scale, causal, output);
}
}
#[inline(always)]
pub fn flash_attention_with_scratch(
query: &[f32],
key: &[f32],
value: &[f32],
scale: f32,
scratch: &mut AttentionScratch,
output: &mut [f32],
) {
let head_dim = query.len();
if head_dim == 0 || key.is_empty() {
return;
}
let kv_len = key.len() / head_dim;
if kv_len == 0 {
output[..head_dim].fill(0.0);
return;
}
let block_size = select_block_size(kv_len, head_dim).min(scratch.max_block_size());
#[cfg(target_arch = "aarch64")]
{
unsafe {
flash_attention_v2_neon_with_scratch(
query, key, value, head_dim, kv_len, scale, block_size, scratch, output,
);
}
return;
}
#[cfg(not(target_arch = "aarch64"))]
{
let _ = scratch; flash_attention_scalar_into(query, key, value, head_dim, kv_len, scale, false, output);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn flash_attention_v2_neon_into(
query: &[f32],
key: &[f32],
value: &[f32],
head_dim: usize,
kv_len: usize,
scale: f32,
_causal: bool,
block_size: usize,
output: &mut [f32],
) {
debug_assert_eq!(query.len(), head_dim);
debug_assert_eq!(key.len(), kv_len * head_dim);
debug_assert_eq!(value.len(), kv_len * head_dim);
debug_assert!(output.len() >= head_dim);
let q_ptr = query.as_ptr();
let k_ptr = key.as_ptr();
let v_ptr = value.as_ptr();
let out_ptr = output.as_mut_ptr();
output[..head_dim].fill(0.0);
let mut m = f32::NEG_INFINITY;
let mut l = 0.0f32;
let num_blocks = (kv_len + block_size - 1) / block_size;
let mut block_scores: SmallVec<[f32; SMALLVEC_BLOCK_SIZE]> = SmallVec::new();
block_scores.resize(block_size, 0.0);
for block_idx in 0..num_blocks {
let block_start = block_idx * block_size;
let block_end = (block_start + block_size).min(kv_len);
let block_len = block_end - block_start;
let mut block_max = f32::NEG_INFINITY;
for t in 0..block_len {
let k_offset = (block_start + t) * head_dim;
let score = compute_dot_product_8x(q_ptr, k_ptr.add(k_offset), head_dim) * scale;
*block_scores.get_unchecked_mut(t) = score;
block_max = block_max.max(score);
}
let m_new = m.max(block_max);
let alpha = (m - m_new).exp();
if l > 0.0 {
rescale_output_8x(out_ptr, head_dim, alpha);
}
let mut l_new = l * alpha;
for t in 0..block_len {
let v_offset = (block_start + t) * head_dim;
let p = (*block_scores.get_unchecked(t) - m_new).exp();
l_new += p;
accumulate_weighted_value_8x(out_ptr, v_ptr.add(v_offset), head_dim, p);
}
m = m_new;
l = l_new;
}
if l > 0.0 {
let inv_l = 1.0 / l;
normalize_output_8x(out_ptr, head_dim, inv_l);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn flash_attention_v2_neon_with_scratch(
query: &[f32],
key: &[f32],
value: &[f32],
head_dim: usize,
kv_len: usize,
scale: f32,
block_size: usize,
scratch: &mut AttentionScratch,
output: &mut [f32],
) {
debug_assert_eq!(query.len(), head_dim);
debug_assert_eq!(key.len(), kv_len * head_dim);
debug_assert_eq!(value.len(), kv_len * head_dim);
debug_assert!(output.len() >= head_dim);
let q_ptr = query.as_ptr();
let k_ptr = key.as_ptr();
let v_ptr = value.as_ptr();
let out_ptr = output.as_mut_ptr();
output[..head_dim].fill(0.0);
let mut m = f32::NEG_INFINITY;
let mut l = 0.0f32;
let num_blocks = (kv_len + block_size - 1) / block_size;
let block_scores = scratch.block_scores_buffer(block_size);
for block_idx in 0..num_blocks {
let block_start = block_idx * block_size;
let block_end = (block_start + block_size).min(kv_len);
let block_len = block_end - block_start;
let mut block_max = f32::NEG_INFINITY;
for t in 0..block_len {
let k_offset = (block_start + t) * head_dim;
let score = compute_dot_product_8x(q_ptr, k_ptr.add(k_offset), head_dim) * scale;
*block_scores.get_unchecked_mut(t) = score;
block_max = block_max.max(score);
}
let m_new = m.max(block_max);
let alpha = (m - m_new).exp();
if l > 0.0 {
rescale_output_8x(out_ptr, head_dim, alpha);
}
let mut l_new = l * alpha;
for t in 0..block_len {
let v_offset = (block_start + t) * head_dim;
let p = (*block_scores.get_unchecked(t) - m_new).exp();
l_new += p;
accumulate_weighted_value_8x(out_ptr, v_ptr.add(v_offset), head_dim, p);
}
m = m_new;
l = l_new;
}
if l > 0.0 {
let inv_l = 1.0 / l;
normalize_output_8x(out_ptr, head_dim, inv_l);
}
}
#[allow(dead_code)]
fn flash_attention_scalar_into(
query: &[f32],
key: &[f32],
value: &[f32],
head_dim: usize,
kv_len: usize,
scale: f32,
_causal: bool,
output: &mut [f32],
) {
let mut scores: SmallVec<[f32; 512]> = SmallVec::with_capacity(kv_len);
for t in 0..kv_len {
let k_offset = t * head_dim;
let score: f32 = query
.iter()
.zip(&key[k_offset..k_offset + head_dim])
.map(|(q, k)| q * k * scale)
.sum();
scores.push(score);
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum_exp = 0.0f32;
for score in scores.iter_mut() {
*score = (*score - max_score).exp();
sum_exp += *score;
}
let inv_sum = 1.0 / sum_exp;
for score in scores.iter_mut() {
*score *= inv_sum;
}
output[..head_dim].fill(0.0);
for (t, &weight) in scores.iter().enumerate() {
let v_offset = t * head_dim;
for (i, v) in value[v_offset..v_offset + head_dim].iter().enumerate() {
unsafe {
*output.get_unchecked_mut(i) += weight * v;
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn flash_attention_v2_neon_impl(
query: &[f32],
key: &[f32],
value: &[f32],
head_dim: usize,
kv_len: usize,
scale: f32,
_causal: bool,
block_size: usize,
) -> Vec<f32> {
debug_assert_eq!(query.len(), head_dim);
debug_assert_eq!(key.len(), kv_len * head_dim);
debug_assert_eq!(value.len(), kv_len * head_dim);
let q_ptr = query.as_ptr();
let k_ptr = key.as_ptr();
let v_ptr = value.as_ptr();
let mut m = f32::NEG_INFINITY; let mut l = 0.0f32; let mut output = vec![0.0f32; head_dim];
let out_ptr = output.as_mut_ptr();
let num_blocks = (kv_len + block_size - 1) / block_size;
let mut block_scores = vec![0.0f32; block_size];
for block_idx in 0..num_blocks {
let block_start = block_idx * block_size;
let block_end = (block_start + block_size).min(kv_len);
let block_len = block_end - block_start;
let mut block_max = f32::NEG_INFINITY;
for t in 0..block_len {
let k_offset = (block_start + t) * head_dim;
let score = compute_dot_product_8x(q_ptr, k_ptr.add(k_offset), head_dim) * scale;
block_scores[t] = score;
block_max = block_max.max(score);
}
let m_new = m.max(block_max);
let alpha = (m - m_new).exp();
if l > 0.0 {
let rescale = alpha;
rescale_output_8x(out_ptr, head_dim, rescale);
}
let mut l_new = l * alpha;
for t in 0..block_len {
let v_offset = (block_start + t) * head_dim;
let p = (block_scores[t] - m_new).exp();
l_new += p;
accumulate_weighted_value_8x(out_ptr, v_ptr.add(v_offset), head_dim, p);
}
m = m_new;
l = l_new;
}
if l > 0.0 {
let inv_l = 1.0 / l;
normalize_output_8x(out_ptr, head_dim, inv_l);
}
output
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn compute_dot_product_8x(a_ptr: *const f32, b_ptr: *const f32, len: usize) -> f32 {
let mut acc0 = vdupq_n_f32(0.0);
let mut acc1 = vdupq_n_f32(0.0);
let chunks_32 = len / 32;
let mut idx = 0usize;
for _ in 0..chunks_32 {
let a0 = vld1q_f32(a_ptr.add(idx));
let a1 = vld1q_f32(a_ptr.add(idx + 4));
let a2 = vld1q_f32(a_ptr.add(idx + 8));
let a3 = vld1q_f32(a_ptr.add(idx + 12));
let a4 = vld1q_f32(a_ptr.add(idx + 16));
let a5 = vld1q_f32(a_ptr.add(idx + 20));
let a6 = vld1q_f32(a_ptr.add(idx + 24));
let a7 = vld1q_f32(a_ptr.add(idx + 28));
let b0 = vld1q_f32(b_ptr.add(idx));
let b1 = vld1q_f32(b_ptr.add(idx + 4));
let b2 = vld1q_f32(b_ptr.add(idx + 8));
let b3 = vld1q_f32(b_ptr.add(idx + 12));
let b4 = vld1q_f32(b_ptr.add(idx + 16));
let b5 = vld1q_f32(b_ptr.add(idx + 20));
let b6 = vld1q_f32(b_ptr.add(idx + 24));
let b7 = vld1q_f32(b_ptr.add(idx + 28));
acc0 = vfmaq_f32(acc0, a0, b0);
acc1 = vfmaq_f32(acc1, a1, b1);
acc0 = vfmaq_f32(acc0, a2, b2);
acc1 = vfmaq_f32(acc1, a3, b3);
acc0 = vfmaq_f32(acc0, a4, b4);
acc1 = vfmaq_f32(acc1, a5, b5);
acc0 = vfmaq_f32(acc0, a6, b6);
acc1 = vfmaq_f32(acc1, a7, b7);
idx += 32;
}
let mut acc = vaddq_f32(acc0, acc1);
let remaining_16 = (len - idx) / 16;
for _ in 0..remaining_16 {
let a0 = vld1q_f32(a_ptr.add(idx));
let a1 = vld1q_f32(a_ptr.add(idx + 4));
let a2 = vld1q_f32(a_ptr.add(idx + 8));
let a3 = vld1q_f32(a_ptr.add(idx + 12));
let b0 = vld1q_f32(b_ptr.add(idx));
let b1 = vld1q_f32(b_ptr.add(idx + 4));
let b2 = vld1q_f32(b_ptr.add(idx + 8));
let b3 = vld1q_f32(b_ptr.add(idx + 12));
acc = vfmaq_f32(acc, a0, b0);
acc = vfmaq_f32(acc, a1, b1);
acc = vfmaq_f32(acc, a2, b2);
acc = vfmaq_f32(acc, a3, b3);
idx += 16;
}
let remaining_4 = (len - idx) / 4;
for _ in 0..remaining_4 {
let a_v = vld1q_f32(a_ptr.add(idx));
let b_v = vld1q_f32(b_ptr.add(idx));
acc = vfmaq_f32(acc, a_v, b_v);
idx += 4;
}
let mut result = vaddvq_f32(acc);
for i in idx..len {
result += *a_ptr.add(i) * *b_ptr.add(i);
}
result
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn rescale_output_8x(out_ptr: *mut f32, len: usize, factor: f32) {
let factor_vec = vdupq_n_f32(factor);
let chunks_32 = len / 32;
let mut idx = 0usize;
for _ in 0..chunks_32 {
let o0 = vmulq_f32(vld1q_f32(out_ptr.add(idx)), factor_vec);
let o1 = vmulq_f32(vld1q_f32(out_ptr.add(idx + 4)), factor_vec);
let o2 = vmulq_f32(vld1q_f32(out_ptr.add(idx + 8)), factor_vec);
let o3 = vmulq_f32(vld1q_f32(out_ptr.add(idx + 12)), factor_vec);
let o4 = vmulq_f32(vld1q_f32(out_ptr.add(idx + 16)), factor_vec);
let o5 = vmulq_f32(vld1q_f32(out_ptr.add(idx + 20)), factor_vec);
let o6 = vmulq_f32(vld1q_f32(out_ptr.add(idx + 24)), factor_vec);
let o7 = vmulq_f32(vld1q_f32(out_ptr.add(idx + 28)), factor_vec);
vst1q_f32(out_ptr.add(idx), o0);
vst1q_f32(out_ptr.add(idx + 4), o1);
vst1q_f32(out_ptr.add(idx + 8), o2);
vst1q_f32(out_ptr.add(idx + 12), o3);
vst1q_f32(out_ptr.add(idx + 16), o4);
vst1q_f32(out_ptr.add(idx + 20), o5);
vst1q_f32(out_ptr.add(idx + 24), o6);
vst1q_f32(out_ptr.add(idx + 28), o7);
idx += 32;
}
let remaining_4 = (len - idx) / 4;
for _ in 0..remaining_4 {
let o = vmulq_f32(vld1q_f32(out_ptr.add(idx)), factor_vec);
vst1q_f32(out_ptr.add(idx), o);
idx += 4;
}
for i in idx..len {
*out_ptr.add(i) *= factor;
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn accumulate_weighted_value_8x(
out_ptr: *mut f32,
v_ptr: *const f32,
len: usize,
weight: f32,
) {
let weight_vec = vdupq_n_f32(weight);
let chunks_32 = len / 32;
let mut idx = 0usize;
for _ in 0..chunks_32 {
let v0 = vld1q_f32(v_ptr.add(idx));
let v1 = vld1q_f32(v_ptr.add(idx + 4));
let v2 = vld1q_f32(v_ptr.add(idx + 8));
let v3 = vld1q_f32(v_ptr.add(idx + 12));
let v4 = vld1q_f32(v_ptr.add(idx + 16));
let v5 = vld1q_f32(v_ptr.add(idx + 20));
let v6 = vld1q_f32(v_ptr.add(idx + 24));
let v7 = vld1q_f32(v_ptr.add(idx + 28));
let o0 = vfmaq_f32(vld1q_f32(out_ptr.add(idx)), v0, weight_vec);
let o1 = vfmaq_f32(vld1q_f32(out_ptr.add(idx + 4)), v1, weight_vec);
let o2 = vfmaq_f32(vld1q_f32(out_ptr.add(idx + 8)), v2, weight_vec);
let o3 = vfmaq_f32(vld1q_f32(out_ptr.add(idx + 12)), v3, weight_vec);
let o4 = vfmaq_f32(vld1q_f32(out_ptr.add(idx + 16)), v4, weight_vec);
let o5 = vfmaq_f32(vld1q_f32(out_ptr.add(idx + 20)), v5, weight_vec);
let o6 = vfmaq_f32(vld1q_f32(out_ptr.add(idx + 24)), v6, weight_vec);
let o7 = vfmaq_f32(vld1q_f32(out_ptr.add(idx + 28)), v7, weight_vec);
vst1q_f32(out_ptr.add(idx), o0);
vst1q_f32(out_ptr.add(idx + 4), o1);
vst1q_f32(out_ptr.add(idx + 8), o2);
vst1q_f32(out_ptr.add(idx + 12), o3);
vst1q_f32(out_ptr.add(idx + 16), o4);
vst1q_f32(out_ptr.add(idx + 20), o5);
vst1q_f32(out_ptr.add(idx + 24), o6);
vst1q_f32(out_ptr.add(idx + 28), o7);
idx += 32;
}
let remaining_4 = (len - idx) / 4;
for _ in 0..remaining_4 {
let v = vld1q_f32(v_ptr.add(idx));
let o = vfmaq_f32(vld1q_f32(out_ptr.add(idx)), v, weight_vec);
vst1q_f32(out_ptr.add(idx), o);
idx += 4;
}
for i in idx..len {
*out_ptr.add(i) += weight * *v_ptr.add(i);
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
unsafe fn normalize_output_8x(out_ptr: *mut f32, len: usize, factor: f32) {
rescale_output_8x(out_ptr, len, factor);
}
#[allow(dead_code)]
fn flash_attention_scalar(
query: &[f32],
key: &[f32],
value: &[f32],
head_dim: usize,
kv_len: usize,
scale: f32,
_causal: bool,
) -> Vec<f32> {
let mut scores = Vec::with_capacity(kv_len);
for t in 0..kv_len {
let k_offset = t * head_dim;
let score: f32 = query
.iter()
.zip(&key[k_offset..k_offset + head_dim])
.map(|(q, k)| q * k * scale)
.sum();
scores.push(score);
}
let max_score = scores.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let exp_scores: Vec<f32> = scores.iter().map(|s| (s - max_score).exp()).collect();
let sum_exp: f32 = exp_scores.iter().sum();
let attn_weights: Vec<f32> = exp_scores.iter().map(|e| e / sum_exp).collect();
let mut output = vec![0.0; head_dim];
for (t, weight) in attn_weights.iter().enumerate() {
let v_offset = t * head_dim;
for (i, v) in value[v_offset..v_offset + head_dim].iter().enumerate() {
output[i] += weight * v;
}
}
output
}
pub fn paged_attention_neon(
query: &[f32],
kv_cache: &PagedKvCache,
block_tables: &[usize],
scale: f32,
) -> Vec<f32> {
if kv_cache.num_tokens == 0 {
return vec![0.0; query.len()];
}
let keys = kv_cache.get_keys();
let values = kv_cache.get_values();
flash_attention_neon(query, &keys, &values, scale, false)
}
pub fn multi_query_attention_neon(
queries: &[f32],
key: &[f32],
value: &[f32],
config: &AttentionConfig,
) -> Vec<f32> {
let head_dim = config.head_dim;
let num_heads = config.num_heads;
let scale = config.effective_scale();
let kv_len = key.len() / head_dim;
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
if num_heads >= 4 && kv_len >= PARALLEL_THRESHOLD {
return multi_query_attention_parallel(queries, key, value, config);
}
let mut output = vec![0.0; num_heads * head_dim];
for h in 0..num_heads {
let q_offset = h * head_dim;
let q_slice = &queries[q_offset..q_offset + head_dim];
let head_output = flash_attention_neon(q_slice, key, value, scale, config.causal);
output[q_offset..q_offset + head_dim].copy_from_slice(&head_output);
}
output
}
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub fn multi_query_attention_parallel(
queries: &[f32],
key: &[f32],
value: &[f32],
config: &AttentionConfig,
) -> Vec<f32> {
let head_dim = config.head_dim;
let num_heads = config.num_heads;
let scale = config.effective_scale();
let causal = config.causal;
let results: Vec<Vec<f32>> = (0..num_heads)
.into_par_iter()
.map(|h| {
let q_offset = h * head_dim;
let q_slice = &queries[q_offset..q_offset + head_dim];
flash_attention_neon(q_slice, key, value, scale, causal)
})
.collect();
let mut output = Vec::with_capacity(num_heads * head_dim);
for head_output in results {
output.extend(head_output);
}
output
}
pub fn grouped_query_attention_neon(
queries: &[f32],
keys: &[f32],
values: &[f32],
config: &AttentionConfig,
) -> Vec<f32> {
let head_dim = config.head_dim;
let num_heads = config.num_heads;
let num_kv_heads = config.num_kv_heads;
let gqa_ratio = config.gqa_ratio();
let scale = config.effective_scale();
let kv_len = keys.len() / (num_kv_heads * head_dim);
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
if num_heads >= 4 && kv_len >= PARALLEL_THRESHOLD {
return grouped_query_attention_parallel(queries, keys, values, config);
}
let mut output = vec![0.0; num_heads * head_dim];
for h in 0..num_heads {
let kv_head = h / gqa_ratio;
let q_offset = h * head_dim;
let q_slice = &queries[q_offset..q_offset + head_dim];
let mut kv_keys = Vec::with_capacity(kv_len * head_dim);
let mut kv_values = Vec::with_capacity(kv_len * head_dim);
for t in 0..kv_len {
let kv_offset = (t * num_kv_heads + kv_head) * head_dim;
kv_keys.extend_from_slice(&keys[kv_offset..kv_offset + head_dim]);
kv_values.extend_from_slice(&values[kv_offset..kv_offset + head_dim]);
}
let head_output = flash_attention_neon(q_slice, &kv_keys, &kv_values, scale, config.causal);
output[q_offset..q_offset + head_dim].copy_from_slice(&head_output);
}
output
}
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub fn grouped_query_attention_parallel(
queries: &[f32],
keys: &[f32],
values: &[f32],
config: &AttentionConfig,
) -> Vec<f32> {
let head_dim = config.head_dim;
let num_heads = config.num_heads;
let num_kv_heads = config.num_kv_heads;
let gqa_ratio = config.gqa_ratio();
let scale = config.effective_scale();
let causal = config.causal;
let kv_len = keys.len() / (num_kv_heads * head_dim);
let kv_slices: Vec<(Vec<f32>, Vec<f32>)> = (0..num_kv_heads)
.map(|kv_head| {
let mut kv_keys = Vec::with_capacity(kv_len * head_dim);
let mut kv_values = Vec::with_capacity(kv_len * head_dim);
for t in 0..kv_len {
let kv_offset = (t * num_kv_heads + kv_head) * head_dim;
kv_keys.extend_from_slice(&keys[kv_offset..kv_offset + head_dim]);
kv_values.extend_from_slice(&values[kv_offset..kv_offset + head_dim]);
}
(kv_keys, kv_values)
})
.collect();
let results: Vec<(usize, Vec<f32>)> = (0..num_heads)
.into_par_iter()
.map(|h| {
let kv_head = h / gqa_ratio;
let q_offset = h * head_dim;
let q_slice = &queries[q_offset..q_offset + head_dim];
let (ref kv_keys, ref kv_values) = kv_slices[kv_head];
let head_output = flash_attention_neon(q_slice, kv_keys, kv_values, scale, causal);
(h, head_output)
})
.collect();
let mut output = vec![0.0; num_heads * head_dim];
for (h, head_output) in results {
let q_offset = h * head_dim;
output[q_offset..q_offset + head_dim].copy_from_slice(&head_output);
}
output
}
#[cfg(all(feature = "parallel", not(target_arch = "wasm32")))]
pub fn multi_head_attention_parallel(
queries: &[f32],
keys: &[f32],
values: &[f32],
config: &AttentionConfig,
) -> Vec<f32> {
let head_dim = config.head_dim;
let num_heads = config.num_heads;
let scale = config.effective_scale();
let causal = config.causal;
let kv_len = keys.len() / (num_heads * head_dim);
let results: Vec<(usize, Vec<f32>)> = (0..num_heads)
.into_par_iter()
.map(|h| {
let q_offset = h * head_dim;
let kv_offset = h * kv_len * head_dim;
let q_slice = &queries[q_offset..q_offset + head_dim];
let k_slice = &keys[kv_offset..kv_offset + kv_len * head_dim];
let v_slice = &values[kv_offset..kv_offset + kv_len * head_dim];
let head_output = flash_attention_neon(q_slice, k_slice, v_slice, scale, causal);
(h, head_output)
})
.collect();
let mut output = vec![0.0; num_heads * head_dim];
for (h, head_output) in results {
let q_offset = h * head_dim;
output[q_offset..q_offset + head_dim].copy_from_slice(&head_output);
}
output
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub unsafe fn batched_attention_scores_neon(
queries: &[f32],
keys: &[f32],
scores: &mut [f32],
batch_size: usize,
seq_len: usize,
kv_len: usize,
head_dim: usize,
scale: f32,
) {
debug_assert_eq!(queries.len(), batch_size * seq_len * head_dim);
debug_assert_eq!(keys.len(), batch_size * kv_len * head_dim);
debug_assert_eq!(scores.len(), batch_size * seq_len * kv_len);
let scale_vec = vdupq_n_f32(scale);
for b in 0..batch_size {
for q_idx in 0..seq_len {
for k_idx in 0..kv_len {
let q_offset = (b * seq_len + q_idx) * head_dim;
let k_offset = (b * kv_len + k_idx) * head_dim;
let s_offset = (b * seq_len + q_idx) * kv_len + k_idx;
let q_ptr = queries.as_ptr().add(q_offset);
let k_ptr = keys.as_ptr().add(k_offset);
let mut dot = vdupq_n_f32(0.0);
let chunks = head_dim / (NEON_LANE_WIDTH * UNROLL_FACTOR);
let mut idx = 0usize;
for _ in 0..chunks {
let q0 = vld1q_f32(q_ptr.add(idx));
let k0 = vld1q_f32(k_ptr.add(idx));
dot = vfmaq_f32(dot, q0, k0);
let q1 = vld1q_f32(q_ptr.add(idx + 4));
let k1 = vld1q_f32(k_ptr.add(idx + 4));
dot = vfmaq_f32(dot, q1, k1);
let q2 = vld1q_f32(q_ptr.add(idx + 8));
let k2 = vld1q_f32(k_ptr.add(idx + 8));
dot = vfmaq_f32(dot, q2, k2);
let q3 = vld1q_f32(q_ptr.add(idx + 12));
let k3 = vld1q_f32(k_ptr.add(idx + 12));
dot = vfmaq_f32(dot, q3, k3);
idx += 16;
}
let remaining = (head_dim - idx) / NEON_LANE_WIDTH;
for _ in 0..remaining {
let q_v = vld1q_f32(q_ptr.add(idx));
let k_v = vld1q_f32(k_ptr.add(idx));
dot = vfmaq_f32(dot, q_v, k_v);
idx += 4;
}
let mut score = vaddvq_f32(vmulq_f32(dot, scale_vec));
for i in idx..head_dim {
score += *q_ptr.add(i) * *k_ptr.add(i) * scale;
}
scores[s_offset] = score;
}
}
}
}
#[cfg(target_arch = "aarch64")]
#[inline(always)]
pub unsafe fn softmax_neon(x: &mut [f32], len: usize) {
debug_assert!(x.len() >= len);
let x_ptr = x.as_mut_ptr();
let mut max_vec = vdupq_n_f32(f32::NEG_INFINITY);
let chunks = len / NEON_LANE_WIDTH;
let mut idx = 0usize;
for _ in 0..chunks {
let v = vld1q_f32(x_ptr.add(idx));
max_vec = vmaxq_f32(max_vec, v);
idx += 4;
}
let mut max_val = vmaxvq_f32(max_vec);
for i in idx..len {
max_val = max_val.max(*x_ptr.add(i));
}
let max_vec = vdupq_n_f32(max_val);
let mut sum_vec = vdupq_n_f32(0.0);
idx = 0;
for _ in 0..chunks {
let v = vld1q_f32(x_ptr.add(idx));
let shifted = vsubq_f32(v, max_vec);
let one = vdupq_n_f32(1.0);
let half = vdupq_n_f32(0.5);
let sixth = vdupq_n_f32(1.0 / 6.0);
let x2 = vmulq_f32(shifted, shifted);
let x3 = vmulq_f32(x2, shifted);
let exp_approx = vaddq_f32(
one,
vaddq_f32(
shifted,
vaddq_f32(vmulq_f32(x2, half), vmulq_f32(x3, sixth)),
),
);
let exp_val = vdupq_n_f32(
(vgetq_lane_f32(shifted, 0)).exp()
+ (vgetq_lane_f32(shifted, 1)).exp()
+ (vgetq_lane_f32(shifted, 2)).exp()
+ (vgetq_lane_f32(shifted, 3)).exp(),
);
let _ = exp_approx; vst1q_f32(
x_ptr.add(idx),
vsetq_lane_f32(
(vgetq_lane_f32(shifted, 3)).exp(),
vsetq_lane_f32(
(vgetq_lane_f32(shifted, 2)).exp(),
vsetq_lane_f32(
(vgetq_lane_f32(shifted, 1)).exp(),
vsetq_lane_f32((vgetq_lane_f32(shifted, 0)).exp(), vdupq_n_f32(0.0), 0),
1,
),
2,
),
3,
),
);
let stored = vld1q_f32(x_ptr.add(idx));
sum_vec = vaddq_f32(sum_vec, stored);
idx += 4;
}
let mut sum_val = vaddvq_f32(sum_vec);
for i in idx..len {
let exp_val = (*x_ptr.add(i) - max_val).exp();
*x_ptr.add(i) = exp_val;
sum_val += exp_val;
}
let inv_sum = 1.0 / sum_val;
let inv_sum_vec = vdupq_n_f32(inv_sum);
idx = 0;
for _ in 0..chunks {
let v = vld1q_f32(x_ptr.add(idx));
vst1q_f32(x_ptr.add(idx), vmulq_f32(v, inv_sum_vec));
idx += 4;
}
for i in idx..len {
*x_ptr.add(i) *= inv_sum;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flash_attention_basic() {
let head_dim = 16;
let kv_len = 4;
let query: Vec<f32> = (0..head_dim).map(|i| (i as f32) * 0.1).collect();
let key: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.01).collect();
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.02).collect();
let scale = 1.0 / (head_dim as f32).sqrt();
let output = flash_attention_neon(&query, &key, &value, scale, false);
assert_eq!(output.len(), head_dim);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_paged_kv_cache() {
let mut cache = PagedKvCache::new(4, 2, 8);
let keys = vec![1.0; 2 * 8]; let values = vec![2.0; 2 * 8];
cache.append(&keys, &values);
assert_eq!(cache.num_tokens, 1);
cache.append(&keys, &values);
assert_eq!(cache.num_tokens, 2);
let retrieved_keys = cache.get_keys();
assert_eq!(retrieved_keys.len(), 2 * 2 * 8);
}
#[test]
fn test_gqa() {
let config = AttentionConfig {
num_heads: 8,
num_kv_heads: 2,
head_dim: 16,
causal: false,
..Default::default()
};
let queries: Vec<f32> = (0..config.num_heads * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let kv_len = 4;
let keys: Vec<f32> = (0..kv_len * config.num_kv_heads * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let values: Vec<f32> = (0..kv_len * config.num_kv_heads * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let output = grouped_query_attention_neon(&queries, &keys, &values, &config);
assert_eq!(output.len(), config.num_heads * config.head_dim);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_mqa() {
let config = AttentionConfig {
num_heads: 8,
num_kv_heads: 1,
head_dim: 16,
causal: false,
..Default::default()
};
let queries: Vec<f32> = (0..config.num_heads * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let kv_len = 4;
let keys: Vec<f32> = (0..kv_len * config.head_dim)
.map(|i| (i as f32) * 0.01)
.collect();
let values: Vec<f32> = (0..kv_len * config.head_dim)
.map(|i| (i as f32) * 0.02)
.collect();
let output = multi_query_attention_neon(&queries, &keys, &values, &config);
assert_eq!(output.len(), config.num_heads * config.head_dim);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_paged_attention() {
let mut cache = PagedKvCache::new(16, 1, 16);
for _ in 0..8 {
let keys: Vec<f32> = (0..16).map(|i| (i as f32) * 0.1).collect();
let values: Vec<f32> = (0..16).map(|i| (i as f32) * 0.2).collect();
cache.append(&keys, &values);
}
let query: Vec<f32> = (0..16).map(|i| (i as f32) * 0.05).collect();
let scale = 1.0 / (16.0f32).sqrt();
let output = paged_attention_neon(&query, &cache, &[], scale);
assert_eq!(output.len(), 16);
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_attention_scratch_buffer() {
let scratch = AttentionScratch::new(128, 64, 32);
assert_eq!(scratch.head_dim(), 128);
assert_eq!(scratch.max_block_size(), 64);
assert!(scratch.is_compatible(128, 64, 32));
assert!(scratch.is_compatible(64, 32, 16));
assert!(!scratch.is_compatible(256, 64, 32)); }
#[test]
fn test_attention_scratch_buffers() {
let mut scratch = AttentionScratch::new(128, 64, 32);
let output = scratch.output_buffer();
assert_eq!(output.len(), 128);
let block_scores = scratch.block_scores_buffer(32);
assert_eq!(block_scores.len(), 32);
let head_outputs = scratch.head_outputs_buffer(16);
assert_eq!(head_outputs.len(), 16 * 128);
}
#[test]
fn test_flash_attention_into_basic() {
let head_dim = 16;
let kv_len = 4;
let query: Vec<f32> = (0..head_dim).map(|i| (i as f32) * 0.1).collect();
let key: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.01).collect();
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.02).collect();
let scale = 1.0 / (head_dim as f32).sqrt();
let mut output = vec![0.0; head_dim];
flash_attention_into(&query, &key, &value, scale, false, &mut output);
assert_eq!(output.len(), head_dim);
assert!(output.iter().all(|&x| x.is_finite()));
let expected = flash_attention_neon(&query, &key, &value, scale, false);
for (a, b) in output.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-5, "Output mismatch: {} vs {}", a, b);
}
}
#[test]
fn test_flash_attention_with_scratch() {
let head_dim = 16;
let kv_len = 8;
let query: Vec<f32> = (0..head_dim).map(|i| (i as f32) * 0.1).collect();
let key: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.01).collect();
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.02).collect();
let scale = 1.0 / (head_dim as f32).sqrt();
let mut scratch = AttentionScratch::new(head_dim, 64, 1);
let mut output = vec![0.0; head_dim];
flash_attention_with_scratch(&query, &key, &value, scale, &mut scratch, &mut output);
assert!(output.iter().all(|&x| x.is_finite()));
let expected = flash_attention_neon(&query, &key, &value, scale, false);
for (a, b) in output.iter().zip(expected.iter()) {
assert!((a - b).abs() < 1e-5, "Output mismatch: {} vs {}", a, b);
}
}
#[test]
fn test_paged_kv_cache_with_capacity() {
let cache = PagedKvCache::with_capacity(16, 2, 8, 4);
assert_eq!(cache.capacity(), 64);
assert!(cache.has_capacity(64));
assert!(!cache.has_capacity(65));
assert_eq!(cache.num_tokens, 0);
}
#[test]
fn test_paged_kv_cache_with_max_tokens() {
let cache = PagedKvCache::with_max_tokens(16, 2, 8, 100);
assert!(cache.capacity() >= 100);
assert!(cache.has_capacity(100));
}
#[test]
fn test_paged_kv_cache_reset() {
let mut cache = PagedKvCache::with_capacity(16, 2, 8, 4);
let keys = vec![1.0; 2 * 8];
let values = vec![2.0; 2 * 8];
cache.append(&keys, &values);
cache.append(&keys, &values);
assert_eq!(cache.num_tokens, 2);
cache.reset();
assert_eq!(cache.num_tokens, 0);
assert_eq!(cache.capacity(), 64); }
#[test]
fn test_paged_kv_cache_copy_into() {
let mut cache = PagedKvCache::new(4, 2, 8);
let keys = vec![1.0; 2 * 8];
let values = vec![2.0; 2 * 8];
cache.append(&keys, &values);
cache.append(&keys, &values);
let mut key_buffer = vec![0.0; cache.num_tokens * 2 * 8];
let written = cache.copy_keys_into(&mut key_buffer);
assert_eq!(written, cache.num_tokens * 2 * 8);
assert!(key_buffer.iter().all(|&x| (x - 1.0).abs() < 1e-6));
let mut value_buffer = vec![0.0; cache.num_tokens * 2 * 8];
let written = cache.copy_values_into(&mut value_buffer);
assert_eq!(written, cache.num_tokens * 2 * 8);
assert!(value_buffer.iter().all(|&x| (x - 2.0).abs() < 1e-6));
}
#[test]
fn test_paged_kv_cache_append_unchecked() {
let mut cache = PagedKvCache::with_capacity(16, 2, 8, 4);
let keys = vec![1.0; 2 * 8];
let values = vec![2.0; 2 * 8];
unsafe {
cache.append_unchecked(&keys, &values);
cache.append_unchecked(&keys, &values);
}
assert_eq!(cache.num_tokens, 2);
let retrieved_keys = cache.get_keys();
assert_eq!(retrieved_keys.len(), 2 * 2 * 8);
assert!(retrieved_keys.iter().all(|&x| (x - 1.0).abs() < 1e-6));
}
#[test]
fn test_zero_allocation_repeated_calls() {
let head_dim = 32;
let kv_len = 16;
let query: Vec<f32> = (0..head_dim).map(|i| (i as f32) * 0.1).collect();
let key: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.01).collect();
let value: Vec<f32> = (0..kv_len * head_dim).map(|i| (i as f32) * 0.02).collect();
let scale = 1.0 / (head_dim as f32).sqrt();
let mut scratch = AttentionScratch::new(head_dim, 64, 1);
let mut output = vec![0.0; head_dim];
for _ in 0..100 {
flash_attention_with_scratch(&query, &key, &value, scale, &mut scratch, &mut output);
}
assert!(output.iter().all(|&x| x.is_finite()));
}
}