Skip to main content

atomr_accel_flashattn/
dispatch.rs

1//! Dispatch table — maps a `(arch, dtype, head_dim, …)` cell onto a
2//! mangled kernel name expression.
3//!
4//! The Phase 7 FlashAttention crate ships forward + backward paths for
5//! v2 (sm_80 / sm_89) and v3 (sm_90a, including the fp8 e4m3 / e5m2
6//! variants). Every kernel is NVRTC-compiled lazily through the Phase
7//! 0.6 disk cache; the dispatch table is the *only* place that knows
8//! the canonical mangled symbol — every request type ([`crate::fa2`],
9//! [`crate::fa3`], [`crate::paged`], [`crate::prefill`], [`crate::varlen`])
10//! produces a [`DispatchKey`] that hashes to the same string.
11//!
12//! Hot path:
13//!
14//! 1. Caller constructs a request (e.g. [`crate::fa2::Fa2FwdRequest`]).
15//! 2. [`FaFwdDispatch::dispatch_key`] yields a [`DispatchKey`].
16//! 3. [`DispatchTable::lookup`] resolves the key to a kernel name.
17//! 4. The actor asks `NvrtcActor` to compile-or-fetch by name.
18//! 5. The cubin is launched on the actor's stream.
19//!
20//! Steps 3–5 are GPU-only and gated behind `cuda-runtime-tests`; the
21//! request-construction path (1–2) is exercised by the unit tests
22//! below and from each request-type module's `tests` block.
23
24use std::collections::HashMap;
25use std::hash::{Hash, Hasher};
26
27use once_cell::sync::Lazy;
28
29/// CUDA streaming-multiprocessor architecture target. The dispatch
30/// table refuses to resolve any key whose `arch` is not in this list.
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
32pub enum SmArch {
33    /// Ampere (A100, A30) — fa2 only.
34    Sm80,
35    /// Ada (RTX 40xx, L4) — fa2 only, supports fp8 cuBLASLt but not fa3.
36    Sm89,
37    /// Hopper (H100, H200) — fa3, fp8, TMA, WGMMA, persistent kernels.
38    Sm90a,
39    /// Blackwell (B100, B200) — forward-compat target; fa3 with fifth-gen
40    /// tensor cores. Falls back to Hopper kernels for now.
41    Sm100,
42}
43
44impl SmArch {
45    /// CUDA `--gpu-architecture` string.
46    pub fn nvrtc_flag(self) -> &'static str {
47        match self {
48            SmArch::Sm80 => "--gpu-architecture=sm_80",
49            SmArch::Sm89 => "--gpu-architecture=sm_89",
50            SmArch::Sm90a => "--gpu-architecture=sm_90a",
51            SmArch::Sm100 => "--gpu-architecture=sm_100a",
52        }
53    }
54
55    /// True if this arch supports FlashAttention v3 (Hopper+).
56    pub fn supports_fa3(self) -> bool {
57        matches!(self, SmArch::Sm90a | SmArch::Sm100)
58    }
59
60    /// True if this arch supports fp8 e4m3 / e5m2 in FA3.
61    pub fn supports_fp8(self) -> bool {
62        matches!(self, SmArch::Sm90a | SmArch::Sm100)
63    }
64}
65
66/// Element type for Q / K / V tiles. Distinct from `atomr-accel-cuda`'s
67/// future `CudaDtype` so the FlashAttn crate is self-contained.
68#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
69pub enum DType {
70    /// IEEE 754 binary16 — fa2 + fa3.
71    F16,
72    /// bfloat16 — fa2 + fa3.
73    Bf16,
74    /// 8-bit float, e4m3 — fa3 only, sm_90a+.
75    F8E4m3,
76    /// 8-bit float, e5m2 — fa3 only, sm_90a+ (used for V in DPA-mixed-precision).
77    F8E5m2,
78}
79
80impl DType {
81    /// Element width in bytes.
82    pub fn size_in_bytes(self) -> usize {
83        match self {
84            DType::F16 | DType::Bf16 => 2,
85            DType::F8E4m3 | DType::F8E5m2 => 1,
86        }
87    }
88
89    /// True iff this dtype is one of the fp8 variants.
90    pub fn is_fp8(self) -> bool {
91        matches!(self, DType::F8E4m3 | DType::F8E5m2)
92    }
93
94    /// Short tag used inside the kernel-name mangling.
95    pub fn tag(self) -> &'static str {
96        match self {
97            DType::F16 => "f16",
98            DType::Bf16 => "bf16",
99            DType::F8E4m3 => "e4m3",
100            DType::F8E5m2 => "e5m2",
101        }
102    }
103}
104
105/// Marker trait for dtypes that can drive a FlashAttention GEMM. Implemented
106/// by the same set of zero-sized types that the rest of `atomr-accel`
107/// uses to phantom-tag GEMM-supported dtypes. The trait itself carries
108/// no methods so it can be referenced from [`crate::fa2`] / [`crate::fa3`]
109/// without requiring callers to depend on `atomr-accel-cuda` directly.
110pub trait GemmSupported: Send + Sync + 'static {
111    /// The runtime dtype tag this marker maps onto.
112    fn dtype() -> DType;
113}
114
115/// Zero-sized marker for `f16` (IEEE binary16).
116#[derive(Debug, Clone, Copy)]
117pub struct F16;
118impl GemmSupported for F16 {
119    fn dtype() -> DType {
120        DType::F16
121    }
122}
123
124/// Zero-sized marker for `bf16` (bfloat16).
125#[derive(Debug, Clone, Copy)]
126pub struct Bf16;
127impl GemmSupported for Bf16 {
128    fn dtype() -> DType {
129        DType::Bf16
130    }
131}
132
133/// Zero-sized marker for fp8 e4m3 (gated `fp8`).
134#[cfg(feature = "fp8")]
135#[derive(Debug, Clone, Copy)]
136pub struct F8E4m3;
137#[cfg(feature = "fp8")]
138impl GemmSupported for F8E4m3 {
139    fn dtype() -> DType {
140        DType::F8E4m3
141    }
142}
143
144/// Zero-sized marker for fp8 e5m2 (gated `fp8`).
145#[cfg(feature = "fp8")]
146#[derive(Debug, Clone, Copy)]
147pub struct F8E5m2;
148#[cfg(feature = "fp8")]
149impl GemmSupported for F8E5m2 {
150    fn dtype() -> DType {
151        DType::F8E5m2
152    }
153}
154
155/// Cell key for the FlashAttention dispatch table.
156///
157/// Every field directly affects the generated CUDA C++ template
158/// instantiation — flipping any one of them changes the resulting
159/// cubin. The table refuses to resolve unsupported combinations
160/// (e.g. `fp8` on `Sm80`, head_dim > 256).
161#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
162pub struct DispatchKey {
163    /// Target SM architecture.
164    pub arch: SmArch,
165    /// Element type for Q/K/V.
166    pub dtype: DType,
167    /// Per-head dimension (D). Supported: 64, 80, 96, 128, 192, 256.
168    pub head_dim: u32,
169    /// Causal masking — autoregressive attention.
170    pub causal: bool,
171    /// Variable-length (cu_seqlens). When false, batched attention with
172    /// uniform seqlen.
173    pub varlen: bool,
174    /// Sliding-window size; `None` means full attention. Window size
175    /// is the number of past tokens each query attends to.
176    pub sliding_window: Option<u32>,
177    /// ALiBi linear-position biases.
178    pub alibi: bool,
179    /// Number of "sink" tokens (StreamingLLM); each query unconditionally
180    /// attends to the first `sink` keys regardless of `sliding_window`.
181    pub sink: u32,
182    /// vLLM-style paged KV-cache.
183    pub paged: bool,
184    /// Q heads per KV head. 1 = MHA, >1 = GQA, equal to num_heads = MQA.
185    pub gqa_ratio: u32,
186}
187
188impl DispatchKey {
189    /// Validate the cell for a *forward* path. Returns `Err` for
190    /// unreachable combinations.
191    pub fn validate_fwd(&self) -> Result<(), DispatchError> {
192        // Head-dim whitelist
193        const ALLOWED: &[u32] = &[64, 80, 96, 128, 192, 256];
194        if !ALLOWED.contains(&self.head_dim) {
195            return Err(DispatchError::UnsupportedHeadDim(self.head_dim));
196        }
197
198        // fp8 only on FA3-capable architectures
199        if self.dtype.is_fp8() && !self.arch.supports_fp8() {
200            return Err(DispatchError::Fp8RequiresHopper(self.arch));
201        }
202
203        // Sink tokens require sliding_window or causal — otherwise the
204        // mask is just full attention.
205        if self.sink > 0 && self.sliding_window.is_none() && !self.causal {
206            return Err(DispatchError::SinkWithoutMask);
207        }
208
209        // GQA ratio must be a power of two and at least 1.
210        if self.gqa_ratio == 0 {
211            return Err(DispatchError::InvalidGqaRatio(self.gqa_ratio));
212        }
213
214        // Sliding-window size must be > 0 when present.
215        if let Some(w) = self.sliding_window {
216            if w == 0 {
217                return Err(DispatchError::ZeroWindow);
218            }
219        }
220
221        Ok(())
222    }
223
224    /// Validate the cell for a *backward* path. Currently the same as
225    /// forward, but kept distinct so we can refuse e.g. fp8 backward
226    /// (numerically too lossy in the stock FA3) without affecting the
227    /// forward whitelist.
228    pub fn validate_bwd(&self) -> Result<(), DispatchError> {
229        self.validate_fwd()?;
230        if self.dtype.is_fp8() {
231            return Err(DispatchError::Fp8BackwardUnsupported);
232        }
233        Ok(())
234    }
235
236    /// Validate the cell for a *paged* forward path.
237    pub fn validate_paged(&self) -> Result<(), DispatchError> {
238        self.validate_fwd()?;
239        if !self.paged {
240            return Err(DispatchError::PagedFlagNotSet);
241        }
242        Ok(())
243    }
244
245    /// Stable 64-bit hash of the key. Useful as a cubin-cache index
246    /// alongside the kernel-name string.
247    pub fn stable_hash(&self) -> u64 {
248        let mut h = std::collections::hash_map::DefaultHasher::new();
249        self.hash(&mut h);
250        h.finish()
251    }
252
253    /// Build the canonical mangled kernel-name expression. Mirrors the
254    /// FA2/FA3 csrc naming convention so we can resolve it via NVRTC's
255    /// `nvrtcGetLoweredName`.
256    pub fn kernel_name(&self) -> String {
257        let kind = if self.arch.supports_fa3() {
258            "fa3"
259        } else {
260            "fa2"
261        };
262        let mut s = format!(
263            "atomr_flashattn::{}::fwd<{}, {}, {}>",
264            kind,
265            self.dtype.tag(),
266            self.head_dim,
267            self.causal_tag(),
268        );
269        if self.varlen {
270            s.push_str("_varlen");
271        }
272        if let Some(w) = self.sliding_window {
273            s.push_str(&format!("_sw{w}"));
274        }
275        if self.alibi {
276            s.push_str("_alibi");
277        }
278        if self.sink > 0 {
279            s.push_str(&format!("_sink{}", self.sink));
280        }
281        if self.paged {
282            s.push_str("_paged");
283        }
284        if self.gqa_ratio > 1 {
285            s.push_str(&format!("_gqa{}", self.gqa_ratio));
286        }
287        s
288    }
289
290    fn causal_tag(&self) -> &'static str {
291        if self.causal {
292            "causal"
293        } else {
294            "full"
295        }
296    }
297}
298
299/// Errors returned from [`DispatchKey::validate_fwd`] /
300/// [`DispatchTable::lookup`].
301#[derive(Debug, Clone, thiserror::Error)]
302pub enum DispatchError {
303    #[error("head_dim {0} is not in the FA whitelist (64, 80, 96, 128, 192, 256)")]
304    UnsupportedHeadDim(u32),
305    #[error("fp8 requires sm_90a or newer, got {0:?}")]
306    Fp8RequiresHopper(SmArch),
307    #[error("fp8 backward is not supported in FA3")]
308    Fp8BackwardUnsupported,
309    #[error("sink tokens require either sliding_window or causal")]
310    SinkWithoutMask,
311    #[error("invalid GQA ratio {0} (must be >= 1)")]
312    InvalidGqaRatio(u32),
313    #[error("sliding window must be > 0")]
314    ZeroWindow,
315    #[error("paged path requires DispatchKey::paged = true")]
316    PagedFlagNotSet,
317    #[error("no kernel registered for key {0:?}")]
318    UnknownKey(Box<DispatchKey>),
319}
320
321/// Forward-pass dispatch trait. Every forward-attention request type
322/// (FA2, FA3, varlen, paged, prefill) implements this and produces a
323/// `DispatchKey`.
324pub trait FaFwdDispatch: Send + 'static {
325    fn dispatch_key(&self) -> DispatchKey;
326}
327
328/// Backward-pass dispatch trait.
329pub trait FaBwdDispatch: Send + 'static {
330    fn dispatch_key(&self) -> DispatchKey;
331}
332
333/// Paged-forward dispatch trait. Distinct from `FaFwdDispatch` so the
334/// `FlashAttnMsg::PagedForward` variant can specialise on the paged
335/// API surface (block table, slot mapping).
336pub trait FaPagedFwdDispatch: Send + 'static {
337    fn dispatch_key(&self) -> DispatchKey;
338}
339
340/// In-process registry of known kernel names. Populated lazily on first
341/// access and shared across all `FlashAttnActor`s.
342///
343/// The "table" is really a `HashMap<DispatchKey, &'static str>`; the
344/// values are static name expressions, never owned. Real cubin
345/// compilation is delegated to `NvrtcActor` via the Phase 0.6 disk
346/// cache.
347pub struct DispatchTable {
348    entries: HashMap<DispatchKey, String>,
349}
350
351impl DispatchTable {
352    fn build() -> Self {
353        let mut entries: HashMap<DispatchKey, String> = HashMap::new();
354
355        // Pre-populate a cross-product of common cells. The dispatch
356        // table also resolves keys absent from this map by falling back
357        // to `key.kernel_name()` — so callers don't need every cell
358        // pre-registered. Pre-registration is just a self-test that
359        // every "common" combination produces a unique mangled name.
360        for &arch in &[SmArch::Sm80, SmArch::Sm89, SmArch::Sm90a, SmArch::Sm100] {
361            for &dtype in &[DType::F16, DType::Bf16] {
362                for &head_dim in &[64u32, 80, 96, 128, 192, 256] {
363                    for &causal in &[false, true] {
364                        let key = DispatchKey {
365                            arch,
366                            dtype,
367                            head_dim,
368                            causal,
369                            varlen: false,
370                            sliding_window: None,
371                            alibi: false,
372                            sink: 0,
373                            paged: false,
374                            gqa_ratio: 1,
375                        };
376                        if key.validate_fwd().is_ok() {
377                            entries.insert(key, key.kernel_name());
378                        }
379                    }
380                }
381            }
382        }
383
384        // FA3 fp8 cells (sm_90a / sm_100 only)
385        #[cfg(feature = "fp8")]
386        for &dtype in &[DType::F8E4m3, DType::F8E5m2] {
387            for &head_dim in &[64u32, 128, 256] {
388                for &arch in &[SmArch::Sm90a, SmArch::Sm100] {
389                    for &causal in &[false, true] {
390                        let key = DispatchKey {
391                            arch,
392                            dtype,
393                            head_dim,
394                            causal,
395                            varlen: false,
396                            sliding_window: None,
397                            alibi: false,
398                            sink: 0,
399                            paged: false,
400                            gqa_ratio: 1,
401                        };
402                        if key.validate_fwd().is_ok() {
403                            entries.insert(key, key.kernel_name());
404                        }
405                    }
406                }
407            }
408        }
409
410        Self { entries }
411    }
412
413    /// Resolve a key to a kernel-name expression.
414    ///
415    /// Lookup order:
416    ///
417    /// 1. Pre-registered entry (fast path — no allocation).
418    /// 2. Computed [`DispatchKey::kernel_name`] for cells outside the
419    ///    pre-registration cross-product.
420    /// 3. `Err(DispatchError::UnknownKey(_))` if the key is invalid.
421    pub fn lookup(&self, key: &DispatchKey) -> Result<String, DispatchError> {
422        key.validate_fwd()?;
423        if let Some(name) = self.entries.get(key) {
424            return Ok(name.clone());
425        }
426        Ok(key.kernel_name())
427    }
428
429    /// Resolve a key, and additionally fail with `UnknownKey` if it is
430    /// not in the pre-registered set. Used by tests.
431    pub fn strict_lookup(&self, key: &DispatchKey) -> Result<&str, DispatchError> {
432        self.entries
433            .get(key)
434            .map(String::as_str)
435            .ok_or_else(|| DispatchError::UnknownKey(Box::new(*key)))
436    }
437
438    /// Number of pre-registered entries.
439    pub fn len(&self) -> usize {
440        self.entries.len()
441    }
442
443    /// True iff the table is empty.
444    pub fn is_empty(&self) -> bool {
445        self.entries.is_empty()
446    }
447}
448
449/// Process-wide dispatch table singleton.
450pub static DISPATCH_TABLE: Lazy<DispatchTable> = Lazy::new(DispatchTable::build);
451
452/// Convenience accessor — `DISPATCH_TABLE.lookup(key)`.
453pub fn lookup(key: &DispatchKey) -> Result<String, DispatchError> {
454    DISPATCH_TABLE.lookup(key)
455}
456
457#[cfg(test)]
458mod tests {
459    use super::*;
460
461    fn fwd_key(arch: SmArch, dtype: DType, head_dim: u32, causal: bool) -> DispatchKey {
462        DispatchKey {
463            arch,
464            dtype,
465            head_dim,
466            causal,
467            varlen: false,
468            sliding_window: None,
469            alibi: false,
470            sink: 0,
471            paged: false,
472            gqa_ratio: 1,
473        }
474    }
475
476    /// Every `(arch, dtype, head_dim, causal, …)` cell builds, validates,
477    /// and round-trips through `kernel_name + stable_hash` deterministically.
478    #[test]
479    fn dispatch_key_round_trip() {
480        let arches = [SmArch::Sm80, SmArch::Sm89, SmArch::Sm90a, SmArch::Sm100];
481        let dtypes = [DType::F16, DType::Bf16];
482        let head_dims = [64u32, 80, 96, 128, 192, 256];
483
484        for &arch in &arches {
485            for &dtype in &dtypes {
486                for &head_dim in &head_dims {
487                    for &causal in &[false, true] {
488                        let key = fwd_key(arch, dtype, head_dim, causal);
489                        assert!(key.validate_fwd().is_ok());
490
491                        // Re-construct identically and re-hash; must match.
492                        let key2 = fwd_key(arch, dtype, head_dim, causal);
493                        assert_eq!(key.stable_hash(), key2.stable_hash());
494                        assert_eq!(key.kernel_name(), key2.kernel_name());
495
496                        // Lookup goes through the table.
497                        let name = lookup(&key).expect("lookup");
498                        assert!(name.contains(dtype.tag()));
499                        assert!(name.contains(&head_dim.to_string()));
500                    }
501                }
502            }
503        }
504
505        // Modifying any field changes both the hash and the name.
506        let a = fwd_key(SmArch::Sm90a, DType::F16, 128, true);
507        let b = fwd_key(SmArch::Sm90a, DType::F16, 128, false);
508        assert_ne!(a.stable_hash(), b.stable_hash());
509        assert_ne!(a.kernel_name(), b.kernel_name());
510    }
511
512    /// Strict lookup of a key that wasn't pre-registered yields
513    /// `UnknownKey`; soft `lookup` succeeds via `kernel_name`.
514    #[test]
515    fn lookup_misses_unknown_key() {
516        // varlen + alibi cell — not in the pre-reg cross-product.
517        let key = DispatchKey {
518            arch: SmArch::Sm90a,
519            dtype: DType::Bf16,
520            head_dim: 128,
521            causal: true,
522            varlen: true,
523            sliding_window: Some(4096),
524            alibi: true,
525            sink: 4,
526            paged: false,
527            gqa_ratio: 8,
528        };
529        assert!(key.validate_fwd().is_ok());
530
531        // Strict lookup misses (not pre-registered).
532        let strict = DISPATCH_TABLE.strict_lookup(&key);
533        assert!(matches!(strict, Err(DispatchError::UnknownKey(_))));
534
535        // Soft lookup synthesises the kernel name on the fly.
536        let name = lookup(&key).expect("soft lookup synthesises a name");
537        assert!(name.contains("varlen"));
538        assert!(name.contains("alibi"));
539        assert!(name.contains("sink4"));
540        assert!(name.contains("sw4096"));
541        assert!(name.contains("gqa8"));
542    }
543
544    #[test]
545    fn fp8_requires_hopper() {
546        let mut key = DispatchKey {
547            arch: SmArch::Sm80,
548            dtype: DType::F8E4m3,
549            head_dim: 128,
550            causal: true,
551            varlen: false,
552            sliding_window: None,
553            alibi: false,
554            sink: 0,
555            paged: false,
556            gqa_ratio: 1,
557        };
558        assert!(matches!(
559            key.validate_fwd(),
560            Err(DispatchError::Fp8RequiresHopper(_))
561        ));
562        key.arch = SmArch::Sm90a;
563        assert!(key.validate_fwd().is_ok());
564    }
565
566    #[test]
567    fn unsupported_head_dim_rejected() {
568        let key = DispatchKey {
569            arch: SmArch::Sm90a,
570            dtype: DType::F16,
571            head_dim: 100,
572            causal: false,
573            varlen: false,
574            sliding_window: None,
575            alibi: false,
576            sink: 0,
577            paged: false,
578            gqa_ratio: 1,
579        };
580        assert!(matches!(
581            key.validate_fwd(),
582            Err(DispatchError::UnsupportedHeadDim(100))
583        ));
584    }
585
586    #[test]
587    fn sink_without_mask_rejected() {
588        let key = DispatchKey {
589            arch: SmArch::Sm90a,
590            dtype: DType::Bf16,
591            head_dim: 128,
592            causal: false,
593            varlen: false,
594            sliding_window: None,
595            alibi: false,
596            sink: 4,
597            paged: false,
598            gqa_ratio: 1,
599        };
600        assert!(matches!(
601            key.validate_fwd(),
602            Err(DispatchError::SinkWithoutMask)
603        ));
604    }
605}