mod build_ptx;
#[derive(Debug, Clone)]
pub struct MultiWarpIncrementalAttentionKernel {
pub max_seq_len: u32,
pub head_dim: u32,
pub num_heads: u32,
pub num_kv_heads: u32,
pub num_warps_per_head: u32,
pub scale: f32,
pub indirect_seq_len: bool,
}
impl MultiWarpIncrementalAttentionKernel {
#[must_use]
pub fn new(
max_seq_len: u32,
head_dim: u32,
num_heads: u32,
num_kv_heads: u32,
num_warps: u32,
) -> Self {
Self {
max_seq_len,
head_dim,
num_heads,
num_kv_heads,
num_warps_per_head: num_warps,
scale: 1.0 / (head_dim as f32).sqrt(),
indirect_seq_len: false,
}
}
#[must_use]
pub fn with_indirect_seq_len(mut self, indirect: bool) -> Self {
self.indirect_seq_len = indirect;
self
}
}