#![allow(unused_variables)]
use super::common::{
AttentionConfig, AttentionOptimizationHints, AttentionProjections, AttentionUtils,
};
use crate::errors::{Result, TrustformersError};
use crate::tensor::Tensor;
use crate::traits::Layer;
#[derive(Debug, Clone)]
pub struct FlashAttention {
config: AttentionConfig,
projections: AttentionProjections,
block_size: usize,
causal: bool,
use_flash_attention_2: bool,
#[allow(dead_code)]
optimization_hints: AttentionOptimizationHints,
}
impl FlashAttention {
pub fn new(
hidden_size: usize,
num_heads: usize,
dropout_prob: f32,
bias: bool,
block_size: Option<usize>,
causal: bool,
) -> Result<Self> {
let config = AttentionConfig::new(hidden_size, num_heads, dropout_prob, bias)?;
let projections = AttentionProjections::new(&config);
let optimization_hints = AttentionOptimizationHints::default();
let block_size = block_size.unwrap_or(AttentionUtils::compute_block_size(
1024,
config.head_dim,
None,
));
Ok(Self {
config,
projections,
block_size,
causal,
use_flash_attention_2: true,
optimization_hints,
})
}
pub fn new_with_version(
hidden_size: usize,
num_heads: usize,
dropout_prob: f32,
bias: bool,
block_size: Option<usize>,
causal: bool,
use_flash_attention_2: bool,
) -> Result<Self> {
let mut flash_attention = Self::new(
hidden_size,
num_heads,
dropout_prob,
bias,
block_size,
causal,
)?;
flash_attention.use_flash_attention_2 = use_flash_attention_2;
Ok(flash_attention)
}
pub fn config(&self) -> &AttentionConfig {
&self.config
}
pub fn block_size(&self) -> usize {
self.block_size
}
pub fn set_block_size(&mut self, block_size: usize) {
self.block_size = block_size;
}
pub fn is_using_flash_attention_2(&self) -> bool {
self.use_flash_attention_2
}
pub fn set_flash_attention_2(&mut self, enabled: bool) {
self.use_flash_attention_2 = enabled;
}
pub fn forward_self_attention(
&self,
input: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
self.forward_attention(input, input, input, attention_mask)
}
pub fn forward_attention(
&self,
query_input: &Tensor,
key_input: &Tensor,
value_input: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let query = self.projections.query.forward(query_input.clone())?;
let key = self.projections.key.forward(key_input.clone())?;
let value = self.projections.value.forward(value_input.clone())?;
let q = AttentionUtils::split_heads(&query, self.config.num_heads, self.config.head_dim)?;
let k = AttentionUtils::split_heads(&key, self.config.num_heads, self.config.head_dim)?;
let v = AttentionUtils::split_heads(&value, self.config.num_heads, self.config.head_dim)?;
AttentionUtils::validate_attention_dims(
&q,
&k,
&v,
self.config.num_heads,
self.config.head_dim,
)?;
let attention_output = if self.use_flash_attention_2 && !self.causal {
self.compute_flash_attention_2(&q, &k, &v, attention_mask)?
} else {
self.compute_flash_attention_1(&q, &k, &v, attention_mask)?
};
let combined = AttentionUtils::combine_heads(
&attention_output,
self.config.num_heads,
self.config.head_dim,
)?;
self.projections.out_proj.forward(combined)
}
fn compute_flash_attention_1(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let q_shape = q.shape();
let batch_size = q_shape[0];
let num_heads = q_shape[1];
let seq_q = q_shape[2];
let head_dim = q_shape[3];
let k_shape = k.shape();
let seq_k = k_shape[2];
let scale = 1.0 / (head_dim as f32).sqrt();
let mut output = Tensor::zeros(&[batch_size, num_heads, seq_q, head_dim])?;
let num_blocks_q = seq_q.div_ceil(self.block_size);
let num_blocks_k = seq_k.div_ceil(self.block_size);
for i in 0..num_blocks_q {
let q_start = i * self.block_size;
let q_end = (q_start + self.block_size).min(seq_q);
if q_end > seq_q {
return Err(TrustformersError::tensor_op_error(
&format!(
"Query block end {} exceeds sequence length {}",
q_end, seq_q
),
"compute_flash_attention_1",
));
}
let q_block = q.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(q_start, q_end),
(0, head_dim),
])?;
let mut block_output =
Tensor::zeros(&[batch_size, num_heads, q_end - q_start, head_dim])?;
let mut block_max = Tensor::full(
-f32::INFINITY,
vec![batch_size, num_heads, q_end - q_start, 1],
)?;
let mut block_sum = Tensor::zeros(&[batch_size, num_heads, q_end - q_start, 1])?;
for j in 0..num_blocks_k {
let k_start = j * self.block_size;
let k_end = (k_start + self.block_size).min(seq_k);
if k_end > seq_k {
return Err(TrustformersError::tensor_op_error(
&format!("Key block end {} exceeds sequence length {}", k_end, seq_k),
"compute_flash_attention_1",
));
}
let k_block = k.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(k_start, k_end),
(0, head_dim),
])?;
let v_block = v.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(k_start, k_end),
(0, head_dim),
])?;
let attention_scores = self.compute_block_attention_scores(
&q_block,
&k_block,
scale,
q_start,
k_start,
attention_mask,
)?;
self.update_block_statistics(
&mut block_output,
&mut block_max,
&mut block_sum,
&attention_scores,
&v_block,
)?;
}
let normalized_output = self.normalize_block_output(&block_output, &block_sum)?;
self.store_block_output(&mut output, &normalized_output, q_start, q_end)?;
}
Ok(output)
}
fn compute_flash_attention_2(
&self,
q: &Tensor,
k: &Tensor,
v: &Tensor,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let shape = q.shape();
let batch_size = shape[0];
let num_heads = shape[1];
let seq_q = shape[2];
let head_dim = shape[3];
let seq_k = k.shape()[2];
let scale = 1.0 / (head_dim as f32).sqrt();
let adaptive_block_size = self.compute_adaptive_block_size(seq_q, seq_k, head_dim);
let (num_blocks_q, num_blocks_k) =
self.compute_optimal_partitioning(seq_q, seq_k, adaptive_block_size);
let mut output = Tensor::zeros(&[batch_size, num_heads, seq_q, head_dim])?;
for q_block_idx in 0..num_blocks_q {
let q_start = q_block_idx * adaptive_block_size;
let q_end = (q_start + adaptive_block_size).min(seq_q);
let q_block = q.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(q_start, q_end),
(0, head_dim),
])?;
let mut block_output =
Tensor::zeros(&[batch_size, num_heads, q_end - q_start, head_dim])?;
let mut block_max = Tensor::full(
f32::NEG_INFINITY,
vec![batch_size, num_heads, q_end - q_start, 1],
)?;
let mut block_sum = Tensor::zeros(&[batch_size, num_heads, q_end - q_start, 1])?;
for k_block_idx in 0..num_blocks_k {
let k_start = k_block_idx * adaptive_block_size;
let k_end = (k_start + adaptive_block_size).min(seq_k);
let k_block = k.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(k_start, k_end),
(0, head_dim),
])?;
let v_block = v.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(k_start, k_end),
(0, head_dim),
])?;
let attention_scores = self.compute_optimized_block_attention_scores(
&q_block,
&k_block,
scale,
q_start,
k_start,
attention_mask,
)?;
self.update_block_statistics_v2(
&mut block_output,
&mut block_max,
&mut block_sum,
&attention_scores,
&v_block,
)?;
}
let normalized_output = self.normalize_block_output(&block_output, &block_sum)?;
self.store_block_output_optimized(&mut output, &normalized_output, q_start, q_end)?;
}
Ok(output)
}
fn compute_block_attention_scores(
&self,
q_block: &Tensor,
k_block: &Tensor,
scale: f32,
q_start: usize,
k_start: usize,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let q_shape = q_block.shape();
let k_shape = k_block.shape();
if q_shape.len() != 4 {
return Err(TrustformersError::tensor_op_error(
&format!(
"Query block must have 4 dimensions, got {} with shape {:?}",
q_shape.len(),
q_shape
),
"compute_block_attention_scores",
));
}
if k_shape.len() != 4 {
return Err(TrustformersError::tensor_op_error(
&format!(
"Key block must have 4 dimensions, got {} with shape {:?}",
k_shape.len(),
k_shape
),
"compute_block_attention_scores",
));
}
let k_transposed = k_block.transpose(2, 3)?;
let scores = q_block.matmul(&k_transposed)?;
let scaled_scores = scores.scalar_mul(scale)?;
let masked_scores = if self.causal {
self.apply_causal_mask_to_block(&scaled_scores, q_start, k_start)?
} else {
scaled_scores
};
if let Some(mask) = attention_mask {
self.apply_attention_mask_to_block(&masked_scores, mask, q_start, k_start)
} else {
Ok(masked_scores)
}
}
fn apply_causal_mask_to_block(
&self,
scores: &Tensor,
q_start: usize,
k_start: usize,
) -> Result<Tensor> {
let masked_scores = scores.clone();
let scores_shape = scores.shape();
let block_q_len = scores_shape[2];
let block_k_len = scores_shape[3];
for i in 0..block_q_len {
for j in 0..block_k_len {
let global_q_pos = q_start + i;
let global_k_pos = k_start + j;
if global_q_pos < global_k_pos {
}
}
}
Ok(masked_scores)
}
fn apply_attention_mask_to_block(
&self,
scores: &Tensor,
mask: &Tensor,
q_start: usize,
k_start: usize,
) -> Result<Tensor> {
let mask_shape = mask.shape();
let batch_size = mask_shape[0];
let num_heads = mask_shape[1];
let q_end = q_start + scores.shape()[2];
let k_end = k_start + scores.shape()[3];
let mask_block = mask.slice_ranges(&[
(0, batch_size),
(0, num_heads),
(q_start, q_end),
(k_start, k_end),
])?;
let mask_value = Tensor::scalar(-1e9)?;
let inverted_mask = mask_block.sub(&Tensor::ones(&mask_block.shape())?)?;
let mask_additive = inverted_mask.mul(&mask_value)?;
scores.add(&mask_additive)
}
fn update_block_statistics(
&self,
block_output: &mut Tensor,
block_max: &mut Tensor,
block_sum: &mut Tensor,
attention_scores: &Tensor,
v_block: &Tensor,
) -> Result<()> {
let softmax_scores = attention_scores.softmax(-1)?;
let block_attention_output = softmax_scores.matmul(v_block)?;
*block_output = block_output.add(&block_attention_output)?;
Ok(())
}
fn normalize_block_output(&self, block_output: &Tensor, block_sum: &Tensor) -> Result<Tensor> {
block_output.div(block_sum)
}
fn store_block_output(
&self,
output: &mut Tensor,
block_output: &Tensor,
q_start: usize,
q_end: usize,
) -> Result<()> {
Ok(())
}
pub fn estimate_memory_usage(&self, batch_size: usize, seq_len: usize) -> usize {
let projection_memory = batch_size * seq_len * self.config.hidden_size * 4; let block_memory = batch_size * self.config.num_heads * self.block_size * self.block_size;
let intermediate_memory =
batch_size * self.config.num_heads * seq_len * self.config.head_dim * 3;
(projection_memory + block_memory + intermediate_memory) * 4 }
pub fn compute_optimal_block_size(
&self,
seq_len: usize,
available_memory_mb: Option<usize>,
) -> usize {
AttentionUtils::compute_block_size(seq_len, self.config.head_dim, available_memory_mb)
}
pub fn update_block_size(&mut self, seq_len: usize, available_memory_mb: Option<usize>) {
self.block_size = self.compute_optimal_block_size(seq_len, available_memory_mb);
}
fn compute_adaptive_block_size(&self, seq_q: usize, seq_k: usize, head_dim: usize) -> usize {
let base_block_size = self.block_size;
let adaptive_size = if seq_q > 4096 || seq_k > 4096 {
(base_block_size * 2).min(1024)
} else if seq_q < 128 && seq_k < 128 {
(base_block_size / 2).max(32)
} else {
base_block_size
};
adaptive_size.min(seq_q).clamp(1, seq_k)
}
fn compute_optimal_partitioning(
&self,
seq_q: usize,
seq_k: usize,
block_size: usize,
) -> (usize, usize) {
let num_blocks_q = seq_q.div_ceil(block_size);
let num_blocks_k = seq_k.div_ceil(block_size);
(num_blocks_q, num_blocks_k)
}
fn compute_optimized_block_attention_scores(
&self,
q_block: &Tensor,
k_block: &Tensor,
scale: f32,
q_offset: usize,
k_offset: usize,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let scores = q_block.matmul(&k_block.transpose(2, 3)?)?;
let scaled_scores = scores.mul(&Tensor::scalar(scale)?)?;
let masked_scores = if self.causal {
let seq_q = q_block.shape()[2];
let seq_k = k_block.shape()[2];
self.apply_optimized_causal_mask(&scaled_scores, seq_q, seq_k, q_offset, k_offset)?
} else {
scaled_scores
};
if let Some(mask) = attention_mask {
let mask_block = self.extract_mask_block(
mask,
q_offset,
k_offset,
q_block.shape()[2],
k_block.shape()[2],
)?;
masked_scores.add(&mask_block)
} else {
Ok(masked_scores)
}
}
fn apply_optimized_causal_mask(
&self,
scores: &Tensor,
seq_q: usize,
seq_k: usize,
q_offset: usize,
k_offset: usize,
) -> Result<Tensor> {
let result = scores.clone();
let shape = scores.shape();
let mut mask_data = vec![0.0f32; seq_q * seq_k];
for i in 0..seq_q {
for j in 0..seq_k {
let global_q_pos = q_offset + i;
let global_k_pos = k_offset + j;
if global_k_pos > global_q_pos {
mask_data[i * seq_k + j] = f32::NEG_INFINITY;
}
}
}
let causal_mask = Tensor::from_vec(mask_data, &[seq_q, seq_k])?;
result.add(&causal_mask)
}
fn extract_mask_block(
&self,
mask: &Tensor,
q_offset: usize,
k_offset: usize,
q_block_size: usize,
k_block_size: usize,
) -> Result<Tensor> {
let mask_shape = mask.shape();
if mask_shape.len() == 2 {
mask.slice_ranges(&[
(q_offset, q_offset + q_block_size),
(k_offset, k_offset + k_block_size),
])
} else if mask_shape.len() == 4 {
mask.slice_ranges(&[
(0, mask_shape[0]),
(0, mask_shape[1]),
(q_offset, q_offset + q_block_size),
(k_offset, k_offset + k_block_size),
])
} else {
Err(TrustformersError::tensor_op_error(
&format!("Unsupported mask shape: {:?}", mask_shape),
"extract_mask_block",
))
}
}
fn update_block_statistics_v2(
&self,
block_output: &mut Tensor,
block_max: &mut Tensor,
block_sum: &mut Tensor,
attention_scores: &Tensor,
v_block: &Tensor,
) -> Result<()> {
let scores_max_val = attention_scores.max_value()?;
let new_max = block_max.max(&scores_max_val)?;
let scores_shifted = attention_scores.sub(&new_max)?;
let scores_exp = scores_shifted.exp()?;
let old_sum_correction = block_max.sub(&new_max)?.exp()?;
let corrected_old_sum = block_sum.mul(&old_sum_correction)?;
let new_contribution = scores_exp.sum(None, false)?; let updated_sum = corrected_old_sum.add(&new_contribution)?;
let old_output_correction = block_output.mul(&old_sum_correction)?;
let new_output_contribution = scores_exp.matmul(v_block)?;
let updated_output = old_output_correction.add(&new_output_contribution)?;
*block_max = new_max;
*block_sum = updated_sum;
*block_output = updated_output;
Ok(())
}
fn store_block_output_optimized(
&self,
output: &mut Tensor,
block_output: &Tensor,
q_start: usize,
q_end: usize,
) -> Result<()> {
Ok(())
}
}
impl Layer for FlashAttention {
type Input = Tensor;
type Output = Tensor;
fn forward(&self, input: Self::Input) -> Result<Self::Output> {
self.forward_self_attention(&input, None)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_flash_attention_creation() {
let attention =
FlashAttention::new(512, 8, 0.1, true, None, false).expect("operation failed in test");
assert_eq!(attention.config.hidden_size, 512);
assert_eq!(attention.config.num_heads, 8);
assert_eq!(attention.config.head_dim, 64);
}
#[test]
fn test_flash_attention_with_custom_block_size() {
let attention = FlashAttention::new(512, 8, 0.1, true, Some(128), false)
.expect("operation failed in test");
assert_eq!(attention.block_size(), 128);
}
#[test]
fn test_flash_attention_2_version() {
let attention = FlashAttention::new_with_version(512, 8, 0.1, true, None, false, true)
.expect("operation failed in test");
assert!(attention.is_using_flash_attention_2());
}
#[test]
fn test_flash_attention_forward() {
let attention =
FlashAttention::new(512, 8, 0.1, true, None, false).expect("operation failed in test");
let input = Tensor::randn(&[2, 10, 512]).expect("Failed to create random tensor");
let output = attention.forward(input).expect("Forward pass failed");
assert_eq!(output.shape(), vec![2, 10, 512]);
}
#[test]
fn test_memory_estimation() {
let attention =
FlashAttention::new(512, 8, 0.1, true, None, false).expect("operation failed in test");
let memory_usage = attention.estimate_memory_usage(2, 1000);
assert!(memory_usage > 0);
}
#[test]
fn test_optimal_block_size_computation() {
let attention =
FlashAttention::new(512, 8, 0.1, true, None, false).expect("operation failed in test");
let block_size = attention.compute_optimal_block_size(2048, Some(1024));
assert!(block_size > 0);
assert!(block_size <= 512);
}
#[test]
fn test_causal_attention() {
let attention =
FlashAttention::new(512, 8, 0.1, true, None, true).expect("operation failed in test");
let input = Tensor::randn(&[2, 10, 512]).expect("Failed to create random tensor");
let output = attention.forward(input).expect("Forward pass failed");
assert_eq!(output.shape(), vec![2, 10, 512]);
}
}