use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
use super::map_status;
#[derive(Copy, Clone, Debug)]
pub struct SsdChunkScanDescriptor {
pub batch_size: i32,
pub seq_len: i32,
pub num_heads: i32,
pub head_dim: i32,
pub state_dim: i32,
pub chunk_size: i32,
pub element: ElementKind,
}
pub struct SsdChunkScanArgs<'a, T: Element> {
pub x: TensorRef<'a, T, 4>,
pub dt: TensorRef<'a, T, 3>,
pub a: TensorRef<'a, T, 1>,
pub b: TensorRef<'a, T, 4>,
pub c: TensorRef<'a, T, 4>,
pub y: TensorMut<'a, T, 4>,
}
pub struct SsdChunkScanPlan<T: Element> {
desc: SsdChunkScanDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> SsdChunkScanPlan<T> {
pub fn select(
_stream: &Stream,
desc: &SsdChunkScanDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::SsdChunkScanPlan: descriptor element != T",
));
}
if desc.batch_size < 0
|| desc.seq_len < 0
|| desc.num_heads < 0
|| desc.head_dim < 0
|| desc.state_dim < 0
{
return Err(Error::InvalidProblem(
"baracuda-kernels::SsdChunkScanPlan: extents must be non-negative",
));
}
if desc.chunk_size <= 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::SsdChunkScanPlan: chunk_size must be positive",
));
}
if desc.head_dim > 256 || desc.state_dim > 256 {
return Err(Error::Unsupported(
"baracuda-kernels::SsdChunkScanPlan: head_dim and state_dim must be <= 256 in the trailblazer",
));
}
let dtype_in_scope = matches!(
T::KIND,
ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16
);
if !dtype_in_scope {
return Err(Error::Unsupported(
"baracuda-kernels::SsdChunkScanPlan: wired today: `{f32, f16, bf16}`",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Attention,
op: AttentionKind::SsdChunkScan as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &SsdChunkScanArgs<'_, T>) -> Result<()> {
let shape_x = [
self.desc.batch_size,
self.desc.seq_len,
self.desc.num_heads,
self.desc.head_dim,
];
let shape_dt = [self.desc.batch_size, self.desc.seq_len, self.desc.num_heads];
let shape_a = [self.desc.num_heads];
let shape_bn = [
self.desc.batch_size,
self.desc.seq_len,
self.desc.num_heads,
self.desc.state_dim,
];
if args.x.shape != shape_x || args.y.shape != shape_x {
return Err(Error::InvalidProblem(
"baracuda-kernels::SsdChunkScanPlan: x / y shape must be [B, L, H, D]",
));
}
if args.dt.shape != shape_dt {
return Err(Error::InvalidProblem(
"baracuda-kernels::SsdChunkScanPlan: dt shape must be [B, L, H]",
));
}
if args.a.shape != shape_a {
return Err(Error::InvalidProblem(
"baracuda-kernels::SsdChunkScanPlan: A shape must be [H]",
));
}
if args.b.shape != shape_bn || args.c.shape != shape_bn {
return Err(Error::InvalidProblem(
"baracuda-kernels::SsdChunkScanPlan: B / C shape must be [B, L, H, N]",
));
}
if !args.x.is_contiguous()
|| !args.dt.is_contiguous()
|| !args.a.is_contiguous()
|| !args.b.is_contiguous()
|| !args.c.is_contiguous()
|| !args.y.is_contiguous()
{
return Err(Error::Unsupported(
"baracuda-kernels::SsdChunkScanPlan: trailblazer requires contiguous tensors",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: SsdChunkScanArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.batch_size == 0
|| self.desc.seq_len == 0
|| self.desc.num_heads == 0
{
return Ok(());
}
let stream_ptr = stream.as_raw() as *mut c_void;
let x_ptr = args.x.data.as_raw().0 as *const c_void;
let dt_ptr = args.dt.data.as_raw().0 as *const c_void;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let c_ptr = args.c.data.as_raw().0 as *const c_void;
let y_ptr = args.y.data.as_raw().0 as *mut c_void;
let status = match T::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ssd_chunk_scan_f32_run(
self.desc.batch_size, self.desc.seq_len, self.desc.num_heads,
self.desc.head_dim, self.desc.state_dim, self.desc.chunk_size,
x_ptr, dt_ptr, a_ptr, b_ptr, c_ptr, y_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ssd_chunk_scan_f16_run(
self.desc.batch_size, self.desc.seq_len, self.desc.num_heads,
self.desc.head_dim, self.desc.state_dim, self.desc.chunk_size,
x_ptr, dt_ptr, a_ptr, b_ptr, c_ptr, y_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ssd_chunk_scan_bf16_run(
self.desc.batch_size, self.desc.seq_len, self.desc.num_heads,
self.desc.head_dim, self.desc.state_dim, self.desc.chunk_size,
x_ptr, dt_ptr, a_ptr, b_ptr, c_ptr, y_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => return Err(Error::Unsupported(
"baracuda-kernels::SsdChunkScanPlan: dtype not wired",
)),
};
map_status(status)
}
}
#[derive(Copy, Clone, Debug)]
pub struct SsdChunkScanBackwardDescriptor {
pub batch_size: i32,
pub seq_len: i32,
pub num_heads: i32,
pub head_dim: i32,
pub state_dim: i32,
pub chunk_size: i32,
pub element: ElementKind,
}
pub struct SsdChunkScanBackwardArgs<'a, T: Element> {
pub x: TensorRef<'a, T, 4>,
pub dt: TensorRef<'a, T, 3>,
pub a: TensorRef<'a, T, 1>,
pub b: TensorRef<'a, T, 4>,
pub c: TensorRef<'a, T, 4>,
pub dy: TensorRef<'a, T, 4>,
pub dx: TensorMut<'a, T, 4>,
pub d_b: TensorMut<'a, T, 4>,
pub d_c: TensorMut<'a, T, 4>,
pub d_dt: TensorMut<'a, T, 3>,
pub d_a: TensorMut<'a, T, 1>,
}
pub struct SsdChunkScanBackwardPlan<T: Element> {
desc: SsdChunkScanBackwardDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> SsdChunkScanBackwardPlan<T> {
pub fn select(
_stream: &Stream,
desc: &SsdChunkScanBackwardDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::SsdChunkScanBackwardPlan: descriptor element != T",
));
}
if desc.head_dim > 64 || desc.state_dim > 64 {
return Err(Error::Unsupported(
"baracuda-kernels::SsdChunkScanBackwardPlan: BW caps head_dim/state_dim at 64 (SMEM budget)",
));
}
let dtype_in_scope = matches!(
T::KIND,
ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16
);
if !dtype_in_scope {
return Err(Error::Unsupported(
"baracuda-kernels::SsdChunkScanBackwardPlan: wired today: `{f32, f16, bf16}`",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: false,
deterministic: false,
};
let sku = KernelSku {
category: OpCategory::Attention,
op: AttentionKind::SsdChunkScan as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
#[inline]
pub fn workspace_size(&self) -> usize {
unsafe {
baracuda_kernels_sys::baracuda_kernels_ssd_chunk_scan_workspace_bytes(
self.desc.batch_size, self.desc.seq_len, self.desc.num_heads,
self.desc.head_dim, self.desc.state_dim,
self.desc.chunk_size, 0,
)
}
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: SsdChunkScanBackwardArgs<'_, T>,
) -> Result<()> {
if self.desc.batch_size == 0
|| self.desc.seq_len == 0
|| self.desc.num_heads == 0
{
return Ok(());
}
let stream_ptr = stream.as_raw() as *mut c_void;
let (ws_ptr, ws_bytes) = match workspace {
Workspace::Borrowed(buf) => (
buf.as_raw().0 as *mut c_void,
buf.len(),
),
Workspace::None => (core::ptr::null_mut(), 0usize),
};
if ws_bytes < self.workspace_size() {
return Err(Error::WorkspaceTooSmall {
needed: self.workspace_size(),
got: ws_bytes,
});
}
let x_ptr = args.x.data.as_raw().0 as *const c_void;
let dt_ptr = args.dt.data.as_raw().0 as *const c_void;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let c_ptr = args.c.data.as_raw().0 as *const c_void;
let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
let dx_ptr = args.dx.data.as_raw().0 as *mut c_void;
let db_ptr = args.d_b.data.as_raw().0 as *mut c_void;
let dc_ptr = args.d_c.data.as_raw().0 as *mut c_void;
let ddt_ptr = args.d_dt.data.as_raw().0 as *mut c_void;
let da_ptr = args.d_a.data.as_raw().0 as *mut c_void;
let status = match T::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ssd_chunk_scan_f32_backward_run(
self.desc.batch_size, self.desc.seq_len, self.desc.num_heads,
self.desc.head_dim, self.desc.state_dim, self.desc.chunk_size,
x_ptr, dt_ptr, a_ptr, b_ptr, c_ptr, dy_ptr,
dx_ptr, db_ptr, dc_ptr, ddt_ptr, da_ptr,
ws_ptr, ws_bytes, stream_ptr,
)
},
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ssd_chunk_scan_f16_backward_run(
self.desc.batch_size, self.desc.seq_len, self.desc.num_heads,
self.desc.head_dim, self.desc.state_dim, self.desc.chunk_size,
x_ptr, dt_ptr, a_ptr, b_ptr, c_ptr, dy_ptr,
dx_ptr, db_ptr, dc_ptr, ddt_ptr, da_ptr,
ws_ptr, ws_bytes, stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_ssd_chunk_scan_bf16_backward_run(
self.desc.batch_size, self.desc.seq_len, self.desc.num_heads,
self.desc.head_dim, self.desc.state_dim, self.desc.chunk_size,
x_ptr, dt_ptr, a_ptr, b_ptr, c_ptr, dy_ptr,
dx_ptr, db_ptr, dc_ptr, ddt_ptr, da_ptr,
ws_ptr, ws_bytes, stream_ptr,
)
},
_ => return Err(Error::Unsupported(
"baracuda-kernels::SsdChunkScanBackwardPlan: dtype not wired",
)),
};
map_status(status)
}
}