use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use once_cell::sync::Lazy;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum SmArch {
Sm80,
Sm89,
Sm90a,
Sm100,
}
impl SmArch {
pub fn nvrtc_flag(self) -> &'static str {
match self {
SmArch::Sm80 => "--gpu-architecture=sm_80",
SmArch::Sm89 => "--gpu-architecture=sm_89",
SmArch::Sm90a => "--gpu-architecture=sm_90a",
SmArch::Sm100 => "--gpu-architecture=sm_100a",
}
}
pub fn supports_fa3(self) -> bool {
matches!(self, SmArch::Sm90a | SmArch::Sm100)
}
pub fn supports_fp8(self) -> bool {
matches!(self, SmArch::Sm90a | SmArch::Sm100)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DType {
F16,
Bf16,
F8E4m3,
F8E5m2,
}
impl DType {
pub fn size_in_bytes(self) -> usize {
match self {
DType::F16 | DType::Bf16 => 2,
DType::F8E4m3 | DType::F8E5m2 => 1,
}
}
pub fn is_fp8(self) -> bool {
matches!(self, DType::F8E4m3 | DType::F8E5m2)
}
pub fn tag(self) -> &'static str {
match self {
DType::F16 => "f16",
DType::Bf16 => "bf16",
DType::F8E4m3 => "e4m3",
DType::F8E5m2 => "e5m2",
}
}
}
pub trait GemmSupported: Send + Sync + 'static {
fn dtype() -> DType;
}
#[derive(Debug, Clone, Copy)]
pub struct F16;
impl GemmSupported for F16 {
fn dtype() -> DType {
DType::F16
}
}
#[derive(Debug, Clone, Copy)]
pub struct Bf16;
impl GemmSupported for Bf16 {
fn dtype() -> DType {
DType::Bf16
}
}
#[cfg(feature = "fp8")]
#[derive(Debug, Clone, Copy)]
pub struct F8E4m3;
#[cfg(feature = "fp8")]
impl GemmSupported for F8E4m3 {
fn dtype() -> DType {
DType::F8E4m3
}
}
#[cfg(feature = "fp8")]
#[derive(Debug, Clone, Copy)]
pub struct F8E5m2;
#[cfg(feature = "fp8")]
impl GemmSupported for F8E5m2 {
fn dtype() -> DType {
DType::F8E5m2
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct DispatchKey {
pub arch: SmArch,
pub dtype: DType,
pub head_dim: u32,
pub causal: bool,
pub varlen: bool,
pub sliding_window: Option<u32>,
pub alibi: bool,
pub sink: u32,
pub paged: bool,
pub gqa_ratio: u32,
}
impl DispatchKey {
pub fn validate_fwd(&self) -> Result<(), DispatchError> {
const ALLOWED: &[u32] = &[64, 80, 96, 128, 192, 256];
if !ALLOWED.contains(&self.head_dim) {
return Err(DispatchError::UnsupportedHeadDim(self.head_dim));
}
if self.dtype.is_fp8() && !self.arch.supports_fp8() {
return Err(DispatchError::Fp8RequiresHopper(self.arch));
}
if self.sink > 0 && self.sliding_window.is_none() && !self.causal {
return Err(DispatchError::SinkWithoutMask);
}
if self.gqa_ratio == 0 {
return Err(DispatchError::InvalidGqaRatio(self.gqa_ratio));
}
if let Some(w) = self.sliding_window {
if w == 0 {
return Err(DispatchError::ZeroWindow);
}
}
Ok(())
}
pub fn validate_bwd(&self) -> Result<(), DispatchError> {
self.validate_fwd()?;
if self.dtype.is_fp8() {
return Err(DispatchError::Fp8BackwardUnsupported);
}
Ok(())
}
pub fn validate_paged(&self) -> Result<(), DispatchError> {
self.validate_fwd()?;
if !self.paged {
return Err(DispatchError::PagedFlagNotSet);
}
Ok(())
}
pub fn stable_hash(&self) -> u64 {
let mut h = std::collections::hash_map::DefaultHasher::new();
self.hash(&mut h);
h.finish()
}
pub fn kernel_name(&self) -> String {
let kind = if self.arch.supports_fa3() {
"fa3"
} else {
"fa2"
};
let mut s = format!(
"atomr_flashattn::{}::fwd<{}, {}, {}>",
kind,
self.dtype.tag(),
self.head_dim,
self.causal_tag(),
);
if self.varlen {
s.push_str("_varlen");
}
if let Some(w) = self.sliding_window {
s.push_str(&format!("_sw{w}"));
}
if self.alibi {
s.push_str("_alibi");
}
if self.sink > 0 {
s.push_str(&format!("_sink{}", self.sink));
}
if self.paged {
s.push_str("_paged");
}
if self.gqa_ratio > 1 {
s.push_str(&format!("_gqa{}", self.gqa_ratio));
}
s
}
fn causal_tag(&self) -> &'static str {
if self.causal {
"causal"
} else {
"full"
}
}
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum DispatchError {
#[error("head_dim {0} is not in the FA whitelist (64, 80, 96, 128, 192, 256)")]
UnsupportedHeadDim(u32),
#[error("fp8 requires sm_90a or newer, got {0:?}")]
Fp8RequiresHopper(SmArch),
#[error("fp8 backward is not supported in FA3")]
Fp8BackwardUnsupported,
#[error("sink tokens require either sliding_window or causal")]
SinkWithoutMask,
#[error("invalid GQA ratio {0} (must be >= 1)")]
InvalidGqaRatio(u32),
#[error("sliding window must be > 0")]
ZeroWindow,
#[error("paged path requires DispatchKey::paged = true")]
PagedFlagNotSet,
#[error("no kernel registered for key {0:?}")]
UnknownKey(Box<DispatchKey>),
}
pub trait FaFwdDispatch: Send + 'static {
fn dispatch_key(&self) -> DispatchKey;
}
pub trait FaBwdDispatch: Send + 'static {
fn dispatch_key(&self) -> DispatchKey;
}
pub trait FaPagedFwdDispatch: Send + 'static {
fn dispatch_key(&self) -> DispatchKey;
}
pub struct DispatchTable {
entries: HashMap<DispatchKey, String>,
}
impl DispatchTable {
fn build() -> Self {
let mut entries: HashMap<DispatchKey, String> = HashMap::new();
for &arch in &[SmArch::Sm80, SmArch::Sm89, SmArch::Sm90a, SmArch::Sm100] {
for &dtype in &[DType::F16, DType::Bf16] {
for &head_dim in &[64u32, 80, 96, 128, 192, 256] {
for &causal in &[false, true] {
let key = DispatchKey {
arch,
dtype,
head_dim,
causal,
varlen: false,
sliding_window: None,
alibi: false,
sink: 0,
paged: false,
gqa_ratio: 1,
};
if key.validate_fwd().is_ok() {
entries.insert(key, key.kernel_name());
}
}
}
}
}
#[cfg(feature = "fp8")]
for &dtype in &[DType::F8E4m3, DType::F8E5m2] {
for &head_dim in &[64u32, 128, 256] {
for &arch in &[SmArch::Sm90a, SmArch::Sm100] {
for &causal in &[false, true] {
let key = DispatchKey {
arch,
dtype,
head_dim,
causal,
varlen: false,
sliding_window: None,
alibi: false,
sink: 0,
paged: false,
gqa_ratio: 1,
};
if key.validate_fwd().is_ok() {
entries.insert(key, key.kernel_name());
}
}
}
}
}
Self { entries }
}
pub fn lookup(&self, key: &DispatchKey) -> Result<String, DispatchError> {
key.validate_fwd()?;
if let Some(name) = self.entries.get(key) {
return Ok(name.clone());
}
Ok(key.kernel_name())
}
pub fn strict_lookup(&self, key: &DispatchKey) -> Result<&str, DispatchError> {
self.entries
.get(key)
.map(String::as_str)
.ok_or_else(|| DispatchError::UnknownKey(Box::new(*key)))
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
}
pub static DISPATCH_TABLE: Lazy<DispatchTable> = Lazy::new(DispatchTable::build);
pub fn lookup(key: &DispatchKey) -> Result<String, DispatchError> {
DISPATCH_TABLE.lookup(key)
}
#[cfg(test)]
mod tests {
use super::*;
fn fwd_key(arch: SmArch, dtype: DType, head_dim: u32, causal: bool) -> DispatchKey {
DispatchKey {
arch,
dtype,
head_dim,
causal,
varlen: false,
sliding_window: None,
alibi: false,
sink: 0,
paged: false,
gqa_ratio: 1,
}
}
#[test]
fn dispatch_key_round_trip() {
let arches = [SmArch::Sm80, SmArch::Sm89, SmArch::Sm90a, SmArch::Sm100];
let dtypes = [DType::F16, DType::Bf16];
let head_dims = [64u32, 80, 96, 128, 192, 256];
for &arch in &arches {
for &dtype in &dtypes {
for &head_dim in &head_dims {
for &causal in &[false, true] {
let key = fwd_key(arch, dtype, head_dim, causal);
assert!(key.validate_fwd().is_ok());
let key2 = fwd_key(arch, dtype, head_dim, causal);
assert_eq!(key.stable_hash(), key2.stable_hash());
assert_eq!(key.kernel_name(), key2.kernel_name());
let name = lookup(&key).expect("lookup");
assert!(name.contains(dtype.tag()));
assert!(name.contains(&head_dim.to_string()));
}
}
}
}
let a = fwd_key(SmArch::Sm90a, DType::F16, 128, true);
let b = fwd_key(SmArch::Sm90a, DType::F16, 128, false);
assert_ne!(a.stable_hash(), b.stable_hash());
assert_ne!(a.kernel_name(), b.kernel_name());
}
#[test]
fn lookup_misses_unknown_key() {
let key = DispatchKey {
arch: SmArch::Sm90a,
dtype: DType::Bf16,
head_dim: 128,
causal: true,
varlen: true,
sliding_window: Some(4096),
alibi: true,
sink: 4,
paged: false,
gqa_ratio: 8,
};
assert!(key.validate_fwd().is_ok());
let strict = DISPATCH_TABLE.strict_lookup(&key);
assert!(matches!(strict, Err(DispatchError::UnknownKey(_))));
let name = lookup(&key).expect("soft lookup synthesises a name");
assert!(name.contains("varlen"));
assert!(name.contains("alibi"));
assert!(name.contains("sink4"));
assert!(name.contains("sw4096"));
assert!(name.contains("gqa8"));
}
#[test]
fn fp8_requires_hopper() {
let mut key = DispatchKey {
arch: SmArch::Sm80,
dtype: DType::F8E4m3,
head_dim: 128,
causal: true,
varlen: false,
sliding_window: None,
alibi: false,
sink: 0,
paged: false,
gqa_ratio: 1,
};
assert!(matches!(
key.validate_fwd(),
Err(DispatchError::Fp8RequiresHopper(_))
));
key.arch = SmArch::Sm90a;
assert!(key.validate_fwd().is_ok());
}
#[test]
fn unsupported_head_dim_rejected() {
let key = DispatchKey {
arch: SmArch::Sm90a,
dtype: DType::F16,
head_dim: 100,
causal: false,
varlen: false,
sliding_window: None,
alibi: false,
sink: 0,
paged: false,
gqa_ratio: 1,
};
assert!(matches!(
key.validate_fwd(),
Err(DispatchError::UnsupportedHeadDim(100))
));
}
#[test]
fn sink_without_mask_rejected() {
let key = DispatchKey {
arch: SmArch::Sm90a,
dtype: DType::Bf16,
head_dim: 128,
causal: false,
varlen: false,
sliding_window: None,
alibi: false,
sink: 4,
paged: false,
gqa_ratio: 1,
};
assert!(matches!(
key.validate_fwd(),
Err(DispatchError::SinkWithoutMask)
));
}
}