Expand description
§atomr-accel-flashattn
FlashAttention v2 + v3 kernel templates for atomr-accel.
Provides forward + backward attention via NVRTC-compiled CUDA C++ kernels with the full feature matrix:
| Feature | FA2 (sm_80 / sm_89) | FA3 (sm_90a / sm_100) |
|---|---|---|
| f16, bf16 | ✔ | ✔ |
| fp8 e4m3 / e5m2 | ✔ (fp8) | |
| causal | ✔ | ✔ |
| varlen (cu_seqlens) | ✔ | ✔ |
| sliding window | ✔ | ✔ |
| sink tokens | ✔ | ✔ |
| ALiBi | ✔ | ✔ |
| MQA / GQA | ✔ | ✔ |
| paged KV-cache | ✔ (paged) | ✔ (paged) |
| chunked prefill | ✔ | ✔ |
| persistent kernels | ✔ | |
| backward | ✔ (f16/bf16 only) | falls through to fa2 |
§Architecture
Each request type (fa2::Fa2FwdRequest, fa3::Fa3FwdRequest,
varlen::VarlenFwdRequest, [paged::PagedAttentionRequest],
prefill::ChunkedPrefillRequest) is generic over a
dispatch::GemmSupported dtype marker and produces a
dispatch::DispatchKey — the canonical (arch, dtype, head_dim,
causal, varlen, sliding_window, alibi, sink, paged, gqa_ratio)
tuple that picks one of FA2 / FA3 cubins.
At runtime, the Phase 0.6 NVRTC disk cache compiles the matching
cubin lazily; the dispatch table maps a dispatch::DispatchKey
to the canonical mangled kernel-name expression. The hot path for
steady-state inference is: dispatch_key() → cache hit → cubin
launch.
§Cargo features
fp8— enables the FA3 fp8 (F8E4m3/F8E5m2) request types.paged— enables the [paged] module and the paged dispatch key cells.cuda-runtime-tests— gates the real-GPU example + theRealactor variant. Off by default so the crate builds and unit-tests on hosts without CUDA.
Re-exports§
pub use actor::FlashAttnActor;pub use actor::FlashAttnInner;pub use actor::FlashAttnMsg;pub use actor::FlashAttnProps;pub use dispatch::lookup;pub use dispatch::Bf16;pub use dispatch::DType;pub use dispatch::DispatchError;pub use dispatch::DispatchKey;pub use dispatch::DispatchTable;pub use dispatch::FaBwdDispatch;pub use dispatch::FaFwdDispatch;pub use dispatch::FaPagedFwdDispatch;pub use dispatch::GemmSupported;pub use dispatch::SmArch;pub use dispatch::DISPATCH_TABLE;pub use dispatch::F16;pub use fa2::Fa2BwdRequest;pub use fa2::Fa2FwdRequest;pub use fa2::MaskKind;pub use fa2::PositionBias;pub use fa3::Fa3FwdRequest;pub use fa3::PersistentMode;pub use prefill::ChunkLayout;pub use prefill::ChunkedPrefillRequest;pub use varlen::CumulativeSeqlens;pub use varlen::VarlenFwdRequest;
Modules§
- actor
FlashAttnActor— receivesFlashAttnMsgs and dispatches to the NVRTC-compiled FA2/FA3 cubin selected by the request’scrate::dispatch::DispatchKey.- dispatch
- Dispatch table — maps a
(arch, dtype, head_dim, …)cell onto a mangled kernel name expression. - fa2
- FlashAttention v2 — forward + backward request types.
- fa3
- FlashAttention v3 — forward request types for Hopper/Blackwell.
- prefill
- Chunked-prefill helpers.
- varlen
- Variable-length attention — packs sequences of different lengths
into a single batch tensor with a
cu_seqlenscumulative offset array.
Enums§
- Flash
Attn Error - Errors surfaced by the FlashAttention crate. Most are construction-
time validation failures; a small set are runtime launch errors
produced by the actor (and kept here so callers can pattern-match
without depending on the rest of
atomr-accel-cuda).