use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
use super::flash_sdpa::FLASH_SDPA_MAX_D;
use super::map_status;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum BackendChoice {
Bespoke,
#[cfg(feature = "fa2")]
FlashAttentionV2,
}
impl BackendChoice {
fn as_public(self) -> BackendKind {
match self {
BackendChoice::Bespoke => BackendKind::Bespoke,
#[cfg(feature = "fa2")]
BackendChoice::FlashAttentionV2 => BackendKind::FlashAttentionV2,
}
}
}
#[cfg(feature = "fa2")]
const FA2_BW_SUPPORTED_HEAD_DIMS: &[i32] = &[32, 64, 96, 128, 192, 256];
#[cfg(feature = "fa2")]
#[inline]
fn fa2_bw_supports_head_dim(d: i32) -> bool {
FA2_BW_SUPPORTED_HEAD_DIMS.iter().any(|&v| v == d)
}
#[cfg(feature = "fa2")]
fn should_use_fa2_bw(desc: &FlashSdpaBackwardDescriptor, num_heads_k: i32) -> bool {
if !fa2_bw_supports_head_dim(desc.d_k) || desc.d_k != desc.d_v {
return false;
}
if !matches!(desc.element, ElementKind::F16 | ElementKind::Bf16) {
return false;
}
if num_heads_k <= 0 || num_heads_k > desc.num_heads || desc.num_heads % num_heads_k != 0 {
return false;
}
true
}
#[derive(Copy, Clone, Debug)]
#[non_exhaustive]
pub struct FlashSdpaBackwardDescriptor {
pub batch_size: i32,
pub num_heads: i32,
pub query_len: i32,
pub key_len: i32,
pub d_k: i32,
pub d_v: i32,
pub scale: f32,
pub is_causal: bool,
pub element: ElementKind,
pub window_size_left: Option<i32>,
pub window_size_right: Option<i32>,
pub softcap: f32,
}
impl FlashSdpaBackwardDescriptor {
#[allow(clippy::too_many_arguments)]
#[inline]
pub fn new(
batch_size: i32,
num_heads: i32,
query_len: i32,
key_len: i32,
d_k: i32,
d_v: i32,
scale: f32,
is_causal: bool,
element: ElementKind,
) -> Self {
Self {
batch_size,
num_heads,
query_len,
key_len,
d_k,
d_v,
scale,
is_causal,
element,
window_size_left: None,
window_size_right: None,
softcap: 0.0,
}
}
#[inline]
pub fn with_window_size_left(mut self, n: Option<i32>) -> Self {
self.window_size_left = n;
self
}
#[inline]
pub fn with_window_size_right(mut self, n: Option<i32>) -> Self {
self.window_size_right = n;
self
}
#[inline]
pub fn with_softcap(mut self, cap: f32) -> Self {
self.softcap = cap;
self
}
}
pub struct FlashSdpaBackwardArgs<'a, T: Element> {
pub q: TensorRef<'a, T, 4>,
pub k: TensorRef<'a, T, 4>,
pub v: TensorRef<'a, T, 4>,
pub y: TensorRef<'a, T, 4>,
pub lse: TensorRef<'a, T, 3>,
pub dy: TensorRef<'a, T, 4>,
pub d_ws: TensorMut<'a, T, 3>,
pub dq: TensorMut<'a, T, 4>,
pub dk: TensorMut<'a, T, 4>,
pub dv: TensorMut<'a, T, 4>,
pub lse_f32: Option<TensorRef<'a, f32, 3>>,
pub alibi_slopes: Option<TensorRef<'a, f32, 2>>,
}
pub struct FlashSdpaBackwardPlan<T: Element> {
desc: FlashSdpaBackwardDescriptor,
sku: KernelSku,
backend: BackendChoice,
_marker: PhantomData<T>,
}
impl<T: Element> FlashSdpaBackwardPlan<T> {
pub fn select(
_stream: &Stream,
desc: &FlashSdpaBackwardDescriptor,
pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: descriptor element != T",
));
}
if desc.batch_size < 0
|| desc.num_heads < 0
|| desc.query_len < 0
|| desc.key_len < 0
|| desc.d_k < 0
|| desc.d_v < 0
{
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: extents must be non-negative",
));
}
if !desc.scale.is_finite() {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: scale must be finite",
));
}
if desc.d_k != desc.d_v {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: requires d_k == d_v",
));
}
let dtype_in_scope = matches!(
T::KIND,
ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16 | ElementKind::F64
);
if !dtype_in_scope {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: wired today: `{f32, f16, bf16, f64}`",
));
}
if desc.softcap < 0.0 || !desc.softcap.is_finite() {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: softcap must be finite and non-negative",
));
}
let backend = pick_backend::<T>(desc, pref);
if matches!(backend, BackendChoice::Bespoke) && desc.d_k > FLASH_SDPA_MAX_D {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: bespoke kernel requires d_k ≤ 128 \
(enable `fa2` feature for d_k > 128)",
));
}
#[cfg(feature = "fa2")]
let is_fa2 = matches!(backend, BackendChoice::FlashAttentionV2);
#[cfg(not(feature = "fa2"))]
let is_fa2 = false;
if !is_fa2 {
if desc.window_size_left.is_some() || desc.window_size_right.is_some() {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: sliding window requires the FA2 backend",
));
}
if desc.softcap != 0.0 {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: softcap requires the FA2 backend",
));
}
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: matches!(backend, BackendChoice::Bespoke),
deterministic: matches!(backend, BackendChoice::Bespoke),
};
let sku = KernelSku {
category: OpCategory::Attention,
op: AttentionKind::FlashAttention as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: backend.as_public(),
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
backend,
_marker: PhantomData,
})
}
#[inline]
pub fn backend(&self) -> BackendKind {
self.backend.as_public()
}
pub fn can_implement(&self, args: &FlashSdpaBackwardArgs<'_, T>) -> Result<()> {
let num_heads_k = args.k.shape[1];
if num_heads_k <= 0
|| num_heads_k > self.desc.num_heads
|| self.desc.num_heads % num_heads_k != 0
{
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: K shape[1] (num_heads_k) must divide num_heads",
));
}
let is_gqa = num_heads_k != self.desc.num_heads;
#[cfg(feature = "fa2")]
let backend_is_fa2 = matches!(self.backend, BackendChoice::FlashAttentionV2);
#[cfg(not(feature = "fa2"))]
let backend_is_fa2 = false;
if is_gqa && !backend_is_fa2 {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: GQA on the bespoke backend is unsupported",
));
}
let shape_q = [
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
self.desc.d_k,
];
let shape_k = [
self.desc.batch_size,
num_heads_k,
self.desc.key_len,
self.desc.d_k,
];
let shape_v = [
self.desc.batch_size,
num_heads_k,
self.desc.key_len,
self.desc.d_v,
];
let shape_y = [
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
self.desc.d_v,
];
let shape_lse = [
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
];
if args.q.shape != shape_q {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: Q shape mismatch",
));
}
if args.k.shape != shape_k {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: K shape mismatch",
));
}
if args.v.shape != shape_v {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: V shape mismatch",
));
}
if args.y.shape != shape_y {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: y shape mismatch",
));
}
if args.lse.shape != shape_lse {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: lse shape must be [B, H, Q]",
));
}
if args.dy.shape != shape_y {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: dy shape mismatch",
));
}
if args.d_ws.shape != shape_lse {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: d_ws shape must be [B, H, Q]",
));
}
if args.dq.shape != shape_q {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: dQ shape mismatch with Q",
));
}
if args.dk.shape != shape_k {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: dK shape mismatch with K",
));
}
if args.dv.shape != shape_v {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: dV shape mismatch with V",
));
}
if !args.q.is_contiguous()
|| !args.k.is_contiguous()
|| !args.v.is_contiguous()
|| !args.y.is_contiguous()
|| !args.lse.is_contiguous()
|| !args.dy.is_contiguous()
|| !args.d_ws.is_contiguous()
|| !args.dq.is_contiguous()
|| !args.dk.is_contiguous()
|| !args.dv.is_contiguous()
{
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: requires contiguous tensors",
));
}
if backend_is_fa2 {
if args.lse_f32.is_none() {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: FA2 backend requires lse_f32 \
(FA2 stores LSE in f32 regardless of operand dtype)",
));
}
let lse_f32 = args.lse_f32.as_ref().unwrap();
if lse_f32.shape != shape_lse {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: lse_f32 shape must be [B, H, Q]",
));
}
if !lse_f32.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: lse_f32 must be contiguous",
));
}
if let Some(slopes) = args.alibi_slopes.as_ref() {
if slopes.shape[1] != self.desc.num_heads {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: alibi_slopes shape[1] must equal num_heads",
));
}
if slopes.shape[0] != 1 && slopes.shape[0] != self.desc.batch_size {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashSdpaBackwardPlan: alibi_slopes shape[0] must be 1 or batch_size",
));
}
}
} else if args.alibi_slopes.is_some() {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan: ALiBi requires the FA2 backend",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
match self.backend {
BackendChoice::Bespoke => 0,
#[cfg(feature = "fa2")]
BackendChoice::FlashAttentionV2 => unsafe {
baracuda_kernels_sys::baracuda_kernels_fa2_sdpa_backward_workspace_size(
self.desc.batch_size,
self.desc.num_heads,
self.desc.query_len,
self.desc.d_k,
)
},
}
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: FlashSdpaBackwardArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if args.q.numel() == 0 || args.k.numel() == 0 {
return Ok(());
}
#[cfg(feature = "fa2")]
if matches!(self.backend, BackendChoice::FlashAttentionV2) {
let capturing = stream.is_capturing().unwrap_or(false);
if !capturing {
return self.run_fa2_bw(stream, workspace, &args);
}
}
let _ = workspace;
let stream_ptr = stream.as_raw() as *mut c_void;
let q_ptr = args.q.data.as_raw().0 as *const c_void;
let k_ptr = args.k.data.as_raw().0 as *const c_void;
let v_ptr = args.v.data.as_raw().0 as *const c_void;
let y_ptr = args.y.data.as_raw().0 as *const c_void;
let lse_ptr = args.lse.data.as_raw().0 as *const c_void;
let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
let d_ws_ptr = args.d_ws.data.as_raw().0 as *mut c_void;
let dq_ptr = args.dq.data.as_raw().0 as *mut c_void;
let dk_ptr = args.dk.data.as_raw().0 as *mut c_void;
let dv_ptr = args.dv.data.as_raw().0 as *mut c_void;
let is_causal_flag = if self.desc.is_causal { 1 } else { 0 };
let status = match T::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_flash_sdpa_backward_f32_run(
self.desc.batch_size, self.desc.num_heads,
self.desc.query_len, self.desc.key_len,
self.desc.d_k, self.desc.d_v, self.desc.scale, is_causal_flag,
q_ptr, k_ptr, v_ptr, y_ptr, lse_ptr, dy_ptr, d_ws_ptr,
dq_ptr, dk_ptr, dv_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_flash_sdpa_backward_f16_run(
self.desc.batch_size, self.desc.num_heads,
self.desc.query_len, self.desc.key_len,
self.desc.d_k, self.desc.d_v, self.desc.scale, is_causal_flag,
q_ptr, k_ptr, v_ptr, y_ptr, lse_ptr, dy_ptr, d_ws_ptr,
dq_ptr, dk_ptr, dv_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_flash_sdpa_backward_bf16_run(
self.desc.batch_size, self.desc.num_heads,
self.desc.query_len, self.desc.key_len,
self.desc.d_k, self.desc.d_v, self.desc.scale, is_causal_flag,
q_ptr, k_ptr, v_ptr, y_ptr, lse_ptr, dy_ptr, d_ws_ptr,
dq_ptr, dk_ptr, dv_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
ElementKind::F64 => unsafe {
baracuda_kernels_sys::baracuda_kernels_flash_sdpa_backward_f64_run(
self.desc.batch_size, self.desc.num_heads,
self.desc.query_len, self.desc.key_len,
self.desc.d_k, self.desc.d_v, self.desc.scale, is_causal_flag,
q_ptr, k_ptr, v_ptr, y_ptr, lse_ptr, dy_ptr, d_ws_ptr,
dq_ptr, dk_ptr, dv_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan::run reached an unimplemented dtype",
));
}
};
map_status(status)
}
#[cfg(feature = "fa2")]
fn run_fa2_bw(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: &FlashSdpaBackwardArgs<'_, T>,
) -> Result<()> {
let stream_ptr = stream.as_raw() as *mut c_void;
let q_ptr = args.q.data.as_raw().0 as *const c_void;
let k_ptr = args.k.data.as_raw().0 as *const c_void;
let v_ptr = args.v.data.as_raw().0 as *const c_void;
let y_ptr = args.y.data.as_raw().0 as *const c_void;
let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
let dq_ptr = args.dq.data.as_raw().0 as *mut c_void;
let dk_ptr = args.dk.data.as_raw().0 as *mut c_void;
let dv_ptr = args.dv.data.as_raw().0 as *mut c_void;
let lse_ptr = args.lse_f32.as_ref().unwrap().data.as_raw().0 as *const c_void;
let is_causal_flag = if self.desc.is_causal { 1 } else { 0 };
let num_heads_k = args.k.shape[1];
let (alibi_ptr, alibi_batch_stride) = match args.alibi_slopes.as_ref() {
None => (core::ptr::null::<c_void>(), 0i32),
Some(slopes) => {
let ptr = slopes.data.as_raw().0 as *const c_void;
let batch_stride = if slopes.shape[0] == 1 {
0_i32
} else {
self.desc.num_heads
};
(ptr, batch_stride)
}
};
let window_left = self.desc.window_size_left.unwrap_or(-1);
let window_right = self.desc.window_size_right.unwrap_or(-1);
let softcap = self.desc.softcap;
let need = self.workspace_size();
let (ws_ptr, ws_bytes) = match workspace {
Workspace::None => {
if need > 0 {
return Err(Error::WorkspaceTooSmall { needed: need, got: 0 });
}
(core::ptr::null_mut::<c_void>(), 0usize)
}
Workspace::Borrowed(slice) => {
if slice.len() < need {
return Err(Error::WorkspaceTooSmall {
needed: need,
got: slice.len(),
});
}
(slice.as_raw().0 as *mut c_void, slice.len())
}
};
let status = match T::KIND {
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_fa2_sdpa_backward_f16_run(
self.desc.batch_size, self.desc.num_heads, num_heads_k,
self.desc.query_len, self.desc.key_len, self.desc.d_k,
self.desc.scale, is_causal_flag,
alibi_ptr, alibi_batch_stride,
window_left, window_right, softcap,
q_ptr, k_ptr, v_ptr, y_ptr, dy_ptr, lse_ptr,
dq_ptr, dk_ptr, dv_ptr,
ws_ptr, ws_bytes, stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_fa2_sdpa_backward_bf16_run(
self.desc.batch_size, self.desc.num_heads, num_heads_k,
self.desc.query_len, self.desc.key_len, self.desc.d_k,
self.desc.scale, is_causal_flag,
alibi_ptr, alibi_batch_stride,
window_left, window_right, softcap,
q_ptr, k_ptr, v_ptr, y_ptr, dy_ptr, lse_ptr,
dq_ptr, dk_ptr, dv_ptr,
ws_ptr, ws_bytes, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::FlashSdpaBackwardPlan::run_fa2_bw: FA2 BW supports only f16 / bf16",
));
}
};
map_status(status)
}
}
fn pick_backend<T: Element>(
#[cfg_attr(not(feature = "fa2"), allow(unused_variables))] desc: &FlashSdpaBackwardDescriptor,
pref: PlanPreference,
) -> BackendChoice {
match pref.prefer_backend {
Some(BackendKind::Bespoke) => BackendChoice::Bespoke,
#[cfg(feature = "fa2")]
Some(BackendKind::FlashAttentionV2) => {
if fa2_bw_is_eligible::<T>(desc) {
BackendChoice::FlashAttentionV2
} else {
BackendChoice::Bespoke
}
}
_ => {
#[cfg(feature = "fa2")]
{
if should_use_fa2_bw(desc, desc.num_heads) {
return BackendChoice::FlashAttentionV2;
}
}
BackendChoice::Bespoke
}
}
}
#[cfg(feature = "fa2")]
fn fa2_bw_is_eligible<T: Element>(desc: &FlashSdpaBackwardDescriptor) -> bool {
fa2_bw_supports_head_dim(desc.d_k)
&& desc.d_k == desc.d_v
&& matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16)
}