use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, QuantizeKind, S8, TensorMut, TensorRef, Workspace,
};
use super::map_status;
#[derive(Copy, Clone, Debug)]
pub struct QuantizedLinearDescriptor {
pub m: i32,
pub c_out: i32,
pub k: i32,
pub q_min: i32,
pub q_max: i32,
pub activation_element: ElementKind,
pub weight_element: ElementKind,
}
pub struct QuantizedLinearArgs<'a, TIn: Element, TWQ: IntElement> {
pub activation: TensorRef<'a, TIn, 2>,
pub weight_q: TensorRef<'a, TWQ, 2>,
pub weight_scale: TensorRef<'a, TIn, 1>,
pub output: TensorMut<'a, TIn, 2>,
pub act_q_scratch: TensorMut<'a, S8, 2>,
pub act_scale_scratch: TensorMut<'a, TIn, 1>,
}
pub struct QuantizedLinearPlan<TIn: Element, TWQ: IntElement> {
desc: QuantizedLinearDescriptor,
sku: KernelSku,
_marker: PhantomData<(TIn, TWQ)>,
}
impl<TIn: Element, TWQ: IntElement> QuantizedLinearPlan<TIn, TWQ> {
pub fn select(
_stream: &Stream,
desc: &QuantizedLinearDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.activation_element != TIn::KIND {
return Err(Error::Unsupported(
"QuantizedLinearPlan: descriptor activation_element != TIn",
));
}
if desc.weight_element != TWQ::KIND {
return Err(Error::Unsupported(
"QuantizedLinearPlan: descriptor weight_element != TWQ",
));
}
if !matches!(TIn::KIND, ElementKind::F32 | ElementKind::F64) {
return Err(Error::Unsupported(
"QuantizedLinearPlan: 8.3 trailblazer only wires f32 / f64 \
activation (f16 / bf16 deferred)",
));
}
if TWQ::KIND != ElementKind::S8 {
return Err(Error::Unsupported(
"QuantizedLinearPlan: 8.3 trailblazer only wires S8 weight \
(U8 deferred)",
));
}
if desc.m < 0 || desc.c_out < 0 || desc.k < 0 {
return Err(Error::InvalidProblem(
"QuantizedLinearPlan: m, c_out, k must be non-negative",
));
}
if desc.q_max <= 0 {
return Err(Error::InvalidProblem(
"QuantizedLinearPlan: q_max must be > 0",
));
}
if desc.q_max < desc.q_min {
return Err(Error::InvalidProblem(
"QuantizedLinearPlan: q_max < q_min",
));
}
if desc.m > 65535 {
return Err(Error::Unsupported(
"QuantizedLinearPlan: M > 65535 — the internal dynamic-range pass \
uses one block per row and would exceed the legacy grid limit \
(lift when row tiling lands)",
));
}
let sku = build_sku::<TIn, TWQ>(QuantizeKind::QuantizedLinear);
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &QuantizedLinearArgs<'_, TIn, TWQ>) -> Result<()> {
if args.activation.shape != [self.desc.m, self.desc.k] {
return Err(Error::InvalidProblem(
"QuantizedLinearPlan: activation shape != [M, K]",
));
}
if args.weight_q.shape != [self.desc.c_out, self.desc.k] {
return Err(Error::InvalidProblem(
"QuantizedLinearPlan: weight_q shape != [C_out, K]",
));
}
if args.weight_scale.shape != [self.desc.c_out] {
return Err(Error::InvalidProblem(
"QuantizedLinearPlan: weight_scale shape != [C_out]",
));
}
if args.output.shape != [self.desc.m, self.desc.c_out] {
return Err(Error::InvalidProblem(
"QuantizedLinearPlan: output shape != [M, C_out]",
));
}
if args.act_q_scratch.shape != [self.desc.m, self.desc.k] {
return Err(Error::InvalidProblem(
"QuantizedLinearPlan: act_q_scratch shape != [M, K]",
));
}
if args.act_scale_scratch.shape != [self.desc.m] {
return Err(Error::InvalidProblem(
"QuantizedLinearPlan: act_scale_scratch shape != [M]",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: QuantizedLinearArgs<'_, TIn, TWQ>,
) -> Result<()> {
self.can_implement(&args)?;
if (self.desc.m as i64) * (self.desc.c_out as i64) == 0
|| self.desc.k == 0
{
return Ok(());
}
let stream_ptr = stream.as_raw() as *mut c_void;
let act_ptr = args.activation.data.as_raw().0 as *const c_void;
let act_scale_ptr = args.act_scale_scratch.data.as_raw().0 as *mut c_void;
let act_q_ptr = args.act_q_scratch.data.as_raw().0 as *mut c_void;
let drq_status = match TIn::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_dynamic_range_quantize_per_token_sym_f32_s8_run(
self.desc.m,
self.desc.k,
self.desc.q_min,
self.desc.q_max,
act_ptr, act_scale_ptr, act_q_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
ElementKind::F64 => unsafe {
baracuda_kernels_sys::baracuda_kernels_dynamic_range_quantize_per_token_sym_f64_s8_run(
self.desc.m,
self.desc.k,
self.desc.q_min,
self.desc.q_max,
act_ptr, act_scale_ptr, act_q_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"QuantizedLinearPlan::run reached unsupported TIn at \
activation-quantize pass (select should have caught)",
))
}
};
map_status(drq_status)?;
let weight_ptr = args.weight_q.data.as_raw().0 as *const c_void;
let act_q_const = args.act_q_scratch.data.as_raw().0 as *const c_void;
let act_scale_const = args.act_scale_scratch.data.as_raw().0 as *const c_void;
let w_scale_ptr = args.weight_scale.data.as_raw().0 as *const c_void;
let out_ptr = args.output.data.as_raw().0 as *mut c_void;
let ql_status = match TIn::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f32_run(
self.desc.m,
self.desc.c_out,
self.desc.k,
weight_ptr, act_q_const,
act_scale_const, w_scale_ptr,
out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
ElementKind::F64 => unsafe {
baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f64_run(
self.desc.m,
self.desc.c_out,
self.desc.k,
weight_ptr, act_q_const,
act_scale_const, w_scale_ptr,
out_ptr,
core::ptr::null_mut(), 0, stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"QuantizedLinearPlan::run reached unsupported TIn at \
quantized-linear pass (select should have caught)",
))
}
};
map_status(ql_status)
}
}
fn build_sku<TIn: Element, TWQ: IntElement>(op: QuantizeKind) -> KernelSku {
let precision_guarantee = PrecisionGuarantee {
math_precision: if TIn::KIND == ElementKind::F64 {
MathPrecision::F64
} else {
MathPrecision::F32
},
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
KernelSku {
category: OpCategory::Quantization,
op: op as u16,
element: TIn::KIND,
aux_element: Some(TWQ::KIND),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
}
}