use ferrum_kernels::backend::Backend;
use ferrum_kernels::Linear;
use ferrum_types::Result;
pub struct GptqLinear<B: Backend> {
store: B::GptqStore,
bias: Option<B::Buffer>,
in_features: usize,
out_features: usize,
}
impl<B: Backend> GptqLinear<B> {
#[allow(clippy::too_many_arguments)]
pub fn from_raw(
qweight: &[i32],
scales: &[f32],
qzeros: &[i32],
g_idx: Option<&[i32]>,
bits: u32,
group_size: usize,
in_features: usize,
out_features: usize,
) -> Result<Self> {
let store = B::load_gptq(
qweight,
scales,
qzeros,
g_idx,
bits,
group_size,
in_features,
out_features,
)?;
Ok(Self {
store,
bias: None,
in_features,
out_features,
})
}
pub fn from_store(store: B::GptqStore, in_features: usize, out_features: usize) -> Self {
Self {
store,
bias: None,
in_features,
out_features,
}
}
pub fn with_bias(mut self, bias: &[f32]) -> Self {
debug_assert_eq!(
bias.len(),
self.out_features,
"GptqLinear bias length mismatch"
);
self.bias = Some(B::from_slice(bias));
self
}
pub fn store(&self) -> &B::GptqStore {
&self.store
}
}
impl<B: Backend> Linear<B> for GptqLinear<B> {
fn in_features(&self) -> usize {
self.in_features
}
fn out_features(&self) -> usize {
self.out_features
}
fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
B::gemm_gptq(ctx, input, &self.store, out, m)
.unwrap_or_else(|e| panic!("GPTQ forward failed: {e}"));
if let Some(bias) = &self.bias {
B::add_bias(ctx, out, bias, m, self.out_features);
}
}
}