use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
use super::map_status;
pub const FLASH_DECODING_MAX_D: i32 = 128;
const CHUNK_K: i32 = 256;
#[derive(Copy, Clone, Debug)]
pub struct FlashDecodingDescriptor {
pub batch_size: i32,
pub num_heads: i32,
pub num_kv_heads: i32,
pub k_len: i32,
pub head_dim: i32,
pub scale: f32,
pub element: ElementKind,
}
impl FlashDecodingDescriptor {
#[inline]
pub fn new(batch_size: i32, num_heads: i32, k_len: i32, head_dim: i32, element: ElementKind) -> Self {
let scale = 1.0_f32 / (head_dim as f32).sqrt();
Self {
batch_size,
num_heads,
num_kv_heads: num_heads,
k_len,
head_dim,
scale,
element,
}
}
#[inline]
pub fn new_gqa(
batch_size: i32,
num_heads: i32,
num_kv_heads: i32,
k_len: i32,
head_dim: i32,
element: ElementKind,
) -> Self {
let scale = 1.0_f32 / (head_dim as f32).sqrt();
Self {
batch_size,
num_heads,
num_kv_heads,
k_len,
head_dim,
scale,
element,
}
}
#[inline]
pub fn with_scale(mut self, scale: f32) -> Self {
self.scale = scale;
self
}
#[inline]
pub fn group_size(&self) -> i32 {
if self.num_kv_heads == 0 {
0
} else {
self.num_heads / self.num_kv_heads
}
}
}
pub struct FlashDecodingArgs<'a, T: Element> {
pub q: TensorRef<'a, T, 3>,
pub k: TensorRef<'a, T, 4>,
pub v: TensorRef<'a, T, 4>,
pub y: TensorMut<'a, T, 3>,
}
pub struct FlashDecodingPlan<T: Element> {
desc: FlashDecodingDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> FlashDecodingPlan<T> {
pub fn select(
_stream: &Stream,
desc: &FlashDecodingDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::FlashDecodingPlan: descriptor element != T",
));
}
if desc.batch_size <= 0
|| desc.num_heads <= 0
|| desc.num_kv_heads <= 0
|| desc.k_len < 0
|| desc.head_dim <= 0
{
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashDecodingPlan: extents must be positive (k_len may be 0)",
));
}
if desc.num_heads % desc.num_kv_heads != 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::FlashDecodingPlan: num_heads must be a multiple of num_kv_heads",
));
}
if desc.head_dim > FLASH_DECODING_MAX_D {
return Err(Error::Unsupported(
"baracuda-kernels::FlashDecodingPlan: head_dim > 128 not supported",
));
}
if !matches!(T::KIND, ElementKind::F16 | ElementKind::Bf16) {
return Err(Error::Unsupported(
"baracuda-kernels::FlashDecodingPlan: wired today: {f16, bf16}",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Attention,
op: AttentionKind::FlashAttention as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &FlashDecodingArgs<'_, T>) -> Result<()> {
let d = self.desc.head_dim;
let b = self.desc.batch_size;
let h_q = self.desc.num_heads;
let h_kv = self.desc.num_kv_heads;
let k = self.desc.k_len;
if args.q.shape != [b, h_q, d] {
return Err(Error::InvalidProblem(
"FlashDecodingPlan: q.shape mismatch (expected [B, H_q, D])",
));
}
if args.y.shape != [b, h_q, d] {
return Err(Error::InvalidProblem(
"FlashDecodingPlan: y.shape mismatch (expected [B, H_q, D])",
));
}
if args.k.shape != [b, h_kv, k, d] {
return Err(Error::InvalidProblem(
"FlashDecodingPlan: k.shape mismatch (expected [B, H_kv, K_len, D])",
));
}
if args.v.shape != [b, h_kv, k, d] {
return Err(Error::InvalidProblem(
"FlashDecodingPlan: v.shape mismatch (expected [B, H_kv, K_len, D])",
));
}
Ok(())
}
#[inline]
pub fn backend(&self) -> BackendKind {
BackendKind::Bespoke
}
#[inline]
pub fn sku(&self) -> &KernelSku {
&self.sku
}
pub fn workspace_size(&self) -> usize {
let b = self.desc.batch_size as i64;
let h = self.desc.num_heads as i64;
let s = num_splits(self.desc.k_len) as i64;
let d = self.desc.head_dim as i64;
if s == 0 || b == 0 || h == 0 {
return 0;
}
(b * h * s * (2 + d) * 4) as usize
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: FlashDecodingArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
let needed = self.workspace_size();
let (ws_ptr, ws_bytes) = match workspace {
Workspace::None => {
if needed > 0 {
return Err(Error::WorkspaceTooSmall {
needed,
got: 0,
});
}
(core::ptr::null_mut::<c_void>(), 0_usize)
}
Workspace::Borrowed(buf) => {
if buf.len() < needed {
return Err(Error::WorkspaceTooSmall {
needed,
got: buf.len(),
});
}
(buf.as_raw().0 as *mut c_void, buf.len())
}
};
let stream_ptr = stream.as_raw() as *mut c_void;
let q_ptr = args.q.data.as_raw().0 as *const c_void;
let k_ptr = args.k.data.as_raw().0 as *const c_void;
let v_ptr = args.v.data.as_raw().0 as *const c_void;
let y_ptr = args.y.data.as_raw().0 as *mut c_void;
let status = unsafe {
match T::KIND {
ElementKind::F16 => baracuda_kernels_sys::baracuda_kernels_flash_decoding_f16_run(
q_ptr,
k_ptr,
v_ptr,
y_ptr,
ws_ptr,
ws_bytes,
self.desc.batch_size,
self.desc.num_heads,
self.desc.num_kv_heads,
self.desc.k_len,
self.desc.head_dim,
args.q.stride[0],
args.q.stride[1],
args.k.stride[0],
args.k.stride[1],
args.k.stride[2],
args.v.stride[0],
args.v.stride[1],
args.v.stride[2],
args.y.stride[0],
args.y.stride[1],
self.desc.scale,
stream_ptr,
),
ElementKind::Bf16 => baracuda_kernels_sys::baracuda_kernels_flash_decoding_bf16_run(
q_ptr,
k_ptr,
v_ptr,
y_ptr,
ws_ptr,
ws_bytes,
self.desc.batch_size,
self.desc.num_heads,
self.desc.num_kv_heads,
self.desc.k_len,
self.desc.head_dim,
args.q.stride[0],
args.q.stride[1],
args.k.stride[0],
args.k.stride[1],
args.k.stride[2],
args.v.stride[0],
args.v.stride[1],
args.v.stride[2],
args.y.stride[0],
args.y.stride[1],
self.desc.scale,
stream_ptr,
),
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::FlashDecodingPlan: only f16 / bf16 wired",
));
}
}
};
map_status(status)
}
}
#[inline]
fn num_splits(k_len: i32) -> i32 {
if k_len <= 0 {
return 0;
}
(k_len + CHUNK_K - 1) / CHUNK_K
}