use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, AttentionKind, BackendKind, Element, ElementKind, KernelSku, MathPrecision,
OpCategory, PlanPreference, PrecisionGuarantee, TensorMut, TensorRef, Workspace,
};
use super::map_status;
#[derive(Copy, Clone, Debug)]
pub struct SelectiveScanDescriptor {
pub batch_size: i32,
pub seq_len: i32,
pub dim: i32,
pub dstate: i32,
pub delta_softplus: bool,
pub element: ElementKind,
}
pub struct SelectiveScanArgs<'a, T: Element> {
pub u: TensorRef<'a, T, 3>,
pub delta: TensorRef<'a, T, 3>,
pub a: TensorRef<'a, T, 2>,
pub b: TensorRef<'a, T, 3>,
pub c: TensorRef<'a, T, 3>,
pub d_skip: Option<TensorRef<'a, T, 1>>,
pub z: Option<TensorRef<'a, T, 3>>,
pub delta_bias: Option<TensorRef<'a, T, 1>>,
pub y: TensorMut<'a, T, 3>,
pub last_state: Option<TensorMut<'a, T, 3>>,
}
pub struct SelectiveScanPlan<T: Element> {
desc: SelectiveScanDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> SelectiveScanPlan<T> {
pub fn select(
_stream: &Stream,
desc: &SelectiveScanDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanPlan: descriptor element != T",
));
}
if desc.batch_size < 0 || desc.seq_len < 0 || desc.dim < 0 || desc.dstate < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanPlan: extents must be non-negative",
));
}
if desc.dstate > 256 {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanPlan: dstate must be <= 256 in the trailblazer",
));
}
let dtype_in_scope = matches!(
T::KIND,
ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16
);
if !dtype_in_scope {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanPlan: wired today: `{f32, f16, bf16}`",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Attention,
op: AttentionKind::SelectiveScan as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &SelectiveScanArgs<'_, T>) -> Result<()> {
let shape_udy = [self.desc.batch_size, self.desc.seq_len, self.desc.dim];
let shape_a = [self.desc.dim, self.desc.dstate];
let shape_bc = [self.desc.batch_size, self.desc.seq_len, self.desc.dstate];
if args.u.shape != shape_udy
|| args.delta.shape != shape_udy
|| args.y.shape != shape_udy
{
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanPlan: u/delta/y shape must be [B, L, D]",
));
}
if args.a.shape != shape_a {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanPlan: A shape must be [D, N]",
));
}
if args.b.shape != shape_bc || args.c.shape != shape_bc {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanPlan: B / C shape must be [B, L, N]",
));
}
if let Some(ds) = &args.d_skip {
if ds.shape != [self.desc.dim] {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanPlan: D (skip) shape must be [D]",
));
}
if !ds.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanPlan: D (skip) must be contiguous",
));
}
}
if let Some(z) = &args.z {
if z.shape != shape_udy {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanPlan: z shape must be [B, L, D]",
));
}
if !z.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanPlan: z must be contiguous",
));
}
}
if let Some(db) = &args.delta_bias {
if db.shape != [self.desc.dim] {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanPlan: delta_bias shape must be [D]",
));
}
if !db.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanPlan: delta_bias must be contiguous",
));
}
}
if let Some(ls) = &args.last_state {
if ls.shape != [self.desc.batch_size, self.desc.dim, self.desc.dstate] {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanPlan: last_state shape must be [B, D, N]",
));
}
if !ls.is_contiguous() {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanPlan: last_state must be contiguous",
));
}
}
if !args.u.is_contiguous()
|| !args.delta.is_contiguous()
|| !args.a.is_contiguous()
|| !args.b.is_contiguous()
|| !args.c.is_contiguous()
|| !args.y.is_contiguous()
{
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanPlan: trailblazer requires contiguous tensors",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
mut args: SelectiveScanArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.batch_size == 0 || self.desc.seq_len == 0 || self.desc.dim == 0 {
return Ok(());
}
let stream_ptr = stream.as_raw() as *mut c_void;
let u_ptr = args.u.data.as_raw().0 as *const c_void;
let delta_ptr = args.delta.data.as_raw().0 as *const c_void;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let c_ptr = args.c.data.as_raw().0 as *const c_void;
let d_ptr = args.d_skip.as_ref()
.map(|d| d.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let z_ptr = args.z.as_ref()
.map(|z| z.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let db_ptr = args.delta_bias.as_ref()
.map(|db| db.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let y_ptr = args.y.data.as_raw().0 as *mut c_void;
let ls_ptr = args.last_state.as_mut()
.map(|ls| ls.data.as_raw().0 as *mut c_void)
.unwrap_or(core::ptr::null_mut());
let dsp = if self.desc.delta_softplus { 1 } else { 0 };
let status = match T::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_selective_scan_f32_run(
self.desc.batch_size, self.desc.seq_len,
self.desc.dim, self.desc.dstate, dsp,
u_ptr, delta_ptr, a_ptr, b_ptr, c_ptr,
d_ptr, z_ptr, db_ptr,
y_ptr, ls_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_selective_scan_f16_run(
self.desc.batch_size, self.desc.seq_len,
self.desc.dim, self.desc.dstate, dsp,
u_ptr, delta_ptr, a_ptr, b_ptr, c_ptr,
d_ptr, z_ptr, db_ptr,
y_ptr, ls_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_selective_scan_bf16_run(
self.desc.batch_size, self.desc.seq_len,
self.desc.dim, self.desc.dstate, dsp,
u_ptr, delta_ptr, a_ptr, b_ptr, c_ptr,
d_ptr, z_ptr, db_ptr,
y_ptr, ls_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanPlan: dtype not wired",
)),
};
map_status(status)
}
}
#[derive(Copy, Clone, Debug)]
pub struct SelectiveScanBackwardDescriptor {
pub batch_size: i32,
pub seq_len: i32,
pub dim: i32,
pub dstate: i32,
pub delta_softplus: bool,
pub element: ElementKind,
}
pub struct SelectiveScanBackwardArgs<'a, T: Element> {
pub u: TensorRef<'a, T, 3>,
pub delta: TensorRef<'a, T, 3>,
pub a: TensorRef<'a, T, 2>,
pub b: TensorRef<'a, T, 3>,
pub c: TensorRef<'a, T, 3>,
pub d_skip: Option<TensorRef<'a, T, 1>>,
pub z: Option<TensorRef<'a, T, 3>>,
pub delta_bias: Option<TensorRef<'a, T, 1>>,
pub dy: TensorRef<'a, T, 3>,
pub du: TensorMut<'a, T, 3>,
pub d_b: TensorMut<'a, T, 3>,
pub d_c: TensorMut<'a, T, 3>,
pub d_delta: TensorMut<'a, T, 3>,
pub d_a: TensorMut<'a, T, 2>,
pub d_d: Option<TensorMut<'a, T, 1>>,
pub dz: Option<TensorMut<'a, T, 3>>,
pub d_delta_bias: Option<TensorMut<'a, T, 1>>,
}
pub struct SelectiveScanBackwardPlan<T: Element> {
desc: SelectiveScanBackwardDescriptor,
sku: KernelSku,
_marker: PhantomData<T>,
}
impl<T: Element> SelectiveScanBackwardPlan<T> {
pub fn select(
_stream: &Stream,
desc: &SelectiveScanBackwardDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.element != T::KIND {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanBackwardPlan: descriptor element != T",
));
}
if desc.batch_size < 0 || desc.seq_len < 0 || desc.dim < 0 || desc.dstate < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanBackwardPlan: extents must be non-negative",
));
}
if desc.dstate > 256 {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanBackwardPlan: dstate must be <= 256 in the trailblazer",
));
}
let dtype_in_scope = matches!(
T::KIND,
ElementKind::F32 | ElementKind::F16 | ElementKind::Bf16
);
if !dtype_in_scope {
return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanBackwardPlan: wired today: `{f32, f16, bf16}`",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: false,
deterministic: false,
};
let sku = KernelSku {
category: OpCategory::Attention,
op: AttentionKind::SelectiveScan as u16,
element: T::KIND,
aux_element: None,
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
#[inline]
pub fn workspace_size(&self) -> usize {
let dtype_id: i32 = match T::KIND {
ElementKind::F32 => 0,
ElementKind::F16 => 1,
ElementKind::Bf16 => 2,
_ => return 0,
};
unsafe {
baracuda_kernels_sys::baracuda_kernels_selective_scan_workspace_bytes(
self.desc.batch_size, self.desc.seq_len,
self.desc.dim, self.desc.dstate, dtype_id,
)
}
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
mut args: SelectiveScanBackwardArgs<'_, T>,
) -> Result<()> {
if self.desc.batch_size == 0 || self.desc.seq_len == 0 || self.desc.dim == 0 {
return Ok(());
}
let stream_ptr = stream.as_raw() as *mut c_void;
let (ws_ptr, ws_bytes) = match workspace {
Workspace::Borrowed(buf) => (buf.as_raw().0 as *mut c_void, buf.len()),
Workspace::None => (core::ptr::null_mut(), 0usize),
};
let need = self.workspace_size();
if ws_bytes < need {
return Err(Error::WorkspaceTooSmall { needed: need, got: ws_bytes });
}
if args.d_skip.is_some() != args.d_d.is_some() {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanBackwardPlan: d_skip and d_d must be both given or both omitted",
));
}
if args.z.is_some() != args.dz.is_some() {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanBackwardPlan: z and dz must be both given or both omitted",
));
}
if args.delta_bias.is_some() != args.d_delta_bias.is_some() {
return Err(Error::InvalidProblem(
"baracuda-kernels::SelectiveScanBackwardPlan: delta_bias and d_delta_bias must be both given or both omitted",
));
}
let u_ptr = args.u.data.as_raw().0 as *const c_void;
let delta_ptr = args.delta.data.as_raw().0 as *const c_void;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let c_ptr = args.c.data.as_raw().0 as *const c_void;
let d_in_ptr = args.d_skip.as_ref()
.map(|d| d.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let z_ptr = args.z.as_ref()
.map(|z| z.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let db_ptr = args.delta_bias.as_ref()
.map(|db| db.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let dy_ptr = args.dy.data.as_raw().0 as *const c_void;
let du_ptr = args.du.data.as_raw().0 as *mut c_void;
let dB_ptr = args.d_b.data.as_raw().0 as *mut c_void;
let dC_ptr = args.d_c.data.as_raw().0 as *mut c_void;
let ddelta_ptr = args.d_delta.data.as_raw().0 as *mut c_void;
let dA_ptr = args.d_a.data.as_raw().0 as *mut c_void;
let dD_ptr = args.d_d.as_mut()
.map(|d| d.data.as_raw().0 as *mut c_void)
.unwrap_or(core::ptr::null_mut());
let dz_ptr = args.dz.as_mut()
.map(|z| z.data.as_raw().0 as *mut c_void)
.unwrap_or(core::ptr::null_mut());
let ddb_ptr = args.d_delta_bias.as_mut()
.map(|db| db.data.as_raw().0 as *mut c_void)
.unwrap_or(core::ptr::null_mut());
let dsp = if self.desc.delta_softplus { 1 } else { 0 };
let status = match T::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_selective_scan_f32_backward_run(
self.desc.batch_size, self.desc.seq_len,
self.desc.dim, self.desc.dstate, dsp,
u_ptr, delta_ptr, a_ptr, b_ptr, c_ptr,
d_in_ptr, z_ptr, db_ptr,
dy_ptr,
du_ptr, dB_ptr, dC_ptr, ddelta_ptr,
dA_ptr, dD_ptr, dz_ptr, ddb_ptr,
ws_ptr, ws_bytes, stream_ptr,
)
},
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_selective_scan_f16_backward_run(
self.desc.batch_size, self.desc.seq_len,
self.desc.dim, self.desc.dstate, dsp,
u_ptr, delta_ptr, a_ptr, b_ptr, c_ptr,
d_in_ptr, z_ptr, db_ptr,
dy_ptr,
du_ptr, dB_ptr, dC_ptr, ddelta_ptr,
dA_ptr, dD_ptr, dz_ptr, ddb_ptr,
ws_ptr, ws_bytes, stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_selective_scan_bf16_backward_run(
self.desc.batch_size, self.desc.seq_len,
self.desc.dim, self.desc.dstate, dsp,
u_ptr, delta_ptr, a_ptr, b_ptr, c_ptr,
d_in_ptr, z_ptr, db_ptr,
dy_ptr,
du_ptr, dB_ptr, dC_ptr, ddelta_ptr,
dA_ptr, dD_ptr, dz_ptr, ddb_ptr,
ws_ptr, ws_bytes, stream_ptr,
)
},
_ => return Err(Error::Unsupported(
"baracuda-kernels::SelectiveScanBackwardPlan: dtype not wired",
)),
};
map_status(status)
}
}