use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_driver::{Context, PinnedBuffer, Stream};
use baracuda_kernels_types::BackendKind;
use crate::error::{status_to_result, Error, Result};
use crate::types::{
ArchSku, BatchedGemmArgs, BatchedGemmDescriptor, BiasElement, CutlassElement, ElementKind,
EpilogueKind, GemmArgs, GemmDescriptor, GemmSku, GroupedPlanPreference, GroupedProblem,
GroupedScheduleMode, IntElement, IntGemmArgs, IntGemmDescriptor, LayoutSku, PlanPreference,
PrecisionGuarantee, ScalarType, Workspace,
};
mod dispatch {
use super::{ElementKind, LayoutSku};
use core::ffi::c_void;
use super::EpilogueKind;
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn gemm_bias_sm80_run(
layout: LayoutSku,
kind: ElementKind,
epilogue: EpilogueKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
bias: *const c_void,
alpha: f32,
beta: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind, epilogue) {
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_bf16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_bf16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_tf32_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_tf32_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
_ => 3,
}
}
#[cfg(feature = "sm80")]
pub(super) fn gemm_bias_sm80_workspace_size(
layout: LayoutSku,
kind: ElementKind,
epilogue: EpilogueKind,
m: i32,
n: i32,
k: i32,
) -> usize {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind, epilogue) {
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_bf16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_bf16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_tf32_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_tf32_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_workspace_size(m, n, k)
},
_ => 0,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn gemm_bias_sm80_can_implement(
layout: LayoutSku,
kind: ElementKind,
epilogue: EpilogueKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
bias: *const c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind, epilogue) {
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_bf16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_bf16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_bf16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_bf16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::Bf16, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_bf16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_tf32_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_tf32_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_tf32_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_tf32_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F32, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_tf32_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32_simt_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32_simt_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32_simt_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32_simt_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32_simt_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
_ => 3,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn gemm_sm80_run(
layout: LayoutSku,
kind: ElementKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
alpha: f32,
beta: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind) {
(LayoutSku::Rcr, ElementKind::F16) => unsafe {
k_sys::baracuda_cutlass_gemm_f16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
k_sys::baracuda_cutlass_gemm_bf16_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_tf32_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F16) => unsafe {
k_sys::baracuda_cutlass_gemm_f16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::Bf16) => unsafe {
k_sys::baracuda_cutlass_gemm_bf16_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_tf32_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict) => unsafe {
k_sys::baracuda_cutlass_gemm_f32_simt_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict) => unsafe {
k_sys::baracuda_cutlass_gemm_f32_simt_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::F64)
| (LayoutSku::Rrr, ElementKind::F64) => 3,
(_, ElementKind::S8) | (_, ElementKind::U8) | (_, ElementKind::I32)
| (_, ElementKind::I64)
| (_, ElementKind::Bool)
| (_, ElementKind::Fp8E4M3)
| (_, ElementKind::Fp8E5M2)
| (_, ElementKind::S4)
| (_, ElementKind::U4)
| (_, ElementKind::Bin)
| (_, ElementKind::Complex32)
| (_, ElementKind::Complex64) => 3,
}
}
#[cfg(feature = "sm80")]
pub(super) fn gemm_sm80_workspace_size(
layout: LayoutSku,
kind: ElementKind,
m: i32,
n: i32,
k: i32,
) -> usize {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind) {
(LayoutSku::Rcr, ElementKind::F16) => unsafe {
k_sys::baracuda_cutlass_gemm_f16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
k_sys::baracuda_cutlass_gemm_bf16_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_tf32_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F16) => unsafe {
k_sys::baracuda_cutlass_gemm_f16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::Bf16) => unsafe {
k_sys::baracuda_cutlass_gemm_bf16_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_tf32_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F32Strict) => unsafe {
k_sys::baracuda_cutlass_gemm_f32_simt_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, ElementKind::F32Strict) => unsafe {
k_sys::baracuda_cutlass_gemm_f32_simt_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::F64)
| (LayoutSku::Rrr, ElementKind::F64) => 0,
(_, ElementKind::S8)
| (_, ElementKind::U8)
| (_, ElementKind::I32)
| (_, ElementKind::I64)
| (_, ElementKind::Bool)
| (_, ElementKind::Fp8E4M3)
| (_, ElementKind::Fp8E5M2)
| (_, ElementKind::S4)
| (_, ElementKind::U4)
| (_, ElementKind::Bin)
| (_, ElementKind::Complex32)
| (_, ElementKind::Complex64) => 0,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn gemm_sm80_can_implement(
layout: LayoutSku,
kind: ElementKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind) {
(LayoutSku::Rcr, ElementKind::F16) => unsafe {
k_sys::baracuda_cutlass_gemm_f16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
k_sys::baracuda_cutlass_gemm_bf16_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rcr, ElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_tf32_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rrr, ElementKind::F16) => unsafe {
k_sys::baracuda_cutlass_gemm_f16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rrr, ElementKind::Bf16) => unsafe {
k_sys::baracuda_cutlass_gemm_bf16_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rrr, ElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_tf32_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rcr, ElementKind::F32Strict) => unsafe {
k_sys::baracuda_cutlass_gemm_f32_simt_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rrr, ElementKind::F32Strict) => unsafe {
k_sys::baracuda_cutlass_gemm_f32_simt_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rcr, ElementKind::F64)
| (LayoutSku::Rrr, ElementKind::F64) => 3,
(_, ElementKind::S8) | (_, ElementKind::U8) | (_, ElementKind::I32)
| (_, ElementKind::I64)
| (_, ElementKind::Bool)
| (_, ElementKind::Fp8E4M3)
| (_, ElementKind::Fp8E5M2)
| (_, ElementKind::S4)
| (_, ElementKind::U4)
| (_, ElementKind::Bin)
| (_, ElementKind::Complex32)
| (_, ElementKind::Complex64) => 3,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn gemm_sm80_run_f64(
layout: LayoutSku,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
alpha: f64,
beta: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match layout {
LayoutSku::Rcr => unsafe {
k_sys::baracuda_cutlass_gemm_f64_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
LayoutSku::Rrr => unsafe {
k_sys::baracuda_cutlass_gemm_f64_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
}
}
#[cfg(feature = "sm80")]
pub(super) fn gemm_sm80_workspace_size_f64(layout: LayoutSku, m: i32, n: i32, k: i32) -> usize {
use baracuda_cutlass_kernels_sys as k_sys;
match layout {
LayoutSku::Rcr => unsafe {
k_sys::baracuda_cutlass_gemm_f64_rcr_sm80_workspace_size(m, n, k)
},
LayoutSku::Rrr => unsafe {
k_sys::baracuda_cutlass_gemm_f64_rrr_sm80_workspace_size(m, n, k)
},
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn gemm_sm80_can_implement_f64(
layout: LayoutSku,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match layout {
LayoutSku::Rcr => unsafe {
k_sys::baracuda_cutlass_gemm_f64_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
LayoutSku::Rrr => unsafe {
k_sys::baracuda_cutlass_gemm_f64_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn gemm_bias_sm80_run_f64(
layout: LayoutSku,
epilogue: EpilogueKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
bias: *const c_void,
alpha: f64,
beta: f64,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, epilogue) {
(LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f64_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f64_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(_, EpilogueKind::Identity) => 3,
}
}
#[cfg(feature = "sm80")]
pub(super) fn gemm_bias_sm80_workspace_size_f64(
layout: LayoutSku,
epilogue: EpilogueKind,
m: i32,
n: i32,
k: i32,
) -> usize {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, epilogue) {
(LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f64_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f64_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_workspace_size(m, n, k)
},
(_, EpilogueKind::Identity) => 0,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn gemm_bias_sm80_can_implement_f64(
layout: LayoutSku,
epilogue: EpilogueKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
bias: *const c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, epilogue) {
(LayoutSku::Rcr, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f64_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f64_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rcr, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f64_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, EpilogueKind::Bias) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f64_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, EpilogueKind::BiasRelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f64_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, EpilogueKind::BiasGelu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f64_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(LayoutSku::Rrr, EpilogueKind::BiasSilu) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f64_rrr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(_, EpilogueKind::Identity) => 3,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn batched_gemm_sm80_run(
layout: LayoutSku,
kind: ElementKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
stride_a: i64,
b: *const c_void,
ldb: i64,
stride_b: i64,
c: *const c_void,
ldc: i64,
stride_c: i64,
d: *mut c_void,
ldd: i64,
stride_d: i64,
alpha: f32,
beta: f32,
batch_count: i32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind) {
(LayoutSku::Rcr, ElementKind::F16) => unsafe {
k_sys::baracuda_cutlass_gemm_batched_f16_rcr_sm80_run(
m, n, k,
a, lda, stride_a,
b, ldb, stride_b,
c, ldc, stride_c,
d, ldd, stride_d,
alpha, beta,
batch_count,
workspace, workspace_bytes,
stream,
)
},
(LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
k_sys::baracuda_cutlass_gemm_batched_bf16_rcr_sm80_run(
m, n, k,
a, lda, stride_a,
b, ldb, stride_b,
c, ldc, stride_c,
d, ldd, stride_d,
alpha, beta,
batch_count,
workspace, workspace_bytes,
stream,
)
},
_ => 3,
}
}
#[cfg(feature = "sm80")]
pub(super) fn batched_gemm_sm80_workspace_size(
layout: LayoutSku,
kind: ElementKind,
m: i32,
n: i32,
k: i32,
batch_count: i32,
) -> usize {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind) {
(LayoutSku::Rcr, ElementKind::F16) => unsafe {
k_sys::baracuda_cutlass_gemm_batched_f16_rcr_sm80_workspace_size(
m, n, k, batch_count,
)
},
(LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
k_sys::baracuda_cutlass_gemm_batched_bf16_rcr_sm80_workspace_size(
m, n, k, batch_count,
)
},
_ => 0,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn batched_gemm_sm80_can_implement(
layout: LayoutSku,
kind: ElementKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
stride_a: i64,
b: *const c_void,
ldb: i64,
stride_b: i64,
c: *const c_void,
ldc: i64,
stride_c: i64,
d: *mut c_void,
ldd: i64,
stride_d: i64,
batch_count: i32,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind) {
(LayoutSku::Rcr, ElementKind::F16) => unsafe {
k_sys::baracuda_cutlass_gemm_batched_f16_rcr_sm80_can_implement(
m, n, k,
a, lda, stride_a,
b, ldb, stride_b,
c, ldc, stride_c,
d, ldd, stride_d,
batch_count,
)
},
(LayoutSku::Rcr, ElementKind::Bf16) => unsafe {
k_sys::baracuda_cutlass_gemm_batched_bf16_rcr_sm80_can_implement(
m, n, k,
a, lda, stride_a,
b, ldb, stride_b,
c, ldc, stride_c,
d, ldd, stride_d,
batch_count,
)
},
_ => 3,
}
}
#[cfg(feature = "sm80")]
pub(super) unsafe fn grouped_gemm_rcr_sm80_sufficient(
kind: ElementKind,
h_m: *const i32,
h_n: *const i32,
h_k: *const i32,
group_count: i32,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match kind {
ElementKind::F16 => unsafe {
k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_sufficient(h_m, h_n, h_k, group_count)
},
ElementKind::Bf16 => unsafe {
k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_sufficient(h_m, h_n, h_k, group_count)
},
ElementKind::F32
| ElementKind::F32Strict
| ElementKind::F64
| ElementKind::S8
| ElementKind::U8
| ElementKind::I32
| ElementKind::I64
| ElementKind::Bool
| ElementKind::Fp8E4M3
| ElementKind::Fp8E5M2
| ElementKind::S4
| ElementKind::U4
| ElementKind::Bin
| ElementKind::Complex32
| ElementKind::Complex64 => 0,
}
}
#[cfg(feature = "sm80")]
pub(super) unsafe fn grouped_gemm_rcr_sm80_scratch_bytes(
kind: ElementKind,
h_m: *const i32,
h_n: *const i32,
h_k: *const i32,
group_count: i32,
threadblock_count: i32,
) -> usize {
use baracuda_cutlass_kernels_sys as k_sys;
match kind {
ElementKind::F16 => unsafe {
k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_scratch_bytes(
h_m, h_n, h_k, group_count, threadblock_count,
)
},
ElementKind::Bf16 => unsafe {
k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_scratch_bytes(
h_m, h_n, h_k, group_count, threadblock_count,
)
},
ElementKind::F32
| ElementKind::F32Strict
| ElementKind::F64
| ElementKind::S8
| ElementKind::U8
| ElementKind::I32
| ElementKind::I64
| ElementKind::Bool
| ElementKind::Fp8E4M3
| ElementKind::Fp8E5M2
| ElementKind::S4
| ElementKind::U4
| ElementKind::Bin
| ElementKind::Complex32
| ElementKind::Complex64 => 0,
}
}
#[cfg(feature = "sm80")]
pub(super) unsafe fn grouped_gemm_rcr_sm80_can_implement(
kind: ElementKind,
h_m: *const i32,
h_n: *const i32,
h_k: *const i32,
group_count: i32,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match kind {
ElementKind::F16 => unsafe {
k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_can_implement(h_m, h_n, h_k, group_count)
},
ElementKind::Bf16 => unsafe {
k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_can_implement(h_m, h_n, h_k, group_count)
},
ElementKind::F32
| ElementKind::F32Strict
| ElementKind::F64
| ElementKind::S8
| ElementKind::U8
| ElementKind::I32
| ElementKind::I64
| ElementKind::Bool
| ElementKind::Fp8E4M3
| ElementKind::Fp8E5M2
| ElementKind::S4
| ElementKind::U4
| ElementKind::Bin
| ElementKind::Complex32
| ElementKind::Complex64 => 3,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn grouped_gemm_rcr_sm80_run(
kind: ElementKind,
group_count: i32,
threadblock_count: i32,
d_problem_sizes: *const c_void,
d_ptr_a: *const c_void,
d_ptr_b: *const c_void,
d_ptr_c: *const c_void,
d_ptr_d: *mut c_void,
d_lda: *const c_void,
d_ldb: *const c_void,
d_ldc: *const c_void,
d_ldd: *const c_void,
h_problem_sizes: *const c_void,
alpha: f32,
beta: f32,
scratch: *mut c_void,
scratch_bytes: usize,
stream: *mut c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match kind {
ElementKind::F16 => unsafe {
k_sys::baracuda_cutlass_grouped_gemm_f16_rcr_sm80_run(
group_count, threadblock_count,
d_problem_sizes,
d_ptr_a, d_ptr_b, d_ptr_c, d_ptr_d,
d_lda, d_ldb, d_ldc, d_ldd,
h_problem_sizes,
alpha, beta,
scratch, scratch_bytes,
stream,
)
},
ElementKind::Bf16 => unsafe {
k_sys::baracuda_cutlass_grouped_gemm_bf16_rcr_sm80_run(
group_count, threadblock_count,
d_problem_sizes,
d_ptr_a, d_ptr_b, d_ptr_c, d_ptr_d,
d_lda, d_ldb, d_ldc, d_ldd,
h_problem_sizes,
alpha, beta,
scratch, scratch_bytes,
stream,
)
},
ElementKind::F32
| ElementKind::F32Strict
| ElementKind::F64
| ElementKind::S8
| ElementKind::U8
| ElementKind::I32
| ElementKind::I64
| ElementKind::Bool
| ElementKind::Fp8E4M3
| ElementKind::Fp8E5M2
| ElementKind::S4
| ElementKind::U4
| ElementKind::Bin
| ElementKind::Complex32
| ElementKind::Complex64 => 3,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn int_gemm_rcr_sm80_run(
layout: LayoutSku,
kind: ElementKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
alpha: f32,
beta: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind) {
(LayoutSku::Rcr, ElementKind::S8) => unsafe {
k_sys::baracuda_cutlass_gemm_s8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rcr, ElementKind::U8) => unsafe {
k_sys::baracuda_cutlass_gemm_u8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
alpha, beta, workspace, workspace_bytes, stream,
)
},
(LayoutSku::Rrr, ElementKind::S8) | (LayoutSku::Rrr, ElementKind::U8) => 3,
_ => 3,
}
}
#[cfg(feature = "sm80")]
pub(super) fn int_gemm_rcr_sm80_workspace_size(
layout: LayoutSku,
kind: ElementKind,
m: i32,
n: i32,
k: i32,
) -> usize {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind) {
(LayoutSku::Rcr, ElementKind::S8) => unsafe {
k_sys::baracuda_cutlass_gemm_s8_rcr_sm80_workspace_size(m, n, k)
},
(LayoutSku::Rcr, ElementKind::U8) => unsafe {
k_sys::baracuda_cutlass_gemm_u8_rcr_sm80_workspace_size(m, n, k)
},
_ => 0,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn int_gemm_rcr_sm80_can_implement(
layout: LayoutSku,
kind: ElementKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
match (layout, kind) {
(LayoutSku::Rcr, ElementKind::S8) => unsafe {
k_sys::baracuda_cutlass_gemm_s8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rcr, ElementKind::U8) => unsafe {
k_sys::baracuda_cutlass_gemm_u8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
)
},
(LayoutSku::Rrr, ElementKind::S8) | (LayoutSku::Rrr, ElementKind::U8) => 3,
_ => 3,
}
}
use crate::types::BiasElementKind;
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn int_gemm_bias_rcr_sm80_run(
layout: LayoutSku,
kind: ElementKind,
epilogue: EpilogueKind,
bias_kind: BiasElementKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
bias: *const c_void,
alpha: f32,
beta: f32,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
if !matches!(layout, LayoutSku::Rcr) {
return 3;
}
match (kind, epilogue, bias_kind) {
(ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_run(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd,
bias, alpha, beta, workspace, workspace_bytes, stream,
)
},
(_, EpilogueKind::Identity, _) => 3,
_ => 3,
}
}
#[cfg(feature = "sm80")]
pub(super) fn int_gemm_bias_rcr_sm80_workspace_size(
layout: LayoutSku,
kind: ElementKind,
epilogue: EpilogueKind,
bias_kind: BiasElementKind,
m: i32,
n: i32,
k: i32,
) -> usize {
use baracuda_cutlass_kernels_sys as k_sys;
if !matches!(layout, LayoutSku::Rcr) {
return 0;
}
match (kind, epilogue, bias_kind) {
(ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
},
(ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_workspace_size(m, n, k)
},
_ => 0,
}
}
#[cfg(feature = "sm80")]
#[allow(clippy::too_many_arguments)]
pub(super) unsafe fn int_gemm_bias_rcr_sm80_can_implement(
layout: LayoutSku,
kind: ElementKind,
epilogue: EpilogueKind,
bias_kind: BiasElementKind,
m: i32,
n: i32,
k: i32,
a: *const c_void,
lda: i64,
b: *const c_void,
ldb: i64,
c: *const c_void,
ldc: i64,
d: *mut c_void,
ldd: i64,
bias: *const c_void,
) -> i32 {
use baracuda_cutlass_kernels_sys as k_sys;
if !matches!(layout, LayoutSku::Rcr) {
return 3;
}
match (kind, epilogue, bias_kind) {
(ElementKind::S8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32bias_s8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_s8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_s8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_s8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::S8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_i32bias_s8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::S8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_s8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::S8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_s8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::S8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_s8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::U8, EpilogueKind::Bias, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_f32bias_u8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_f32bias_u8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_f32bias_u8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::F32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_f32bias_u8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::U8, EpilogueKind::Bias, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_i32bias_u8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::U8, EpilogueKind::BiasRelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_relu_i32bias_u8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::U8, EpilogueKind::BiasGelu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_gelu_i32bias_u8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
(ElementKind::U8, EpilogueKind::BiasSilu, BiasElementKind::I32) => unsafe {
k_sys::baracuda_cutlass_gemm_bias_silu_i32bias_u8_rcr_sm80_can_implement(
m, n, k, a, lda, b, ldb, c, ldc, d, ldd, bias,
)
},
_ => 3,
}
}
}
fn min_elements_row_major(rows: i32, cols: i32, ld: i64) -> Option<usize> {
let r = (rows - 1) as i64;
let needed = r.checked_mul(ld)?.checked_add(cols as i64)?;
usize::try_from(needed).ok()
}
fn min_elements_col_major(rows: i32, cols: i32, ld: i64) -> Option<usize> {
let c = (cols - 1) as i64;
let needed = c.checked_mul(ld)?.checked_add(rows as i64)?;
usize::try_from(needed).ok()
}
#[cfg(test)]
fn min_elements_rcr_a(rows: i32, cols: i32, ld: i64) -> Option<usize> {
min_elements_row_major(rows, cols, ld)
}
#[cfg(test)]
fn min_elements_rcr_b(rows: i32, cols: i32, ld: i64) -> Option<usize> {
min_elements_col_major(rows, cols, ld)
}
#[cfg(test)]
fn min_elements_rcr_cd(rows: i32, cols: i32, ld: i64) -> Option<usize> {
min_elements_row_major(rows, cols, ld)
}
fn check_descriptor(desc: &GemmDescriptor) -> Result<()> {
if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
return Err(Error::InvalidProblem("M, N, K must all be positive"));
}
Ok(())
}
fn check_args<T: CutlassElement>(desc: &GemmDescriptor, args: &GemmArgs<'_, T>) -> Result<()> {
match (desc.epilogue.requires_bias(), &args.bias) {
(false, Some(_)) => {
return Err(Error::InvalidProblem(
"args.bias must be None when descriptor.epilogue is Identity",
));
}
(true, None) => {
return Err(Error::InvalidProblem(
"args.bias is required when descriptor.epilogue is in the Bias family \
(Bias / BiasRelu / BiasGelu / BiasSilu)",
));
}
(false, None) | (true, Some(_)) => {}
}
if let Some(bias) = &args.bias {
if bias.len != desc.n {
return Err(Error::InvalidProblem(
"bias vector length must equal N",
));
}
if bias.stride != 1 {
return Err(Error::Unsupported(
"bias vector must be contiguous (stride 1) — strided bias not supported",
));
}
if bias.data.len() < desc.n as usize {
return Err(Error::BufferTooSmall {
needed: desc.n as usize,
got: bias.data.len(),
});
}
}
if args.a.rows != desc.m || args.a.cols != desc.k {
return Err(Error::InvalidProblem("A shape doesn't match descriptor (M, K)"));
}
if args.b.rows != desc.k || args.b.cols != desc.n {
return Err(Error::InvalidProblem("B shape doesn't match descriptor (K, N)"));
}
if args.d.rows != desc.m || args.d.cols != desc.n {
return Err(Error::InvalidProblem("D shape doesn't match descriptor (M, N)"));
}
if let Some(c) = &args.c {
if c.rows != desc.m || c.cols != desc.n {
return Err(Error::InvalidProblem("C shape doesn't match descriptor (M, N)"));
}
}
if args.a.ld < desc.k as i64 {
return Err(Error::InvalidProblem("A leading dimension must be >= K"));
}
let b_min_ld = match desc.layout {
LayoutSku::Rcr => desc.k as i64,
LayoutSku::Rrr => desc.n as i64,
};
if args.b.ld < b_min_ld {
return Err(Error::InvalidProblem(match desc.layout {
LayoutSku::Rcr => "B leading dimension must be >= K (column-major Rcr layout)",
LayoutSku::Rrr => "B leading dimension must be >= N (row-major Rrr layout)",
}));
}
if args.d.ld < desc.n as i64 {
return Err(Error::InvalidProblem("D leading dimension must be >= N"));
}
if let Some(c) = &args.c {
if c.ld < desc.n as i64 {
return Err(Error::InvalidProblem("C leading dimension must be >= N"));
}
}
let need_a = min_elements_row_major(args.a.rows, args.a.cols, args.a.ld)
.ok_or(Error::InvalidProblem("A storage size overflow"))?;
if args.a.data.len() < need_a {
return Err(Error::BufferTooSmall {
needed: need_a,
got: args.a.data.len(),
});
}
let need_b = match desc.layout {
LayoutSku::Rcr => min_elements_col_major(args.b.rows, args.b.cols, args.b.ld),
LayoutSku::Rrr => min_elements_row_major(args.b.rows, args.b.cols, args.b.ld),
}
.ok_or(Error::InvalidProblem("B storage size overflow"))?;
if args.b.data.len() < need_b {
return Err(Error::BufferTooSmall {
needed: need_b,
got: args.b.data.len(),
});
}
let need_d = min_elements_row_major(args.d.rows, args.d.cols, args.d.ld)
.ok_or(Error::InvalidProblem("D storage size overflow"))?;
if args.d.data.len() < need_d {
return Err(Error::BufferTooSmall {
needed: need_d,
got: args.d.data.len(),
});
}
if let Some(c) = &args.c {
let need_c = min_elements_row_major(c.rows, c.cols, c.ld)
.ok_or(Error::InvalidProblem("C storage size overflow"))?;
if c.data.len() < need_c {
return Err(Error::BufferTooSmall {
needed: need_c,
got: c.data.len(),
});
}
}
Ok(())
}
mod cublas_backend {
use core::cell::RefCell;
use baracuda_cublas::Handle as CublasHandle;
use baracuda_driver::Stream;
thread_local! {
static HANDLE_CACHE: RefCell<Vec<(usize, CublasHandle)>> =
const { RefCell::new(Vec::new()) };
}
pub(super) fn handle_for(stream: &Stream) -> crate::Result<CublasHandle> {
let ctx_key = stream.context().as_raw() as usize;
let handle = HANDLE_CACHE.with(|cache| -> crate::Result<CublasHandle> {
let mut cache = cache.borrow_mut();
if let Some((_, h)) = cache.iter().find(|(k, _)| *k == ctx_key) {
return Ok(h.clone());
}
stream
.context()
.set_current()
.map_err(crate::Error::Driver)?;
let h = {
let mut last_err = None;
let mut handle: Option<CublasHandle> = None;
for attempt in 0..5 {
match CublasHandle::new() {
Ok(h) => { handle = Some(h); break }
Err(e) => {
last_err = Some(e);
std::thread::sleep(std::time::Duration::from_millis(
50 * (attempt as u64 + 1),
));
}
}
}
match handle {
Some(h) => h,
None => {
let _ = last_err; return Err(crate::Error::Unsupported(
"cuBLAS handle creation failed after 5 retries \
(library missing, device unavailable, or \
persistent driver-init contention)",
));
}
}
};
cache.push((ctx_key, h.clone()));
Ok(h)
})?;
handle
.set_stream(stream)
.map_err(|_| crate::Error::Unsupported(
"cuBLAS set_stream failed",
))?;
Ok(handle)
}
}
const CUBLAS_GEMM_ALGO: i32 = -1;
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
enum BackendChoice {
Cutlass { arch: ArchSku },
Cublas,
Ozaki { slices: u8 },
}
impl BackendChoice {
fn as_public(self) -> BackendKind {
match self {
BackendChoice::Cutlass { .. } => BackendKind::Cutlass,
BackendChoice::Cublas => BackendKind::Cublas,
BackendChoice::Ozaki { slices } => BackendKind::Ozaki { slices },
}
}
}
fn should_use_cublas_for_fp(
desc: &GemmDescriptor,
element: ElementKind,
) -> bool {
if desc.epilogue.requires_bias() {
return false;
}
match element {
ElementKind::F16 | ElementKind::Bf16 => desc.m >= 2 && desc.m < 128,
ElementKind::F32 => false,
ElementKind::F32Strict => false,
ElementKind::F64 => false,
_ => false,
}
}
#[cfg_attr(not(feature = "ozimmu"), allow(unused_variables))]
fn validate_ozaki_request(
desc: &GemmDescriptor,
element: ElementKind,
slices: u8,
) -> Result<()> {
#[cfg(not(feature = "ozimmu"))]
{
return Err(Error::Unsupported(
"PlanPreference::prefer_backend = Some(Ozaki {..}) requires the \
`ozimmu` cargo feature on baracuda-cutlass (off by default — \
enable on baracuda-kernels too if going through the kernels facade)",
));
}
#[cfg(feature = "ozimmu")]
{
if element != ElementKind::F64 {
return Err(Error::Unsupported(
"BackendKind::Ozaki is FP64-only (Ozaki-scheme synthesizes \
DGEMM from int8; f16/bf16/f32/F32Strict have no Ozaki path)",
));
}
if desc.epilogue != EpilogueKind::Identity {
return Err(Error::Unsupported(
"BackendKind::Ozaki only supports the Identity epilogue \
(no fused bias / activation chain on the Ozaki path)",
));
}
let s = slices & 0x1F; let v = slices >> 5; if s != 0 && !(3..=18).contains(&s) {
return Err(Error::Unsupported(
"BackendKind::Ozaki slice count (low 5 bits) must be 0 \
(auto) or 3..=18",
));
}
if v > 3 {
return Err(Error::Unsupported(
"BackendKind::Ozaki variant (high 3 bits) must be 0 (Base), \
1 (EF), 2 (RN), or 3 (H)",
));
}
Ok(())
}
}
fn cublas_dtype_for(kind: ElementKind) -> Option<baracuda_cublas_sys::functions::cudaDataType_t> {
use baracuda_cublas_sys::functions::cudaDataType_t;
match kind {
ElementKind::F16 => Some(cudaDataType_t::R_16F),
ElementKind::Bf16 => Some(cudaDataType_t::R_16BF),
ElementKind::F32 => Some(cudaDataType_t::R_32F),
ElementKind::F64 => Some(cudaDataType_t::R_64F),
_ => None,
}
}
#[derive(Debug)]
pub struct GemmPlan<T: CutlassElement> {
desc: GemmDescriptor,
sku: GemmSku,
backend: BackendChoice,
_element: PhantomData<T>,
}
impl<T: CutlassElement> GemmPlan<T> {
pub fn select(stream: &Stream, desc: &GemmDescriptor, pref: PlanPreference) -> Result<Self> {
check_descriptor(desc)?;
let element = T::KIND;
if let Some(BackendKind::Ozaki { slices }) = pref.prefer_backend {
validate_ozaki_request(desc, element, slices)?;
let arch_for_sku = pick_arch(stream, desc, pref)?;
let backend = BackendChoice::Ozaki { slices };
let sku = GemmSku {
arch: arch_for_sku,
layout: desc.layout,
epilogue: desc.epilogue,
element,
bias_element: None,
};
return Ok(Self {
desc: *desc,
sku,
backend,
_element: PhantomData,
});
}
let use_cublas = match pref.prefer_backend {
Some(BackendKind::Cublas) => {
if desc.epilogue.requires_bias() {
return Err(Error::Unsupported(
"cuBLAS backend doesn't fuse bias activations \
(use Cutlass backend for Bias* epilogues)",
));
}
if cublas_dtype_for(element).is_none() {
return Err(Error::Unsupported(
"cuBLAS backend has no GemmEx dtype for this element \
(F32Strict / integer / FP8 stay on Cutlass)",
));
}
true
}
Some(BackendKind::Cutlass) => false,
Some(BackendKind::Ozaki { .. }) => {
false
}
Some(_) => {
should_use_cublas_for_fp(desc, element)
&& cublas_dtype_for(element).is_some()
}
None => {
should_use_cublas_for_fp(desc, element)
&& cublas_dtype_for(element).is_some()
}
};
let (backend, sku_arch) = if use_cublas {
let arch_for_sku = pick_arch(stream, desc, pref)?;
(BackendChoice::Cublas, arch_for_sku)
} else {
let arch = pick_arch(stream, desc, pref)?;
(BackendChoice::Cutlass { arch }, arch)
};
let sku = GemmSku {
arch: sku_arch,
layout: desc.layout,
epilogue: desc.epilogue,
element,
bias_element: None,
};
Ok(Self {
desc: *desc,
sku,
backend,
_element: PhantomData,
})
}
pub fn backend(&self) -> BackendKind {
self.backend.as_public()
}
pub fn can_implement(&self, args: &GemmArgs<'_, T>) -> Result<()> {
check_args(&self.desc, args)?;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc) = match &args.c {
Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
None => (core::ptr::null(), 0i64),
};
let bias_ptr = args
.bias
.as_ref()
.map(|b| b.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let bias_family = self.sku.epilogue.requires_bias();
let status = match (self.sku.arch, bias_family) {
#[cfg(feature = "sm80")]
(ArchSku::Sm80, false) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
dispatch::gemm_sm80_can_implement_f64(
self.sku.layout,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
)
},
#[cfg(feature = "sm80")]
(ArchSku::Sm80, false) => unsafe {
dispatch::gemm_sm80_can_implement(
self.sku.layout,
T::KIND,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
)
},
#[cfg(feature = "sm80")]
(ArchSku::Sm80, true) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
dispatch::gemm_bias_sm80_can_implement_f64(
self.sku.layout,
self.sku.epilogue,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
bias_ptr,
)
},
#[cfg(feature = "sm80")]
(ArchSku::Sm80, true) => unsafe {
dispatch::gemm_bias_sm80_can_implement(
self.sku.layout,
T::KIND,
self.sku.epilogue,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
bias_ptr,
)
},
#[cfg(not(feature = "sm80"))]
(ArchSku::Sm80, _) => {
return Err(Error::Unsupported(
"sm80 selected but the `sm80` feature isn't enabled",
));
}
(ArchSku::Sm90a, _) => {
return Err(Error::Unsupported(
"sm90a kernels not yet shipped (deferred until Hopper hardware available for validation)",
));
}
(ArchSku::Sm89, _) => {
return Err(Error::Unsupported(
"Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
));
}
};
status_to_result(status)
}
pub fn workspace_size(&self) -> usize {
let bias_family = self.sku.epilogue.requires_bias();
match (self.sku.arch, bias_family) {
#[cfg(feature = "sm80")]
(ArchSku::Sm80, false) if <T::Scalar as ScalarType>::IS_F64 => {
dispatch::gemm_sm80_workspace_size_f64(
self.sku.layout,
self.desc.m, self.desc.n, self.desc.k,
)
}
#[cfg(feature = "sm80")]
(ArchSku::Sm80, false) => dispatch::gemm_sm80_workspace_size(
self.sku.layout,
T::KIND,
self.desc.m, self.desc.n, self.desc.k,
),
#[cfg(feature = "sm80")]
(ArchSku::Sm80, true) if <T::Scalar as ScalarType>::IS_F64 => {
dispatch::gemm_bias_sm80_workspace_size_f64(
self.sku.layout,
self.sku.epilogue,
self.desc.m, self.desc.n, self.desc.k,
)
}
#[cfg(feature = "sm80")]
(ArchSku::Sm80, true) => dispatch::gemm_bias_sm80_workspace_size(
self.sku.layout,
T::KIND,
self.sku.epilogue,
self.desc.m, self.desc.n, self.desc.k,
),
#[cfg(not(feature = "sm80"))]
(ArchSku::Sm80, _) => 0,
(ArchSku::Sm90a, _) => 0,
(ArchSku::Sm89, _) => 0,
}
}
pub fn sku(&self) -> GemmSku {
self.sku
}
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee()
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: GemmArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
let needed = self.workspace_size();
let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
Workspace::None => {
if needed != 0 {
return Err(Error::WorkspaceTooSmall {
needed,
got: 0,
});
}
(core::ptr::null_mut(), 0)
}
Workspace::Borrowed(slice) => {
if slice.len() < needed {
return Err(Error::WorkspaceTooSmall {
needed,
got: slice.len(),
});
}
(slice.as_raw().0 as *mut c_void, slice.len())
}
};
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc) = match &args.c {
Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
None => (core::ptr::null(), 0i64),
};
let bias_ptr = args
.bias
.as_ref()
.map(|b| b.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let beta_eff = if args.c.is_some() { args.beta } else { <T::Scalar as Default>::default() };
let stream_raw = stream.as_raw();
if matches!(self.backend, BackendChoice::Cublas) {
let capturing = stream.is_capturing().unwrap_or(false);
if !capturing {
return self.run_cublas(stream, args, beta_eff);
}
}
#[cfg(feature = "ozimmu")]
if let BackendChoice::Ozaki { slices } = self.backend {
let capturing = stream.is_capturing().unwrap_or(false);
if !capturing {
return self.run_ozaki(stream, args, beta_eff, slices);
}
}
#[cfg(not(feature = "ozimmu"))]
if matches!(self.backend, BackendChoice::Ozaki { .. }) {
return Err(Error::Unsupported(
"BackendChoice::Ozaki selected without `ozimmu` cargo feature",
));
}
let bias_family = self.sku.epilogue.requires_bias();
let status = match (self.sku.arch, bias_family) {
#[cfg(feature = "sm80")]
(ArchSku::Sm80, false) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
dispatch::gemm_sm80_run_f64(
self.sku.layout,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
args.alpha.to_f64(),
beta_eff.to_f64(),
ws_ptr, ws_bytes, stream_raw,
)
},
#[cfg(feature = "sm80")]
(ArchSku::Sm80, false) => unsafe {
dispatch::gemm_sm80_run(
self.sku.layout,
T::KIND,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
args.alpha.to_f32(),
beta_eff.to_f32(),
ws_ptr, ws_bytes, stream_raw,
)
},
#[cfg(feature = "sm80")]
(ArchSku::Sm80, true) if <T::Scalar as ScalarType>::IS_F64 => unsafe {
dispatch::gemm_bias_sm80_run_f64(
self.sku.layout,
self.sku.epilogue,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
bias_ptr,
args.alpha.to_f64(),
beta_eff.to_f64(),
ws_ptr, ws_bytes, stream_raw,
)
},
#[cfg(feature = "sm80")]
(ArchSku::Sm80, true) => unsafe {
dispatch::gemm_bias_sm80_run(
self.sku.layout,
T::KIND,
self.sku.epilogue,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
bias_ptr,
args.alpha.to_f32(),
beta_eff.to_f32(),
ws_ptr, ws_bytes, stream_raw,
)
},
#[cfg(not(feature = "sm80"))]
(ArchSku::Sm80, _) => {
return Err(Error::Unsupported(
"sm80 selected but the `sm80` feature isn't enabled",
));
}
(ArchSku::Sm90a, _) => {
return Err(Error::Unsupported(
"sm90a kernels not yet implemented (Phase 4c)",
));
}
(ArchSku::Sm89, _) => {
return Err(Error::Unsupported(
"Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
));
}
};
status_to_result(status)
}
fn run_cublas(
&self,
stream: &Stream,
args: GemmArgs<'_, T>,
beta_eff: T::Scalar,
) -> Result<()> {
use baracuda_cublas::Op as CublasOp;
use baracuda_cublas_sys::functions::cublasComputeType_t;
if self.sku.epilogue.requires_bias() {
return Err(Error::Unsupported(
"cuBLAS backend doesn't fuse bias activations \
(caller forced a Bias* epilogue onto the cuBLAS path)",
));
}
let handle = cublas_backend::handle_for(stream)?;
let m = self.desc.m;
let n = self.desc.n;
let k = self.desc.k;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc_arg) = match &args.c {
Some(c) => (c.data.as_raw().0 as *mut c_void, c.ld as i32),
None => (d_ptr, args.d.ld as i32),
};
let (transa, transb) = match self.desc.layout {
LayoutSku::Rcr => (CublasOp::T, CublasOp::N),
LayoutSku::Rrr => (CublasOp::N, CublasOp::N),
};
let cublas_lda = args.b.ld as i32; let cublas_ldb = args.a.ld as i32; let ldd_arg = args.d.ld as i32;
if <T::Scalar as ScalarType>::IS_F64 {
use baracuda_cublas_sys::cublasOperation_t;
let to_raw = |op: CublasOp| match op {
CublasOp::N => cublasOperation_t::N,
CublasOp::T => cublasOperation_t::T,
CublasOp::C => cublasOperation_t::C,
};
let alpha_f64 = args.alpha.to_f64();
let beta_f64 = beta_eff.to_f64();
let c_api = baracuda_cublas_sys::cublas()
.map_err(|_| Error::Unsupported("cuBLAS library unavailable"))?;
let dgemm = c_api
.cublas_dgemm()
.map_err(|_| Error::Unsupported("cublasDgemm symbol unavailable"))?;
let status = unsafe {
dgemm(
handle.as_raw(),
to_raw(transa),
to_raw(transb),
n,
m,
k,
&alpha_f64,
b_ptr as *const f64,
cublas_lda,
a_ptr as *const f64,
cublas_ldb,
&beta_f64,
if args.c.is_some() {
return Err(Error::Unsupported(
"cuBLAS f64 GEMM with explicit C operand is not yet wired \
(D and C alias differently than cuBLAS expects); \
use Cutlass backend or set c = None",
));
} else {
d_ptr as *mut f64
},
ldd_arg,
)
};
return match status {
baracuda_cublas_sys::cublasStatus_t::SUCCESS => Ok(()),
_ => Err(Error::CutlassInternal(status.0)),
};
}
let dtype = cublas_dtype_for(self.sku.element).ok_or(Error::Unsupported(
"cuBLAS backend selected for element kind without a cuBLAS dtype mapping",
))?;
let a_type = dtype;
let b_type = dtype;
let c_type = dtype;
let alpha_f32 = args.alpha.to_f32();
let beta_f32 = beta_eff.to_f32();
if args.c.is_some() {
return Err(Error::Unsupported(
"cuBLAS GemmPlan path requires c = None \
(cublasGemmEx writes the output in-place into the C operand; \
explicit-C with D ≠ C requires an extra copy step — \
force Cutlass backend if you need it)",
));
}
let _ = (c_ptr, ldc_arg);
unsafe {
baracuda_cublas::gemm_ex(
&handle,
transa,
transb,
n,
m,
k,
&alpha_f32 as *const f32 as *const c_void,
b_ptr,
b_type,
cublas_lda,
a_ptr,
a_type,
cublas_ldb,
&beta_f32 as *const f32 as *const c_void,
d_ptr,
c_type,
ldd_arg,
cublasComputeType_t::Compute32F,
CUBLAS_GEMM_ALGO,
)
.map_err(|_| Error::CutlassInternal(-1))
}
}
#[cfg(feature = "ozimmu")]
fn run_ozaki(
&self,
stream: &Stream,
args: GemmArgs<'_, T>,
beta_eff: T::Scalar,
slices: u8,
) -> Result<()> {
use baracuda_ozimmu::{Op as OzakiOp, OzakiSlices, OzakiVariant};
if !<T::Scalar as ScalarType>::IS_F64 {
return Err(Error::Unsupported(
"BackendChoice::Ozaki reached on non-f64 element \
(select() guard should have rejected this)",
));
}
if args.c.is_some() {
return Err(Error::Unsupported(
"ozIMMU GemmPlan path requires c = None \
(the Ozaki path writes its output in-place into the C \
operand of the underlying cuBLAS GEMM — explicit-C with \
D ≠ C requires an extra copy step that the Phase 44 \
alpha does not yet wire; force Cutlass backend if needed)",
));
}
let s = slices & 0x1F;
let v = slices >> 5;
let slice_choice = match s {
0 => OzakiSlices::Auto,
3 => OzakiSlices::S3,
4 => OzakiSlices::S4,
5 => OzakiSlices::S5,
6 => OzakiSlices::S6,
7 => OzakiSlices::S7,
8 => OzakiSlices::S8,
9 => OzakiSlices::S9,
10 => OzakiSlices::S10,
11 => OzakiSlices::S11,
12 => OzakiSlices::S12,
13 => OzakiSlices::S13,
14 => OzakiSlices::S14,
15 => OzakiSlices::S15,
16 => OzakiSlices::S16,
17 => OzakiSlices::S17,
18 => OzakiSlices::S18,
_ => {
return Err(Error::Unsupported(
"ozIMMU slice count out of range (validated at select; \
this is unreachable)",
));
}
};
let variant_choice = match v {
0 => OzakiVariant::Base,
1 => OzakiVariant::EF,
2 => OzakiVariant::RN,
3 => OzakiVariant::H,
_ => {
return Err(Error::Unsupported(
"ozIMMU variant out of range (validated at select; \
this is unreachable)",
));
}
};
let handle = ozimmu_backend::handle_for(stream)?;
let (transa, transb) = match self.desc.layout {
LayoutSku::Rcr => (OzakiOp::T, OzakiOp::N),
LayoutSku::Rrr => (OzakiOp::N, OzakiOp::N),
};
let m = self.desc.m as usize;
let n = self.desc.n as usize;
let k = self.desc.k as usize;
let lda = args.b.ld as usize; let ldb = args.a.ld as usize; let ldc = args.d.ld as usize;
let a_ptr = args.a.data.as_raw().0 as *const f64;
let b_ptr = args.b.data.as_raw().0 as *const f64;
let d_ptr = args.d.data.as_raw().0 as *mut f64;
let alpha = args.alpha.to_f64();
let beta = beta_eff.to_f64();
unsafe {
handle.dgemm_with_variant(
transa, transb,
n, m, k,
alpha,
b_ptr, lda,
a_ptr, ldb,
beta,
d_ptr, ldc,
slice_choice,
variant_choice,
)
.map_err(|e| {
use baracuda_ozimmu::Error as OzErr;
match e {
OzErr::DgemmFailed(s) => Error::CutlassInternal(s),
_ => Error::Unsupported(
"ozIMMU dgemm rejected the request (see logs)",
),
}
})
}
}
}
#[cfg(feature = "ozimmu")]
mod ozimmu_backend {
use core::cell::RefCell;
use std::rc::Rc;
use baracuda_driver::Stream;
use baracuda_ozimmu::Handle as OzimmuHandle;
thread_local! {
static HANDLE_CACHE: RefCell<Vec<(usize, Rc<OzimmuHandle>)>> =
const { RefCell::new(Vec::new()) };
}
pub(super) fn handle_for(stream: &Stream) -> crate::Result<Rc<OzimmuHandle>> {
let ctx_key = stream.context().as_raw() as usize;
let handle = HANDLE_CACHE.with(|cache| -> crate::Result<Rc<OzimmuHandle>> {
let mut cache = cache.borrow_mut();
if let Some((_, h)) = cache.iter().find(|(k, _)| *k == ctx_key) {
return Ok(h.clone());
}
stream
.context()
.set_current()
.map_err(crate::Error::Driver)?;
let mut last_status: Option<i32> = None;
let mut handle: Option<OzimmuHandle> = None;
for attempt in 0..5 {
match OzimmuHandle::new() {
Ok(h) => { handle = Some(h); break }
Err(e) => {
if let baracuda_ozimmu::Error::CreateFailed(s) = e {
last_status = Some(s);
}
std::thread::sleep(std::time::Duration::from_millis(
50 * (attempt as u64 + 1),
));
}
}
}
let h = match handle {
Some(h) => h,
None => {
let _ = last_status;
return Err(crate::Error::Unsupported(
"ozIMMU handle creation failed after 5 retries \
(library missing, device unavailable, or persistent \
init contention)",
));
}
};
let rc = Rc::new(h);
cache.push((ctx_key, rc.clone()));
Ok(rc)
})?;
handle.set_stream(stream);
Ok(handle)
}
}
fn check_batched_descriptor(desc: &BatchedGemmDescriptor) -> Result<()> {
if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
return Err(Error::InvalidProblem("M, N, K must all be positive"));
}
if desc.batch_count <= 0 {
return Err(Error::InvalidProblem("batch_count must be positive"));
}
if desc.epilogue != EpilogueKind::Identity {
return Err(Error::Unsupported(
"BatchedGemmPlan v1 supports only EpilogueKind::Identity",
));
}
Ok(())
}
fn check_batched_args<T: CutlassElement>(
desc: &BatchedGemmDescriptor,
args: &BatchedGemmArgs<'_, T>,
) -> Result<()> {
if args.a.rows != desc.m || args.a.cols != desc.k {
return Err(Error::InvalidProblem("A shape doesn't match descriptor (M, K)"));
}
if args.b.rows != desc.k || args.b.cols != desc.n {
return Err(Error::InvalidProblem("B shape doesn't match descriptor (K, N)"));
}
if args.d.rows != desc.m || args.d.cols != desc.n {
return Err(Error::InvalidProblem("D shape doesn't match descriptor (M, N)"));
}
if let Some(c) = &args.c {
if c.rows != desc.m || c.cols != desc.n {
return Err(Error::InvalidProblem("C shape doesn't match descriptor (M, N)"));
}
}
if args.a.ld < desc.k as i64 {
return Err(Error::InvalidProblem("A leading dimension must be >= K"));
}
let b_min_ld = match desc.layout {
LayoutSku::Rcr => desc.k as i64,
LayoutSku::Rrr => desc.n as i64,
};
if args.b.ld < b_min_ld {
return Err(Error::InvalidProblem("B leading dimension too small for layout"));
}
if args.d.ld < desc.n as i64 {
return Err(Error::InvalidProblem("D leading dimension must be >= N"));
}
if let Some(c) = &args.c {
if c.ld < desc.n as i64 {
return Err(Error::InvalidProblem("C leading dimension must be >= N"));
}
}
fn need_for_batches(
per_batch_min: usize,
stride: i64,
batch_count: i32,
) -> Option<usize> {
if batch_count <= 1 || stride == 0 {
return Some(per_batch_min);
}
let extra = stride.checked_mul((batch_count - 1) as i64)?;
let extra = usize::try_from(extra).ok()?;
per_batch_min.checked_add(extra)
}
let a_per = min_elements_row_major(args.a.rows, args.a.cols, args.a.ld)
.ok_or(Error::InvalidProblem("A storage size overflow"))?;
let need_a = need_for_batches(a_per, args.stride_a, desc.batch_count)
.ok_or(Error::InvalidProblem("A batched storage size overflow"))?;
if args.a.data.len() < need_a {
return Err(Error::BufferTooSmall {
needed: need_a,
got: args.a.data.len(),
});
}
let b_per = match desc.layout {
LayoutSku::Rcr => min_elements_col_major(args.b.rows, args.b.cols, args.b.ld),
LayoutSku::Rrr => min_elements_row_major(args.b.rows, args.b.cols, args.b.ld),
}
.ok_or(Error::InvalidProblem("B storage size overflow"))?;
let need_b = need_for_batches(b_per, args.stride_b, desc.batch_count)
.ok_or(Error::InvalidProblem("B batched storage size overflow"))?;
if args.b.data.len() < need_b {
return Err(Error::BufferTooSmall {
needed: need_b,
got: args.b.data.len(),
});
}
let d_per = min_elements_row_major(args.d.rows, args.d.cols, args.d.ld)
.ok_or(Error::InvalidProblem("D storage size overflow"))?;
let need_d = need_for_batches(d_per, args.stride_d, desc.batch_count)
.ok_or(Error::InvalidProblem("D batched storage size overflow"))?;
if args.d.data.len() < need_d {
return Err(Error::BufferTooSmall {
needed: need_d,
got: args.d.data.len(),
});
}
if let Some(c) = &args.c {
let c_per = min_elements_row_major(c.rows, c.cols, c.ld)
.ok_or(Error::InvalidProblem("C storage size overflow"))?;
let need_c = need_for_batches(c_per, args.stride_c, desc.batch_count)
.ok_or(Error::InvalidProblem("C batched storage size overflow"))?;
if c.data.len() < need_c {
return Err(Error::BufferTooSmall {
needed: need_c,
got: c.data.len(),
});
}
}
Ok(())
}
#[derive(Debug)]
pub struct BatchedGemmPlan<T: CutlassElement> {
desc: BatchedGemmDescriptor,
sku: GemmSku,
_element: PhantomData<T>,
}
impl<T: CutlassElement> BatchedGemmPlan<T> {
pub fn select(
stream: &Stream,
desc: &BatchedGemmDescriptor,
pref: PlanPreference,
) -> Result<Self> {
check_batched_descriptor(desc)?;
let one_off_desc = GemmDescriptor {
m: desc.m,
n: desc.n,
k: desc.k,
layout: desc.layout,
epilogue: desc.epilogue,
};
let arch = pick_arch(stream, &one_off_desc, pref)?;
match (desc.layout, T::KIND) {
(LayoutSku::Rcr, ElementKind::F16) | (LayoutSku::Rcr, ElementKind::Bf16) => {}
_ => {
return Err(Error::Unsupported(
"BatchedGemmPlan v1 only ships Rcr × {F16, Bf16} on sm_80",
));
}
}
let sku = GemmSku {
arch,
layout: desc.layout,
epilogue: desc.epilogue,
element: T::KIND,
bias_element: None,
};
Ok(Self {
desc: *desc,
sku,
_element: PhantomData,
})
}
pub fn can_implement(&self, args: &BatchedGemmArgs<'_, T>) -> Result<()> {
check_batched_args(&self.desc, args)?;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc, stride_c) = match &args.c {
Some(c) => (c.data.as_raw().0 as *const c_void, c.ld, args.stride_c),
None => (core::ptr::null(), 0i64, 0i64),
};
let status = match self.sku.arch {
#[cfg(feature = "sm80")]
ArchSku::Sm80 => unsafe {
dispatch::batched_gemm_sm80_can_implement(
self.sku.layout,
T::KIND,
self.desc.m,
self.desc.n,
self.desc.k,
a_ptr,
args.a.ld,
args.stride_a,
b_ptr,
args.b.ld,
args.stride_b,
c_ptr,
ldc,
stride_c,
d_ptr,
args.d.ld,
args.stride_d,
self.desc.batch_count,
)
},
#[cfg(not(feature = "sm80"))]
ArchSku::Sm80 => {
return Err(Error::Unsupported(
"sm80 selected but the `sm80` feature isn't enabled",
));
}
ArchSku::Sm90a => {
return Err(Error::Unsupported(
"sm90a batched kernels not yet shipped",
));
}
ArchSku::Sm89 => {
return Err(Error::Unsupported(
"Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
));
}
};
status_to_result(status)
}
pub fn workspace_size(&self) -> usize {
match self.sku.arch {
#[cfg(feature = "sm80")]
ArchSku::Sm80 => dispatch::batched_gemm_sm80_workspace_size(
self.sku.layout,
T::KIND,
self.desc.m,
self.desc.n,
self.desc.k,
self.desc.batch_count,
),
#[cfg(not(feature = "sm80"))]
ArchSku::Sm80 => 0,
ArchSku::Sm90a => 0,
ArchSku::Sm89 => 0,
}
}
pub fn sku(&self) -> GemmSku {
self.sku
}
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee()
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: BatchedGemmArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
let needed = self.workspace_size();
let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
Workspace::None => {
if needed != 0 {
return Err(Error::WorkspaceTooSmall { needed, got: 0 });
}
(core::ptr::null_mut(), 0)
}
Workspace::Borrowed(slice) => {
if slice.len() < needed {
return Err(Error::WorkspaceTooSmall {
needed,
got: slice.len(),
});
}
(slice.as_raw().0 as *mut c_void, slice.len())
}
};
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc, stride_c) = match &args.c {
Some(c) => (c.data.as_raw().0 as *const c_void, c.ld, args.stride_c),
None => (core::ptr::null(), 0i64, 0i64),
};
let beta_eff = if args.c.is_some() { args.beta } else { <T::Scalar as Default>::default() };
let stream_raw = stream.as_raw();
let status = match self.sku.arch {
#[cfg(feature = "sm80")]
ArchSku::Sm80 => unsafe {
dispatch::batched_gemm_sm80_run(
self.sku.layout,
T::KIND,
self.desc.m,
self.desc.n,
self.desc.k,
a_ptr,
args.a.ld,
args.stride_a,
b_ptr,
args.b.ld,
args.stride_b,
c_ptr,
ldc,
stride_c,
d_ptr,
args.d.ld,
args.stride_d,
args.alpha.to_f32(),
beta_eff.to_f32(),
self.desc.batch_count,
ws_ptr,
ws_bytes,
stream_raw,
)
},
#[cfg(not(feature = "sm80"))]
ArchSku::Sm80 => {
return Err(Error::Unsupported(
"sm80 selected but the `sm80` feature isn't enabled",
));
}
ArchSku::Sm90a => {
return Err(Error::Unsupported(
"sm90a batched kernels not yet shipped",
));
}
ArchSku::Sm89 => {
return Err(Error::Unsupported(
"Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
));
}
};
status_to_result(status)
}
}
fn pick_arch(
stream: &Stream,
_desc: &GemmDescriptor,
pref: PlanPreference,
) -> Result<ArchSku> {
let (major, _minor) = stream.context().device().compute_capability()?;
if pref.allow_sm90a && cfg!(feature = "sm90a") && major >= 9 {
return Ok(ArchSku::Sm90a);
}
if cfg!(feature = "sm80") {
if major >= 8 {
return Ok(ArchSku::Sm80);
}
return Err(Error::Unsupported(
"device compute capability < 8.0; sm_80 kernels won't run here",
));
}
Err(Error::Unsupported(
"no arch features enabled — build with --features sm80",
))
}
const COORD_BYTES: usize = 12; const PTR_BYTES: usize = 8; const LD_BYTES: usize = 8; const SCRATCH_ALIGN: usize = 256;
#[inline]
fn align_up(x: usize, align: usize) -> usize {
(x + align - 1) & !(align - 1)
}
#[derive(Copy, Clone, Debug)]
struct MetadataLayout {
problem_sizes_offset: usize,
ptr_a_offset: usize,
ptr_b_offset: usize,
ptr_c_offset: usize,
ptr_d_offset: usize,
lda_offset: usize,
ldb_offset: usize,
ldc_offset: usize,
ldd_offset: usize,
metadata_end: usize,
scratch_offset: usize,
total_workspace_bytes: usize,
}
impl MetadataLayout {
fn compute(group_count: usize, scratch_bytes: usize) -> Self {
let mut off = 0usize;
let problem_sizes_offset = off;
off += COORD_BYTES * group_count;
off = align_up(off, 8);
let ptr_a_offset = off;
off += PTR_BYTES * group_count;
let ptr_b_offset = off;
off += PTR_BYTES * group_count;
let ptr_c_offset = off;
off += PTR_BYTES * group_count;
let ptr_d_offset = off;
off += PTR_BYTES * group_count;
let lda_offset = off;
off += LD_BYTES * group_count;
let ldb_offset = off;
off += LD_BYTES * group_count;
let ldc_offset = off;
off += LD_BYTES * group_count;
let ldd_offset = off;
off += LD_BYTES * group_count;
let metadata_end = off;
let scratch_offset = align_up(metadata_end, SCRATCH_ALIGN);
let total_workspace_bytes = scratch_offset + scratch_bytes;
Self {
problem_sizes_offset,
ptr_a_offset,
ptr_b_offset,
ptr_c_offset,
ptr_d_offset,
lda_offset,
ldb_offset,
ldc_offset,
ldd_offset,
metadata_end,
scratch_offset,
total_workspace_bytes,
}
}
}
#[derive(Debug)]
pub struct GroupedGemmPlan<T: CutlassElement> {
sku: GemmSku,
schedule: GroupedScheduleMode,
context: Context,
_element: PhantomData<T>,
}
impl<T: CutlassElement> GroupedGemmPlan<T> {
pub fn select(
stream: &Stream,
epilogue: EpilogueKind,
pref: GroupedPlanPreference,
) -> Result<Self> {
if epilogue != EpilogueKind::Identity {
return Err(Error::Unsupported(
"v0 grouped GEMM supports only EpilogueKind::Identity",
));
}
let dummy_desc = GemmDescriptor {
m: 1,
n: 1,
k: 1,
layout: LayoutSku::Rcr,
epilogue,
};
let arch = pick_arch(stream, &dummy_desc, pref.base)?;
let sku = GemmSku {
arch,
layout: LayoutSku::Rcr,
epilogue,
element: T::KIND,
bias_element: None,
};
Ok(Self {
sku,
schedule: pref.schedule,
context: stream.context().clone(),
_element: PhantomData,
})
}
pub fn sku(&self) -> GemmSku {
self.sku
}
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee()
}
pub fn schedule(&self) -> GroupedScheduleMode {
self.schedule
}
pub fn prepare<'a, 'g>(
&'a self,
groups: &'g [GroupedProblem<'g, T>],
) -> Result<PreparedGroupedGemm<'a, T>> {
if groups.is_empty() {
return Err(Error::InvalidProblem("grouped GEMM requires at least one group"));
}
let first_alpha = groups[0].alpha;
let first_beta = groups[0].beta;
let first_has_c = groups[0].c.is_some();
for g in groups {
if g.m <= 0 || g.n <= 0 || g.k <= 0 {
return Err(Error::InvalidProblem("group M, N, K must all be positive"));
}
if g.a.rows != g.m || g.a.cols != g.k {
return Err(Error::InvalidProblem("group A shape doesn't match (M, K)"));
}
if g.b.rows != g.k || g.b.cols != g.n {
return Err(Error::InvalidProblem("group B shape doesn't match (K, N)"));
}
if g.d.rows != g.m || g.d.cols != g.n {
return Err(Error::InvalidProblem("group D shape doesn't match (M, N)"));
}
if let Some(c) = &g.c {
if c.rows != g.m || c.cols != g.n {
return Err(Error::InvalidProblem("group C shape doesn't match (M, N)"));
}
}
if g.a.ld < g.k as i64 || g.b.ld < g.k as i64 || g.d.ld < g.n as i64 {
return Err(Error::InvalidProblem("group leading dimension too small"));
}
if g.alpha != first_alpha {
return Err(Error::Unsupported(
"v0 grouped GEMM requires all groups to share alpha",
));
}
if g.beta != first_beta {
return Err(Error::Unsupported(
"v0 grouped GEMM requires all groups to share beta",
));
}
if g.c.is_some() != first_has_c {
return Err(Error::Unsupported(
"v0 grouped GEMM requires all groups to consistently have c=None or c=Some",
));
}
}
let group_count = groups.len();
let mut h_m: Vec<i32> = Vec::with_capacity(group_count);
let mut h_n: Vec<i32> = Vec::with_capacity(group_count);
let mut h_k: Vec<i32> = Vec::with_capacity(group_count);
for g in groups {
h_m.push(g.m);
h_n.push(g.n);
h_k.push(g.k);
}
let kind = T::KIND;
let group_count_i32 = group_count as i32;
let ci_status = match self.sku.arch {
#[cfg(feature = "sm80")]
ArchSku::Sm80 => unsafe {
dispatch::grouped_gemm_rcr_sm80_can_implement(
kind,
h_m.as_ptr(),
h_n.as_ptr(),
h_k.as_ptr(),
group_count_i32,
)
},
#[cfg(not(feature = "sm80"))]
ArchSku::Sm80 => {
return Err(Error::Unsupported(
"sm80 selected but the `sm80` feature isn't enabled",
));
}
ArchSku::Sm90a => {
return Err(Error::Unsupported(
"sm90a grouped kernels not yet shipped (deferred until Hopper hardware available)",
));
}
ArchSku::Sm89 => {
return Err(Error::Unsupported(
"Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
));
}
};
status_to_result(ci_status)?;
let threadblock_count = match self.sku.arch {
#[cfg(feature = "sm80")]
ArchSku::Sm80 => unsafe {
dispatch::grouped_gemm_rcr_sm80_sufficient(
kind,
h_m.as_ptr(),
h_n.as_ptr(),
h_k.as_ptr(),
group_count_i32,
)
},
#[cfg(not(feature = "sm80"))]
ArchSku::Sm80 => 0,
ArchSku::Sm90a => 0,
ArchSku::Sm89 => 0,
};
if threadblock_count <= 0 {
return Err(Error::CutlassInternal(threadblock_count));
}
let scratch_bytes = match self.sku.arch {
#[cfg(feature = "sm80")]
ArchSku::Sm80 => unsafe {
dispatch::grouped_gemm_rcr_sm80_scratch_bytes(
kind,
h_m.as_ptr(),
h_n.as_ptr(),
h_k.as_ptr(),
group_count_i32,
threadblock_count,
)
},
#[cfg(not(feature = "sm80"))]
ArchSku::Sm80 => 0,
ArchSku::Sm90a => 0,
ArchSku::Sm89 => 0,
};
let layout = MetadataLayout::compute(group_count, scratch_bytes);
let mut pinned: PinnedBuffer<u8> = PinnedBuffer::new(&self.context, layout.metadata_end)?;
let ptr_a: Vec<u64> = groups.iter().map(|g| g.a.data.as_raw().0).collect();
let ptr_b: Vec<u64> = groups.iter().map(|g| g.b.data.as_raw().0).collect();
let ptr_d: Vec<u64> = groups.iter().map(|g| g.d.data.as_raw().0).collect();
let ptr_c: Vec<u64> = groups
.iter()
.map(|g| {
g.c.as_ref()
.map(|c| c.data.as_raw().0)
.unwrap_or_else(|| g.d.data.as_raw().0)
})
.collect();
let lda: Vec<i64> = groups.iter().map(|g| g.a.ld).collect();
let ldb: Vec<i64> = groups.iter().map(|g| g.b.ld).collect();
let ldd: Vec<i64> = groups.iter().map(|g| g.d.ld).collect();
let ldc: Vec<i64> = groups
.iter()
.map(|g| g.c.as_ref().map(|c| c.ld).unwrap_or(g.d.ld))
.collect();
{
let host_packed: &mut [u8] = &mut pinned;
let mut p = layout.problem_sizes_offset;
for g in groups {
host_packed[p..p + 4].copy_from_slice(&g.m.to_ne_bytes());
host_packed[p + 4..p + 8].copy_from_slice(&g.n.to_ne_bytes());
host_packed[p + 8..p + 12].copy_from_slice(&g.k.to_ne_bytes());
p += COORD_BYTES;
}
let pack_ptrs = |dst: &mut [u8], offset: usize, ptrs: &[u64]| {
let mut p = offset;
for &val in ptrs {
dst[p..p + 8].copy_from_slice(&val.to_ne_bytes());
p += PTR_BYTES;
}
};
pack_ptrs(host_packed, layout.ptr_a_offset, &ptr_a);
pack_ptrs(host_packed, layout.ptr_b_offset, &ptr_b);
pack_ptrs(host_packed, layout.ptr_c_offset, &ptr_c);
pack_ptrs(host_packed, layout.ptr_d_offset, &ptr_d);
let pack_lds = |dst: &mut [u8], offset: usize, lds: &[i64]| {
let mut p = offset;
for &val in lds {
dst[p..p + 8].copy_from_slice(&val.to_ne_bytes());
p += LD_BYTES;
}
};
pack_lds(host_packed, layout.lda_offset, &lda);
pack_lds(host_packed, layout.ldb_offset, &ldb);
pack_lds(host_packed, layout.ldc_offset, &ldc);
pack_lds(host_packed, layout.ldd_offset, &ldd);
}
let mut host_problem_sizes: Vec<i32> = Vec::with_capacity(group_count * 3);
for g in groups {
host_problem_sizes.push(g.m);
host_problem_sizes.push(g.n);
host_problem_sizes.push(g.k);
}
let beta_eff = if first_has_c { first_beta } else { <T::Scalar as Default>::default() };
Ok(PreparedGroupedGemm {
plan: self,
pinned,
host_problem_sizes,
layout,
threadblock_count,
alpha: first_alpha.to_f32(),
beta: beta_eff.to_f32(),
_element: PhantomData,
})
}
}
#[derive(Debug)]
pub struct PreparedGroupedGemm<'a, T: CutlassElement> {
plan: &'a GroupedGemmPlan<T>,
pinned: PinnedBuffer<u8>,
host_problem_sizes: Vec<i32>,
layout: MetadataLayout,
threadblock_count: i32,
alpha: f32,
beta: f32,
_element: PhantomData<T>,
}
impl<'a, T: CutlassElement> PreparedGroupedGemm<'a, T> {
pub fn workspace_size(&self) -> usize {
self.layout.total_workspace_bytes
}
pub fn sku(&self) -> GemmSku {
self.plan.sku
}
pub fn group_count(&self) -> usize {
self.host_problem_sizes.len() / 3
}
pub fn run(&self, stream: &Stream, workspace: Workspace<'_>) -> Result<()> {
let needed = self.workspace_size();
let workspace_slice = match workspace {
Workspace::None => {
return Err(Error::WorkspaceTooSmall { needed, got: 0 });
}
Workspace::Borrowed(slice) => {
if slice.len() < needed {
return Err(Error::WorkspaceTooSmall {
needed,
got: slice.len(),
});
}
slice
}
};
let workspace_base = workspace_slice.as_raw().0;
{
let mut workspace_for_meta = workspace_slice;
let metadata_dst = workspace_for_meta.slice_mut(0..self.layout.metadata_end);
metadata_dst.copy_from_host_async(&self.pinned, stream)?;
}
let off = |o: usize| (workspace_base + o as u64) as *const c_void;
let off_mut = |o: usize| (workspace_base + o as u64) as *mut c_void;
let d_problem_sizes = off(self.layout.problem_sizes_offset);
let d_ptr_a = off(self.layout.ptr_a_offset);
let d_ptr_b = off(self.layout.ptr_b_offset);
let d_ptr_c = off(self.layout.ptr_c_offset);
let d_ptr_d = off_mut(self.layout.ptr_d_offset);
let d_lda = off(self.layout.lda_offset);
let d_ldb = off(self.layout.ldb_offset);
let d_ldc = off(self.layout.ldc_offset);
let d_ldd = off(self.layout.ldd_offset);
let scratch_ptr = off_mut(self.layout.scratch_offset);
let scratch_bytes = self.layout.total_workspace_bytes - self.layout.scratch_offset;
let h_problem_sizes = self.host_problem_sizes.as_ptr() as *const c_void;
let stream_raw = stream.as_raw();
let group_count = self.group_count() as i32;
let status = match self.plan.sku.arch {
#[cfg(feature = "sm80")]
ArchSku::Sm80 => unsafe {
dispatch::grouped_gemm_rcr_sm80_run(
T::KIND,
group_count,
self.threadblock_count,
d_problem_sizes,
d_ptr_a,
d_ptr_b,
d_ptr_c,
d_ptr_d,
d_lda,
d_ldb,
d_ldc,
d_ldd,
h_problem_sizes,
self.alpha,
self.beta,
scratch_ptr,
scratch_bytes,
stream_raw,
)
},
#[cfg(not(feature = "sm80"))]
ArchSku::Sm80 => {
return Err(Error::Unsupported(
"sm80 selected but the `sm80` feature isn't enabled",
));
}
ArchSku::Sm90a => {
return Err(Error::Unsupported(
"sm90a grouped kernels not yet shipped",
));
}
ArchSku::Sm89 => {
return Err(Error::Unsupported(
"Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
));
}
};
status_to_result(status)
}
}
#[derive(Debug)]
pub struct IntGemmPlan<T: IntElement, BT: BiasElement = f32> {
desc: IntGemmDescriptor,
sku: GemmSku,
_element: PhantomData<T>,
_bias_element: PhantomData<BT>,
}
impl<T: IntElement, BT: BiasElement> IntGemmPlan<T, BT> {
pub fn select(
stream: &Stream,
desc: &IntGemmDescriptor,
pref: PlanPreference,
) -> Result<Self> {
check_int_descriptor(desc)?;
let arch = pick_int_arch(stream, pref)?;
if !matches!(desc.layout, LayoutSku::Rcr) {
return Err(Error::Unsupported(
"int8 GEMM kernels are RCR-only in this release \
(CUTLASS 4.2.0 lacks 8-bit `TensorOpMultiplicandCongruous` \
warp iterators for RRR / row-major-B layout)",
));
}
let bias_element = if desc.epilogue.requires_bias() {
Some(BT::KIND)
} else {
None
};
let sku = GemmSku {
arch,
layout: desc.layout,
epilogue: desc.epilogue,
element: T::KIND,
bias_element,
};
Ok(Self {
desc: *desc,
sku,
_element: PhantomData,
_bias_element: PhantomData,
})
}
pub fn can_implement(&self, args: &IntGemmArgs<'_, T, BT>) -> Result<()> {
check_int_args(&self.desc, args)?;
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc) = match &args.c {
Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
None => (core::ptr::null(), 0i64),
};
let bias_ptr = args
.bias
.as_ref()
.map(|b| b.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let bias_family = self.sku.epilogue.requires_bias();
let status = match (self.sku.arch, bias_family) {
#[cfg(feature = "sm80")]
(ArchSku::Sm80, false) => unsafe {
dispatch::int_gemm_rcr_sm80_can_implement(
self.sku.layout,
T::KIND,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
)
},
#[cfg(feature = "sm80")]
(ArchSku::Sm80, true) => unsafe {
dispatch::int_gemm_bias_rcr_sm80_can_implement(
self.sku.layout,
T::KIND,
self.sku.epilogue,
BT::KIND,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
bias_ptr,
)
},
#[cfg(not(feature = "sm80"))]
(ArchSku::Sm80, _) => {
return Err(Error::Unsupported(
"sm80 selected but the `sm80` feature isn't enabled",
));
}
(ArchSku::Sm90a, _) => {
return Err(Error::Unsupported(
"sm90a int8 kernels not yet shipped",
));
}
(ArchSku::Sm89, _) => {
return Err(Error::Unsupported(
"Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
));
}
};
status_to_result(status)
}
pub fn workspace_size(&self) -> usize {
let bias_family = self.sku.epilogue.requires_bias();
match (self.sku.arch, bias_family) {
#[cfg(feature = "sm80")]
(ArchSku::Sm80, false) => dispatch::int_gemm_rcr_sm80_workspace_size(
self.sku.layout,
T::KIND,
self.desc.m, self.desc.n, self.desc.k,
),
#[cfg(feature = "sm80")]
(ArchSku::Sm80, true) => dispatch::int_gemm_bias_rcr_sm80_workspace_size(
self.sku.layout,
T::KIND,
self.sku.epilogue,
BT::KIND,
self.desc.m, self.desc.n, self.desc.k,
),
#[cfg(not(feature = "sm80"))]
(ArchSku::Sm80, _) => 0,
(ArchSku::Sm90a, _) => 0,
(ArchSku::Sm89, _) => 0,
}
}
pub fn sku(&self) -> GemmSku {
self.sku
}
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee()
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: IntGemmArgs<'_, T, BT>,
) -> Result<()> {
self.can_implement(&args)?;
let needed = self.workspace_size();
let (ws_ptr, ws_bytes): (*mut c_void, usize) = match workspace {
Workspace::None => {
if needed != 0 {
return Err(Error::WorkspaceTooSmall { needed, got: 0 });
}
(core::ptr::null_mut(), 0)
}
Workspace::Borrowed(slice) => {
if slice.len() < needed {
return Err(Error::WorkspaceTooSmall {
needed,
got: slice.len(),
});
}
(slice.as_raw().0 as *mut c_void, slice.len())
}
};
let a_ptr = args.a.data.as_raw().0 as *const c_void;
let b_ptr = args.b.data.as_raw().0 as *const c_void;
let d_ptr = args.d.data.as_raw().0 as *mut c_void;
let (c_ptr, ldc) = match &args.c {
Some(c) => (c.data.as_raw().0 as *const c_void, c.ld),
None => (core::ptr::null(), 0i64),
};
let bias_ptr = args
.bias
.as_ref()
.map(|b| b.data.as_raw().0 as *const c_void)
.unwrap_or(core::ptr::null());
let beta_eff: f32 = if args.c.is_some() { args.beta } else { 0.0 };
let stream_raw = stream.as_raw();
let bias_family = self.sku.epilogue.requires_bias();
let status = match (self.sku.arch, bias_family) {
#[cfg(feature = "sm80")]
(ArchSku::Sm80, false) => unsafe {
dispatch::int_gemm_rcr_sm80_run(
self.sku.layout,
T::KIND,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
args.alpha,
beta_eff,
ws_ptr, ws_bytes, stream_raw,
)
},
#[cfg(feature = "sm80")]
(ArchSku::Sm80, true) => unsafe {
dispatch::int_gemm_bias_rcr_sm80_run(
self.sku.layout,
T::KIND,
self.sku.epilogue,
BT::KIND,
self.desc.m, self.desc.n, self.desc.k,
a_ptr, args.a.ld,
b_ptr, args.b.ld,
c_ptr, ldc,
d_ptr, args.d.ld,
bias_ptr,
args.alpha,
beta_eff,
ws_ptr, ws_bytes, stream_raw,
)
},
#[cfg(not(feature = "sm80"))]
(ArchSku::Sm80, _) => {
return Err(Error::Unsupported(
"sm80 selected but the `sm80` feature isn't enabled",
));
}
(ArchSku::Sm90a, _) => {
return Err(Error::Unsupported("sm90a int8 kernels not yet shipped"));
}
(ArchSku::Sm89, _) => {
return Err(Error::Unsupported(
"Ada-specialized FP8 / sm_89 SKUs live in baracuda-kernels-sys, not baracuda-cutlass",
));
}
};
status_to_result(status)
}
}
fn check_int_descriptor(desc: &IntGemmDescriptor) -> Result<()> {
if desc.m <= 0 || desc.n <= 0 || desc.k <= 0 {
return Err(Error::InvalidProblem("M, N, K must all be positive"));
}
Ok(())
}
fn check_int_args<T: IntElement, BT: BiasElement>(
desc: &IntGemmDescriptor,
args: &IntGemmArgs<'_, T, BT>,
) -> Result<()> {
match (desc.epilogue.requires_bias(), &args.bias) {
(false, Some(_)) => {
return Err(Error::InvalidProblem(
"args.bias must be None when descriptor.epilogue is Identity",
));
}
(true, None) => {
return Err(Error::InvalidProblem(
"args.bias is required when descriptor.epilogue is in the Bias family \
(Bias / BiasRelu / BiasGelu / BiasSilu)",
));
}
(false, None) | (true, Some(_)) => {}
}
if let Some(bias) = &args.bias {
if bias.len != desc.n {
return Err(Error::InvalidProblem("bias vector length must equal N"));
}
if bias.stride != 1 {
return Err(Error::Unsupported(
"bias vector must be contiguous (stride 1) — strided bias not supported",
));
}
if bias.data.len() < desc.n as usize {
return Err(Error::BufferTooSmall {
needed: desc.n as usize,
got: bias.data.len(),
});
}
}
if args.a.rows != desc.m || args.a.cols != desc.k {
return Err(Error::InvalidProblem("A shape doesn't match descriptor (M, K)"));
}
if args.b.rows != desc.k || args.b.cols != desc.n {
return Err(Error::InvalidProblem("B shape doesn't match descriptor (K, N)"));
}
if args.d.rows != desc.m || args.d.cols != desc.n {
return Err(Error::InvalidProblem("D shape doesn't match descriptor (M, N)"));
}
if let Some(c) = &args.c {
if c.rows != desc.m || c.cols != desc.n {
return Err(Error::InvalidProblem("C shape doesn't match descriptor (M, N)"));
}
}
if args.a.ld < desc.k as i64 {
return Err(Error::InvalidProblem("A leading dimension must be >= K"));
}
let b_min_ld = match desc.layout {
LayoutSku::Rcr => desc.k as i64,
LayoutSku::Rrr => desc.n as i64,
};
if args.b.ld < b_min_ld {
return Err(Error::InvalidProblem(match desc.layout {
LayoutSku::Rcr => "B leading dimension must be >= K (column-major Rcr layout)",
LayoutSku::Rrr => "B leading dimension must be >= N (row-major Rrr layout)",
}));
}
if args.d.ld < desc.n as i64 {
return Err(Error::InvalidProblem("D leading dimension must be >= N"));
}
if let Some(c) = &args.c {
if c.ld < desc.n as i64 {
return Err(Error::InvalidProblem("C leading dimension must be >= N"));
}
}
let need_a = min_elements_row_major(args.a.rows, args.a.cols, args.a.ld)
.ok_or(Error::InvalidProblem("A storage size overflow"))?;
if args.a.data.len() < need_a {
return Err(Error::BufferTooSmall {
needed: need_a,
got: args.a.data.len(),
});
}
let need_b = match desc.layout {
LayoutSku::Rcr => min_elements_col_major(args.b.rows, args.b.cols, args.b.ld),
LayoutSku::Rrr => min_elements_row_major(args.b.rows, args.b.cols, args.b.ld),
}
.ok_or(Error::InvalidProblem("B storage size overflow"))?;
if args.b.data.len() < need_b {
return Err(Error::BufferTooSmall {
needed: need_b,
got: args.b.data.len(),
});
}
let need_d = min_elements_row_major(args.d.rows, args.d.cols, args.d.ld)
.ok_or(Error::InvalidProblem("D storage size overflow"))?;
if args.d.data.len() < need_d {
return Err(Error::BufferTooSmall {
needed: need_d,
got: args.d.data.len(),
});
}
if let Some(c) = &args.c {
let need_c = min_elements_row_major(c.rows, c.cols, c.ld)
.ok_or(Error::InvalidProblem("C storage size overflow"))?;
if c.data.len() < need_c {
return Err(Error::BufferTooSmall {
needed: need_c,
got: c.data.len(),
});
}
}
Ok(())
}
fn pick_int_arch(stream: &Stream, pref: PlanPreference) -> Result<ArchSku> {
let (major, _minor) = stream.context().device().compute_capability()?;
if pref.allow_sm90a && cfg!(feature = "sm90a") && major >= 9 {
}
if cfg!(feature = "sm80") {
if major >= 8 {
return Ok(ArchSku::Sm80);
}
return Err(Error::Unsupported(
"device compute capability < 8.0; sm_80 int8 kernels won't run here",
));
}
Err(Error::Unsupported(
"no arch features enabled — build with --features sm80",
))
}
#[cfg(test)]
mod buffer_size_tests {
use super::{min_elements_rcr_a, min_elements_rcr_b, min_elements_rcr_cd};
#[test]
fn rcr_a_tight_layout() {
assert_eq!(min_elements_rcr_a(4, 8, 8), Some(32));
}
#[test]
fn rcr_a_padded_layout_accepts_smaller_count() {
assert_eq!(min_elements_rcr_a(4, 8, 16), Some(56));
}
#[test]
fn rcr_b_tight_layout() {
assert_eq!(min_elements_rcr_b(8, 4, 8), Some(32));
}
#[test]
fn rcr_b_padded_layout_accepts_smaller_count() {
assert_eq!(min_elements_rcr_b(8, 4, 16), Some(56));
}
#[test]
fn rcr_cd_tight_layout() {
assert_eq!(min_elements_rcr_cd(4, 8, 8), Some(32));
}
#[test]
fn rcr_cd_padded_layout_accepts_smaller_count() {
assert_eq!(min_elements_rcr_cd(4, 8, 16), Some(56));
}
#[test]
fn single_row_matrix_does_not_underflow() {
assert_eq!(min_elements_rcr_a(1, 8, 8), Some(8));
assert_eq!(min_elements_rcr_a(1, 8, 256), Some(8));
}
#[test]
fn overflow_returns_none() {
assert_eq!(min_elements_rcr_a(i32::MAX, 1, i64::MAX), None);
}
}