use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BiasElement, BiasElementKind, ElementKind, EpilogueKind, IntElement, LayoutSku,
PlanPreference, PrecisionGuarantee, S8, Workspace,
};
pub use baracuda_cutlass::{GemmSku, IntGemmArgs, IntGemmDescriptor};
pub struct IntGemmPlan<T: IntElement, BT: BiasElement = f32> {
desc: IntGemmDescriptor,
sku: GemmSku,
backend: Backend<T, BT>,
}
enum Backend<T: IntElement, BT: BiasElement> {
Cutlass(baracuda_cutlass::IntGemmPlan<T, BT>),
Bespoke(BespokeRrr<T, BT>),
}
struct BespokeRrr<T: IntElement, BT: BiasElement> {
_element: PhantomData<T>,
_bias_element: PhantomData<BT>,
}
impl<T: IntElement, BT: BiasElement> IntGemmPlan<T, BT> {
pub fn select(
stream: &Stream,
desc: &IntGemmDescriptor,
pref: PlanPreference,
) -> Result<Self> {
match desc.layout {
LayoutSku::Rcr => {
let inner = baracuda_cutlass::IntGemmPlan::<T, BT>::select(stream, desc, pref)?;
let sku = inner.sku();
Ok(Self {
desc: *desc,
sku,
backend: Backend::Cutlass(inner),
})
}
LayoutSku::Rrr => {
if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
return Err(Error::InvalidProblem(
"int GEMM problem must have positive M, N, K",
));
}
if !matches!(T::KIND, ElementKind::S8 | ElementKind::U8) {
return Err(Error::Unsupported(
"baracuda-kernels: int RRR bespoke kernels: \
only S8 / U8 are implemented today \
(int4 / bin land in Phase 2)",
));
}
let sku = GemmSku {
arch: ArchSku::Sm80,
layout: desc.layout,
epilogue: desc.epilogue,
element: T::KIND,
bias_element: if desc.epilogue.requires_bias() {
Some(BT::KIND)
} else {
None
},
};
Ok(Self {
desc: *desc,
sku,
backend: Backend::Bespoke(BespokeRrr {
_element: PhantomData,
_bias_element: PhantomData,
}),
})
}
}
}
pub fn can_implement(&self, args: &IntGemmArgs<'_, T, BT>) -> Result<()> {
match &self.backend {
Backend::Cutlass(inner) => inner.can_implement(args),
Backend::Bespoke(_) => {
if self.desc.m <= 0 || self.desc.n <= 0 || self.desc.k <= 0 {
return Err(Error::InvalidProblem(
"int GEMM problem must have positive M, N, K",
));
}
if args.a.rows != self.desc.m || args.a.cols != self.desc.k {
return Err(Error::InvalidProblem(
"A shape mismatch with descriptor (M, K)",
));
}
if args.b.rows != self.desc.k || args.b.cols != self.desc.n {
return Err(Error::InvalidProblem(
"B shape mismatch with descriptor (K, N) (row-major)",
));
}
if args.d.rows != self.desc.m || args.d.cols != self.desc.n {
return Err(Error::InvalidProblem(
"D shape mismatch with descriptor (M, N)",
));
}
Ok(())
}
}
}
pub fn workspace_size(&self) -> usize {
match &self.backend {
Backend::Cutlass(inner) => inner.workspace_size(),
Backend::Bespoke(_) => 0,
}
}
#[inline]
pub fn sku(&self) -> GemmSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee()
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: IntGemmArgs<'_, T, BT>,
) -> Result<()> {
match &self.backend {
Backend::Cutlass(inner) => inner.run(stream, workspace, args),
Backend::Bespoke(_) => {
let _ = workspace;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc) = match &args.c {
Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
None => (core::ptr::null(), 0i64),
};
let bias_ptr = match &args.bias {
Some(b) => b.data.as_raw().0 as *const c_void,
None => core::ptr::null(),
};
let needs_bias = self.sku.epilogue.requires_bias();
if needs_bias && bias_ptr.is_null() {
return Err(Error::InvalidProblem(
"Bias* epilogue requires a bias vector",
));
}
if !needs_bias && !bias_ptr.is_null() {
return Err(Error::InvalidProblem(
"Identity epilogue must not be supplied a bias vector",
));
}
let stream_ptr = stream.as_raw() as *mut c_void;
let m = self.desc.m;
let n = self.desc.n;
let k = self.desc.k;
let lda = args.a.ld;
let ldb = args.b.ld;
let ldd = args.d.ld;
let alpha = args.alpha;
let beta = args.beta;
let status = match (T::KIND, self.sku.epilogue, BT::KIND) {
(ElementKind::S8, EpilogueKind::Identity, _) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s8_rrr_sm80_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::U8, EpilogueKind::Identity, _) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u8_rrr_sm80_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s8_rrr_sm80_bias_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s8_rrr_sm80_bias_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s8_rrr_sm80_bias_relu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s8_rrr_sm80_bias_gelu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_s8_rrr_sm80_bias_silu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u8_rrr_sm80_bias_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_f32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u8_rrr_sm80_bias_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u8_rrr_sm80_bias_relu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u8_rrr_sm80_bias_gelu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_u8_rrr_sm80_bias_silu_i32_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels: int RRR bespoke kernel dispatcher \
reached an unimplemented (element, epilogue, bias) triple",
));
}
};
map_bespoke_status(status)
}
}
}
}
fn map_bespoke_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}
#[allow(dead_code)]
fn _hold_s8_in_scope() {
let _ = S8(0);
}