1use std::marker::PhantomData;
13
14use tokio::sync::oneshot;
15
16use crate::dispatch::{DType, DispatchKey, FaFwdDispatch, GemmSupported, SmArch};
17use crate::fa2::{MaskKind, PositionBias};
18use crate::FlashAttnError;
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum PersistentMode {
26 Grid,
28 Persistent { num_sms: u32 },
30}
31
32pub struct Fa3FwdRequest<T: GemmSupported> {
37 pub arch: SmArch,
38 pub head_dim: u32,
39 pub gqa_ratio: u32,
40 pub mask: MaskKind,
41 pub bias: PositionBias,
42 pub sink_tokens: u32,
43 pub softmax_scale: f32,
44 pub persistent: PersistentMode,
45 pub reply: oneshot::Sender<Result<(), FlashAttnError>>,
48 _marker: PhantomData<T>,
49}
50
51impl<T: GemmSupported> Fa3FwdRequest<T> {
52 pub fn new(
53 arch: SmArch,
54 head_dim: u32,
55 gqa_ratio: u32,
56 mask: MaskKind,
57 bias: PositionBias,
58 sink_tokens: u32,
59 softmax_scale: f32,
60 persistent: PersistentMode,
61 ) -> Result<(Self, oneshot::Receiver<Result<(), FlashAttnError>>), FlashAttnError> {
62 if !arch.supports_fa3() {
63 return Err(FlashAttnError::Fa3RequiresHopper(arch));
64 }
65 if T::dtype() == DType::F8E4m3 || T::dtype() == DType::F8E5m2 {
66 return Err(FlashAttnError::Fp8MustUseFp8Request);
67 }
68 let (tx, rx) = oneshot::channel();
69 let req = Self {
70 arch,
71 head_dim,
72 gqa_ratio,
73 mask,
74 bias,
75 sink_tokens,
76 softmax_scale,
77 persistent,
78 reply: tx,
79 _marker: PhantomData,
80 };
81 let key = req.compute_key();
82 key.validate_fwd().map_err(FlashAttnError::Dispatch)?;
83 Ok((req, rx))
84 }
85
86 fn compute_key(&self) -> DispatchKey {
87 DispatchKey {
88 arch: self.arch,
89 dtype: T::dtype(),
90 head_dim: self.head_dim,
91 causal: self.mask.causal(),
92 varlen: false,
93 sliding_window: self.mask.sliding_window(),
94 alibi: self.bias.requires_alibi_flag(),
95 sink: self.sink_tokens,
96 paged: false,
97 gqa_ratio: self.gqa_ratio,
98 }
99 }
100
101 pub fn is_persistent(&self) -> bool {
103 matches!(self.persistent, PersistentMode::Persistent { .. })
104 }
105}
106
107impl<T: GemmSupported> FaFwdDispatch for Fa3FwdRequest<T> {
108 fn dispatch_key(&self) -> DispatchKey {
109 self.compute_key()
110 }
111}
112
113#[cfg(feature = "fp8")]
117pub struct Fa3FwdFp8Request<TQ: GemmSupported, TKV: GemmSupported> {
118 pub arch: SmArch,
119 pub head_dim: u32,
120 pub gqa_ratio: u32,
121 pub mask: MaskKind,
122 pub sink_tokens: u32,
123 pub softmax_scale: f32,
124 pub q_scale: f32,
128 pub k_scale: f32,
130 pub v_scale: f32,
132 pub persistent: PersistentMode,
133 pub reply: oneshot::Sender<Result<(), FlashAttnError>>,
134 _marker: PhantomData<(TQ, TKV)>,
135}
136
137#[cfg(feature = "fp8")]
138impl<TQ: GemmSupported, TKV: GemmSupported> Fa3FwdFp8Request<TQ, TKV> {
139 pub fn new(
140 arch: SmArch,
141 head_dim: u32,
142 gqa_ratio: u32,
143 mask: MaskKind,
144 sink_tokens: u32,
145 softmax_scale: f32,
146 q_scale: f32,
147 k_scale: f32,
148 v_scale: f32,
149 persistent: PersistentMode,
150 ) -> Result<(Self, oneshot::Receiver<Result<(), FlashAttnError>>), FlashAttnError> {
151 if !arch.supports_fp8() {
152 return Err(FlashAttnError::Dispatch(
153 crate::dispatch::DispatchError::Fp8RequiresHopper(arch),
154 ));
155 }
156 if !TQ::dtype().is_fp8() || !TKV::dtype().is_fp8() {
157 return Err(FlashAttnError::Fp8MustUseFp8Request);
158 }
159 let (tx, rx) = oneshot::channel();
160 let req = Self {
161 arch,
162 head_dim,
163 gqa_ratio,
164 mask,
165 sink_tokens,
166 softmax_scale,
167 q_scale,
168 k_scale,
169 v_scale,
170 persistent,
171 reply: tx,
172 _marker: PhantomData,
173 };
174 let key = req.compute_key();
178 key.validate_fwd().map_err(FlashAttnError::Dispatch)?;
179 Ok((req, rx))
180 }
181
182 fn compute_key(&self) -> DispatchKey {
183 DispatchKey {
184 arch: self.arch,
185 dtype: TQ::dtype(),
186 head_dim: self.head_dim,
187 causal: self.mask.causal(),
188 varlen: false,
189 sliding_window: self.mask.sliding_window(),
190 alibi: false,
191 sink: self.sink_tokens,
192 paged: false,
193 gqa_ratio: self.gqa_ratio,
194 }
195 }
196
197 pub fn fp8_dtypes(&self) -> (DType, DType) {
200 (TQ::dtype(), TKV::dtype())
201 }
202}
203
204#[cfg(feature = "fp8")]
205impl<TQ: GemmSupported, TKV: GemmSupported> FaFwdDispatch for Fa3FwdFp8Request<TQ, TKV> {
206 fn dispatch_key(&self) -> DispatchKey {
207 self.compute_key()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214 use crate::dispatch::{Bf16, F16};
215
216 #[test]
217 fn fa3_fwd_request_requires_hopper() {
218 let err = Fa3FwdRequest::<F16>::new(
220 SmArch::Sm80,
221 128,
222 1,
223 MaskKind::Causal,
224 PositionBias::None,
225 0,
226 1.0 / (128f32).sqrt(),
227 PersistentMode::Grid,
228 )
229 .err()
230 .expect("expected an error");
231 assert!(matches!(err, FlashAttnError::Fa3RequiresHopper(_)));
232
233 let (req, _rx) = Fa3FwdRequest::<Bf16>::new(
235 SmArch::Sm90a,
236 128,
237 8,
238 MaskKind::Causal,
239 PositionBias::None,
240 0,
241 1.0 / (128f32).sqrt(),
242 PersistentMode::Persistent { num_sms: 132 },
243 )
244 .expect("fa3 fwd on hopper");
245 assert!(req.is_persistent());
246 let key = req.dispatch_key();
247 assert_eq!(key.arch, SmArch::Sm90a);
248 assert_eq!(key.dtype, DType::Bf16);
249 }
250
251 #[cfg(feature = "fp8")]
252 #[test]
253 fn fa3_fwd_fp8_request_round_trip() {
254 use crate::dispatch::{F8E4m3, F8E5m2};
255
256 let (req, _rx) = Fa3FwdFp8Request::<F8E4m3, F8E5m2>::new(
257 SmArch::Sm90a,
258 128,
259 8,
260 MaskKind::Causal,
261 0,
262 1.0 / (128f32).sqrt(),
263 1.0,
264 1.0,
265 1.0,
266 PersistentMode::Persistent { num_sms: 132 },
267 )
268 .expect("fp8 fwd on hopper");
269 let key = req.dispatch_key();
270 assert_eq!(key.arch, SmArch::Sm90a);
271 assert_eq!(key.dtype, DType::F8E4m3);
272 assert!(key.causal);
273 assert_eq!(key.head_dim, 128);
274 assert_eq!(key.gqa_ratio, 8);
275
276 let (q_t, kv_t) = req.fp8_dtypes();
277 assert_eq!(q_t, DType::F8E4m3);
278 assert_eq!(kv_t, DType::F8E5m2);
279
280 let err = Fa3FwdFp8Request::<F16, F8E5m2>::new(
282 SmArch::Sm90a,
283 128,
284 1,
285 MaskKind::Full,
286 0,
287 1.0 / (128f32).sqrt(),
288 1.0,
289 1.0,
290 1.0,
291 PersistentMode::Grid,
292 )
293 .err()
294 .expect("expected an error");
295 assert!(matches!(err, FlashAttnError::Fp8MustUseFp8Request));
296
297 let err = Fa3FwdFp8Request::<F8E4m3, F8E4m3>::new(
299 SmArch::Sm80,
300 128,
301 1,
302 MaskKind::Full,
303 0,
304 1.0 / (128f32).sqrt(),
305 1.0,
306 1.0,
307 1.0,
308 PersistentMode::Grid,
309 )
310 .err()
311 .expect("expected an error");
312 assert!(matches!(
313 err,
314 FlashAttnError::Dispatch(crate::dispatch::DispatchError::Fp8RequiresHopper(_))
315 ));
316 }
317}