use crate::error::Result;
use crate::quant::traits::QuantMatmulOps;
use numr::dtype::DType;
use numr::ops::{BinaryOps, TypeConversionOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use super::method::DecomposedQuantMethod;
use super::tensor::DecomposedQuantTensor;
pub struct DecomposedQuantLinear<R: Runtime> {
weight: DecomposedQuantTensor<R>,
bias: Option<Tensor<R>>,
}
impl<R: Runtime> DecomposedQuantLinear<R> {
pub fn new(weight: DecomposedQuantTensor<R>, bias: Option<Tensor<R>>) -> Self {
Self { weight, bias }
}
pub fn forward<C>(&self, client: &C, input: &Tensor<R>) -> Result<Tensor<R>>
where
C: QuantMatmulOps<R> + BinaryOps<R> + RuntimeClient<R> + TypeConversionOps<R>,
R: Runtime<DType = DType>,
{
let input_dtype = input.dtype();
let input_f32 = if input_dtype != DType::F32 {
client
.cast(input, DType::F32)
.map_err(crate::error::Error::Numr)?
} else {
input.clone()
};
let output =
match self.weight.method {
DecomposedQuantMethod::Awq { group_size } => client.int4_gemm(
&input_f32,
&self.weight.qweight,
&self.weight.scales,
&self.weight.qzeros,
group_size,
)?,
DecomposedQuantMethod::Gptq { .. } => {
let g_idx = self.weight.g_idx.as_ref().ok_or_else(|| {
crate::error::Error::ModelError {
reason: "GPTQ requires g_idx tensor".into(),
}
})?;
client.int4_gemm_gptq(
&input_f32,
&self.weight.qweight,
&self.weight.qzeros,
&self.weight.scales,
g_idx,
)?
}
};
let output = match &self.bias {
Some(bias) => client
.add(&output, bias)
.map_err(crate::error::Error::Numr)?,
None => output,
};
if input_dtype != DType::F32 {
client
.cast(&output, input_dtype)
.map_err(crate::error::Error::Numr)
} else {
Ok(output)
}
}
pub fn weight(&self) -> &DecomposedQuantTensor<R> {
&self.weight
}
pub fn bias(&self) -> Option<&Tensor<R>> {
self.bias.as_ref()
}
}