use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut, TensorRef, Workspace,
};
use half::f16;
use crate::quantize::map_status;
pub trait MarlinActivation: Element + sealed::Sealed {}
mod sealed {
pub trait Sealed {}
impl Sealed for half::f16 {}
}
impl MarlinActivation for f16 {}
#[derive(Copy, Clone, Debug)]
#[non_exhaustive]
pub struct Int4MarlinGemmDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub group_size: i32,
pub max_par: i32,
}
impl Int4MarlinGemmDescriptor {
pub fn new(m: i32, n: i32, k: i32) -> Self {
Self {
m,
n,
k,
group_size: 128,
max_par: 16,
}
}
#[must_use]
pub fn with_group_size(mut self, g: i32) -> Self {
self.group_size = g;
self
}
#[must_use]
pub fn with_max_par(mut self, mp: i32) -> Self {
self.max_par = mp;
self
}
}
pub struct Int4MarlinGemmArgs<'a, T: MarlinActivation> {
pub activation: TensorRef<'a, T, 2>,
pub weight_packed: TensorRef<'a, i32, 1>,
pub scales: TensorRef<'a, T, 2>,
pub workspace: TensorMut<'a, i32, 1>,
pub output: TensorMut<'a, T, 2>,
}
pub struct Int4MarlinGemmPlan<T: MarlinActivation> {
desc: Int4MarlinGemmDescriptor,
sku: KernelSku,
_phantom: PhantomData<T>,
}
impl<T: MarlinActivation> Int4MarlinGemmPlan<T> {
pub fn select(
_stream: &Stream,
desc: &Int4MarlinGemmDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.m < 0 || desc.n <= 0 || desc.k <= 0 {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: M, N, K must be non-negative (M) / positive (N, K)",
));
}
if desc.group_size != -1 && desc.group_size != 128 {
return Err(Error::Unsupported(
"Int4MarlinGemmPlan: group_size must be 128 or -1 (per-channel)",
));
}
if desc.k % 128 != 0 {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: K must be divisible by 128 (kernel K-block bound)",
));
}
if desc.n % 256 != 0 {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: N must be divisible by 256 (kernel N-tile bound)",
));
}
if desc.max_par <= 0 {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: max_par must be positive",
));
}
if !matches!(T::KIND, ElementKind::F16) {
return Err(Error::Unsupported(
"Int4MarlinGemmPlan: activation dtype must be f16 (upstream is fp16-only)",
));
}
Ok(Self {
desc: *desc,
sku: build_sku(T::KIND),
_phantom: PhantomData,
})
}
pub fn can_implement(&self, args: &Int4MarlinGemmArgs<'_, T>) -> Result<()> {
let m = self.desc.m;
let n = self.desc.n;
let k = self.desc.k;
let g = self.desc.group_size;
if args.activation.shape != [m, k] {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: activation shape != [M, K]",
));
}
if args.activation.stride[1] != 1 {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: activation must be contig along K",
));
}
let expected_packed_i32 = (k as i64) * (n as i64) / 8;
if (args.weight_packed.shape[0] as i64) != expected_packed_i32 {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: weight_packed length != K * N / 8 (i32 count)",
));
}
let scale_rows = if g == -1 { 1 } else { k / g };
if args.scales.shape != [scale_rows, n] {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: scales shape != [K/group_size, N] (or [1, N] for per-channel)",
));
}
let need = (n / 128) as i64 * self.desc.max_par as i64;
if (args.workspace.shape[0] as i64) < need {
return Err(Error::WorkspaceTooSmall {
needed: (need as usize) * core::mem::size_of::<i32>(),
got: (args.workspace.shape[0] as usize) * core::mem::size_of::<i32>(),
});
}
if args.output.shape != [m, n] {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: output shape != [M, N]",
));
}
if args.output.stride[1] != 1 {
return Err(Error::InvalidProblem(
"Int4MarlinGemmPlan: output must be contig along N",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
let n_tiles = (self.desc.n / 128).max(0) as usize;
n_tiles * (self.desc.max_par.max(0) as usize) * core::mem::size_of::<i32>()
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: Int4MarlinGemmArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.m == 0 {
return Ok(());
}
let a_ptr = args.activation.data.as_raw().0 as *const c_void;
let b_ptr = args.weight_packed.data.as_raw().0 as *const c_void;
let s_ptr = args.scales.data.as_raw().0 as *const c_void;
let ws_ptr = args.workspace.data.as_raw().0 as *mut c_void;
let c_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
dispatch_marlin::<T>(
self.desc.m,
self.desc.n,
self.desc.k,
a_ptr,
b_ptr,
c_ptr,
s_ptr,
ws_ptr,
self.desc.group_size,
self.desc.max_par,
stream_ptr,
)
};
map_status(status)
}
}
fn build_sku(act_kind: ElementKind) -> KernelSku {
KernelSku {
category: OpCategory::Quantization,
op: QuantizeKind::GgufMmvq as u16,
element: act_kind,
aux_element: Some(ElementKind::U8),
layout: None,
epilogue: None,
arch: ArchSku::Sm89, backend: BackendKind::Bespoke,
precision_guarantee: PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
},
}
}
#[cfg(feature = "marlin")]
#[inline]
unsafe fn dispatch_marlin<T: MarlinActivation>(
m: i32,
n: i32,
k: i32,
a: *const c_void,
b: *const c_void,
c: *mut c_void,
scales: *const c_void,
workspace: *mut c_void,
group_size: i32,
max_par: i32,
stream: *mut c_void,
) -> i32 {
match T::KIND {
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_int4_marlin_gemm_f16_run(
m, n, k, a, b, c, scales, workspace, group_size, max_par, stream,
)
},
_ => 3,
}
}
#[cfg(not(feature = "marlin"))]
#[inline]
unsafe fn dispatch_marlin<T: MarlinActivation>(
_: i32, _: i32, _: i32,
_: *const c_void, _: *const c_void, _: *mut c_void, _: *const c_void,
_: *mut c_void, _: i32, _: i32, _: *mut c_void,
) -> i32 {
3
}