#![allow(clippy::similar_names)]
mod standard;
mod tensor_core;
use crate::kernels::Kernel;
use crate::ptx::PtxKernel;
#[derive(Debug, Clone)]
pub struct AttentionKernel {
pub seq_len: u32,
pub head_dim: u32,
pub tile_q: u32,
pub tile_kv: u32,
pub scale: f32,
pub causal: bool,
pub use_tensor_cores: bool,
}
impl AttentionKernel {
#[must_use]
pub fn new(seq_len: u32, head_dim: u32) -> Self {
let tile_q = seq_len.min(64);
let tile_kv = seq_len.min(64).max(head_dim);
Self {
seq_len,
head_dim,
tile_q,
tile_kv,
scale: 1.0 / (head_dim as f32).sqrt(),
causal: false,
use_tensor_cores: false,
}
}
#[must_use]
pub fn tensor_core(seq_len: u32, head_dim: u32) -> Self {
let tile_q = seq_len.clamp(16, 64);
let tile_kv = seq_len.clamp(16, 64).max(head_dim);
Self {
seq_len,
head_dim,
tile_q,
tile_kv,
scale: 1.0 / (head_dim as f32).sqrt(),
causal: false,
use_tensor_cores: true,
}
}
#[must_use]
pub const fn with_tiles(mut self, tile_q: u32, tile_kv: u32) -> Self {
self.tile_q = tile_q;
self.tile_kv = if tile_kv >= self.head_dim { tile_kv } else { self.head_dim };
self
}
#[must_use]
pub const fn with_causal(mut self) -> Self {
self.causal = true;
self
}
#[must_use]
pub const fn with_scale(mut self, scale: f32) -> Self {
self.scale = scale;
self
}
#[must_use]
pub const fn with_tensor_cores(mut self) -> Self {
self.use_tensor_cores = true;
self
}
}
impl Kernel for AttentionKernel {
fn name(&self) -> &str {
match (self.use_tensor_cores, self.causal) {
(true, true) => "flash_attention_tensor_core_causal",
(true, false) => "flash_attention_tensor_core",
(false, true) => "flash_attention_causal",
(false, false) => "flash_attention",
}
}
fn build_ptx(&self) -> PtxKernel {
if self.use_tensor_cores {
self.build_tensor_core_attention()
} else {
self.build_flash_attention()
}
}
}