use core::ffi::c_void;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, Bin, ElementKind, EpilogueKind, LayoutSku, MatrixMut, MatrixRef, PlanPreference,
PrecisionGuarantee, Workspace,
};
pub use baracuda_cutlass::GemmSku;
#[derive(Copy, Clone, Debug)]
pub struct BinGemmDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub layout: LayoutSku,
}
#[derive(Debug)]
pub struct BinGemmArgs<'a> {
pub a: MatrixRef<'a, Bin>,
pub b: MatrixRef<'a, Bin>,
pub d: MatrixMut<'a, i32>,
}
pub struct BinGemmPlan {
desc: BinGemmDescriptor,
sku: GemmSku,
}
impl BinGemmPlan {
pub fn select(
_stream: &Stream,
desc: &BinGemmDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
return Err(Error::InvalidProblem(
"bin GEMM problem must have positive M, N, K",
));
}
if (desc.k & 7) != 0 {
return Err(Error::InvalidProblem(
"bin GEMM requires K to be a multiple of 8 (packed-bit storage)",
));
}
if matches!(desc.layout, LayoutSku::Rrr) && (desc.n & 7) != 0 {
return Err(Error::InvalidProblem(
"bin GEMM RRR requires N to be a multiple of 8 \
(B is bit-packed along N in gmem)",
));
}
if !matches!(desc.layout, LayoutSku::Rcr | LayoutSku::Rrr) {
return Err(Error::Unsupported(
"baracuda-kernels: bin GEMM: only RCR / RRR layouts are shipped",
));
}
let sku = GemmSku {
arch: ArchSku::Sm89,
layout: desc.layout,
epilogue: EpilogueKind::Identity,
element: ElementKind::Bin,
bias_element: None,
};
Ok(Self { desc: *desc, sku })
}
pub fn can_implement(&self, args: &BinGemmArgs<'_>) -> Result<()> {
if self.desc.m <= 0 || self.desc.n <= 0 || self.desc.k <= 0 {
return Err(Error::InvalidProblem(
"bin GEMM problem must have positive M, N, K",
));
}
if args.a.rows != self.desc.m || args.a.cols != self.desc.k {
return Err(Error::InvalidProblem(
"A shape mismatch with descriptor (M, K in elements)",
));
}
if args.b.rows != self.desc.k || args.b.cols != self.desc.n {
return Err(Error::InvalidProblem(
"B shape mismatch with descriptor (K, N in elements)",
));
}
if args.d.rows != self.desc.m || args.d.cols != self.desc.n {
return Err(Error::InvalidProblem(
"D shape mismatch with descriptor (M, N in elements)",
));
}
let k_bytes_min = (self.desc.k / 8) as i64;
let n_bytes_min = (self.desc.n / 8) as i64;
if args.a.ld < k_bytes_min {
return Err(Error::InvalidProblem(
"A leading dimension (bytes) must be >= K/8 for row-major bin A",
));
}
let b_ld_min = match self.sku.layout {
LayoutSku::Rcr => k_bytes_min,
LayoutSku::Rrr => n_bytes_min,
};
if args.b.ld < b_ld_min {
return Err(Error::InvalidProblem(
"B leading dimension (bytes) must be >= K/8 for RCR \
(col-major) or >= N/8 for RRR (row-major)",
));
}
if args.d.ld < self.desc.n as i64 {
return Err(Error::InvalidProblem(
"D leading dimension (i32 elements) must be >= N for row-major i32 D",
));
}
Ok(())
}
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> GemmSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee()
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: BinGemmArgs<'_>,
) -> Result<()> {
let _ = workspace;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let m = self.desc.m;
let n = self.desc.n;
let k = self.desc.k;
let lda = args.a.ld;
let ldb = args.b.ld;
let ldd = args.d.ld;
#[cfg(not(feature = "sm89"))]
{
let _ = (a_ptr, b_ptr, d_ptr, lda, ldb, ldd, m, n, k, stream_ptr);
return Err(Error::Unsupported(
"baracuda-kernels: bin GEMM requires the `sm89` feature \
to be enabled in baracuda-kernels-sys",
));
}
#[cfg(feature = "sm89")]
let status = match self.sku.layout {
LayoutSku::Rcr => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_bin_rcr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, d_ptr, ldd,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
LayoutSku::Rrr => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_bin_rrr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, d_ptr, ldd,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
};
#[cfg(feature = "sm89")]
{ map_status(status) }
#[cfg(not(feature = "sm89"))]
#[allow(unreachable_code)]
{ unreachable!("returned earlier under #[cfg(not(feature = \"sm89\"))]") }
}
}
#[cfg(feature = "sm89")]
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration \
(bin GEMM requires K to be a multiple of 8)",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}