atomr_accel_flashattn/dispatch.rs
1//! Dispatch table — maps a `(arch, dtype, head_dim, …)` cell onto a
2//! mangled kernel name expression.
3//!
4//! The Phase 7 FlashAttention crate ships forward + backward paths for
5//! v2 (sm_80 / sm_89) and v3 (sm_90a, including the fp8 e4m3 / e5m2
6//! variants). Every kernel is NVRTC-compiled lazily through the Phase
7//! 0.6 disk cache; the dispatch table is the *only* place that knows
8//! the canonical mangled symbol — every request type ([`crate::fa2`],
9//! [`crate::fa3`], [`crate::paged`], [`crate::prefill`], [`crate::varlen`])
10//! produces a [`DispatchKey`] that hashes to the same string.
11//!
12//! Hot path:
13//!
14//! 1. Caller constructs a request (e.g. [`crate::fa2::Fa2FwdRequest`]).
15//! 2. [`FaFwdDispatch::dispatch_key`] yields a [`DispatchKey`].
16//! 3. [`DispatchTable::lookup`] resolves the key to a kernel name.
17//! 4. The actor asks `NvrtcActor` to compile-or-fetch by name.
18//! 5. The cubin is launched on the actor's stream.
19//!
20//! Steps 3–5 are GPU-only and gated behind `cuda-runtime-tests`; the
21//! request-construction path (1–2) is exercised by the unit tests
22//! below and from each request-type module's `tests` block.
23
24use std::collections::HashMap;
25use std::hash::{Hash, Hasher};
26
27use once_cell::sync::Lazy;
28
29/// CUDA streaming-multiprocessor architecture target. The dispatch
30/// table refuses to resolve any key whose `arch` is not in this list.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum SmArch {
33 /// Ampere (A100, A30) — fa2 only.
34 Sm80,
35 /// Ada (RTX 40xx, L4) — fa2 only, supports fp8 cuBLASLt but not fa3.
36 Sm89,
37 /// Hopper (H100, H200) — fa3, fp8, TMA, WGMMA, persistent kernels.
38 Sm90a,
39 /// Blackwell (B100, B200) — forward-compat target; fa3 with fifth-gen
40 /// tensor cores. Falls back to Hopper kernels for now.
41 Sm100,
42}
43
44impl SmArch {
45 /// CUDA `--gpu-architecture` string.
46 pub fn nvrtc_flag(self) -> &'static str {
47 match self {
48 SmArch::Sm80 => "--gpu-architecture=sm_80",
49 SmArch::Sm89 => "--gpu-architecture=sm_89",
50 SmArch::Sm90a => "--gpu-architecture=sm_90a",
51 SmArch::Sm100 => "--gpu-architecture=sm_100a",
52 }
53 }
54
55 /// True if this arch supports FlashAttention v3 (Hopper+).
56 pub fn supports_fa3(self) -> bool {
57 matches!(self, SmArch::Sm90a | SmArch::Sm100)
58 }
59
60 /// True if this arch supports fp8 e4m3 / e5m2 in FA3.
61 pub fn supports_fp8(self) -> bool {
62 matches!(self, SmArch::Sm90a | SmArch::Sm100)
63 }
64}
65
66/// Element type for Q / K / V tiles. Distinct from `atomr-accel-cuda`'s
67/// future `CudaDtype` so the FlashAttn crate is self-contained.
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
69pub enum DType {
70 /// IEEE 754 binary16 — fa2 + fa3.
71 F16,
72 /// bfloat16 — fa2 + fa3.
73 Bf16,
74 /// 8-bit float, e4m3 — fa3 only, sm_90a+.
75 F8E4m3,
76 /// 8-bit float, e5m2 — fa3 only, sm_90a+ (used for V in DPA-mixed-precision).
77 F8E5m2,
78}
79
80impl DType {
81 /// Element width in bytes.
82 pub fn size_in_bytes(self) -> usize {
83 match self {
84 DType::F16 | DType::Bf16 => 2,
85 DType::F8E4m3 | DType::F8E5m2 => 1,
86 }
87 }
88
89 /// True iff this dtype is one of the fp8 variants.
90 pub fn is_fp8(self) -> bool {
91 matches!(self, DType::F8E4m3 | DType::F8E5m2)
92 }
93
94 /// Short tag used inside the kernel-name mangling.
95 pub fn tag(self) -> &'static str {
96 match self {
97 DType::F16 => "f16",
98 DType::Bf16 => "bf16",
99 DType::F8E4m3 => "e4m3",
100 DType::F8E5m2 => "e5m2",
101 }
102 }
103}
104
105/// Marker trait for dtypes that can drive a FlashAttention GEMM. Implemented
106/// by the same set of zero-sized types that the rest of `atomr-accel`
107/// uses to phantom-tag GEMM-supported dtypes. The trait itself carries
108/// no methods so it can be referenced from [`crate::fa2`] / [`crate::fa3`]
109/// without requiring callers to depend on `atomr-accel-cuda` directly.
110pub trait GemmSupported: Send + Sync + 'static {
111 /// The runtime dtype tag this marker maps onto.
112 fn dtype() -> DType;
113}
114
115/// Zero-sized marker for `f16` (IEEE binary16).
116#[derive(Debug, Clone, Copy)]
117pub struct F16;
118impl GemmSupported for F16 {
119 fn dtype() -> DType {
120 DType::F16
121 }
122}
123
124/// Zero-sized marker for `bf16` (bfloat16).
125#[derive(Debug, Clone, Copy)]
126pub struct Bf16;
127impl GemmSupported for Bf16 {
128 fn dtype() -> DType {
129 DType::Bf16
130 }
131}
132
133/// Zero-sized marker for fp8 e4m3 (gated `fp8`).
134#[cfg(feature = "fp8")]
135#[derive(Debug, Clone, Copy)]
136pub struct F8E4m3;
137#[cfg(feature = "fp8")]
138impl GemmSupported for F8E4m3 {
139 fn dtype() -> DType {
140 DType::F8E4m3
141 }
142}
143
144/// Zero-sized marker for fp8 e5m2 (gated `fp8`).
145#[cfg(feature = "fp8")]
146#[derive(Debug, Clone, Copy)]
147pub struct F8E5m2;
148#[cfg(feature = "fp8")]
149impl GemmSupported for F8E5m2 {
150 fn dtype() -> DType {
151 DType::F8E5m2
152 }
153}
154
155/// Cell key for the FlashAttention dispatch table.
156///
157/// Every field directly affects the generated CUDA C++ template
158/// instantiation — flipping any one of them changes the resulting
159/// cubin. The table refuses to resolve unsupported combinations
160/// (e.g. `fp8` on `Sm80`, head_dim > 256).
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
162pub struct DispatchKey {
163 /// Target SM architecture.
164 pub arch: SmArch,
165 /// Element type for Q/K/V.
166 pub dtype: DType,
167 /// Per-head dimension (D). Supported: 64, 80, 96, 128, 192, 256.
168 pub head_dim: u32,
169 /// Causal masking — autoregressive attention.
170 pub causal: bool,
171 /// Variable-length (cu_seqlens). When false, batched attention with
172 /// uniform seqlen.
173 pub varlen: bool,
174 /// Sliding-window size; `None` means full attention. Window size
175 /// is the number of past tokens each query attends to.
176 pub sliding_window: Option<u32>,
177 /// ALiBi linear-position biases.
178 pub alibi: bool,
179 /// Number of "sink" tokens (StreamingLLM); each query unconditionally
180 /// attends to the first `sink` keys regardless of `sliding_window`.
181 pub sink: u32,
182 /// vLLM-style paged KV-cache.
183 pub paged: bool,
184 /// Q heads per KV head. 1 = MHA, >1 = GQA, equal to num_heads = MQA.
185 pub gqa_ratio: u32,
186}
187
188impl DispatchKey {
189 /// Validate the cell for a *forward* path. Returns `Err` for
190 /// unreachable combinations.
191 pub fn validate_fwd(&self) -> Result<(), DispatchError> {
192 // Head-dim whitelist
193 const ALLOWED: &[u32] = &[64, 80, 96, 128, 192, 256];
194 if !ALLOWED.contains(&self.head_dim) {
195 return Err(DispatchError::UnsupportedHeadDim(self.head_dim));
196 }
197
198 // fp8 only on FA3-capable architectures
199 if self.dtype.is_fp8() && !self.arch.supports_fp8() {
200 return Err(DispatchError::Fp8RequiresHopper(self.arch));
201 }
202
203 // Sink tokens require sliding_window or causal — otherwise the
204 // mask is just full attention.
205 if self.sink > 0 && self.sliding_window.is_none() && !self.causal {
206 return Err(DispatchError::SinkWithoutMask);
207 }
208
209 // GQA ratio must be a power of two and at least 1.
210 if self.gqa_ratio == 0 {
211 return Err(DispatchError::InvalidGqaRatio(self.gqa_ratio));
212 }
213
214 // Sliding-window size must be > 0 when present.
215 if let Some(w) = self.sliding_window {
216 if w == 0 {
217 return Err(DispatchError::ZeroWindow);
218 }
219 }
220
221 Ok(())
222 }
223
224 /// Validate the cell for a *backward* path. Currently the same as
225 /// forward, but kept distinct so we can refuse e.g. fp8 backward
226 /// (numerically too lossy in the stock FA3) without affecting the
227 /// forward whitelist.
228 pub fn validate_bwd(&self) -> Result<(), DispatchError> {
229 self.validate_fwd()?;
230 if self.dtype.is_fp8() {
231 return Err(DispatchError::Fp8BackwardUnsupported);
232 }
233 Ok(())
234 }
235
236 /// Validate the cell for a *paged* forward path.
237 pub fn validate_paged(&self) -> Result<(), DispatchError> {
238 self.validate_fwd()?;
239 if !self.paged {
240 return Err(DispatchError::PagedFlagNotSet);
241 }
242 Ok(())
243 }
244
245 /// Stable 64-bit hash of the key. Useful as a cubin-cache index
246 /// alongside the kernel-name string.
247 pub fn stable_hash(&self) -> u64 {
248 let mut h = std::collections::hash_map::DefaultHasher::new();
249 self.hash(&mut h);
250 h.finish()
251 }
252
253 /// Build the canonical mangled kernel-name expression. Mirrors the
254 /// FA2/FA3 csrc naming convention so we can resolve it via NVRTC's
255 /// `nvrtcGetLoweredName`.
256 pub fn kernel_name(&self) -> String {
257 let kind = if self.arch.supports_fa3() {
258 "fa3"
259 } else {
260 "fa2"
261 };
262 let mut s = format!(
263 "atomr_flashattn::{}::fwd<{}, {}, {}>",
264 kind,
265 self.dtype.tag(),
266 self.head_dim,
267 self.causal_tag(),
268 );
269 if self.varlen {
270 s.push_str("_varlen");
271 }
272 if let Some(w) = self.sliding_window {
273 s.push_str(&format!("_sw{w}"));
274 }
275 if self.alibi {
276 s.push_str("_alibi");
277 }
278 if self.sink > 0 {
279 s.push_str(&format!("_sink{}", self.sink));
280 }
281 if self.paged {
282 s.push_str("_paged");
283 }
284 if self.gqa_ratio > 1 {
285 s.push_str(&format!("_gqa{}", self.gqa_ratio));
286 }
287 s
288 }
289
290 fn causal_tag(&self) -> &'static str {
291 if self.causal {
292 "causal"
293 } else {
294 "full"
295 }
296 }
297}
298
299/// Errors returned from [`DispatchKey::validate_fwd`] /
300/// [`DispatchTable::lookup`].
301#[derive(Debug, Clone, thiserror::Error)]
302pub enum DispatchError {
303 #[error("head_dim {0} is not in the FA whitelist (64, 80, 96, 128, 192, 256)")]
304 UnsupportedHeadDim(u32),
305 #[error("fp8 requires sm_90a or newer, got {0:?}")]
306 Fp8RequiresHopper(SmArch),
307 #[error("fp8 backward is not supported in FA3")]
308 Fp8BackwardUnsupported,
309 #[error("sink tokens require either sliding_window or causal")]
310 SinkWithoutMask,
311 #[error("invalid GQA ratio {0} (must be >= 1)")]
312 InvalidGqaRatio(u32),
313 #[error("sliding window must be > 0")]
314 ZeroWindow,
315 #[error("paged path requires DispatchKey::paged = true")]
316 PagedFlagNotSet,
317 #[error("no kernel registered for key {0:?}")]
318 UnknownKey(Box<DispatchKey>),
319}
320
321/// Forward-pass dispatch trait. Every forward-attention request type
322/// (FA2, FA3, varlen, paged, prefill) implements this and produces a
323/// `DispatchKey`.
324pub trait FaFwdDispatch: Send + 'static {
325 fn dispatch_key(&self) -> DispatchKey;
326}
327
328/// Backward-pass dispatch trait.
329pub trait FaBwdDispatch: Send + 'static {
330 fn dispatch_key(&self) -> DispatchKey;
331}
332
333/// Paged-forward dispatch trait. Distinct from `FaFwdDispatch` so the
334/// `FlashAttnMsg::PagedForward` variant can specialise on the paged
335/// API surface (block table, slot mapping).
336pub trait FaPagedFwdDispatch: Send + 'static {
337 fn dispatch_key(&self) -> DispatchKey;
338}
339
340/// In-process registry of known kernel names. Populated lazily on first
341/// access and shared across all `FlashAttnActor`s.
342///
343/// The "table" is really a `HashMap<DispatchKey, &'static str>`; the
344/// values are static name expressions, never owned. Real cubin
345/// compilation is delegated to `NvrtcActor` via the Phase 0.6 disk
346/// cache.
347pub struct DispatchTable {
348 entries: HashMap<DispatchKey, String>,
349}
350
351impl DispatchTable {
352 fn build() -> Self {
353 let mut entries: HashMap<DispatchKey, String> = HashMap::new();
354
355 // Pre-populate a cross-product of common cells. The dispatch
356 // table also resolves keys absent from this map by falling back
357 // to `key.kernel_name()` — so callers don't need every cell
358 // pre-registered. Pre-registration is just a self-test that
359 // every "common" combination produces a unique mangled name.
360 for &arch in &[SmArch::Sm80, SmArch::Sm89, SmArch::Sm90a, SmArch::Sm100] {
361 for &dtype in &[DType::F16, DType::Bf16] {
362 for &head_dim in &[64u32, 80, 96, 128, 192, 256] {
363 for &causal in &[false, true] {
364 let key = DispatchKey {
365 arch,
366 dtype,
367 head_dim,
368 causal,
369 varlen: false,
370 sliding_window: None,
371 alibi: false,
372 sink: 0,
373 paged: false,
374 gqa_ratio: 1,
375 };
376 if key.validate_fwd().is_ok() {
377 entries.insert(key, key.kernel_name());
378 }
379 }
380 }
381 }
382 }
383
384 // FA3 fp8 cells (sm_90a / sm_100 only)
385 #[cfg(feature = "fp8")]
386 for &dtype in &[DType::F8E4m3, DType::F8E5m2] {
387 for &head_dim in &[64u32, 128, 256] {
388 for &arch in &[SmArch::Sm90a, SmArch::Sm100] {
389 for &causal in &[false, true] {
390 let key = DispatchKey {
391 arch,
392 dtype,
393 head_dim,
394 causal,
395 varlen: false,
396 sliding_window: None,
397 alibi: false,
398 sink: 0,
399 paged: false,
400 gqa_ratio: 1,
401 };
402 if key.validate_fwd().is_ok() {
403 entries.insert(key, key.kernel_name());
404 }
405 }
406 }
407 }
408 }
409
410 Self { entries }
411 }
412
413 /// Resolve a key to a kernel-name expression.
414 ///
415 /// Lookup order:
416 ///
417 /// 1. Pre-registered entry (fast path — no allocation).
418 /// 2. Computed [`DispatchKey::kernel_name`] for cells outside the
419 /// pre-registration cross-product.
420 /// 3. `Err(DispatchError::UnknownKey(_))` if the key is invalid.
421 pub fn lookup(&self, key: &DispatchKey) -> Result<String, DispatchError> {
422 key.validate_fwd()?;
423 if let Some(name) = self.entries.get(key) {
424 return Ok(name.clone());
425 }
426 Ok(key.kernel_name())
427 }
428
429 /// Resolve a key, and additionally fail with `UnknownKey` if it is
430 /// not in the pre-registered set. Used by tests.
431 pub fn strict_lookup(&self, key: &DispatchKey) -> Result<&str, DispatchError> {
432 self.entries
433 .get(key)
434 .map(String::as_str)
435 .ok_or_else(|| DispatchError::UnknownKey(Box::new(*key)))
436 }
437
438 /// Number of pre-registered entries.
439 pub fn len(&self) -> usize {
440 self.entries.len()
441 }
442
443 /// True iff the table is empty.
444 pub fn is_empty(&self) -> bool {
445 self.entries.is_empty()
446 }
447}
448
449/// Process-wide dispatch table singleton.
450pub static DISPATCH_TABLE: Lazy<DispatchTable> = Lazy::new(DispatchTable::build);
451
452/// Convenience accessor — `DISPATCH_TABLE.lookup(key)`.
453pub fn lookup(key: &DispatchKey) -> Result<String, DispatchError> {
454 DISPATCH_TABLE.lookup(key)
455}
456
457#[cfg(test)]
458mod tests {
459 use super::*;
460
461 fn fwd_key(arch: SmArch, dtype: DType, head_dim: u32, causal: bool) -> DispatchKey {
462 DispatchKey {
463 arch,
464 dtype,
465 head_dim,
466 causal,
467 varlen: false,
468 sliding_window: None,
469 alibi: false,
470 sink: 0,
471 paged: false,
472 gqa_ratio: 1,
473 }
474 }
475
476 /// Every `(arch, dtype, head_dim, causal, …)` cell builds, validates,
477 /// and round-trips through `kernel_name + stable_hash` deterministically.
478 #[test]
479 fn dispatch_key_round_trip() {
480 let arches = [SmArch::Sm80, SmArch::Sm89, SmArch::Sm90a, SmArch::Sm100];
481 let dtypes = [DType::F16, DType::Bf16];
482 let head_dims = [64u32, 80, 96, 128, 192, 256];
483
484 for &arch in &arches {
485 for &dtype in &dtypes {
486 for &head_dim in &head_dims {
487 for &causal in &[false, true] {
488 let key = fwd_key(arch, dtype, head_dim, causal);
489 assert!(key.validate_fwd().is_ok());
490
491 // Re-construct identically and re-hash; must match.
492 let key2 = fwd_key(arch, dtype, head_dim, causal);
493 assert_eq!(key.stable_hash(), key2.stable_hash());
494 assert_eq!(key.kernel_name(), key2.kernel_name());
495
496 // Lookup goes through the table.
497 let name = lookup(&key).expect("lookup");
498 assert!(name.contains(dtype.tag()));
499 assert!(name.contains(&head_dim.to_string()));
500 }
501 }
502 }
503 }
504
505 // Modifying any field changes both the hash and the name.
506 let a = fwd_key(SmArch::Sm90a, DType::F16, 128, true);
507 let b = fwd_key(SmArch::Sm90a, DType::F16, 128, false);
508 assert_ne!(a.stable_hash(), b.stable_hash());
509 assert_ne!(a.kernel_name(), b.kernel_name());
510 }
511
512 /// Strict lookup of a key that wasn't pre-registered yields
513 /// `UnknownKey`; soft `lookup` succeeds via `kernel_name`.
514 #[test]
515 fn lookup_misses_unknown_key() {
516 // varlen + alibi cell — not in the pre-reg cross-product.
517 let key = DispatchKey {
518 arch: SmArch::Sm90a,
519 dtype: DType::Bf16,
520 head_dim: 128,
521 causal: true,
522 varlen: true,
523 sliding_window: Some(4096),
524 alibi: true,
525 sink: 4,
526 paged: false,
527 gqa_ratio: 8,
528 };
529 assert!(key.validate_fwd().is_ok());
530
531 // Strict lookup misses (not pre-registered).
532 let strict = DISPATCH_TABLE.strict_lookup(&key);
533 assert!(matches!(strict, Err(DispatchError::UnknownKey(_))));
534
535 // Soft lookup synthesises the kernel name on the fly.
536 let name = lookup(&key).expect("soft lookup synthesises a name");
537 assert!(name.contains("varlen"));
538 assert!(name.contains("alibi"));
539 assert!(name.contains("sink4"));
540 assert!(name.contains("sw4096"));
541 assert!(name.contains("gqa8"));
542 }
543
544 #[test]
545 fn fp8_requires_hopper() {
546 let mut key = DispatchKey {
547 arch: SmArch::Sm80,
548 dtype: DType::F8E4m3,
549 head_dim: 128,
550 causal: true,
551 varlen: false,
552 sliding_window: None,
553 alibi: false,
554 sink: 0,
555 paged: false,
556 gqa_ratio: 1,
557 };
558 assert!(matches!(
559 key.validate_fwd(),
560 Err(DispatchError::Fp8RequiresHopper(_))
561 ));
562 key.arch = SmArch::Sm90a;
563 assert!(key.validate_fwd().is_ok());
564 }
565
566 #[test]
567 fn unsupported_head_dim_rejected() {
568 let key = DispatchKey {
569 arch: SmArch::Sm90a,
570 dtype: DType::F16,
571 head_dim: 100,
572 causal: false,
573 varlen: false,
574 sliding_window: None,
575 alibi: false,
576 sink: 0,
577 paged: false,
578 gqa_ratio: 1,
579 };
580 assert!(matches!(
581 key.validate_fwd(),
582 Err(DispatchError::UnsupportedHeadDim(100))
583 ));
584 }
585
586 #[test]
587 fn sink_without_mask_rejected() {
588 let key = DispatchKey {
589 arch: SmArch::Sm90a,
590 dtype: DType::Bf16,
591 head_dim: 128,
592 causal: false,
593 varlen: false,
594 sliding_window: None,
595 alibi: false,
596 sink: 4,
597 paged: false,
598 gqa_ratio: 1,
599 };
600 assert!(matches!(
601 key.validate_fwd(),
602 Err(DispatchError::SinkWithoutMask)
603 ));
604 }
605}