Skip to main content

ferrum_kernels/quant_linear/
cpu_dequant.rs

1//! `Linear<CpuBackend>` impl for GPTQ weights, dequantized at load time.
2//!
3//! Phase 3e/2: replaces the old `BackendQuantMarlin::gemm_gptq` impl on
4//! CpuBackend. The kernel call (`Self::gemm` on dequantized weights)
5//! lives inside `CpuGptqLinear::forward` instead of the trait method
6//! body.
7
8use crate::backend::cpu::CpuBackend;
9use crate::Linear;
10
11/// CPU GPTQ Linear: holds dequantized fp32 weights `[out_features, in_features]`
12/// row-major, optional bias `[out_features]`, dispatches via `CpuBackend::gemm`.
13///
14/// The dequantization happens once in `BackendQuantMarlin::load_gptq` —
15/// inference is just a regular f32 GEMM.
16pub struct CpuGptqLinear {
17    pub weight_f32: Vec<f32>,
18    pub bias: Option<Vec<f32>>,
19    pub in_features: usize,
20    pub out_features: usize,
21}
22
23impl Linear<CpuBackend> for CpuGptqLinear {
24    fn in_features(&self) -> usize {
25        self.in_features
26    }
27
28    fn out_features(&self) -> usize {
29        self.out_features
30    }
31
32    fn forward(
33        &self,
34        ctx: &mut <CpuBackend as crate::backend::Backend>::Context,
35        input: &<CpuBackend as crate::backend::Backend>::Buffer,
36        out: &mut <CpuBackend as crate::backend::Backend>::Buffer,
37        m: usize,
38    ) {
39        // out[m, n] = a[m, k] @ w[n, k]^T — same contract as `B::gemm`.
40        <CpuBackend as crate::backend::Backend>::gemm(
41            ctx,
42            input,
43            &self.weight_f32,
44            out,
45            m,
46            self.out_features,
47            self.in_features,
48        );
49        if let Some(bias) = &self.bias {
50            <CpuBackend as crate::backend::Backend>::add_bias(ctx, out, bias, m, self.out_features);
51        }
52    }
53}