use crate::error::{AttentionError, AttentionResult};
use crate::traits::Attention;
pub struct FlashAttention {
dim: usize,
block_size: usize,
scale: f32,
causal: bool,
}
impl FlashAttention {
pub fn new(dim: usize, block_size: usize) -> Self {
Self {
dim,
block_size,
scale: 1.0 / (dim as f32).sqrt(),
causal: false,
}
}
pub fn causal(dim: usize, block_size: usize) -> Self {
Self {
dim,
block_size,
scale: 1.0 / (dim as f32).sqrt(),
causal: true,
}
}
fn compute_block_scores(&self, query: &[f32], keys: &[&[f32]], start_idx: usize) -> Vec<f32> {
keys.iter()
.enumerate()
.map(|(j, key)| {
if self.causal && start_idx + j > 0 {
f32::NEG_INFINITY
} else {
query
.iter()
.zip(key.iter())
.map(|(q, k)| q * k)
.sum::<f32>()
* self.scale
}
})
.collect()
}
}
impl Attention for FlashAttention {
fn compute(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
) -> AttentionResult<Vec<f32>> {
if keys.is_empty() {
return Err(AttentionError::InvalidConfig("Empty keys".to_string()));
}
if keys.len() != values.len() {
return Err(AttentionError::DimensionMismatch {
expected: keys.len(),
actual: values.len(),
});
}
if query.len() != self.dim {
return Err(AttentionError::DimensionMismatch {
expected: self.dim,
actual: query.len(),
});
}
let n = keys.len();
let value_dim = values[0].len();
let mut output = vec![0.0f32; value_dim];
let mut max_so_far = f32::NEG_INFINITY;
let mut sum_exp = 0.0f32;
for block_start in (0..n).step_by(self.block_size) {
let block_end = (block_start + self.block_size).min(n);
let block_keys: Vec<&[f32]> = keys[block_start..block_end].to_vec();
let block_scores = self.compute_block_scores(query, &block_keys, block_start);
let block_max = block_scores
.iter()
.copied()
.filter(|x| x.is_finite())
.fold(f32::NEG_INFINITY, f32::max);
if !block_max.is_finite() {
continue; }
let new_max = max_so_far.max(block_max);
if max_so_far.is_finite() {
let rescale = (max_so_far - new_max).exp();
sum_exp *= rescale;
output.iter_mut().for_each(|o| *o *= rescale);
}
for (local_idx, &score) in block_scores.iter().enumerate() {
if score.is_finite() {
let exp_score = (score - new_max).exp();
sum_exp += exp_score;
let global_idx = block_start + local_idx;
for (j, &vj) in values[global_idx].iter().enumerate() {
output[j] += exp_score * vj;
}
}
}
max_so_far = new_max;
}
if sum_exp > 1e-8 {
output.iter_mut().for_each(|o| *o /= sum_exp);
}
Ok(output)
}
fn compute_with_mask(
&self,
query: &[f32],
keys: &[&[f32]],
values: &[&[f32]],
mask: Option<&[bool]>,
) -> AttentionResult<Vec<f32>> {
if let Some(m) = mask {
let filtered: Vec<(usize, bool)> = m
.iter()
.copied()
.enumerate()
.filter(|(_, keep)| *keep)
.collect();
let filtered_keys: Vec<&[f32]> = filtered.iter().map(|(i, _)| keys[*i]).collect();
let filtered_values: Vec<&[f32]> = filtered.iter().map(|(i, _)| values[*i]).collect();
self.compute(query, &filtered_keys, &filtered_values)
} else {
self.compute(query, keys, values)
}
}
fn dim(&self) -> usize {
self.dim
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::attention::ScaledDotProductAttention;
#[test]
fn test_flash_attention() {
let attention = FlashAttention::new(64, 16);
let query = vec![0.5; 64];
let keys: Vec<Vec<f32>> = (0..256).map(|_| vec![0.3; 64]).collect();
let values: Vec<Vec<f32>> = (0..256).map(|_| vec![1.0; 64]).collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
assert_eq!(result.len(), 64);
}
#[test]
fn test_flash_matches_standard() {
let dim = 32;
let flash = FlashAttention::new(dim, 8);
let standard = ScaledDotProductAttention::new(dim);
let query = vec![0.5; dim];
let keys: Vec<Vec<f32>> = (0..16).map(|i| vec![(i as f32) * 0.1; dim]).collect();
let values: Vec<Vec<f32>> = (0..16).map(|i| vec![(i as f32) * 0.2; dim]).collect();
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let flash_result = flash.compute(&query, &keys_refs, &values_refs).unwrap();
let standard_result = standard.compute(&query, &keys_refs, &values_refs).unwrap();
for (f, s) in flash_result.iter().zip(standard_result.iter()) {
assert!((f - s).abs() < 1e-4, "Flash: {}, Standard: {}", f, s);
}
}
#[test]
fn test_causal_flash() {
let attention = FlashAttention::causal(32, 8);
let query = vec![1.0; 32];
let keys = vec![vec![0.5; 32]; 20];
let values = vec![vec![1.0; 32]; 20];
let keys_refs: Vec<&[f32]> = keys.iter().map(|k| k.as_slice()).collect();
let values_refs: Vec<&[f32]> = values.iter().map(|v| v.as_slice()).collect();
let result = attention.compute(&query, &keys_refs, &values_refs).unwrap();
assert_eq!(result.len(), 32);
}
}