use crate::backend::{BackendError, BackendResult};
use crate::tensor::{DType, Tensor};
use rayon::prelude::*;
const BLOCK_SIZE: usize = 64;
#[derive(Debug, Clone)]
pub struct FlashAttentionConfig {
pub block_q: usize,
pub block_kv: usize,
pub causal: bool,
}
impl Default for FlashAttentionConfig {
fn default() -> Self {
Self {
block_q: BLOCK_SIZE,
block_kv: BLOCK_SIZE,
causal: true,
}
}
}
pub fn flash_attention(
q: &Tensor,
k: &Tensor,
v: &Tensor,
out: &mut Tensor,
scale: f32,
causal: bool,
) -> BackendResult<()> {
let q_shape = q.shape();
let k_shape = k.shape();
let v_shape = v.shape();
if q_shape.len() < 3 || k_shape.len() < 3 || v_shape.len() < 3 {
return Err(BackendError::InvalidArgument(
"Flash attention requires at least 3D tensors".into(),
));
}
let (num_heads, seq_len_q, head_dim) = if q_shape.len() == 3 {
(q_shape[0], q_shape[1], q_shape[2])
} else {
(q_shape[1], q_shape[2], q_shape[3])
};
let seq_len_kv = if k_shape.len() == 3 {
k_shape[1]
} else {
k_shape[2]
};
let num_kv_heads = if k_shape.len() == 3 {
k_shape[0]
} else {
k_shape[1]
};
if q.dtype() != DType::F32 || k.dtype() != DType::F32 || v.dtype() != DType::F32 {
return Err(BackendError::InvalidArgument(
"Flash attention requires F32 tensors".into(),
));
}
let q_data = q.as_f32()?;
let k_data = k.as_f32()?;
let v_data = v.as_f32()?;
let out_data = out.as_f32_mut()?;
let heads_per_group = num_heads / num_kv_heads;
let results: Vec<(usize, Vec<f32>)> = (0..num_heads)
.into_par_iter()
.map(|head| {
let kv_head = head / heads_per_group;
let q_head_offset = head * seq_len_q * head_dim;
let k_head_offset = kv_head * seq_len_kv * head_dim;
let v_head_offset = kv_head * seq_len_kv * head_dim;
let mut head_output = vec![0.0f32; seq_len_q * head_dim];
for q_pos in 0..seq_len_q {
let q_offset = q_head_offset + q_pos * head_dim;
let mut max_score = f32::NEG_INFINITY;
let mut sum_exp = 0.0f32;
let mut output = vec![0.0f32; head_dim];
let kv_end = if causal {
(q_pos + 1).min(seq_len_kv)
} else {
seq_len_kv
};
let block_size = BLOCK_SIZE.min(kv_end.max(1));
let num_blocks = kv_end.div_ceil(block_size);
for block_idx in 0..num_blocks {
let block_start = block_idx * block_size;
let block_end = (block_start + block_size).min(kv_end);
for kv_pos in block_start..block_end {
let k_offset = k_head_offset + kv_pos * head_dim;
let mut score = 0.0f32;
for d in 0..head_dim {
score += q_data[q_offset + d] * k_data[k_offset + d];
}
score *= scale;
let v_offset = v_head_offset + kv_pos * head_dim;
if score > max_score {
let rescale = (max_score - score).exp();
for out_val in output.iter_mut().take(head_dim) {
*out_val *= rescale;
}
sum_exp *= rescale;
max_score = score;
}
let exp_score = (score - max_score).exp();
sum_exp += exp_score;
for d in 0..head_dim {
output[d] += exp_score * v_data[v_offset + d];
}
}
}
let inv_sum = if sum_exp > 0.0 { 1.0 / sum_exp } else { 0.0 };
let out_pos_offset = q_pos * head_dim;
for d in 0..head_dim {
head_output[out_pos_offset + d] = output[d] * inv_sum;
}
}
(head, head_output)
})
.collect();
for (head, head_output) in results {
let out_head_offset = head * seq_len_q * head_dim;
for (i, &val) in head_output.iter().enumerate() {
out_data[out_head_offset + i] = val;
}
}
Ok(())
}
pub fn flash_attention_blocked(
q: &Tensor,
k: &Tensor,
v: &Tensor,
out: &mut Tensor,
scale: f32,
config: &FlashAttentionConfig,
) -> BackendResult<()> {
flash_attention(q, k, v, out, scale, config.causal)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flash_attention_basic() {
let q_data: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
let k_data: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
let v_data: Vec<f32> = (0..64).map(|i| (i as f32) * 0.1).collect();
let q = Tensor::from_f32(&q_data, vec![2, 4, 8]).unwrap();
let k = Tensor::from_f32(&k_data, vec![2, 4, 8]).unwrap();
let v = Tensor::from_f32(&v_data, vec![2, 4, 8]).unwrap();
let mut out = Tensor::zeros(vec![2, 4, 8], DType::F32);
let scale = 1.0 / (8.0f32).sqrt();
let result = flash_attention(&q, &k, &v, &mut out, scale, true);
assert!(result.is_ok());
let out_data = out.as_f32().unwrap();
let sum: f32 = out_data.iter().sum();
assert!(sum.abs() > 0.0);
}
#[test]
fn test_flash_attention_causal() {
let q_data = vec![1.0f32; 8];
let k_data = vec![1.0f32; 24]; let v_data: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, ];
let q = Tensor::from_f32(&q_data, vec![1, 1, 8]).unwrap();
let k = Tensor::from_f32(&k_data, vec![1, 3, 8]).unwrap();
let v = Tensor::from_f32(&v_data, vec![1, 3, 8]).unwrap();
let mut out = Tensor::zeros(vec![1, 1, 8], DType::F32);
let scale = 1.0 / (8.0f32).sqrt();
flash_attention(&q, &k, &v, &mut out, scale, true).unwrap();
let out_data = out.as_f32().unwrap();
assert!((out_data[0] - 1.0).abs() < 0.01);
assert!((out_data[1]).abs() < 0.01);
}
#[test]
fn test_flash_attention_non_causal() {
let q_data = vec![1.0f32; 8];
let k_data = vec![1.0f32; 16]; let v_data: Vec<f32> = vec![
1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ];
let q = Tensor::from_f32(&q_data, vec![1, 1, 8]).unwrap();
let k = Tensor::from_f32(&k_data, vec![1, 2, 8]).unwrap();
let v = Tensor::from_f32(&v_data, vec![1, 2, 8]).unwrap();
let mut out = Tensor::zeros(vec![1, 1, 8], DType::F32);
let scale = 1.0 / (8.0f32).sqrt();
flash_attention(&q, &k, &v, &mut out, scale, false).unwrap();
let out_data = out.as_f32().unwrap();
assert!((out_data[0] - 0.5).abs() < 0.01);
assert!((out_data[1] - 0.5).abs() < 0.01);
}
}