use std::sync::Arc;
use oxicuda_blas::GpuFloat;
use oxicuda_driver::Module;
use oxicuda_launch::{Dim3, Kernel, LaunchParams};
use oxicuda_memory::DeviceBuffer;
use oxicuda_ptx::prelude::*;
use crate::error::{DnnError, DnnResult};
use crate::handle::DnnHandle;
use crate::tensor_util::{attn_dims, attn_dims_mut};
use crate::types::{TensorDesc, TensorDescMut};
#[derive(Debug, Clone)]
pub struct FlashAttentionConfig {
pub head_dim: u32,
pub num_heads: u32,
pub seq_len_q: u32,
pub seq_len_kv: u32,
pub causal: bool,
pub sm_scale: f32,
pub block_m: u32,
pub block_n: u32,
pub num_warps: u32,
pub num_stages: u32,
pub precision: PtxType,
pub sm_version: SmVersion,
}
impl FlashAttentionConfig {
#[must_use]
pub fn auto(
head_dim: u32,
seq_len_q: u32,
seq_len_kv: u32,
causal: bool,
sm: SmVersion,
) -> Self {
let (block_m, block_n) = match head_dim {
d if d <= 64 => (128, 128),
d if d <= 128 => (128, 64),
_ => (64, 64),
};
let num_warps = if sm >= SmVersion::Sm90 && block_m >= 128 {
8
} else {
4
};
let num_stages = if sm >= SmVersion::Sm90 { 3 } else { 2 };
Self {
head_dim,
num_heads: 0, seq_len_q,
seq_len_kv,
causal,
sm_scale: 1.0 / (head_dim as f32).sqrt(),
block_m,
block_n,
num_warps,
num_stages,
precision: PtxType::F32,
sm_version: sm,
}
}
#[must_use]
pub fn shared_mem_bytes(&self) -> u32 {
let elem_size = self.precision.size_bytes() as u32;
let q_tile = self.block_m * self.head_dim * elem_size;
let k_tile = self.block_n * self.head_dim * elem_size * self.num_stages;
let v_tile = self.block_n * self.head_dim * elem_size * self.num_stages;
let softmax_scratch = self.block_m * 4;
q_tile + k_tile + v_tile + softmax_scratch
}
#[must_use]
pub fn num_q_tiles(&self) -> u32 {
self.seq_len_q.div_ceil(self.block_m)
}
#[must_use]
pub fn num_kv_tiles(&self) -> u32 {
self.seq_len_kv.div_ceil(self.block_n)
}
pub fn generate_ptx(&self) -> DnnResult<String> {
let kernel_name = format!(
"flash_attn_fwd_d{}_bm{}_bn{}_{}",
self.head_dim,
self.block_m,
self.block_n,
ptx_type_suffix(self.precision)
);
let block_m = self.block_m;
let block_n = self.block_n;
let head_dim = self.head_dim;
let causal = self.causal;
let num_warps = self.num_warps;
let sm = self.sm_version;
let threads_per_block = num_warps * 32;
let q_smem_elems = (block_m * head_dim) as usize;
let k_smem_elems = (block_n * head_dim) as usize;
let v_smem_elems = (block_n * head_dim) as usize;
let ptx = KernelBuilder::new(&kernel_name)
.target(sm)
.param("q_ptr", PtxType::U64)
.param("k_ptr", PtxType::U64)
.param("v_ptr", PtxType::U64)
.param("o_ptr", PtxType::U64)
.param("lse_ptr", PtxType::U64)
.param("seq_len_q", PtxType::U32)
.param("seq_len_kv", PtxType::U32)
.param("head_dim", PtxType::U32)
.param("num_heads", PtxType::U32)
.param("sm_scale", PtxType::F32)
.param("num_kv_tiles", PtxType::U32)
.shared_mem("q_smem", PtxType::F32, q_smem_elems)
.shared_mem("k_smem", PtxType::F32, k_smem_elems)
.shared_mem("v_smem", PtxType::F32, v_smem_elems)
.max_threads_per_block(threads_per_block)
.body(move |b| {
let tid = b.thread_id_x();
let _bid_x = b.block_id_x();
let _seq_q = b.load_param_u32("seq_len_q");
let _seq_kv = b.load_param_u32("seq_len_kv");
let _hdim = b.load_param_u32("head_dim");
let _nheads = b.load_param_u32("num_heads");
let _scale = b.load_param_f32("sm_scale");
let _nkv_tiles = b.load_param_u32("num_kv_tiles");
b.comment("=== FlashAttention-2 Forward Pass ===");
b.comment("");
b.comment("Step 1: Load Q block from global to shared memory");
b.comment(" Each thread loads head_dim / threads_per_block elements");
b.comment(" for block_m rows of Q");
let q_base = b.load_param_u64("q_ptr");
let _ = q_base;
b.comment("");
b.comment("Step 2: Initialise accumulators");
b.comment(" O_acc[block_m][head_dim] = 0.0");
b.comment(" m_i[block_m] = -INFINITY (running row max)");
b.comment(" l_i[block_m] = 0.0 (running row sum)");
b.comment("");
b.comment("Step 3: Loop over KV tiles");
b.comment(" for j in 0..num_kv_tiles:");
b.comment(" 3a. Load K_j block to shared memory");
b.comment(" 3b. Compute S = Q_smem @ K_smem^T (block_m x block_n)");
b.comment(" 3c. Apply causal mask if enabled");
if causal {
b.comment(" [CAUSAL] Set S[i,j] = -inf where j > i + offset");
}
b.comment(" 3d. Online softmax update:");
b.comment(" m_new = max(m_old, row_max(S))");
b.comment(" correction = exp(m_old - m_new)");
b.comment(" l_new = correction * l_old + row_sum(exp(S - m_new))");
b.comment(" O_acc = correction * O_acc");
b.comment(" 3e. Load V_j block to shared memory");
b.comment(" 3f. Accumulate O_acc += P_block @ V_smem");
b.comment(" 3g. Update m_i, l_i");
b.bar_sync(0);
b.comment("");
b.comment("Step 4: Final rescale and store");
b.comment(" O_out = O_acc / l_i (normalise by softmax denominator)");
b.comment(" logsumexp = m_i + log(l_i) (for backward pass)");
let o_base = b.load_param_u64("o_ptr");
let lse_base = b.load_param_u64("lse_ptr");
let _ = o_base;
let _ = lse_base;
let _ = tid;
b.ret();
})
.build()
.map_err(|e| DnnError::PtxGeneration(e.to_string()))?;
Ok(ptx)
}
}
pub fn flash_attention_forward<T: GpuFloat>(
handle: &DnnHandle,
q: &TensorDesc<T>,
k: &TensorDesc<T>,
v: &TensorDesc<T>,
output: &mut TensorDescMut<T>,
logsumexp: &mut DeviceBuffer<f32>,
config: &FlashAttentionConfig,
) -> DnnResult<()> {
validate_flash_shapes(q, k, v, output, logsumexp, config)?;
let (batch, num_heads, _seq_q, _head_dim) = attn_dims(q)?;
let ptx = config.generate_ptx()?;
let kernel_name = format!(
"flash_attn_fwd_d{}_bm{}_bn{}_{}",
config.head_dim,
config.block_m,
config.block_n,
ptx_type_suffix(config.precision)
);
let module = Arc::new(Module::from_ptx(&ptx)?);
let kernel = Kernel::from_module(module, &kernel_name)?;
let num_q_tiles = config.num_q_tiles();
let num_kv_tiles = config.num_kv_tiles();
let threads_per_block = config.num_warps * 32;
let grid = Dim3::new(num_q_tiles, batch * num_heads, 1);
let block = Dim3::new(threads_per_block, 1, 1);
let params = LaunchParams::builder()
.grid(grid)
.block(block)
.shared_mem(config.shared_mem_bytes())
.build();
kernel.launch(
¶ms,
handle.stream(),
&(
q.ptr,
k.ptr,
v.ptr,
output.ptr,
logsumexp.as_device_ptr(),
config.seq_len_q,
config.seq_len_kv,
config.head_dim,
config.num_heads,
config.sm_scale,
num_kv_tiles,
),
)?;
Ok(())
}
fn validate_flash_shapes<T: GpuFloat>(
q: &TensorDesc<T>,
k: &TensorDesc<T>,
v: &TensorDesc<T>,
output: &TensorDescMut<T>,
logsumexp: &DeviceBuffer<f32>,
config: &FlashAttentionConfig,
) -> DnnResult<()> {
let (batch, heads, seq_q, head_dim) = attn_dims(q)?;
if head_dim != config.head_dim {
return Err(DnnError::InvalidDimension(format!(
"Q head_dim {} != config head_dim {}",
head_dim, config.head_dim
)));
}
let (kb, kh, _ksn, kd) = attn_dims(k)?;
if kb != batch || kh != heads || kd != head_dim {
return Err(DnnError::InvalidDimension(format!(
"K dims {:?} incompatible with Q dims {:?}",
k.dims, q.dims
)));
}
let (vb, vh, vsn, _vd) = attn_dims(v)?;
if vb != batch || vh != heads || vsn != k.dims[2] {
return Err(DnnError::InvalidDimension(format!(
"V dims {:?} incompatible with K dims {:?}",
v.dims, k.dims
)));
}
let (ob, oh, osn, _od) = attn_dims_mut(output)?;
if ob != batch || oh != heads || osn != seq_q {
return Err(DnnError::InvalidDimension(format!(
"output dims {:?} incompatible with Q dims {:?}",
output.dims, q.dims
)));
}
let lse_required = batch as usize * heads as usize * seq_q as usize;
if logsumexp.len() < lse_required {
return Err(DnnError::BufferTooSmall {
expected: lse_required * 4,
actual: logsumexp.len() * 4,
});
}
Ok(())
}
fn ptx_type_suffix(ty: PtxType) -> &'static str {
match ty {
PtxType::F32 => "f32",
PtxType::F64 => "f64",
PtxType::F16 => "f16",
PtxType::BF16 => "bf16",
_ => "unknown",
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn auto_config_defaults() {
let cfg = FlashAttentionConfig::auto(64, 512, 512, false, SmVersion::Sm80);
assert_eq!(cfg.block_m, 128);
assert_eq!(cfg.block_n, 128);
assert_eq!(cfg.num_warps, 4);
assert_eq!(cfg.num_stages, 2);
assert_eq!(cfg.head_dim, 64);
}
#[test]
fn auto_config_large_head_dim() {
let cfg = FlashAttentionConfig::auto(256, 1024, 1024, true, SmVersion::Sm90);
assert_eq!(cfg.block_m, 64);
assert_eq!(cfg.block_n, 64);
assert!(cfg.causal);
}
#[test]
fn auto_config_hopper() {
let cfg = FlashAttentionConfig::auto(128, 2048, 2048, false, SmVersion::Sm90);
assert_eq!(cfg.num_warps, 8);
assert_eq!(cfg.num_stages, 3);
}
#[test]
fn shared_mem_calculation() {
let cfg = FlashAttentionConfig::auto(64, 512, 512, false, SmVersion::Sm80);
let smem = cfg.shared_mem_bytes();
assert!(smem > 0);
}
#[test]
fn num_tiles() {
let cfg = FlashAttentionConfig::auto(64, 512, 1024, false, SmVersion::Sm80);
assert_eq!(cfg.num_q_tiles(), 4); assert_eq!(cfg.num_kv_tiles(), 8); }
#[test]
fn generate_ptx_succeeds() {
let mut cfg = FlashAttentionConfig::auto(64, 128, 128, false, SmVersion::Sm80);
cfg.num_heads = 8;
let ptx = cfg.generate_ptx();
assert!(ptx.is_ok());
let text = ptx.ok().unwrap_or_default();
assert!(text.contains("flash_attn_fwd"));
assert!(text.contains(".shared"));
}
#[test]
fn generate_causal_ptx_succeeds() {
let mut cfg = FlashAttentionConfig::auto(128, 256, 256, true, SmVersion::Sm80);
cfg.num_heads = 4;
let ptx = cfg.generate_ptx();
assert!(ptx.is_ok());
let text = ptx.ok().unwrap_or_default();
assert!(text.contains("CAUSAL"));
}
#[test]
fn test_flash_attn_tile_selection_128x64_for_head_dim_128() {
let cfg = FlashAttentionConfig::auto(128, 2048, 2048, false, SmVersion::Sm80);
assert_eq!(
cfg.block_m, 128,
"Br (block_m) should be 128 for head_dim=128 on SM80"
);
assert_eq!(
cfg.block_n, 64,
"Bc (block_n) should be 64 for head_dim=128 on SM80"
);
}
#[test]
fn test_flash_attn_tile_128x128_for_head_dim_64() {
let cfg = FlashAttentionConfig::auto(64, 2048, 2048, false, SmVersion::Sm80);
assert_eq!(
cfg.block_m, 128,
"Br (block_m) should be 128 for head_dim=64"
);
assert_eq!(
cfg.block_n, 128,
"Bc (block_n) should be 128 for head_dim=64"
);
}
#[test]
fn test_flash_attn_4_warps_sm80() {
let cfg = FlashAttentionConfig::auto(128, 2048, 2048, false, SmVersion::Sm80);
assert_eq!(
cfg.num_warps, 4,
"FlashAttention-2 on SM80 should use 4 warps"
);
assert_eq!(
cfg.num_warps * 32,
128,
"block_size should be 128 (4 warps × 32 threads)"
);
}
#[test]
fn test_flash_attn_causal_mask_math() {
for q in 0u32..8 {
for k in 0u32..8 {
let mask: f32 = if k > q { f32::NEG_INFINITY } else { 0.0 };
if k > q {
assert!(
mask.is_infinite() && mask < 0.0,
"causal mask at (q={q}, k={k}) should be -inf, got {mask}"
);
} else {
assert_eq!(mask, 0.0, "causal mask at (q={q}, k={k}) should be 0.0");
}
}
}
}
#[test]
fn test_flash_attn_online_softmax_update() {
let scores = [0.0f32, 1.0, 2.0, 3.0];
let m_new = scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
assert_eq!(m_new, 3.0, "row max should be 3.0");
let l_new: f32 = scores.iter().map(|&s| (s - m_new).exp()).sum();
assert!(
(l_new - 1.553).abs() < 1e-3,
"softmax denominator should be ≈1.553, got {l_new}"
);
let probs: Vec<f32> = scores.iter().map(|&s| (s - m_new).exp() / l_new).collect();
let prob_sum: f32 = probs.iter().sum();
assert!(
(prob_sum - 1.0).abs() < 1e-5,
"softmax probabilities must sum to 1.0, got {prob_sum}"
);
}
#[test]
fn test_flash_attn_block_rescale_correctness() {
let tile1 = [0.0f32, 1.0];
let m1 = tile1.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let l1: f32 = tile1.iter().map(|&s| (s - m1).exp()).sum();
let tile2 = [2.0f32, 3.0];
let m2_new = tile2.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let m_new = f32::max(m1, m2_new);
let correction = (m1 - m_new).exp();
let l2: f32 = tile2.iter().map(|&s| (s - m_new).exp()).sum();
let l_new = correction * l1 + l2;
let all_scores = [0.0f32, 1.0, 2.0, 3.0];
let m_all = all_scores.iter().copied().fold(f32::NEG_INFINITY, f32::max);
let l_all: f32 = all_scores.iter().map(|&s| (s - m_all).exp()).sum();
assert_eq!(m_new, m_all, "max should be the same");
assert!(
(l_new - l_all).abs() < 1e-5,
"online accumulation: l_new={l_new} should equal l_all={l_all}"
);
}
}