use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, ElementKind, EpilogueKind, Fp8E4M3, Fp8E5M2, FpElement, LayoutSku, MatrixMut,
MatrixRef, PlanPreference, PrecisionGuarantee, VectorRef, Workspace,
};
pub use baracuda_cutlass::GemmSku;
#[derive(Copy, Clone, Debug)]
pub struct Fp8GemmDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub layout: LayoutSku,
pub epilogue: EpilogueKind,
}
#[derive(Debug)]
pub struct Fp8GemmArgs<'a, T: FpElement> {
pub a: MatrixRef<'a, T>,
pub b: MatrixRef<'a, T>,
pub c: Option<MatrixRef<'a, T>>,
pub d: MatrixMut<'a, T>,
pub bias: Option<VectorRef<'a, f32>>,
pub alpha: f32,
pub beta: f32,
}
pub struct Fp8GemmPlan<T: FpElement> {
desc: Fp8GemmDescriptor,
sku: GemmSku,
_phantom: PhantomData<T>,
}
impl<T: FpElement> Fp8GemmPlan<T> {
pub fn select(
_stream: &Stream,
desc: &Fp8GemmDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
return Err(Error::InvalidProblem(
"FP8 GEMM problem must have positive M, N, K",
));
}
if !matches!(desc.layout, LayoutSku::Rcr | LayoutSku::Rrr) {
return Err(Error::Unsupported(
"baracuda-kernels: FP8 GEMM: only RCR / RRR layouts are shipped",
));
}
if !matches!(T::KIND, ElementKind::Fp8E4M3 | ElementKind::Fp8E5M2) {
return Err(Error::Unsupported(
"baracuda-kernels: FP8 GEMM: only Fp8E4M3 / Fp8E5M2 elements are shipped",
));
}
let sku = GemmSku {
arch: ArchSku::Sm89,
layout: desc.layout,
epilogue: desc.epilogue,
element: T::KIND,
bias_element: if desc.epilogue.requires_bias() {
Some(baracuda_kernels_types::BiasElementKind::F32)
} else {
None
},
};
Ok(Self {
desc: *desc,
sku,
_phantom: PhantomData,
})
}
pub fn can_implement(&self, args: &Fp8GemmArgs<'_, T>) -> Result<()> {
if self.desc.m <= 0 || self.desc.n <= 0 || self.desc.k <= 0 {
return Err(Error::InvalidProblem(
"FP8 GEMM problem must have positive M, N, K",
));
}
if args.a.rows != self.desc.m || args.a.cols != self.desc.k {
return Err(Error::InvalidProblem(
"A shape mismatch with descriptor (M, K)",
));
}
if args.b.rows != self.desc.k || args.b.cols != self.desc.n {
return Err(Error::InvalidProblem(
"B shape mismatch with descriptor (K, N)",
));
}
match self.sku.layout {
LayoutSku::Rcr => {
if args.b.ld < self.desc.k as i64 {
return Err(Error::InvalidProblem(
"B leading dimension must be >= K for RCR (col-major K-contig)",
));
}
}
LayoutSku::Rrr => {
if args.b.ld < self.desc.n as i64 {
return Err(Error::InvalidProblem(
"B leading dimension must be >= N for RRR (row-major N-contig)",
));
}
}
}
if args.d.rows != self.desc.m || args.d.cols != self.desc.n {
return Err(Error::InvalidProblem(
"D shape mismatch with descriptor (M, N)",
));
}
let needs_bias = self.sku.epilogue.requires_bias();
match (needs_bias, &args.bias) {
(true, None) => {
return Err(Error::InvalidProblem(
"Bias* epilogue requires a bias vector",
));
}
(false, Some(_)) => {
return Err(Error::InvalidProblem(
"Identity epilogue must not be supplied a bias vector",
));
}
_ => {}
}
if let Some(b) = &args.bias
&& b.len != self.desc.n
{
return Err(Error::InvalidProblem(
"bias length must equal N",
));
}
Ok(())
}
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> GemmSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee()
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: Fp8GemmArgs<'_, T>,
) -> Result<()> {
let _ = workspace;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc) = match &args.c {
Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
None => (core::ptr::null(), 0i64),
};
let bias_ptr = match &args.bias {
Some(b) => b.data.as_raw().0 as *const c_void,
None => core::ptr::null(),
};
let needs_bias = self.sku.epilogue.requires_bias();
if needs_bias && bias_ptr.is_null() {
return Err(Error::InvalidProblem(
"Bias* epilogue requires a bias vector",
));
}
if !needs_bias && !bias_ptr.is_null() {
return Err(Error::InvalidProblem(
"Identity epilogue must not be supplied a bias vector",
));
}
let stream_ptr = stream.as_raw() as *mut c_void;
let m = self.desc.m;
let n = self.desc.n;
let k = self.desc.k;
let lda = args.a.ld;
let ldb = args.b.ld;
let ldd = args.d.ld;
let alpha = args.alpha;
let beta = args.beta;
#[cfg(not(feature = "sm89"))]
{
let _ = (a_ptr, b_ptr, c_ptr, d_ptr, bias_ptr, ldc, lda, ldb, ldd,
m, n, k, alpha, beta, stream_ptr);
return Err(Error::Unsupported(
"baracuda-kernels: FP8 GEMM requires the `sm89` feature \
to be enabled in baracuda-kernels-sys",
));
}
#[cfg(feature = "sm89")]
let status = match (T::KIND, self.sku.layout, self.sku.epilogue) {
(ElementKind::Fp8E4M3, LayoutSku::Rcr, EpilogueKind::Identity) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::Fp8E4M3, LayoutSku::Rrr, EpilogueKind::Identity) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rcr, EpilogueKind::Identity) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rrr, EpilogueKind::Identity) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
alpha, beta,
core::ptr::null_mut(), 0,
stream_ptr,
)
},
(ElementKind::Fp8E4M3, LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_relu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_gelu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rcr_sm89_bias_silu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_relu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_gelu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E4M3, LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e4m3_rrr_sm89_bias_silu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_relu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_gelu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rcr_sm89_bias_silu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_relu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_gelu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
(ElementKind::Fp8E5M2, LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
baracuda_kernels_sys::baracuda_kernels_gemm_fp8_e5m2_rrr_sm89_bias_silu_run(
m, n, k,
a_ptr, lda, b_ptr, ldb, c_ptr, ldc, d_ptr, ldd,
bias_ptr, alpha, beta,
core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"baracuda-kernels: FP8 GEMM dispatcher reached an \
unimplemented (element, layout, epilogue) triple",
));
}
};
#[cfg(feature = "sm89")]
{ map_status(status) }
#[cfg(not(feature = "sm89"))]
#[allow(unreachable_code)]
{ unreachable!("returned earlier under #[cfg(not(feature = \"sm89\"))]") }
}
}
#[allow(dead_code)]
fn _hold_fp8_elements_in_scope() {
let _ = Fp8E4M3(0);
let _ = Fp8E5M2(0);
}
#[cfg(feature = "sm89")]
fn map_status(code: i32) -> Result<()> {
match code {
0 => Ok(()),
1 => Err(Error::MisalignedOperand),
2 => Err(Error::InvalidProblem(
"baracuda-kernels-sys reported invalid problem",
)),
3 => Err(Error::Unsupported(
"baracuda-kernels-sys reported unsupported configuration",
)),
4 => Err(Error::WorkspaceTooSmall { needed: 0, got: 0 }),
n => Err(Error::CutlassInternal(n)),
}
}