use crate::{CoreError, CoreResult};
use scirs2_core::ndarray::{s, Array2, Array3};
use std::f32;
#[derive(Debug, Clone)]
pub struct FlashAttentionConfig {
pub num_heads: usize,
pub head_dim: usize,
pub tile_q: usize,
pub tile_kv: usize,
pub dropout: f32,
pub scale: f32,
pub causal: bool,
}
impl FlashAttentionConfig {
pub fn new(num_heads: usize, head_dim: usize) -> Self {
let scale = 1.0 / (head_dim as f32).sqrt();
Self {
num_heads,
head_dim,
tile_q: 64, tile_kv: 64, dropout: 0.0,
scale,
causal: false,
}
}
pub fn with_tile_sizes(mut self, tile_q: usize, tile_kv: usize) -> Self {
self.tile_q = tile_q;
self.tile_kv = tile_kv;
self
}
pub fn with_causal(mut self, causal: bool) -> Self {
self.causal = causal;
self
}
pub fn with_dropout(mut self, dropout: f32) -> Self {
self.dropout = dropout;
self
}
pub fn validate(&self) -> CoreResult<()> {
if self.num_heads == 0 || self.head_dim == 0 {
return Err(CoreError::InvalidConfig(
"num_heads and head_dim must be positive".to_string(),
));
}
if self.tile_q == 0 || self.tile_kv == 0 {
return Err(CoreError::InvalidConfig(
"tile sizes must be positive".to_string(),
));
}
if self.dropout < 0.0 || self.dropout >= 1.0 {
return Err(CoreError::InvalidConfig(
"dropout must be in [0, 1)".to_string(),
));
}
Ok(())
}
}
impl Default for FlashAttentionConfig {
fn default() -> Self {
Self::new(8, 64)
}
}
pub struct FlashAttention {
config: FlashAttentionConfig,
}
impl FlashAttention {
pub fn new(config: FlashAttentionConfig) -> CoreResult<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn forward(
&self,
q: &Array3<f32>,
k: &Array3<f32>,
v: &Array3<f32>,
) -> CoreResult<Array3<f32>> {
let (batch_size, _seq_len_q, d_model) = q.dim();
let (_, _seq_len_kv, _) = k.dim();
if d_model != self.config.num_heads * self.config.head_dim {
return Err(CoreError::DimensionMismatch {
expected: self.config.num_heads * self.config.head_dim,
got: d_model,
});
}
let q_reshaped = self.reshape_qkv(q)?;
let k_reshaped = self.reshape_qkv(k)?;
let v_reshaped = self.reshape_qkv(v)?;
let output = if batch_size == 1 {
self.flash_attention_forward(&q_reshaped, &k_reshaped, &v_reshaped)?
} else {
self.flash_attention_forward(&q_reshaped, &k_reshaped, &v_reshaped)?
};
Ok(output)
}
fn flash_attention_forward(
&self,
q: &Array3<f32>,
k: &Array3<f32>,
v: &Array3<f32>,
) -> CoreResult<Array3<f32>> {
let (seq_len_q, num_heads, head_dim) = q.dim();
let (seq_len_kv, _, _) = k.dim();
let tile_q = self.config.tile_q.min(seq_len_q);
let tile_kv = self.config.tile_kv.min(seq_len_kv);
let mut output = Array3::zeros((seq_len_q, num_heads, head_dim));
let mut row_max = Array2::<f32>::from_elem((seq_len_q, num_heads), f32::NEG_INFINITY);
let mut row_sum = Array2::<f32>::zeros((seq_len_q, num_heads));
let num_tiles_kv = seq_len_kv.div_ceil(tile_kv);
for kv_tile_idx in 0..num_tiles_kv {
let kv_start = kv_tile_idx * tile_kv;
let kv_end = (kv_start + tile_kv).min(seq_len_kv);
let kv_tile_size = kv_end - kv_start;
let k_tile = k.slice(s![kv_start..kv_end, .., ..]);
let v_tile = v.slice(s![kv_start..kv_end, .., ..]);
let num_tiles_q = seq_len_q.div_ceil(tile_q);
for q_tile_idx in 0..num_tiles_q {
let q_start = q_tile_idx * tile_q;
let q_end = (q_start + tile_q).min(seq_len_q);
let q_tile_size = q_end - q_start;
let q_tile = q.slice(s![q_start..q_end, .., ..]);
let mut scores = Array3::zeros((q_tile_size, num_heads, kv_tile_size));
for h in 0..num_heads {
for i in 0..q_tile_size {
for j in 0..kv_tile_size {
let mut score = 0.0f32;
for d in 0..head_dim {
score += q_tile[[i, h, d]] * k_tile[[j, h, d]];
}
score *= self.config.scale;
if self.config.causal {
let q_pos = q_start + i;
let kv_pos = kv_start + j;
if kv_pos > q_pos {
score = f32::NEG_INFINITY;
}
}
scores[[i, h, j]] = score;
}
}
}
for i in 0..q_tile_size {
let global_i = q_start + i;
for h in 0..num_heads {
let mut tile_max = f32::NEG_INFINITY;
for j in 0..kv_tile_size {
tile_max = tile_max.max(scores[[i, h, j]]);
}
let old_max = row_max[[global_i, h]];
let new_max = old_max.max(tile_max);
let mut tile_sum = 0.0f32;
for j in 0..kv_tile_size {
let exp_val = (scores[[i, h, j]] - new_max).exp();
scores[[i, h, j]] = exp_val;
tile_sum += exp_val;
}
let scale_factor = (old_max - new_max).exp();
for d in 0..head_dim {
output[[global_i, h, d]] *= scale_factor;
}
let old_sum = row_sum[[global_i, h]] * scale_factor;
for j in 0..kv_tile_size {
let attn_weight = scores[[i, h, j]];
for d in 0..head_dim {
output[[global_i, h, d]] += attn_weight * v_tile[[j, h, d]];
}
}
row_max[[global_i, h]] = new_max;
row_sum[[global_i, h]] = old_sum + tile_sum;
}
}
}
}
for i in 0..seq_len_q {
for h in 0..num_heads {
let sum = row_sum[[i, h]];
if sum > 1e-8 {
for d in 0..head_dim {
output[[i, h, d]] /= sum;
}
}
}
}
Ok(output)
}
fn reshape_qkv(&self, x: &Array3<f32>) -> CoreResult<Array3<f32>> {
let (_batch, _seq, d_model) = x.dim();
let num_heads = self.config.num_heads;
let head_dim = self.config.head_dim;
if d_model != num_heads * head_dim {
return Err(CoreError::DimensionMismatch {
expected: num_heads * head_dim,
got: d_model,
});
}
Ok(x.clone())
}
pub fn config(&self) -> &FlashAttentionConfig {
&self.config
}
}
pub fn flash_attention_fused(
q: &Array2<f32>, k: &Array2<f32>, v: &Array2<f32>, num_heads: usize,
head_dim: usize,
causal: bool,
) -> CoreResult<Array2<f32>> {
let config = FlashAttentionConfig::new(num_heads, head_dim).with_causal(causal);
let flash_attn = FlashAttention::new(config)?;
let q_3d = q.clone().into_shape_with_order((1, q.nrows(), q.ncols()))?;
let k_3d = k.clone().into_shape_with_order((1, k.nrows(), k.ncols()))?;
let v_3d = v.clone().into_shape_with_order((1, v.nrows(), v.ncols()))?;
let output_3d = flash_attn.forward(&q_3d, &k_3d, &v_3d)?;
let output = output_3d.into_shape_with_order((q.nrows(), q.ncols()))?;
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flash_attention_config() {
let config = FlashAttentionConfig::new(8, 64);
assert_eq!(config.num_heads, 8);
assert_eq!(config.head_dim, 64);
assert!(config.validate().is_ok());
}
#[test]
fn test_flash_attention_config_validation() {
let mut config = FlashAttentionConfig::new(8, 64);
config.num_heads = 0;
assert!(config.validate().is_err());
let mut config = FlashAttentionConfig::new(8, 64);
config.dropout = 1.5;
assert!(config.validate().is_err());
}
#[test]
fn test_flash_attention_creation() {
let config = FlashAttentionConfig::new(4, 32);
let flash_attn = FlashAttention::new(config);
assert!(flash_attn.is_ok());
}
#[test]
fn test_flash_attention_forward_small() {
let config = FlashAttentionConfig::new(2, 4);
let flash_attn = FlashAttention::new(config).unwrap();
let batch_size = 1;
let seq_len = 4;
let d_model = 8;
let q = Array3::from_elem((batch_size, seq_len, d_model), 0.1);
let k = Array3::from_elem((batch_size, seq_len, d_model), 0.1);
let v = Array3::from_elem((batch_size, seq_len, d_model), 1.0);
let result = flash_attn.forward(&q, &k, &v);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.dim(), (batch_size, seq_len, d_model));
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_flash_attention_causal_mask() {
let config = FlashAttentionConfig::new(1, 4).with_causal(true);
let flash_attn = FlashAttention::new(config).unwrap();
let batch_size = 1;
let seq_len = 4;
let d_model = 4;
let q = Array3::from_elem((batch_size, seq_len, d_model), 1.0);
let k = Array3::from_elem((batch_size, seq_len, d_model), 1.0);
let v = Array3::from_elem((batch_size, seq_len, d_model), 1.0);
let result = flash_attn.forward(&q, &k, &v);
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_flash_attention_tiling() {
let config = FlashAttentionConfig::new(2, 8).with_tile_sizes(2, 2); let flash_attn = FlashAttention::new(config).unwrap();
let batch_size = 1;
let seq_len = 8;
let d_model = 16;
let q = Array3::from_elem((batch_size, seq_len, d_model), 0.5);
let k = Array3::from_elem((batch_size, seq_len, d_model), 0.5);
let v = Array3::from_elem((batch_size, seq_len, d_model), 1.0);
let result = flash_attn.forward(&q, &k, &v);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.dim(), (batch_size, seq_len, d_model));
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_flash_attention_fused() {
let seq_len = 8;
let d_model = 16;
let num_heads = 4;
let head_dim = 4;
let q = Array2::from_elem((seq_len, d_model), 0.5);
let k = Array2::from_elem((seq_len, d_model), 0.5);
let v = Array2::from_elem((seq_len, d_model), 1.0);
let result = flash_attention_fused(&q, &k, &v, num_heads, head_dim, false);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.dim(), (seq_len, d_model));
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_flash_attention_batch() {
let config = FlashAttentionConfig::new(4, 8);
let flash_attn = FlashAttention::new(config).unwrap();
let batch_size = 3;
let seq_len = 16;
let d_model = 32;
let q = Array3::from_elem((batch_size, seq_len, d_model), 0.3);
let k = Array3::from_elem((batch_size, seq_len, d_model), 0.3);
let v = Array3::from_elem((batch_size, seq_len, d_model), 0.7);
let result = flash_attn.forward(&q, &k, &v);
assert!(result.is_ok());
let output = result.unwrap();
assert_eq!(output.dim(), (batch_size, seq_len, d_model));
assert!(output.iter().all(|&x| x.is_finite()));
}
#[test]
fn test_flash_attention_numerical_stability() {
let config = FlashAttentionConfig::new(2, 4);
let flash_attn = FlashAttention::new(config).unwrap();
let batch_size = 1;
let seq_len = 4;
let d_model = 8;
let q = Array3::from_elem((batch_size, seq_len, d_model), 10.0);
let k = Array3::from_elem((batch_size, seq_len, d_model), 10.0);
let v = Array3::from_elem((batch_size, seq_len, d_model), 1.0);
let result = flash_attn.forward(&q, &k, &v);
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.iter().all(|&x| x.is_finite()));
assert!(output.iter().any(|&x| x.abs() > 1e-6));
}
}