use crate::error::{Result, TrustformerError};
use tensorlogic_ir::{EinsumGraph, EinsumNode};
#[derive(Debug, Clone)]
pub struct FlashAttentionConfig {
pub d_model: usize,
pub n_heads: usize,
pub d_k: usize,
pub block_size_q: usize,
pub block_size_kv: usize,
pub causal: bool,
pub dropout: f64,
pub max_seq_len: usize,
}
impl FlashAttentionConfig {
pub fn new(d_model: usize, n_heads: usize) -> Result<Self> {
if !d_model.is_multiple_of(n_heads) {
return Err(TrustformerError::InvalidHeadCount { d_model, n_heads });
}
let d_k = d_model / n_heads;
let block_size_q = 128;
let block_size_kv = 128;
Ok(Self {
d_model,
n_heads,
d_k,
block_size_q,
block_size_kv,
causal: false,
dropout: 0.0,
max_seq_len: 4096,
})
}
pub fn with_block_sizes(mut self, block_q: usize, block_kv: usize) -> Self {
self.block_size_q = block_q;
self.block_size_kv = block_kv;
self
}
pub fn with_causal(mut self, causal: bool) -> Self {
self.causal = causal;
self
}
pub fn with_dropout(mut self, dropout: f64) -> Self {
self.dropout = dropout;
self
}
pub fn with_max_seq_len(mut self, max_seq_len: usize) -> Self {
self.max_seq_len = max_seq_len;
self
}
pub fn validate(&self) -> Result<()> {
if self.d_model == 0 {
return Err(TrustformerError::MissingParameter(
"d_model must be positive".to_string(),
));
}
if self.n_heads == 0 {
return Err(TrustformerError::MissingParameter(
"n_heads must be positive".to_string(),
));
}
if self.block_size_q == 0 || self.block_size_kv == 0 {
return Err(TrustformerError::MissingParameter(
"block sizes must be positive".to_string(),
));
}
if self.dropout < 0.0 || self.dropout > 1.0 {
return Err(TrustformerError::CompilationError(
"dropout must be between 0 and 1".to_string(),
));
}
Ok(())
}
pub fn sram_usage_per_block(&self) -> usize {
let q_block = self.block_size_q * self.d_k;
let k_block = self.block_size_kv * self.d_k;
let v_block = self.block_size_kv * self.d_k;
let s_block = self.block_size_q * self.block_size_kv;
let o_block = self.block_size_q * self.d_k;
q_block + k_block + v_block + s_block + o_block
}
pub fn memory_savings(&self, seq_len: usize) -> f64 {
let standard_memory = seq_len * seq_len + seq_len * self.d_k;
let flash_memory = seq_len * self.d_k + self.sram_usage_per_block();
1.0 - (flash_memory as f64 / standard_memory as f64)
}
pub fn num_kv_passes(&self, seq_len: usize) -> usize {
seq_len.div_ceil(self.block_size_kv)
}
pub fn num_q_blocks(&self, seq_len: usize) -> usize {
seq_len.div_ceil(self.block_size_q)
}
}
#[derive(Debug, Clone)]
pub struct FlashAttention {
config: FlashAttentionConfig,
}
impl FlashAttention {
pub fn new(config: FlashAttentionConfig) -> Result<Self> {
config.validate()?;
Ok(Self { config })
}
pub fn config(&self) -> &FlashAttentionConfig {
&self.config
}
pub fn build_flash_graph(&self, graph: &mut EinsumGraph) -> Result<Vec<usize>> {
let _n_heads = self.config.n_heads;
let d_k = self.config.d_k;
let q_split = graph.add_tensor("flash_q_split");
let k_split = graph.add_tensor("flash_k_split");
let v_split = graph.add_tensor("flash_v_split");
let reshape_spec = format!("bsd->bsh{}", d_k);
let q_reshape = EinsumNode::new(&reshape_spec, vec![0], vec![q_split]);
graph.add_node(q_reshape)?;
let k_reshape = EinsumNode::new(&reshape_spec, vec![1], vec![k_split]);
graph.add_node(k_reshape)?;
let v_reshape = EinsumNode::new(&reshape_spec, vec![2], vec![v_split]);
graph.add_node(v_reshape)?;
let q_transposed = graph.add_tensor("flash_q_transposed");
let k_transposed = graph.add_tensor("flash_k_transposed");
let v_transposed = graph.add_tensor("flash_v_transposed");
graph.add_node(EinsumNode::new(
"bshd->bhsd",
vec![q_split],
vec![q_transposed],
))?;
graph.add_node(EinsumNode::new(
"bshd->bhsd",
vec![k_split],
vec![k_transposed],
))?;
graph.add_node(EinsumNode::new(
"bshd->bhsd",
vec![v_split],
vec![v_transposed],
))?;
let flash_output = graph.add_tensor("flash_attn_output");
let flash_spec = format!(
"flash_bhqd,bhkd,bhkv->bhqv_{}_{}",
self.config.block_size_q, self.config.block_size_kv
);
let flash_node = EinsumNode::new(
&flash_spec,
vec![q_transposed, k_transposed, v_transposed],
vec![flash_output],
);
graph.add_node(flash_node)?;
let transposed_back = graph.add_tensor("flash_transposed_back");
graph.add_node(EinsumNode::new(
"bhsd->bshd",
vec![flash_output],
vec![transposed_back],
))?;
let output = graph.add_tensor("flash_output");
let reshape_back_spec = format!("bsh{}-:bsd", d_k);
graph.add_node(EinsumNode::new(
&reshape_back_spec,
vec![transposed_back],
vec![output],
))?;
Ok(vec![output])
}
}
#[derive(Debug, Clone)]
pub struct FlashAttentionV2Config {
pub base: FlashAttentionConfig,
pub sequence_parallel: bool,
pub window_size: Option<usize>,
}
impl FlashAttentionV2Config {
pub fn new(d_model: usize, n_heads: usize) -> Result<Self> {
Ok(Self {
base: FlashAttentionConfig::new(d_model, n_heads)?,
sequence_parallel: false,
window_size: None,
})
}
pub fn with_sequence_parallel(mut self, enabled: bool) -> Self {
self.sequence_parallel = enabled;
self
}
pub fn with_window(mut self, window_size: usize) -> Self {
self.window_size = Some(window_size);
self
}
pub fn with_causal(mut self, causal: bool) -> Self {
self.base = self.base.with_causal(causal);
self
}
}
#[derive(Debug, Clone)]
pub struct FlashAttentionStats {
pub config: FlashAttentionConfig,
pub memory_savings: f64,
pub sram_usage: usize,
pub num_kv_passes: usize,
pub num_q_blocks: usize,
}
impl FlashAttentionStats {
pub fn from_config(config: &FlashAttentionConfig, seq_len: usize) -> Self {
Self {
config: config.clone(),
memory_savings: config.memory_savings(seq_len),
sram_usage: config.sram_usage_per_block(),
num_kv_passes: config.num_kv_passes(seq_len),
num_q_blocks: config.num_q_blocks(seq_len),
}
}
pub fn summary(&self, seq_len: usize) -> String {
format!(
"Flash Attention Statistics\n d_model: {}\n n_heads: {}\n \
block_size_q: {}\n block_size_kv: {}\n causal: {}\n \
memory savings: {:.1}%\n SRAM usage: {} elements\n \
num_kv_passes: {}\n num_q_blocks: {}\n seq_len: {}",
self.config.d_model,
self.config.n_heads,
self.config.block_size_q,
self.config.block_size_kv,
self.config.causal,
self.memory_savings * 100.0,
self.sram_usage,
self.num_kv_passes,
self.num_q_blocks,
seq_len
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum FlashAttentionPreset {
Standard,
LargeBlocks,
SmallBlocks,
A100Optimized,
H100Optimized,
}
impl FlashAttentionPreset {
pub fn block_sizes(&self) -> (usize, usize) {
match self {
FlashAttentionPreset::Standard => (128, 128),
FlashAttentionPreset::LargeBlocks => (256, 256),
FlashAttentionPreset::SmallBlocks => (64, 64),
FlashAttentionPreset::A100Optimized => (128, 64),
FlashAttentionPreset::H100Optimized => (128, 128),
}
}
pub fn name(&self) -> &'static str {
match self {
FlashAttentionPreset::Standard => "Standard",
FlashAttentionPreset::LargeBlocks => "Large Blocks",
FlashAttentionPreset::SmallBlocks => "Small Blocks",
FlashAttentionPreset::A100Optimized => "A100 Optimized",
FlashAttentionPreset::H100Optimized => "H100 Optimized",
}
}
pub fn config(&self, d_model: usize, n_heads: usize) -> Result<FlashAttentionConfig> {
let (block_q, block_kv) = self.block_sizes();
FlashAttentionConfig::new(d_model, n_heads).map(|c| c.with_block_sizes(block_q, block_kv))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_flash_config_creation() {
let config = FlashAttentionConfig::new(512, 8).expect("unwrap");
assert_eq!(config.d_model, 512);
assert_eq!(config.n_heads, 8);
assert_eq!(config.d_k, 64);
assert_eq!(config.block_size_q, 128);
assert_eq!(config.block_size_kv, 128);
}
#[test]
fn test_flash_config_builder() {
let config = FlashAttentionConfig::new(512, 8)
.expect("unwrap")
.with_block_sizes(64, 64)
.with_causal(true)
.with_dropout(0.1)
.with_max_seq_len(8192);
assert_eq!(config.block_size_q, 64);
assert_eq!(config.block_size_kv, 64);
assert!(config.causal);
assert!((config.dropout - 0.1).abs() < 1e-10);
assert_eq!(config.max_seq_len, 8192);
}
#[test]
fn test_flash_invalid_config() {
assert!(FlashAttentionConfig::new(512, 7).is_err());
}
#[test]
fn test_flash_sram_usage() {
let config = FlashAttentionConfig::new(512, 8).expect("unwrap");
let sram = config.sram_usage_per_block();
assert_eq!(sram, 49152);
}
#[test]
fn test_flash_memory_savings() {
let config = FlashAttentionConfig::new(512, 8).expect("unwrap");
let savings = config.memory_savings(4096);
assert!(savings > 0.9); }
#[test]
fn test_flash_num_passes() {
let config = FlashAttentionConfig::new(512, 8).expect("unwrap");
assert_eq!(config.num_kv_passes(4096), 32);
assert_eq!(config.num_q_blocks(4096), 32);
assert_eq!(config.num_kv_passes(1000), 8);
}
#[test]
fn test_flash_graph_building() {
let config = FlashAttentionConfig::new(512, 8).expect("unwrap");
let flash = FlashAttention::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("Q");
graph.add_tensor("K");
graph.add_tensor("V");
let outputs = flash.build_flash_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
}
#[test]
fn test_flash_causal_graph() {
let config = FlashAttentionConfig::new(512, 8)
.expect("unwrap")
.with_causal(true);
let flash = FlashAttention::new(config).expect("unwrap");
let mut graph = EinsumGraph::new();
graph.add_tensor("Q");
graph.add_tensor("K");
graph.add_tensor("V");
let outputs = flash.build_flash_graph(&mut graph).expect("unwrap");
assert_eq!(outputs.len(), 1);
}
#[test]
fn test_flash_v2_config() {
let config = FlashAttentionV2Config::new(512, 8)
.expect("unwrap")
.with_sequence_parallel(true)
.with_window(4096)
.with_causal(true);
assert!(config.sequence_parallel);
assert_eq!(config.window_size, Some(4096));
assert!(config.base.causal);
}
#[test]
fn test_flash_presets() {
let (q, kv) = FlashAttentionPreset::Standard.block_sizes();
assert_eq!(q, 128);
assert_eq!(kv, 128);
let (q, kv) = FlashAttentionPreset::LargeBlocks.block_sizes();
assert_eq!(q, 256);
assert_eq!(kv, 256);
let (q, kv) = FlashAttentionPreset::A100Optimized.block_sizes();
assert_eq!(q, 128);
assert_eq!(kv, 64);
}
#[test]
fn test_flash_preset_config() {
let config = FlashAttentionPreset::Standard
.config(512, 8)
.expect("unwrap");
assert_eq!(config.block_size_q, 128);
assert_eq!(config.block_size_kv, 128);
}
#[test]
fn test_flash_preset_names() {
assert_eq!(FlashAttentionPreset::Standard.name(), "Standard");
assert_eq!(FlashAttentionPreset::A100Optimized.name(), "A100 Optimized");
}
#[test]
fn test_flash_stats() {
let config = FlashAttentionConfig::new(512, 8).expect("unwrap");
let stats = FlashAttentionStats::from_config(&config, 4096);
assert!(stats.memory_savings > 0.9);
assert_eq!(stats.sram_usage, 49152);
assert_eq!(stats.num_kv_passes, 32);
assert_eq!(stats.num_q_blocks, 32);
}
#[test]
fn test_flash_validate() {
let config = FlashAttentionConfig::new(512, 8).expect("unwrap");
assert!(config.validate().is_ok());
let mut bad = config.clone();
bad.dropout = 1.5;
assert!(bad.validate().is_err());
let mut bad = config.clone();
bad.block_size_q = 0;
assert!(bad.validate().is_err());
}
}