atomr_accel_flashattn/lib.rs
1//! # atomr-accel-flashattn
2//!
3//! FlashAttention v2 + v3 kernel templates for atomr-accel.
4//!
5//! Provides forward + backward attention via NVRTC-compiled CUDA C++
6//! kernels with the full feature matrix:
7//!
8//! | Feature | FA2 (sm_80 / sm_89) | FA3 (sm_90a / sm_100) |
9//! |-------------------------|---------------------|------------------------|
10//! | f16, bf16 | ✔ | ✔ |
11//! | fp8 e4m3 / e5m2 | | ✔ (`fp8`) |
12//! | causal | ✔ | ✔ |
13//! | varlen (cu_seqlens) | ✔ | ✔ |
14//! | sliding window | ✔ | ✔ |
15//! | sink tokens | ✔ | ✔ |
16//! | ALiBi | ✔ | ✔ |
17//! | MQA / GQA | ✔ | ✔ |
18//! | paged KV-cache | ✔ (`paged`) | ✔ (`paged`) |
19//! | chunked prefill | ✔ | ✔ |
20//! | persistent kernels | | ✔ |
21//! | backward | ✔ (f16/bf16 only) | falls through to fa2 |
22//!
23//! ## Architecture
24//!
25//! Each request type ([`fa2::Fa2FwdRequest`], [`fa3::Fa3FwdRequest`],
26//! [`varlen::VarlenFwdRequest`], [`paged::PagedAttentionRequest`],
27//! [`prefill::ChunkedPrefillRequest`]) is generic over a
28//! [`dispatch::GemmSupported`] dtype marker and produces a
29//! [`dispatch::DispatchKey`] — the canonical (arch, dtype, head_dim,
30//! causal, varlen, sliding_window, alibi, sink, paged, gqa_ratio)
31//! tuple that picks one of FA2 / FA3 cubins.
32//!
33//! At runtime, the Phase 0.6 NVRTC disk cache compiles the matching
34//! cubin lazily; the dispatch table maps a [`dispatch::DispatchKey`]
35//! to the canonical mangled kernel-name expression. The hot path for
36//! steady-state inference is: `dispatch_key()` → cache hit → cubin
37//! launch.
38//!
39//! ## Cargo features
40//!
41//! - `fp8` — enables the FA3 fp8 (`F8E4m3` / `F8E5m2`) request types.
42//! - `paged` — enables the [`paged`] module and the paged dispatch
43//! key cells.
44//! - `cuda-runtime-tests` — gates the real-GPU example + the
45//! `Real` actor variant. Off by default so the crate builds and
46//! unit-tests on hosts without CUDA.
47
48#![allow(clippy::module_name_repetitions, clippy::too_many_arguments)]
49
50pub mod actor;
51pub mod dispatch;
52pub mod fa2;
53pub mod fa3;
54#[cfg(feature = "paged")]
55pub mod paged;
56pub mod prefill;
57pub mod varlen;
58
59#[cfg(feature = "cuda-runtime-tests")]
60mod cuda_real {
61 //! Real-GPU types referenced from [`crate::actor::FlashAttnInner::Real`].
62 //!
63 //! Defined behind `cuda-runtime-tests` so the crate builds without
64 //! a working CUDA driver.
65
66 /// Opaque reference to the host's `NvrtcActor`. The real type lives
67 /// in `atomr-accel-cuda::kernel::NvrtcActor`; the `FlashAttnActor`
68 /// only needs to forward `Compile { … }` / `Launch { … }` messages
69 /// to it, so a newtype around `ActorRef<NvrtcMsg>` is enough.
70 pub struct NvrtcRef {
71 // The concrete `ActorRef<NvrtcMsg>` is constructed by callers
72 // and embedded by `FlashAttnActor::props`. We keep the field
73 // pub(crate) so this stub can be replaced once the runtime
74 // launch path lands.
75 pub(crate) _opaque: (),
76 }
77}
78
79pub use actor::{FlashAttnActor, FlashAttnInner, FlashAttnMsg, FlashAttnProps};
80pub use dispatch::{
81 lookup, Bf16, DType, DispatchError, DispatchKey, DispatchTable, FaBwdDispatch, FaFwdDispatch,
82 FaPagedFwdDispatch, GemmSupported, SmArch, DISPATCH_TABLE, F16,
83};
84
85#[cfg(feature = "fp8")]
86pub use dispatch::{F8E4m3, F8E5m2};
87
88pub use fa2::{Fa2BwdRequest, Fa2FwdRequest, MaskKind, PositionBias};
89#[cfg(feature = "fp8")]
90pub use fa3::Fa3FwdFp8Request;
91pub use fa3::{Fa3FwdRequest, PersistentMode};
92
93#[cfg(feature = "paged")]
94pub use paged::{PagedAttentionRequest, PagedKvCache};
95
96pub use prefill::{ChunkLayout, ChunkedPrefillRequest};
97pub use varlen::{CumulativeSeqlens, VarlenFwdRequest};
98
99/// Errors surfaced by the FlashAttention crate. Most are construction-
100/// time validation failures; a small set are runtime launch errors
101/// produced by the actor (and kept here so callers can pattern-match
102/// without depending on the rest of `atomr-accel-cuda`).
103#[derive(Debug, Clone, thiserror::Error)]
104pub enum FlashAttnError {
105 /// Validation against the dispatch table failed.
106 #[error("dispatch error: {0}")]
107 Dispatch(#[from] DispatchError),
108
109 /// A FlashAttention v3 request targeted a non-Hopper arch.
110 #[error("FA3 requires sm_90a or newer, got {0:?}")]
111 Fa3RequiresHopper(SmArch),
112
113 /// An fp8 dtype was passed to a non-fp8 request type, or vice
114 /// versa.
115 #[error("fp8 dtypes must use Fa3FwdFp8Request and vice versa")]
116 Fp8MustUseFp8Request,
117
118 /// Variable-length / paged batch is empty.
119 #[error("attention batch must contain at least one sequence")]
120 EmptyBatch,
121
122 /// Sequence length is zero.
123 #[error("seqlen must be > 0")]
124 ZeroSeqlen,
125
126 /// Cumulative seqlens overflow `batch_size * max_seqlen`.
127 #[error("cumulative seqlens overflow batch_size * max_seqlen")]
128 SeqlenOverflow,
129
130 /// Paged KV cache is empty / zero-sized.
131 #[error("paged KV cache must be non-empty")]
132 EmptyPagedCache,
133
134 /// Paged KV-cache block size not in the supported set.
135 #[error("paged block_size {0} is not in (8, 16, 32, 64, 128)")]
136 InvalidPagedBlockSize(u32),
137
138 /// Paged cache head_dim doesn't match the request head_dim.
139 #[error("paged cache head_dim {cache} != request head_dim {req}")]
140 PagedHeadDimMismatch { cache: u32, req: u32 },
141
142 /// Chunked-prefill chunk index is out of range.
143 #[error("chunk_index {index} >= total_chunks {total}")]
144 ChunkIndexOutOfRange { index: u32, total: u32 },
145
146 /// Mock-mode actor saw a launch it can't honour.
147 #[error("flashattn actor is in mock mode (no GPU)")]
148 MockMode,
149}