use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, QuantizeKind, TensorMut, TensorRef, Workspace,
};
use half::f16;
use crate::quantize::map_status;
pub trait AwqActivation: Element + sealed::Sealed {}
mod sealed {
pub trait Sealed {}
impl Sealed for half::f16 {}
}
impl AwqActivation for f16 {}
#[derive(Copy, Clone, Debug)]
#[non_exhaustive]
pub struct Int4AwqGemmDescriptor {
pub m: i32,
pub ic: i32,
pub oc: i32,
pub group_size: i32,
pub split_k_iters: i32,
}
impl Int4AwqGemmDescriptor {
pub fn new(m: i32, ic: i32, oc: i32) -> Self {
Self {
m,
ic,
oc,
group_size: 128,
split_k_iters: 8,
}
}
#[must_use]
pub fn with_group_size(mut self, g: i32) -> Self {
self.group_size = g;
self
}
#[must_use]
pub fn with_split_k_iters(mut self, s: i32) -> Self {
self.split_k_iters = s;
self
}
}
pub struct Int4AwqGemmArgs<'a, T: AwqActivation> {
pub activation: TensorRef<'a, T, 2>,
pub weight_packed: TensorRef<'a, i32, 2>,
pub scales: TensorRef<'a, T, 2>,
pub zeros: TensorRef<'a, i32, 2>,
pub output: TensorMut<'a, T, 2>,
}
pub struct Int4AwqGemmPlan<T: AwqActivation> {
desc: Int4AwqGemmDescriptor,
sku: KernelSku,
_phantom: PhantomData<T>,
}
impl<T: AwqActivation> Int4AwqGemmPlan<T> {
pub fn select(
_stream: &Stream,
desc: &Int4AwqGemmDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.m < 0 || desc.ic <= 0 || desc.oc <= 0 {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: M must be non-negative; IC and OC must be positive",
));
}
if desc.group_size != 64 && desc.group_size != 128 {
return Err(Error::Unsupported(
"Int4AwqGemmPlan: group_size must be 64 or 128",
));
}
if desc.split_k_iters <= 0 {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: split_k_iters must be positive",
));
}
if desc.oc % 64 != 0 {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: OC must be divisible by 64 (kernel cta_N tile)",
));
}
if desc.ic % desc.group_size != 0 {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: IC must be divisible by group_size",
));
}
if desc.ic % (32 * desc.split_k_iters) != 0 {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: IC must be divisible by 32 * split_k_iters",
));
}
if !matches!(T::KIND, ElementKind::F16) {
return Err(Error::Unsupported(
"Int4AwqGemmPlan: activation dtype must be f16 (upstream is fp16-only)",
));
}
Ok(Self {
desc: *desc,
sku: build_sku(T::KIND),
_phantom: PhantomData,
})
}
pub fn can_implement(&self, args: &Int4AwqGemmArgs<'_, T>) -> Result<()> {
let m = self.desc.m;
let ic = self.desc.ic;
let oc = self.desc.oc;
let g = self.desc.group_size;
if args.activation.shape != [m, ic] {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: activation shape != [M, IC]",
));
}
if args.activation.stride[1] != 1 {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: activation must be contig along IC",
));
}
if args.weight_packed.shape != [oc, ic / 8] {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: weight_packed shape != [OC, IC/8] (i32 storage)",
));
}
if args.scales.shape != [ic / g, oc] {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: scales shape != [IC/group_size, OC]",
));
}
if args.zeros.shape != [ic / g, oc / 8] {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: zeros shape != [IC/group_size, OC/8] (i32 storage)",
));
}
if args.output.shape != [m, oc] {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: output shape != [M, OC]",
));
}
if args.output.stride[1] != 1 {
return Err(Error::InvalidProblem(
"Int4AwqGemmPlan: output must be contig along OC",
));
}
Ok(())
}
pub fn workspace_size(&self) -> usize {
if self.desc.m <= 0 || self.desc.oc <= 0 || self.desc.split_k_iters <= 0 {
return 0;
}
let padded_m = ((self.desc.m as i64 + 127) / 128) * 128;
(self.desc.split_k_iters as usize) * (padded_m as usize) * (self.desc.oc as usize) * 2
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
workspace: Workspace<'_>,
args: Int4AwqGemmArgs<'_, T>,
) -> Result<()> {
self.can_implement(&args)?;
if self.desc.m == 0 {
return Ok(());
}
let need = self.workspace_size();
let (ws_ptr, ws_bytes) = match workspace {
Workspace::Borrowed(buf) => (
buf.as_raw().0 as *mut c_void,
buf.len(),
),
Workspace::None => (core::ptr::null_mut(), 0usize),
};
if ws_bytes < need {
return Err(Error::WorkspaceTooSmall {
needed: need,
got: ws_bytes,
});
}
let a_ptr = args.activation.data.as_raw().0 as *const c_void;
let w_ptr = args.weight_packed.data.as_raw().0 as *const c_void;
let s_ptr = args.scales.data.as_raw().0 as *const c_void;
let z_ptr = args.zeros.data.as_raw().0 as *const c_void;
let o_ptr = args.output.data.as_raw().0 as *mut c_void;
let stream_ptr = stream.as_raw() as *mut c_void;
let status = unsafe {
dispatch_awq::<T>(
self.desc.m,
self.desc.ic,
self.desc.oc,
self.desc.group_size,
self.desc.split_k_iters,
a_ptr,
w_ptr,
s_ptr,
z_ptr,
o_ptr,
ws_ptr,
ws_bytes,
stream_ptr,
)
};
map_status(status)
}
}
fn build_sku(act_kind: ElementKind) -> KernelSku {
KernelSku {
category: OpCategory::Quantization,
op: QuantizeKind::GgufMmvq as u16,
element: act_kind,
aux_element: Some(ElementKind::U8),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee: PrecisionGuarantee {
math_precision: MathPrecision::F32,
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
},
}
}
#[cfg(feature = "awq")]
#[inline]
unsafe fn dispatch_awq<T: AwqActivation>(
m: i32,
ic: i32,
oc: i32,
group_size: i32,
split_k_iters: i32,
a: *const c_void,
w: *const c_void,
s: *const c_void,
z: *const c_void,
out: *mut c_void,
workspace: *mut c_void,
workspace_bytes: usize,
stream: *mut c_void,
) -> i32 {
match T::KIND {
ElementKind::F16 => unsafe {
baracuda_kernels_sys::baracuda_kernels_int4_awq_gemm_f16_run(
m, ic, oc, group_size, split_k_iters,
a, w, s, z, out, workspace, workspace_bytes, stream,
)
},
_ => 3,
}
}
#[cfg(not(feature = "awq"))]
#[inline]
unsafe fn dispatch_awq<T: AwqActivation>(
_: i32, _: i32, _: i32, _: i32, _: i32,
_: *const c_void, _: *const c_void, _: *const c_void, _: *const c_void,
_: *mut c_void, _: *mut c_void, _: usize, _: *mut c_void,
) -> i32 {
3
}