#[derive(Debug, Clone, Copy)]
pub struct FlashAttentionConfig {
pub tile_size: u32,
pub head_dim: u32,
pub block_size: u32,
pub use_vectorized_loads: bool,
pub causal_mask: bool,
}
impl Default for FlashAttentionConfig {
fn default() -> Self {
Self {
tile_size: 128,
head_dim: 64,
block_size: 256,
use_vectorized_loads: true,
causal_mask: false,
}
}
}
impl FlashAttentionConfig {
#[must_use]
pub fn for_rtx_5080() -> Self {
Self {
tile_size: 256,
head_dim: 64,
block_size: 256,
use_vectorized_loads: true,
causal_mask: false,
}
}
#[must_use]
pub fn for_rtx_3090_ti() -> Self {
Self {
tile_size: 128,
head_dim: 64,
block_size: 256,
use_vectorized_loads: true,
causal_mask: false,
}
}
#[must_use]
pub fn for_datacenter() -> Self {
Self {
tile_size: 256,
head_dim: 128,
block_size: 256,
use_vectorized_loads: true,
causal_mask: false,
}
}
#[must_use]
pub const fn with_causal_mask(mut self) -> Self {
self.causal_mask = true;
self
}
#[must_use]
pub fn with_tile_size(mut self, tile_size: u32) -> Self {
assert!(tile_size.is_power_of_two(), "tile_size must be power of 2");
self.tile_size = tile_size;
self
}
#[must_use]
pub const fn with_head_dim(mut self, head_dim: u32) -> Self {
self.head_dim = head_dim;
self
}
#[must_use]
pub const fn shared_memory_bytes(&self, bytes_per_elem: usize) -> usize {
let tile = self.tile_size as usize;
let dim = self.head_dim as usize;
let qkv_tiles = 3 * tile * dim;
let scores = tile * tile;
let stats = 2 * tile;
(qkv_tiles + scores + stats) * bytes_per_elem
}
#[must_use]
pub const fn num_q_tiles(&self, seq_len: u32) -> u32 {
seq_len.div_ceil(self.tile_size)
}
#[must_use]
pub const fn num_kv_tiles(&self, seq_len: u32) -> u32 {
seq_len.div_ceil(self.tile_size)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = FlashAttentionConfig::default();
assert_eq!(config.tile_size, 128);
assert_eq!(config.block_size, 256);
assert!(config.use_vectorized_loads);
assert!(!config.causal_mask);
}
#[test]
fn test_shared_memory_calculation() {
let config = FlashAttentionConfig::default();
let bytes = config.shared_memory_bytes(4);
assert_eq!(bytes, 164_864);
}
#[test]
fn test_num_tiles() {
let config = FlashAttentionConfig::default();
assert_eq!(config.num_q_tiles(512), 4);
assert_eq!(config.num_q_tiles(1024), 8);
assert_eq!(config.num_q_tiles(1025), 9); }
#[test]
fn test_builder_pattern() {
let config = FlashAttentionConfig::default()
.with_tile_size(256)
.with_head_dim(128)
.with_causal_mask();
assert_eq!(config.tile_size, 256);
assert_eq!(config.head_dim, 128);
assert!(config.causal_mask);
}
#[test]
#[should_panic(expected = "tile_size must be power of 2")]
fn test_invalid_tile_size() {
let _ = FlashAttentionConfig::default().with_tile_size(100);
}
}