use std::marker::PhantomData;
use tokio::sync::oneshot;
use crate::dispatch::{DType, DispatchKey, FaFwdDispatch, GemmSupported, SmArch};
use crate::fa2::{MaskKind, PositionBias};
use crate::FlashAttnError;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PersistentMode {
Grid,
Persistent { num_sms: u32 },
}
pub struct Fa3FwdRequest<T: GemmSupported> {
pub arch: SmArch,
pub head_dim: u32,
pub gqa_ratio: u32,
pub mask: MaskKind,
pub bias: PositionBias,
pub sink_tokens: u32,
pub softmax_scale: f32,
pub persistent: PersistentMode,
pub reply: oneshot::Sender<Result<(), FlashAttnError>>,
_marker: PhantomData<T>,
}
impl<T: GemmSupported> Fa3FwdRequest<T> {
pub fn new(
arch: SmArch,
head_dim: u32,
gqa_ratio: u32,
mask: MaskKind,
bias: PositionBias,
sink_tokens: u32,
softmax_scale: f32,
persistent: PersistentMode,
) -> Result<(Self, oneshot::Receiver<Result<(), FlashAttnError>>), FlashAttnError> {
if !arch.supports_fa3() {
return Err(FlashAttnError::Fa3RequiresHopper(arch));
}
if T::dtype() == DType::F8E4m3 || T::dtype() == DType::F8E5m2 {
return Err(FlashAttnError::Fp8MustUseFp8Request);
}
let (tx, rx) = oneshot::channel();
let req = Self {
arch,
head_dim,
gqa_ratio,
mask,
bias,
sink_tokens,
softmax_scale,
persistent,
reply: tx,
_marker: PhantomData,
};
let key = req.compute_key();
key.validate_fwd().map_err(FlashAttnError::Dispatch)?;
Ok((req, rx))
}
fn compute_key(&self) -> DispatchKey {
DispatchKey {
arch: self.arch,
dtype: T::dtype(),
head_dim: self.head_dim,
causal: self.mask.causal(),
varlen: false,
sliding_window: self.mask.sliding_window(),
alibi: self.bias.requires_alibi_flag(),
sink: self.sink_tokens,
paged: false,
gqa_ratio: self.gqa_ratio,
}
}
pub fn is_persistent(&self) -> bool {
matches!(self.persistent, PersistentMode::Persistent { .. })
}
}
impl<T: GemmSupported> FaFwdDispatch for Fa3FwdRequest<T> {
fn dispatch_key(&self) -> DispatchKey {
self.compute_key()
}
}
#[cfg(feature = "fp8")]
pub struct Fa3FwdFp8Request<TQ: GemmSupported, TKV: GemmSupported> {
pub arch: SmArch,
pub head_dim: u32,
pub gqa_ratio: u32,
pub mask: MaskKind,
pub sink_tokens: u32,
pub softmax_scale: f32,
pub q_scale: f32,
pub k_scale: f32,
pub v_scale: f32,
pub persistent: PersistentMode,
pub reply: oneshot::Sender<Result<(), FlashAttnError>>,
_marker: PhantomData<(TQ, TKV)>,
}
#[cfg(feature = "fp8")]
impl<TQ: GemmSupported, TKV: GemmSupported> Fa3FwdFp8Request<TQ, TKV> {
pub fn new(
arch: SmArch,
head_dim: u32,
gqa_ratio: u32,
mask: MaskKind,
sink_tokens: u32,
softmax_scale: f32,
q_scale: f32,
k_scale: f32,
v_scale: f32,
persistent: PersistentMode,
) -> Result<(Self, oneshot::Receiver<Result<(), FlashAttnError>>), FlashAttnError> {
if !arch.supports_fp8() {
return Err(FlashAttnError::Dispatch(
crate::dispatch::DispatchError::Fp8RequiresHopper(arch),
));
}
if !TQ::dtype().is_fp8() || !TKV::dtype().is_fp8() {
return Err(FlashAttnError::Fp8MustUseFp8Request);
}
let (tx, rx) = oneshot::channel();
let req = Self {
arch,
head_dim,
gqa_ratio,
mask,
sink_tokens,
softmax_scale,
q_scale,
k_scale,
v_scale,
persistent,
reply: tx,
_marker: PhantomData,
};
let key = req.compute_key();
key.validate_fwd().map_err(FlashAttnError::Dispatch)?;
Ok((req, rx))
}
fn compute_key(&self) -> DispatchKey {
DispatchKey {
arch: self.arch,
dtype: TQ::dtype(),
head_dim: self.head_dim,
causal: self.mask.causal(),
varlen: false,
sliding_window: self.mask.sliding_window(),
alibi: false,
sink: self.sink_tokens,
paged: false,
gqa_ratio: self.gqa_ratio,
}
}
pub fn fp8_dtypes(&self) -> (DType, DType) {
(TQ::dtype(), TKV::dtype())
}
}
#[cfg(feature = "fp8")]
impl<TQ: GemmSupported, TKV: GemmSupported> FaFwdDispatch for Fa3FwdFp8Request<TQ, TKV> {
fn dispatch_key(&self) -> DispatchKey {
self.compute_key()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dispatch::{Bf16, F16};
#[test]
fn fa3_fwd_request_requires_hopper() {
let err = Fa3FwdRequest::<F16>::new(
SmArch::Sm80,
128,
1,
MaskKind::Causal,
PositionBias::None,
0,
1.0 / (128f32).sqrt(),
PersistentMode::Grid,
)
.err()
.expect("expected an error");
assert!(matches!(err, FlashAttnError::Fa3RequiresHopper(_)));
let (req, _rx) = Fa3FwdRequest::<Bf16>::new(
SmArch::Sm90a,
128,
8,
MaskKind::Causal,
PositionBias::None,
0,
1.0 / (128f32).sqrt(),
PersistentMode::Persistent { num_sms: 132 },
)
.expect("fa3 fwd on hopper");
assert!(req.is_persistent());
let key = req.dispatch_key();
assert_eq!(key.arch, SmArch::Sm90a);
assert_eq!(key.dtype, DType::Bf16);
}
#[cfg(feature = "fp8")]
#[test]
fn fa3_fwd_fp8_request_round_trip() {
use crate::dispatch::{F8E4m3, F8E5m2};
let (req, _rx) = Fa3FwdFp8Request::<F8E4m3, F8E5m2>::new(
SmArch::Sm90a,
128,
8,
MaskKind::Causal,
0,
1.0 / (128f32).sqrt(),
1.0,
1.0,
1.0,
PersistentMode::Persistent { num_sms: 132 },
)
.expect("fp8 fwd on hopper");
let key = req.dispatch_key();
assert_eq!(key.arch, SmArch::Sm90a);
assert_eq!(key.dtype, DType::F8E4m3);
assert!(key.causal);
assert_eq!(key.head_dim, 128);
assert_eq!(key.gqa_ratio, 8);
let (q_t, kv_t) = req.fp8_dtypes();
assert_eq!(q_t, DType::F8E4m3);
assert_eq!(kv_t, DType::F8E5m2);
let err = Fa3FwdFp8Request::<F16, F8E5m2>::new(
SmArch::Sm90a,
128,
1,
MaskKind::Full,
0,
1.0 / (128f32).sqrt(),
1.0,
1.0,
1.0,
PersistentMode::Grid,
)
.err()
.expect("expected an error");
assert!(matches!(err, FlashAttnError::Fp8MustUseFp8Request));
let err = Fa3FwdFp8Request::<F8E4m3, F8E4m3>::new(
SmArch::Sm80,
128,
1,
MaskKind::Full,
0,
1.0 / (128f32).sqrt(),
1.0,
1.0,
1.0,
PersistentMode::Grid,
)
.err()
.expect("expected an error");
assert!(matches!(
err,
FlashAttnError::Dispatch(crate::dispatch::DispatchError::Fp8RequiresHopper(_))
));
}
}