use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
BackendKind, Element, ElementKind, MathPrecision, MatrixMut, MatrixRef, PlanPreference,
PrecisionGuarantee, ScalarType, Workspace,
};
#[derive(Copy, Clone, Debug, Eq, PartialEq, Hash)]
pub enum DenseGemmLayout {
Rrr,
Rcr,
Crr,
}
impl DenseGemmLayout {
#[inline]
fn ffi_tag(self) -> i32 {
match self {
DenseGemmLayout::Rrr => 0,
DenseGemmLayout::Rcr => 1,
DenseGemmLayout::Crr => 2,
}
}
#[inline]
fn min_lds(self, m: i32, n: i32, k: i32) -> (i64, i64) {
match self {
DenseGemmLayout::Rrr => (k as i64, n as i64),
DenseGemmLayout::Rcr => (k as i64, k as i64),
DenseGemmLayout::Crr => (m as i64, n as i64),
}
}
}
#[derive(Copy, Clone, Debug)]
pub struct DenseGemmDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub batch: i32,
pub layout: DenseGemmLayout,
}
#[derive(Debug)]
pub struct DenseGemmArgs<'a, T: Element> {
pub a: MatrixRef<'a, T>,
pub stride_a: i64,
pub b: MatrixRef<'a, T>,
pub stride_b: i64,
pub d: MatrixMut<'a, T>,
pub stride_d: i64,
pub alpha: T::Scalar,
pub beta: T::Scalar,
}
pub struct DenseGemmPlan<T: Element> {
desc: DenseGemmDescriptor,
_marker: PhantomData<T>,
}
impl<T: Element> DenseGemmPlan<T> {
pub fn select(
_stream: &Stream,
desc: &DenseGemmDescriptor,
pref: PlanPreference,
) -> Result<Self> {
if !matches!(
T::KIND,
ElementKind::F32 | ElementKind::F64 | ElementKind::F16 | ElementKind::Bf16
) {
return Err(Error::Unsupported(
"baracuda-kernels::DenseGemmPlan: dense GEMM covers f32 / f64 / f16 / \
bf16 only (for bit-stable strict-f32 SIMT math use \
baracuda_cutlass::GemmPlan<F32Strict>)",
));
}
if desc.m < 0 || desc.n < 0 || desc.k < 0 || desc.batch < 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: M, N, K, batch must be non-negative",
));
}
match pref.prefer_backend {
None | Some(BackendKind::Cublas) => {}
Some(_) => {
return Err(Error::Unsupported(
"baracuda-kernels::DenseGemmPlan: v1 is cuBLAS-backed only; \
leave PlanPreference::prefer_backend unset or set it to \
BackendKind::Cublas",
));
}
}
Ok(Self {
desc: *desc,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &DenseGemmArgs<'_, T>) -> Result<()> {
let d = &self.desc;
if args.a.rows != d.m || args.a.cols != d.k {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: A logical shape mismatch with \
descriptor (M, K)",
));
}
if args.b.rows != d.k || args.b.cols != d.n {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: B logical shape mismatch with \
descriptor (K, N)",
));
}
if args.d.rows != d.m || args.d.cols != d.n {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: D shape mismatch with descriptor \
(M, N)",
));
}
let (min_lda, min_ldb) = d.layout.min_lds(d.m, d.n, d.k);
if args.a.ld < min_lda.max(1) {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: lda below the layout's minimum",
));
}
if args.b.ld < min_ldb.max(1) {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: ldb below the layout's minimum",
));
}
if args.d.ld < (d.n as i64).max(1) {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: ldd below N",
));
}
let i32_max = i32::MAX as i64;
if args.a.ld > i32_max || args.b.ld > i32_max || args.d.ld > i32_max {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: leading dimensions must fit in i32 \
(cuBLAS limit)",
));
}
if d.batch > 1 && args.stride_d == 0 {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: stride_d must be non-zero when \
batch > 1 (overlapping outputs race)",
));
}
if d.batch > 1 && (args.stride_a < 0 || args.stride_b < 0 || args.stride_d < 0) {
return Err(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: negative batch strides walk before \
the buffer base (MatrixRef has no base offset)",
));
}
if d.m > 0 && d.n > 0 && d.batch > 0 {
let footprint = |rows: i64, cols: i64, ld: i64| -> i64 {
if rows == 0 || cols == 0 {
0
} else {
(rows - 1) * ld + cols
}
};
let (m, n, k) = (d.m as i64, d.n as i64, d.k as i64);
let a_slot = match d.layout {
DenseGemmLayout::Rrr | DenseGemmLayout::Rcr => footprint(m, k, args.a.ld),
DenseGemmLayout::Crr => footprint(k, m, args.a.ld),
};
let b_slot = match d.layout {
DenseGemmLayout::Rrr | DenseGemmLayout::Crr => footprint(k, n, args.b.ld),
DenseGemmLayout::Rcr => footprint(n, k, args.b.ld),
};
let d_slot = footprint(m, n, args.d.ld);
let reach = |slot: i64, stride: i64| -> Result<i64> {
if d.batch == 1 || stride == 0 {
return Ok(slot);
}
(d.batch as i64 - 1)
.checked_mul(stride)
.and_then(|extra| slot.checked_add(extra))
.ok_or(Error::InvalidProblem(
"baracuda-kernels::DenseGemmPlan: batch-stride reach overflows i64",
))
};
let check = |needed: i64, got: usize| -> Result<()> {
if (got as i64) < needed {
return Err(Error::BufferTooSmall {
needed: needed as usize,
got,
});
}
Ok(())
};
check(reach(a_slot, args.stride_a)?, args.a.data.len())?;
check(reach(b_slot, args.stride_b)?, args.b.data.len())?;
check(reach(d_slot, args.stride_d)?, args.d.data.len())?;
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn backend(&self) -> BackendKind {
BackendKind::Cublas
}
#[inline]
pub fn layout(&self) -> DenseGemmLayout {
self.desc.layout
}
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
let (math_precision, accumulator) = match T::KIND {
ElementKind::F16 => (MathPrecision::F16, ElementKind::F32),
ElementKind::Bf16 => (MathPrecision::Bf16, ElementKind::F32),
ElementKind::F64 => (MathPrecision::F64, ElementKind::F64),
_ => (MathPrecision::F32, ElementKind::F32),
};
PrecisionGuarantee {
math_precision,
accumulator,
bit_stable_on_same_hardware: false,
deterministic: false,
}
}
pub fn sku(&self) -> baracuda_kernels_types::KernelSku {
use baracuda_kernels_types::{KernelSku, LayoutSku, OpCategory};
KernelSku {
category: OpCategory::Gemm,
op: 0,
element: T::KIND,
aux_element: None,
layout: match self.desc.layout {
DenseGemmLayout::Rrr => Some(LayoutSku::Rrr),
DenseGemmLayout::Rcr => Some(LayoutSku::Rcr),
DenseGemmLayout::Crr => None,
},
epilogue: None,
arch: baracuda_kernels_types::ArchSku::Sm80,
backend: BackendKind::Cublas,
precision_guarantee: self.precision_guarantee(),
}
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: DenseGemmArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
let _ = workspace;
let d = &self.desc;
let layout = d.layout.ffi_tag();
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = match T::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_dense_f32_run(
d.m, d.n, d.k, d.batch, layout,
args.alpha.to_f32(), args.beta.to_f32(),
a_ptr, args.a.ld, args.stride_a,
b_ptr, args.b.ld, args.stride_b,
d_ptr, args.d.ld, args.stride_d,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
ElementKind::F64 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_dense_f64_run(
d.m, d.n, d.k, d.batch, layout,
args.alpha.to_f64(), args.beta.to_f64(),
a_ptr, args.a.ld, args.stride_a,
b_ptr, args.b.ld, args.stride_b,
d_ptr, args.d.ld, args.stride_d,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_dense_f16_run(
d.m, d.n, d.k, d.batch, layout,
args.alpha.to_f32(), args.beta.to_f32(),
a_ptr, args.a.ld, args.stride_a,
b_ptr, args.b.ld, args.stride_b,
d_ptr, args.d.ld, args.stride_d,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
ElementKind::Bf16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_dense_bf16_run(
d.m, d.n, d.k, d.batch, layout,
args.alpha.to_f32(), args.beta.to_f32(),
a_ptr, args.a.ld, args.stride_a,
b_ptr, args.b.ld, args.stride_b,
d_ptr, args.d.ld, args.stride_d,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels::DenseGemmPlan: unreachable dtype dispatch arm",
));
}
};
match status {
0 => Ok(()),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys dense GEMM facade reported an invalid problem",
)),
n => Err(Error::CutlassInternal(n)),
}
}
}