Skip to main content

atomr_accel_flashattn/
fa3.rs

1//! FlashAttention v3 — forward request types for Hopper/Blackwell.
2//!
3//! v3 ships the persistent / warp-specialised kernel shapes plus the
4//! fp8 e4m3 / e5m2 paths. Backward is currently shared with FA2 — the
5//! v3 backward kernels in the upstream csrc are not stable enough to
6//! vendor for general use; callers fall through to [`crate::fa2::Fa2BwdRequest`]
7//! against `Sm90a` (which the dispatch layer correctly resolves to the
8//! fa3 cubin via [`crate::dispatch::SmArch::supports_fa3`]).
9//!
10//! The fp8 variants ([`Fa3FwdFp8Request`]) require feature `fp8`.
11
12use std::marker::PhantomData;
13
14use tokio::sync::oneshot;
15
16use crate::dispatch::{DType, DispatchKey, FaFwdDispatch, GemmSupported, SmArch};
17use crate::fa2::{MaskKind, PositionBias};
18use crate::FlashAttnError;
19
20/// Persistence mode for FA3. The v3 kernels can run as a single
21/// "persistent" grid that consumes a stream of work tiles, or as a
22/// classic per-tile grid. Persistent mode wins for short seqlens and
23/// loses for very long seqlens — callers pick based on workload.
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum PersistentMode {
26    /// Classic grid — one block per (batch, head, q-tile).
27    Grid,
28    /// Persistent — `num_sms` blocks; each consumes a tile-queue.
29    Persistent { num_sms: u32 },
30}
31
32/// Request payload for a FlashAttention v3 forward pass (non-fp8).
33///
34/// Mirrors [`crate::fa2::Fa2FwdRequest`] but with the FA3-only
35/// [`PersistentMode`] knob. Validates that `arch.supports_fa3()`.
36pub struct Fa3FwdRequest<T: GemmSupported> {
37    pub arch: SmArch,
38    pub head_dim: u32,
39    pub gqa_ratio: u32,
40    pub mask: MaskKind,
41    pub bias: PositionBias,
42    pub sink_tokens: u32,
43    pub softmax_scale: f32,
44    pub persistent: PersistentMode,
45    /// FP16 / bfloat16 only at this entry point. Use [`Fa3FwdFp8Request`]
46    /// for the fp8 paths.
47    pub reply: oneshot::Sender<Result<(), FlashAttnError>>,
48    _marker: PhantomData<T>,
49}
50
51impl<T: GemmSupported> Fa3FwdRequest<T> {
52    pub fn new(
53        arch: SmArch,
54        head_dim: u32,
55        gqa_ratio: u32,
56        mask: MaskKind,
57        bias: PositionBias,
58        sink_tokens: u32,
59        softmax_scale: f32,
60        persistent: PersistentMode,
61    ) -> Result<(Self, oneshot::Receiver<Result<(), FlashAttnError>>), FlashAttnError> {
62        if !arch.supports_fa3() {
63            return Err(FlashAttnError::Fa3RequiresHopper(arch));
64        }
65        if T::dtype() == DType::F8E4m3 || T::dtype() == DType::F8E5m2 {
66            return Err(FlashAttnError::Fp8MustUseFp8Request);
67        }
68        let (tx, rx) = oneshot::channel();
69        let req = Self {
70            arch,
71            head_dim,
72            gqa_ratio,
73            mask,
74            bias,
75            sink_tokens,
76            softmax_scale,
77            persistent,
78            reply: tx,
79            _marker: PhantomData,
80        };
81        let key = req.compute_key();
82        key.validate_fwd().map_err(FlashAttnError::Dispatch)?;
83        Ok((req, rx))
84    }
85
86    fn compute_key(&self) -> DispatchKey {
87        DispatchKey {
88            arch: self.arch,
89            dtype: T::dtype(),
90            head_dim: self.head_dim,
91            causal: self.mask.causal(),
92            varlen: false,
93            sliding_window: self.mask.sliding_window(),
94            alibi: self.bias.requires_alibi_flag(),
95            sink: self.sink_tokens,
96            paged: false,
97            gqa_ratio: self.gqa_ratio,
98        }
99    }
100
101    /// True iff this request runs in persistent mode.
102    pub fn is_persistent(&self) -> bool {
103        matches!(self.persistent, PersistentMode::Persistent { .. })
104    }
105}
106
107impl<T: GemmSupported> FaFwdDispatch for Fa3FwdRequest<T> {
108    fn dispatch_key(&self) -> DispatchKey {
109        self.compute_key()
110    }
111}
112
113/// Request payload for FA3 fp8 forward. Q is fp8 (`TQ`), K/V can be a
114/// distinct fp8 type (`TKV`) — DPA-mixed-precision uses e4m3 for Q/K
115/// and e5m2 for V.
116#[cfg(feature = "fp8")]
117pub struct Fa3FwdFp8Request<TQ: GemmSupported, TKV: GemmSupported> {
118    pub arch: SmArch,
119    pub head_dim: u32,
120    pub gqa_ratio: u32,
121    pub mask: MaskKind,
122    pub sink_tokens: u32,
123    pub softmax_scale: f32,
124    /// Per-tensor descale factor for Q. Required because fp8 storage
125    /// precision can't represent the dequantised range without an
126    /// out-of-band scale.
127    pub q_scale: f32,
128    /// Per-tensor descale factor for K.
129    pub k_scale: f32,
130    /// Per-tensor descale factor for V.
131    pub v_scale: f32,
132    pub persistent: PersistentMode,
133    pub reply: oneshot::Sender<Result<(), FlashAttnError>>,
134    _marker: PhantomData<(TQ, TKV)>,
135}
136
137#[cfg(feature = "fp8")]
138impl<TQ: GemmSupported, TKV: GemmSupported> Fa3FwdFp8Request<TQ, TKV> {
139    pub fn new(
140        arch: SmArch,
141        head_dim: u32,
142        gqa_ratio: u32,
143        mask: MaskKind,
144        sink_tokens: u32,
145        softmax_scale: f32,
146        q_scale: f32,
147        k_scale: f32,
148        v_scale: f32,
149        persistent: PersistentMode,
150    ) -> Result<(Self, oneshot::Receiver<Result<(), FlashAttnError>>), FlashAttnError> {
151        if !arch.supports_fp8() {
152            return Err(FlashAttnError::Dispatch(
153                crate::dispatch::DispatchError::Fp8RequiresHopper(arch),
154            ));
155        }
156        if !TQ::dtype().is_fp8() || !TKV::dtype().is_fp8() {
157            return Err(FlashAttnError::Fp8MustUseFp8Request);
158        }
159        let (tx, rx) = oneshot::channel();
160        let req = Self {
161            arch,
162            head_dim,
163            gqa_ratio,
164            mask,
165            sink_tokens,
166            softmax_scale,
167            q_scale,
168            k_scale,
169            v_scale,
170            persistent,
171            reply: tx,
172            _marker: PhantomData,
173        };
174        // Validate against the Q dtype's cell. KV dtype rides along in
175        // the kernel-name expression but doesn't change the dispatch
176        // table key shape.
177        let key = req.compute_key();
178        key.validate_fwd().map_err(FlashAttnError::Dispatch)?;
179        Ok((req, rx))
180    }
181
182    fn compute_key(&self) -> DispatchKey {
183        DispatchKey {
184            arch: self.arch,
185            dtype: TQ::dtype(),
186            head_dim: self.head_dim,
187            causal: self.mask.causal(),
188            varlen: false,
189            sliding_window: self.mask.sliding_window(),
190            alibi: false,
191            sink: self.sink_tokens,
192            paged: false,
193            gqa_ratio: self.gqa_ratio,
194        }
195    }
196
197    /// Convenience: returns `(q_dtype, kv_dtype)` for the kernel-name
198    /// suffix the runtime appends.
199    pub fn fp8_dtypes(&self) -> (DType, DType) {
200        (TQ::dtype(), TKV::dtype())
201    }
202}
203
204#[cfg(feature = "fp8")]
205impl<TQ: GemmSupported, TKV: GemmSupported> FaFwdDispatch for Fa3FwdFp8Request<TQ, TKV> {
206    fn dispatch_key(&self) -> DispatchKey {
207        self.compute_key()
208    }
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214    use crate::dispatch::{Bf16, F16};
215
216    #[test]
217    fn fa3_fwd_request_requires_hopper() {
218        // sm_80 must be rejected.
219        let err = Fa3FwdRequest::<F16>::new(
220            SmArch::Sm80,
221            128,
222            1,
223            MaskKind::Causal,
224            PositionBias::None,
225            0,
226            1.0 / (128f32).sqrt(),
227            PersistentMode::Grid,
228        )
229        .err()
230        .expect("expected an error");
231        assert!(matches!(err, FlashAttnError::Fa3RequiresHopper(_)));
232
233        // sm_90a is fine.
234        let (req, _rx) = Fa3FwdRequest::<Bf16>::new(
235            SmArch::Sm90a,
236            128,
237            8,
238            MaskKind::Causal,
239            PositionBias::None,
240            0,
241            1.0 / (128f32).sqrt(),
242            PersistentMode::Persistent { num_sms: 132 },
243        )
244        .expect("fa3 fwd on hopper");
245        assert!(req.is_persistent());
246        let key = req.dispatch_key();
247        assert_eq!(key.arch, SmArch::Sm90a);
248        assert_eq!(key.dtype, DType::Bf16);
249    }
250
251    #[cfg(feature = "fp8")]
252    #[test]
253    fn fa3_fwd_fp8_request_round_trip() {
254        use crate::dispatch::{F8E4m3, F8E5m2};
255
256        let (req, _rx) = Fa3FwdFp8Request::<F8E4m3, F8E5m2>::new(
257            SmArch::Sm90a,
258            128,
259            8,
260            MaskKind::Causal,
261            0,
262            1.0 / (128f32).sqrt(),
263            1.0,
264            1.0,
265            1.0,
266            PersistentMode::Persistent { num_sms: 132 },
267        )
268        .expect("fp8 fwd on hopper");
269        let key = req.dispatch_key();
270        assert_eq!(key.arch, SmArch::Sm90a);
271        assert_eq!(key.dtype, DType::F8E4m3);
272        assert!(key.causal);
273        assert_eq!(key.head_dim, 128);
274        assert_eq!(key.gqa_ratio, 8);
275
276        let (q_t, kv_t) = req.fp8_dtypes();
277        assert_eq!(q_t, DType::F8E4m3);
278        assert_eq!(kv_t, DType::F8E5m2);
279
280        // Non-fp8 marker types must be rejected.
281        let err = Fa3FwdFp8Request::<F16, F8E5m2>::new(
282            SmArch::Sm90a,
283            128,
284            1,
285            MaskKind::Full,
286            0,
287            1.0 / (128f32).sqrt(),
288            1.0,
289            1.0,
290            1.0,
291            PersistentMode::Grid,
292        )
293        .err()
294        .expect("expected an error");
295        assert!(matches!(err, FlashAttnError::Fp8MustUseFp8Request));
296
297        // Non-Hopper must be rejected.
298        let err = Fa3FwdFp8Request::<F8E4m3, F8E4m3>::new(
299            SmArch::Sm80,
300            128,
301            1,
302            MaskKind::Full,
303            0,
304            1.0 / (128f32).sqrt(),
305            1.0,
306            1.0,
307            1.0,
308            PersistentMode::Grid,
309        )
310        .err()
311        .expect("expected an error");
312        assert!(matches!(
313            err,
314            FlashAttnError::Dispatch(crate::dispatch::DispatchError::Fp8RequiresHopper(_))
315        ));
316    }
317}