Skip to main content

Crate atomr_accel_flashattn

Crate atomr_accel_flashattn 

Source
Expand description

§atomr-accel-flashattn

FlashAttention v2 + v3 kernel templates for atomr-accel.

Provides forward + backward attention via NVRTC-compiled CUDA C++ kernels with the full feature matrix:

FeatureFA2 (sm_80 / sm_89)FA3 (sm_90a / sm_100)
f16, bf16
fp8 e4m3 / e5m2✔ (fp8)
causal
varlen (cu_seqlens)
sliding window
sink tokens
ALiBi
MQA / GQA
paged KV-cache✔ (paged)✔ (paged)
chunked prefill
persistent kernels
backward✔ (f16/bf16 only)falls through to fa2

§Architecture

Each request type (fa2::Fa2FwdRequest, fa3::Fa3FwdRequest, varlen::VarlenFwdRequest, [paged::PagedAttentionRequest], prefill::ChunkedPrefillRequest) is generic over a dispatch::GemmSupported dtype marker and produces a dispatch::DispatchKey — the canonical (arch, dtype, head_dim, causal, varlen, sliding_window, alibi, sink, paged, gqa_ratio) tuple that picks one of FA2 / FA3 cubins.

At runtime, the Phase 0.6 NVRTC disk cache compiles the matching cubin lazily; the dispatch table maps a dispatch::DispatchKey to the canonical mangled kernel-name expression. The hot path for steady-state inference is: dispatch_key() → cache hit → cubin launch.

§Cargo features

  • fp8 — enables the FA3 fp8 (F8E4m3 / F8E5m2) request types.
  • paged — enables the [paged] module and the paged dispatch key cells.
  • cuda-runtime-tests — gates the real-GPU example + the Real actor variant. Off by default so the crate builds and unit-tests on hosts without CUDA.

Re-exports§

pub use actor::FlashAttnActor;
pub use actor::FlashAttnInner;
pub use actor::FlashAttnMsg;
pub use actor::FlashAttnProps;
pub use dispatch::lookup;
pub use dispatch::Bf16;
pub use dispatch::DType;
pub use dispatch::DispatchError;
pub use dispatch::DispatchKey;
pub use dispatch::DispatchTable;
pub use dispatch::FaBwdDispatch;
pub use dispatch::FaFwdDispatch;
pub use dispatch::FaPagedFwdDispatch;
pub use dispatch::GemmSupported;
pub use dispatch::SmArch;
pub use dispatch::DISPATCH_TABLE;
pub use dispatch::F16;
pub use fa2::Fa2BwdRequest;
pub use fa2::Fa2FwdRequest;
pub use fa2::MaskKind;
pub use fa2::PositionBias;
pub use fa3::Fa3FwdRequest;
pub use fa3::PersistentMode;
pub use prefill::ChunkLayout;
pub use prefill::ChunkedPrefillRequest;
pub use varlen::CumulativeSeqlens;
pub use varlen::VarlenFwdRequest;

Modules§

actor
FlashAttnActor — receives FlashAttnMsgs and dispatches to the NVRTC-compiled FA2/FA3 cubin selected by the request’s crate::dispatch::DispatchKey.
dispatch
Dispatch table — maps a (arch, dtype, head_dim, …) cell onto a mangled kernel name expression.
fa2
FlashAttention v2 — forward + backward request types.
fa3
FlashAttention v3 — forward request types for Hopper/Blackwell.
prefill
Chunked-prefill helpers.
varlen
Variable-length attention — packs sequences of different lengths into a single batch tensor with a cu_seqlens cumulative offset array.

Enums§

FlashAttnError
Errors surfaced by the FlashAttention crate. Most are construction- time validation failures; a small set are runtime launch errors produced by the actor (and kept here so callers can pattern-match without depending on the rest of atomr-accel-cuda).