use crate::error::{RuntimeError, RuntimeResult};
use crate::flash_attention::flash_attention_forward;
use crate::kv_cache::BatchedKvView;
pub fn batched_flash_attention<V: BatchedKvView>(
q_batch: &[f32],
kv_view: &V,
num_heads: usize,
head_dim: usize,
softmax_scale: f32,
) -> RuntimeResult<Vec<f32>> {
if num_heads == 0 || head_dim == 0 {
return Err(RuntimeError::AttentionError {
message: "num_heads and head_dim must be > 0".to_string(),
});
}
let head_stride = num_heads * head_dim;
if head_stride == 0 || q_batch.len() % head_stride != 0 {
return Err(RuntimeError::AttentionError {
message: format!(
"q_batch length {} is not a multiple of num_heads * head_dim = {}",
q_batch.len(),
head_stride
),
});
}
let batch_size = q_batch.len() / head_stride;
let slot_count = kv_view.slot_count();
if batch_size != slot_count {
return Err(RuntimeError::AttentionError {
message: format!(
"batch_size ({batch_size}) must equal kv_view.slot_count() ({slot_count})"
),
});
}
let mut output = vec![0.0f32; batch_size * head_stride];
for slot in 0..batch_size {
let seq_len_kv = kv_view.position(slot);
if seq_len_kv == 0 {
continue;
}
let (k_flat, v_flat) = kv_view.kv_for_slot(slot);
let kv_expected = seq_len_kv * head_stride;
if k_flat.len() < kv_expected {
return Err(RuntimeError::AttentionError {
message: format!(
"slot {slot}: k_flat length {} < expected {} (seq_len_kv={seq_len_kv}, \
num_heads={num_heads}, head_dim={head_dim})",
k_flat.len(),
kv_expected,
),
});
}
if v_flat.len() < kv_expected {
return Err(RuntimeError::AttentionError {
message: format!(
"slot {slot}: v_flat length {} < expected {} (seq_len_kv={seq_len_kv}, \
num_heads={num_heads}, head_dim={head_dim})",
v_flat.len(),
kv_expected,
),
});
}
let q_off = slot * head_stride;
let q_slot = &q_batch[q_off..q_off + head_stride];
let slot_out = flash_attention_forward(
q_slot,
&k_flat[..kv_expected],
&v_flat[..kv_expected],
num_heads,
head_dim,
softmax_scale,
false, )?;
let o_off = slot * head_stride;
output[o_off..o_off + head_stride].copy_from_slice(&slot_out);
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::kv_cache::{BatchedKvView, KvSlot, VecBatchedKvView};
fn sin_data(len: usize) -> Vec<f32> {
(0..len).map(|i| f32::sin(i as f32 * 0.1) * 0.1).collect()
}
#[test]
fn batched_kv_view_basic() {
let num_heads = 2usize;
let head_dim = 4usize;
let kv_dim = num_heads * head_dim;
let positions = [3usize, 7, 1];
let mut slots = Vec::new();
let mut keys = Vec::new();
let mut values = Vec::new();
for (i, &pos) in positions.iter().enumerate() {
slots.push(KvSlot::new(i as u64 + 1, i, pos));
let k: Vec<f32> = (0..pos * kv_dim)
.map(|j| (i as f32 * 10.0) + j as f32)
.collect();
let v: Vec<f32> = (0..pos * kv_dim)
.map(|j| (i as f32 * 100.0) + j as f32)
.collect();
keys.push(k);
values.push(v);
}
let view = VecBatchedKvView::new(slots, keys, values);
assert_eq!(view.slot_count(), 3, "slot_count must be 3");
assert_eq!(view.position(0), 3, "slot 0 position must be 3");
assert_eq!(view.position(1), 7, "slot 1 position must be 7");
assert_eq!(view.position(2), 1, "slot 2 position must be 1");
let (k0, v0) = view.kv_for_slot(0);
assert_eq!(k0.len(), 3 * kv_dim, "slot 0 key length must be pos*kv_dim");
assert_eq!(
v0.len(),
3 * kv_dim,
"slot 0 value length must be pos*kv_dim"
);
let (k1, v1) = view.kv_for_slot(1);
assert_eq!(k1.len(), 7 * kv_dim, "slot 1 key length");
assert_eq!(v1.len(), 7 * kv_dim, "slot 1 value length");
let (k2, v2) = view.kv_for_slot(2);
assert_eq!(k2.len(), kv_dim, "slot 2 key length");
assert_eq!(v2.len(), kv_dim, "slot 2 value length");
assert!(
(k0[0] - 0.0f32).abs() < 1e-7,
"slot 0 k[0] should be 0.0, got {}",
k0[0]
);
assert!(
(k0[1] - 1.0f32).abs() < 1e-7,
"slot 0 k[1] should be 1.0, got {}",
k0[1]
);
assert!(
(v0[0] - 0.0f32).abs() < 1e-7,
"slot 0 v[0] should be 0.0, got {}",
v0[0]
);
assert!(
(k1[0] - 10.0f32).abs() < 1e-7,
"slot 1 k[0] should be 10.0, got {}",
k1[0]
);
assert!(
(v2[0] - 200.0f32).abs() < 1e-7,
"slot 2 v[0] should be 200.0, got {}",
v2[0]
);
}
#[test]
fn batched_flash_decode_matches_serial() {
let num_heads = 2usize;
let head_dim = 8usize;
let kv_dim = num_heads * head_dim;
let scale = 1.0f32 / (head_dim as f32).sqrt();
let seq_kv_0 = 16usize;
let seq_kv_1 = 24usize;
let q0 = sin_data(kv_dim);
let q1: Vec<f32> = (0..kv_dim)
.map(|i| f32::cos(i as f32 * 0.07) * 0.12)
.collect();
let k0 = sin_data(seq_kv_0 * kv_dim);
let v0: Vec<f32> = (0..seq_kv_0 * kv_dim)
.map(|i| f32::cos(i as f32 * 0.13) * 0.1)
.collect();
let k1: Vec<f32> = (0..seq_kv_1 * kv_dim)
.map(|i| f32::sin(i as f32 * 0.05 + 1.0) * 0.08)
.collect();
let v1: Vec<f32> = (0..seq_kv_1 * kv_dim)
.map(|i| f32::cos(i as f32 * 0.09 + 0.5) * 0.09)
.collect();
let slots = vec![KvSlot::new(1, 0, seq_kv_0), KvSlot::new(2, 1, seq_kv_1)];
let keys_vec = vec![k0.clone(), k1.clone()];
let vals_vec = vec![v0.clone(), v1.clone()];
let view = VecBatchedKvView::new(slots, keys_vec, vals_vec);
let mut q_batch = Vec::with_capacity(2 * kv_dim);
q_batch.extend_from_slice(&q0);
q_batch.extend_from_slice(&q1);
let batched_out = batched_flash_attention(&q_batch, &view, num_heads, head_dim, scale)
.expect("batched_flash_attention failed");
assert_eq!(
batched_out.len(),
2 * kv_dim,
"output must be batch_size * num_heads * head_dim"
);
let serial_out_0 =
flash_attention_forward(&q0, &k0, &v0, num_heads, head_dim, scale, false)
.expect("serial slot 0 failed");
let serial_out_1 =
flash_attention_forward(&q1, &k1, &v1, num_heads, head_dim, scale, false)
.expect("serial slot 1 failed");
for (idx, (&b, &s)) in batched_out[..kv_dim]
.iter()
.zip(serial_out_0.iter())
.enumerate()
{
let diff = (b - s).abs();
assert!(
diff < 1e-5,
"slot 0, index {idx}: batched={b} serial={s} diff={diff}"
);
}
for (idx, (&b, &s)) in batched_out[kv_dim..]
.iter()
.zip(serial_out_1.iter())
.enumerate()
{
let diff = (b - s).abs();
assert!(
diff < 1e-5,
"slot 1, index {idx}: batched={b} serial={s} diff={diff}"
);
}
}
}