pub mod cpu_flash;
pub mod varlen;
use hanzo_ml::Tensor;
pub use cpu_flash::flash_attn;
pub use cpu_flash::varlen::flash_attn_varlen_cpu;
pub use varlen::flash_attn_varlen_unfused;
#[derive(Debug, Clone, Default)]
pub enum AttnMask {
#[default]
None,
Causal {
kv_offset: usize,
},
Mask(Tensor),
}
impl AttnMask {
#[inline]
pub fn causal() -> Self {
AttnMask::Causal { kv_offset: 0 }
}
#[inline]
pub fn causal_with_offset(kv_offset: usize) -> Self {
AttnMask::Causal { kv_offset }
}
#[inline]
pub fn is_causal(&self) -> bool {
matches!(self, AttnMask::Causal { .. })
}
#[inline]
pub fn kv_offset(&self) -> usize {
match self {
AttnMask::Causal { kv_offset } => *kv_offset,
_ => 0,
}
}
}