use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BiasElement, BiasElementKind, ElementKind, EpilogueKind, IntElement, LayoutSku,
MatrixMut, MatrixRef, PlanPreference, PrecisionGuarantee, S4, U4, VectorRef, Workspace,
};
pub use baracuda_cutlass::GemmSku;
#[derive(Copy, Clone, Debug)]
pub struct Int4GemmDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub layout: LayoutSku,
pub epilogue: EpilogueKind,
}
#[derive(Debug)]
pub struct Int4GemmArgs<'a, T: IntElement, BT: BiasElement = f32> {
pub a: MatrixRef<'a, T>,
pub b: MatrixRef<'a, T>,
pub c: Option<MatrixRef<'a, T>>,
pub d: MatrixMut<'a, T>,
pub bias: Option<VectorRef<'a, BT>>,
pub alpha: f32,
pub beta: f32,
}
pub struct Int4GemmPlan<T: IntElement, BT: BiasElement = f32> {
desc: Int4GemmDescriptor,
sku: GemmSku,
_element: PhantomData<T>,
_bias_element: PhantomData<BT>,
}
impl<T: IntElement, BT: BiasElement> Int4GemmPlan<T, BT> {
pub fn select(
_stream: &Stream,
desc: &Int4GemmDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
return Err(Error::InvalidProblem(
"int4 GEMM problem must have positive M, N, K",
));
}
if (desc.k & 1) != 0 {
return Err(Error::InvalidProblem(
"int4 GEMM requires K to be even (packed-pair storage along K)",
));
}
if (desc.n & 1) != 0 {
return Err(Error::InvalidProblem(
"int4 GEMM requires N to be even (packed-pair storage along N for D output)",
));
}
if !matches!(T::KIND, ElementKind::S4 | ElementKind::U4) {
return Err(Error::Unsupported(
"baracuda-kernels: int4 GEMM: only S4 / U4 elements are accepted",
));
}
let _ = desc.layout;
let _ = desc.epilogue;
let sku = GemmSku {
arch: ArchSku::Sm89,
layout: desc.layout,
epilogue: desc.epilogue,
element: T::KIND,
bias_element: if desc.epilogue.requires_bias() {
Some(BT::KIND)
} else {
None
},
};
Ok(Self {
desc: *desc,
sku,
_element: PhantomData,
_bias_element: PhantomData,
})
}
pub fn can_implement(&self, args: &Int4GemmArgs<'_, T, BT>) -> Result<()> {
if self.desc.m <= 0 || self.desc.n <= 0 || self.desc.k <= 0 {
return Err(Error::InvalidProblem(
"int4 GEMM problem must have positive M, N, K",
));
}
if args.a.rows != self.desc.m || args.a.cols != self.desc.k {
return Err(Error::InvalidProblem(
"A shape mismatch with descriptor (M, K in elements)",
));
}
if args.b.rows != self.desc.k || args.b.cols != self.desc.n {
return Err(Error::InvalidProblem(
"B shape mismatch with descriptor (K, N in elements)",
));
}
if args.d.rows != self.desc.m || args.d.cols != self.desc.n {
return Err(Error::InvalidProblem(
"D shape mismatch with descriptor (M, N in elements)",
));
}
let k_bytes_min = (self.desc.k / 2) as i64;
let n_bytes_min = (self.desc.n / 2) as i64;
match self.sku.layout {
LayoutSku::Rcr => {
if args.a.ld < k_bytes_min {
return Err(Error::InvalidProblem(
"A leading dimension (bytes) must be >= K/2 for row-major int4 A",
));
}
if args.b.ld < k_bytes_min {
return Err(Error::InvalidProblem(
"B leading dimension (bytes) must be >= K/2 for col-major int4 B (RCR)",
));
}
}
LayoutSku::Rrr => {
if args.a.ld < k_bytes_min {
return Err(Error::InvalidProblem(
"A leading dimension (bytes) must be >= K/2 for row-major int4 A",
));
}
if args.b.ld < n_bytes_min {
return Err(Error::InvalidProblem(
"B leading dimension (bytes) must be >= N/2 for row-major int4 B (RRR)",
));
}
}
}
if args.d.ld < n_bytes_min {
return Err(Error::InvalidProblem(
"D leading dimension (bytes) must be >= N/2 for row-major int4 D",
));
}
if let Some(c) = &args.c {
if c.rows != self.desc.m || c.cols != self.desc.n {
return Err(Error::InvalidProblem(
"C shape mismatch with descriptor (M, N in elements)",
));
}
if c.ld < n_bytes_min {
return Err(Error::InvalidProblem(
"C leading dimension (bytes) must be >= N/2 for row-major int4 C",
));
}
}
let needs_bias = self.sku.epilogue.requires_bias();
match (needs_bias, &args.bias) {
(true, None) => {
return Err(Error::InvalidProblem(
"Bias* epilogue requires a bias vector",
));
}
(false, Some(_)) => {
return Err(Error::InvalidProblem(
"Identity epilogue must not be supplied a bias vector",
));
}
_ => {}
}
if let Some(b) = &args.bias
&& b.len != self.desc.n
{
return Err(Error::InvalidProblem(
"bias length must equal N",
));
}
Ok(())
}
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> GemmSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee()
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: Int4GemmArgs<'_, T, BT>,
) -> Result<()> {
let _ = workspace;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc) = match &args.c {
Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
None => (core::ptr::null(), 0i64),
};
let bias_ptr: *const c_void = match &args.bias {
Some(b) => b.data.as_raw().0 as *const c_void,
None => core::ptr::null(),
};
let needs_bias = self.sku.epilogue.requires_bias();
if needs_bias && bias_ptr.is_null() {
return Err(Error::InvalidProblem(
"Bias* epilogue requires a bias vector",
));
}
if !needs_bias && !bias_ptr.is_null() {
return Err(Error::InvalidProblem(
"Identity epilogue must not be supplied a bias vector",
));
}
let stream_ptr = stream.as_raw() as *mut c_void;
let m = self.desc.m;
let n = self.desc.n;
let k = self.desc.k;
let lda = args.a.ld;
let ldb = args.b.ld;
let ldd = args.d.ld;
let alpha = args.alpha;
let beta = if args.c.is_some() { args.beta } else { 0.0 };
#[cfg(not(feature = "sm89"))]
{
let _ = (a_ptr, b_ptr, c_ptr, d_ptr, ldc, lda, ldb, ldd,
m, n, k, alpha, beta, stream_ptr);
return Err(Error::Unsupported(
"baracuda-kernels: int4 GEMM requires the `sm89` feature \
to be enabled in baracuda-kernels-sys",
));
}
#[cfg(feature = "sm89")]
let status = match (T::KIND, self.sku.layout, self.sku.epilogue) {
(ElementKind::S4, LayoutSku::Rcr, EpilogueKind::Identity) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rcr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::U4, LayoutSku::Rcr, EpilogueKind::Identity) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rcr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::S4, LayoutSku::Rrr, EpilogueKind::Identity) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rrr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::U4, LayoutSku::Rrr, EpilogueKind::Identity) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rrr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::S4, LayoutSku::Rcr, EpilogueKind::Bias) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rcr_sm89_bias_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rcr_sm89_bias_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::S4, LayoutSku::Rcr, EpilogueKind::BiasRelu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rcr_sm89_bias_relu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::S4, LayoutSku::Rcr, EpilogueKind::BiasGelu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rcr_sm89_bias_gelu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::S4, LayoutSku::Rcr, EpilogueKind::BiasSilu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rcr_sm89_bias_silu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::U4, LayoutSku::Rcr, EpilogueKind::Bias) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rcr_sm89_bias_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rcr_sm89_bias_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::U4, LayoutSku::Rcr, EpilogueKind::BiasRelu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rcr_sm89_bias_relu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::U4, LayoutSku::Rcr, EpilogueKind::BiasGelu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rcr_sm89_bias_gelu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::U4, LayoutSku::Rcr, EpilogueKind::BiasSilu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rcr_sm89_bias_silu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::S4, LayoutSku::Rrr, EpilogueKind::Bias) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rrr_sm89_bias_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rrr_sm89_bias_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::S4, LayoutSku::Rrr, EpilogueKind::BiasRelu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rrr_sm89_bias_relu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::S4, LayoutSku::Rrr, EpilogueKind::BiasGelu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rrr_sm89_bias_gelu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::S4, LayoutSku::Rrr, EpilogueKind::BiasSilu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s4_rrr_sm89_bias_silu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::U4, LayoutSku::Rrr, EpilogueKind::Bias) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rrr_sm89_bias_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rrr_sm89_bias_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::U4, LayoutSku::Rrr, EpilogueKind::BiasRelu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rrr_sm89_bias_relu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::U4, LayoutSku::Rrr, EpilogueKind::BiasGelu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rrr_sm89_bias_gelu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
(ElementKind::U4, LayoutSku::Rrr, EpilogueKind::BiasSilu) => match BT::KIND {
BiasElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
BiasElementKind::I32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u4_rrr_sm89_bias_silu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels: int4 GEMM dispatcher reached an \
unimplemented (element, layout, epilogue) triple \
(T must be S4 / U4)",
));
}
};
#[cfg(feature = "sm89")]
{ map_status(status) }
#[cfg(not(feature = "sm89"))]
#[allow(unreachable_code)]
{ unreachable!("returned earlier under #[cfg(not(feature = \"sm89\"))]") }
}
}
#[allow(dead_code)]
fn _hold_int4_elements_in_scope() {
let _ = S4(0);
let _ = U4(0);
}
#[cfg(feature = "sm89")]
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration \
(likely K or N odd — int4 packing requires both even)",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}