use super::{softmax_inplace, Attention};
#[derive(Debug, Clone)]
pub struct FlashAttention {
#[allow(dead_code)]
block_size_q: usize,
block_size_kv: usize,
scale: f32,
}
impl FlashAttention {
pub fn new(head_dim: usize, block_size: usize) -> Self {
Self {
block_size_q: block_size,
block_size_kv: block_size,
scale: 1.0 / (head_dim as f32).sqrt(),
}
}
pub fn with_head_dim(head_dim: usize) -> Self {
Self::new(head_dim, 64)
}
#[inline]
fn compute_score(&self, query: &[f32], key: &[f32]) -> f32 {
let dot: f32 = query.iter().zip(key.iter()).map(|(q, k)| q * k).sum();
dot * self.scale
}
fn process_block(
&self,
query_block: &[f32],
key_block: &[&[f32]],
value_block: &[&[f32]],
) -> Vec<f32> {
if key_block.is_empty() {
return vec![0.0; value_block.first().map_or(0, |v| v.len())];
}
let mut scores: Vec<f32> = key_block
.iter()
.map(|key| self.compute_score(query_block, key))
.collect();
softmax_inplace(&mut scores);
let value_dim = value_block[0].len();
let mut output = vec![0.0; value_dim];
for (score, value) in scores.iter().zip(value_block.iter()) {
for (out, val) in output.iter_mut().zip(value.iter()) {
*out += score * val;
}
}
output
}
pub fn forward_tiled(&self, query: &[f32], keys: &[&[f32]], values: &[&[f32]]) -> Vec<f32> {
assert_eq!(keys.len(), values.len(), "Keys and values length mismatch");
if keys.is_empty() {
return Vec::new();
}
let num_keys = keys.len();
let value_dim = values[0].len();
if num_keys <= self.block_size_kv {
return self.process_block(query, keys, values);
}
let mut block_outputs = Vec::new();
let mut block_max_scores = Vec::new();
for block_start in (0..num_keys).step_by(self.block_size_kv) {
let block_end = (block_start + self.block_size_kv).min(num_keys);
let key_block = &keys[block_start..block_end];
let value_block = &values[block_start..block_end];
let mut scores: Vec<f32> = key_block
.iter()
.map(|key| self.compute_score(query, key))
.collect();
let block_max = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
block_max_scores.push(block_max);
for score in &mut scores {
*score = (*score - block_max).exp();
}
let mut block_output = vec![0.0; value_dim];
for (score, value) in scores.iter().zip(value_block.iter()) {
for (out, val) in block_output.iter_mut().zip(value.iter()) {
*out += score * val;
}
}
block_outputs.push((scores.iter().sum::<f32>(), block_output));
}
let global_max = block_max_scores
.iter()
.copied()
.fold(f32::NEG_INFINITY, f32::max);
let mut output = vec![0.0; value_dim];
let mut total_weight = 0.0;
for ((block_sum, block_output), block_max) in
block_outputs.iter().zip(block_max_scores.iter())
{
let correction = (block_max - global_max).exp();
let block_weight = block_sum * correction;
total_weight += block_weight;
for (out, block_val) in output.iter_mut().zip(block_output.iter()) {
*out += block_val * correction;
}
}
if total_weight > 0.0 {
for out in &mut output {
*out /= total_weight;
}
}
output
}
}
impl Default for FlashAttention {
fn default() -> Self {
Self::new(64, 64)
}
}
impl Attention for FlashAttention {
fn attention_scores(&self, query: &[f32], keys: &[&[f32]]) -> Vec<f32> {
if keys.is_empty() {
return Vec::new();
}
let mut scores: Vec<f32> = keys
.iter()
.map(|key| self.compute_score(query, key))
.collect();
softmax_inplace(&mut scores);
scores
}
fn forward(&self, query: &[f32], keys: &[&[f32]], values: &[&[f32]]) -> Vec<f32> {
self.forward_tiled(query, keys, values)
}
}
#[cfg(feature = "pg_test")]
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
#[test]
fn test_flash_attention_basic() {
let flash = FlashAttention::new(4, 64);
let query = vec![1.0, 0.0, 0.0, 0.0];
let key1 = vec![1.0, 0.0, 0.0, 0.0];
let key2 = vec![0.0, 1.0, 0.0, 0.0];
let keys = vec![&key1[..], &key2[..]];
let scores = flash.attention_scores(&query, &keys);
assert_eq!(scores.len(), 2);
let sum: f32 = scores.iter().sum();
assert_relative_eq!(sum, 1.0, epsilon = 1e-6);
assert!(scores[0] > scores[1]); }
#[test]
fn test_flash_forward_small() {
let flash = FlashAttention::new(2, 64);
let query = vec![1.0, 0.0];
let key1 = vec![1.0, 0.0];
let key2 = vec![0.0, 1.0];
let value1 = vec![1.0, 2.0, 3.0];
let value2 = vec![4.0, 5.0, 6.0];
let keys = vec![&key1[..], &key2[..]];
let values = vec![&value1[..], &value2[..]];
let result = flash.forward(&query, &keys, &values);
assert_eq!(result.len(), 3);
assert!(result[0] < 2.5);
}
#[test]
fn test_flash_tiled_processing() {
let flash = FlashAttention::new(4, 2);
let query = vec![1.0, 0.0, 0.0, 0.0];
let keys: Vec<Vec<f32>> = vec![
vec![1.0, 0.0, 0.0, 0.0],
vec![0.9, 0.1, 0.0, 0.0],
vec![0.8, 0.2, 0.0, 0.0],
vec![0.0, 1.0, 0.0, 0.0],
];
let values: Vec<Vec<f32>> = vec![vec![1.0], vec![2.0], vec![3.0], vec![4.0]];
let key_refs: Vec<&[f32]> = keys.iter().map(|k| &k[..]).collect();
let value_refs: Vec<&[f32]> = values.iter().map(|v| &v[..]).collect();
let result = flash.forward(&query, &key_refs, &value_refs);
assert_eq!(result.len(), 1);
assert!(result[0] < 2.5);
}
#[test]
fn test_flash_vs_standard_attention() {
use super::super::ScaledDotAttention;
let head_dim = 4;
let flash = FlashAttention::new(head_dim, 2);
let standard = ScaledDotAttention::new(head_dim);
let query = vec![1.0, 0.5, 0.25, 0.0];
let keys: Vec<Vec<f32>> = vec![
vec![1.0, 0.5, 0.25, 0.0],
vec![0.0, 0.25, 0.5, 1.0],
vec![0.5, 0.5, 0.5, 0.5],
];
let values: Vec<Vec<f32>> = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
let key_refs: Vec<&[f32]> = keys.iter().map(|k| &k[..]).collect();
let value_refs: Vec<&[f32]> = values.iter().map(|v| &v[..]).collect();
let flash_result = flash.forward(&query, &key_refs, &value_refs);
let standard_result = standard.forward(&query, &key_refs, &value_refs);
assert_eq!(flash_result.len(), standard_result.len());
for (f, s) in flash_result.iter().zip(standard_result.iter()) {
assert_relative_eq!(f, s, epsilon = 1e-4);
}
}
#[test]
fn test_flash_empty_sequence() {
let flash = FlashAttention::new(4, 64);
let query = vec![1.0, 0.0, 0.0, 0.0];
let keys: Vec<&[f32]> = vec![];
let values: Vec<&[f32]> = vec![];
let result = flash.forward(&query, &keys, &values);
assert!(result.is_empty());
}
#[test]
fn test_flash_numerical_stability() {
let flash = FlashAttention::new(4, 2);
let query = vec![100.0, 100.0, 100.0, 100.0];
let keys: Vec<Vec<f32>> = vec![
vec![100.0, 100.0, 100.0, 100.0],
vec![99.0, 99.0, 99.0, 99.0],
vec![98.0, 98.0, 98.0, 98.0],
];
let values: Vec<Vec<f32>> = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.5, 0.5]];
let key_refs: Vec<&[f32]> = keys.iter().map(|k| &k[..]).collect();
let value_refs: Vec<&[f32]> = values.iter().map(|v| &v[..]).collect();
let result = flash.forward(&query, &key_refs, &value_refs);
assert!(result.iter().all(|x| x.is_finite()));
}
}
#[cfg(feature = "pg_test")]
#[pgrx::pg_schema]
mod pg_tests {
use super::*;
use pgrx::prelude::*;
#[pg_test]
fn test_pg_flash_attention() {
let flash = FlashAttention::new(4, 64);
let query = vec![1.0, 0.0, 0.0, 0.0];
let key = vec![1.0, 0.0, 0.0, 0.0];
let value = vec![5.0, 10.0];
let keys = vec![&key[..]];
let values = vec![&value[..]];
let result = flash.forward(&query, &keys, &values);
assert_eq!(result.len(), 2);
assert!((result[0] - 5.0).abs() < 0.01);
assert!((result[1] - 10.0).abs() < 0.01);
}
#[pg_test]
fn test_pg_flash_tiled() {
let flash = FlashAttention::new(2, 2);
let query = vec![1.0, 0.0];
let keys: Vec<Vec<f32>> = vec![
vec![1.0, 0.0],
vec![0.9, 0.1],
vec![0.0, 1.0],
vec![0.1, 0.9],
];
let values: Vec<Vec<f32>> = vec![vec![10.0], vec![20.0], vec![30.0], vec![40.0]];
let key_refs: Vec<&[f32]> = keys.iter().map(|k| &k[..]).collect();
let value_refs: Vec<&[f32]> = values.iter().map(|v| &v[..]).collect();
let result = flash.forward(&query, &key_refs, &value_refs);
assert_eq!(result.len(), 1);
assert!(result[0] < 25.0);
}
}