use core::ffi::c_void;
use core::marker::PhantomData;
use baracuda_cutlass::{Error, Result};
use baracuda_driver::Stream;
use baracuda_kernels_types::{
ArchSku, BackendKind, Element, ElementKind, IntElement, KernelSku, MathPrecision, OpCategory,
PlanPreference, PrecisionGuarantee, QuantizeKind, S8, TensorMut, TensorRef, Workspace,
};
use super::map_status;
#[derive(Copy, Clone, Debug)]
#[non_exhaustive]
pub struct SmoothQuantLinearDescriptor {
pub m: i32,
pub n: i32,
pub k: i32,
pub act_scale: f32,
pub activation_element: ElementKind,
pub weight_element: ElementKind,
pub output_element: ElementKind,
}
impl SmoothQuantLinearDescriptor {
pub fn new<TIn: Element>(m: i32, n: i32, k: i32, act_scale: f32) -> Self {
Self {
m,
n,
k,
act_scale,
activation_element: ElementKind::S8,
weight_element: ElementKind::S8,
output_element: TIn::KIND,
}
}
}
pub struct SmoothQuantLinearArgs<'a, TIn: Element, TWQ: IntElement> {
pub act_q: TensorRef<'a, S8, 2>,
pub weight_q: TensorRef<'a, TWQ, 2>,
pub weight_scale: TensorRef<'a, TIn, 1>,
pub output: TensorMut<'a, TIn, 2>,
pub act_scale_scratch: TensorMut<'a, TIn, 1>,
}
pub struct SmoothQuantLinearPlan<TIn: Element, TWQ: IntElement> {
desc: SmoothQuantLinearDescriptor,
sku: KernelSku,
_marker: PhantomData<(TIn, TWQ)>,
}
impl<TIn: Element, TWQ: IntElement> SmoothQuantLinearPlan<TIn, TWQ> {
pub fn select(
_stream: &Stream,
desc: &SmoothQuantLinearDescriptor,
_pref: PlanPreference,
) -> Result<Self> {
if desc.output_element != TIn::KIND {
return Err(Error::Unsupported(
"SmoothQuantLinearPlan: descriptor output_element != TIn",
));
}
if desc.weight_element != TWQ::KIND {
return Err(Error::Unsupported(
"SmoothQuantLinearPlan: descriptor weight_element != TWQ",
));
}
if desc.activation_element != ElementKind::S8 {
return Err(Error::Unsupported(
"SmoothQuantLinearPlan: trailblazer only wires S8 activation \
(matches underlying quantized_linear_w8a8 kernel)",
));
}
if !matches!(TIn::KIND, ElementKind::F32 | ElementKind::F64) {
return Err(Error::Unsupported(
"SmoothQuantLinearPlan: trailblazer only wires f32 / f64 \
output (f16 / bf16 follow when quantized_linear_w8a8 grows them)",
));
}
if TWQ::KIND != ElementKind::S8 {
return Err(Error::Unsupported(
"SmoothQuantLinearPlan: trailblazer only wires S8 weight (U8 deferred)",
));
}
if desc.m < 0 || desc.n < 0 || desc.k < 0 {
return Err(Error::InvalidProblem(
"SmoothQuantLinearPlan: m, n, k must be non-negative",
));
}
if !desc.act_scale.is_finite() {
return Err(Error::InvalidProblem(
"SmoothQuantLinearPlan: act_scale must be finite",
));
}
let precision_guarantee = PrecisionGuarantee {
math_precision: if TIn::KIND == ElementKind::F64 {
MathPrecision::F64
} else {
MathPrecision::F32
},
accumulator: ElementKind::F32,
bit_stable_on_same_hardware: true,
deterministic: true,
};
let sku = KernelSku {
category: OpCategory::Quantization,
op: QuantizeKind::QuantizedLinear as u16,
element: TIn::KIND,
aux_element: Some(TWQ::KIND),
layout: None,
epilogue: None,
arch: ArchSku::Sm80,
backend: BackendKind::Bespoke,
precision_guarantee,
};
Ok(Self {
desc: *desc,
sku,
_marker: PhantomData,
})
}
pub fn can_implement(&self, args: &SmoothQuantLinearArgs<'_, TIn, TWQ>) -> Result<()> {
if args.act_q.shape != [self.desc.m, self.desc.k] {
return Err(Error::InvalidProblem(
"SmoothQuantLinearPlan: act_q shape != [M, K]",
));
}
if args.weight_q.shape != [self.desc.n, self.desc.k] {
return Err(Error::InvalidProblem(
"SmoothQuantLinearPlan: weight_q shape != [N, K]",
));
}
if args.weight_scale.shape != [self.desc.n] {
return Err(Error::InvalidProblem(
"SmoothQuantLinearPlan: weight_scale shape != [N]",
));
}
if args.output.shape != [self.desc.m, self.desc.n] {
return Err(Error::InvalidProblem(
"SmoothQuantLinearPlan: output shape != [M, N]",
));
}
if args.act_scale_scratch.shape != [self.desc.m] {
return Err(Error::InvalidProblem(
"SmoothQuantLinearPlan: act_scale_scratch shape != [M]",
));
}
Ok(())
}
#[inline]
pub fn workspace_size(&self) -> usize {
0
}
#[inline]
pub fn sku(&self) -> KernelSku {
self.sku
}
#[inline]
pub fn precision_guarantee(&self) -> PrecisionGuarantee {
self.sku.precision_guarantee
}
pub fn run(
&self,
stream: &Stream,
_workspace: Workspace<'_>,
args: SmoothQuantLinearArgs<'_, TIn, TWQ>,
) -> Result<()> {
self.can_implement(&args)?;
if (self.desc.m as i64) * (self.desc.n as i64) == 0 || self.desc.k == 0 {
return Ok(());
}
let stream_ptr = stream.as_raw() as *mut c_void;
let fill_ptr = args.act_scale_scratch.data.as_raw().0 as *mut c_void;
let fill_status = match TIn::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_fill_f32_run(
self.desc.m as i64,
fill_ptr,
self.desc.act_scale,
core::ptr::null_mut(),
0,
stream_ptr,
)
},
ElementKind::F64 => unsafe {
baracuda_kernels_sys::baracuda_kernels_fill_f64_run(
self.desc.m as i64,
fill_ptr,
self.desc.act_scale as f64,
core::ptr::null_mut(),
0,
stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"SmoothQuantLinearPlan::run reached unsupported TIn at \
act-scale broadcast (select should have caught)",
))
}
};
map_status(fill_status)?;
let weight_ptr = args.weight_q.data.as_raw().0 as *const c_void;
let act_q_ptr = args.act_q.data.as_raw().0 as *const c_void;
let act_scale_const = args.act_scale_scratch.data.as_raw().0 as *const c_void;
let w_scale_ptr = args.weight_scale.data.as_raw().0 as *const c_void;
let out_ptr = args.output.data.as_raw().0 as *mut c_void;
let ql_status = match TIn::KIND {
ElementKind::F32 => unsafe {
baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f32_run(
self.desc.m,
self.desc.n,
self.desc.k,
weight_ptr,
act_q_ptr,
act_scale_const,
w_scale_ptr,
out_ptr,
core::ptr::null_mut(),
0,
stream_ptr,
)
},
ElementKind::F64 => unsafe {
baracuda_kernels_sys::baracuda_kernels_quantized_linear_w8a8_f64_run(
self.desc.m,
self.desc.n,
self.desc.k,
weight_ptr,
act_q_ptr,
act_scale_const,
w_scale_ptr,
out_ptr,
core::ptr::null_mut(),
0,
stream_ptr,
)
},
_ => {
return Err(Error::Unsupported(
"SmoothQuantLinearPlan::run reached unsupported TIn at \
quantized-linear pass (select should have caught)",
))
}
};
map_status(ql_status)
}
}