Skip to main content

ferrum_quantization/
gptq.rs

1//! GPTQ linear projection.
2//!
3//! GPTQ packs f16 weights as int4 groups, each group sharing a scale +
4//! zero_point. On-disk layout from AutoGPTQ / gptq-for-llama:
5//!
6//!   qweight:  `[in_features / 8, out_features]`  i32 — 8 int4s per int32
7//!   qzeros:   `[in_features / group_size, out_features / 8]`  i32
8//!   scales:   `[in_features / group_size, out_features]`      f16
9//!   g_idx:    `[in_features]` i32 — per-row scale-group map (desc_act only)
10//!
11//! `GptqLinear<B>` stores a backend-specific `B::GptqStore` produced by
12//! `Backend::load_gptq`. The store holds whatever format the backend
13//! needs (CUDA: Marlin-repacked tiles; CPU: dequantized f32 weights;
14//! Metal: unsupported).
15
16use ferrum_kernels::backend::Backend;
17use ferrum_kernels::Linear;
18use ferrum_types::Result;
19
20pub struct GptqLinear<B: Backend> {
21    store: B::GptqStore,
22    bias: Option<B::Buffer>,
23    in_features: usize,
24    out_features: usize,
25}
26
27impl<B: Backend> GptqLinear<B> {
28    /// Build from raw host-side GPTQ tensors. The Backend repacks into
29    /// its preferred format once; inference uses the repacked store.
30    ///
31    /// `qweight`: `[k/8, n]` i32 (packed int4)
32    /// `scales`:  `[k/group_size, n]` f32 (converted from f16 by caller)
33    /// `qzeros`:  `[k/group_size, n/8]` i32
34    /// `g_idx`:   `[k]` i32 — optional, only used for desc_act=true
35    #[allow(clippy::too_many_arguments)]
36    pub fn from_raw(
37        qweight: &[i32],
38        scales: &[f32],
39        qzeros: &[i32],
40        g_idx: Option<&[i32]>,
41        bits: u32,
42        group_size: usize,
43        in_features: usize,
44        out_features: usize,
45    ) -> Result<Self> {
46        let store = B::load_gptq(
47            qweight,
48            scales,
49            qzeros,
50            g_idx,
51            bits,
52            group_size,
53            in_features,
54            out_features,
55        )?;
56        Ok(Self {
57            store,
58            bias: None,
59            in_features,
60            out_features,
61        })
62    }
63
64    /// Construct directly from a pre-built backend store (e.g. tests).
65    pub fn from_store(store: B::GptqStore, in_features: usize, out_features: usize) -> Self {
66        Self {
67            store,
68            bias: None,
69            in_features,
70            out_features,
71        }
72    }
73
74    /// Attach a bias vector (`[out_features]` f32 on host, uploaded to backend).
75    /// Qwen2.5 / Llama-with-bias variants require this.
76    pub fn with_bias(mut self, bias: &[f32]) -> Self {
77        debug_assert_eq!(
78            bias.len(),
79            self.out_features,
80            "GptqLinear bias length mismatch"
81        );
82        self.bias = Some(B::from_slice(bias));
83        self
84    }
85
86    pub fn store(&self) -> &B::GptqStore {
87        &self.store
88    }
89}
90
91impl<B: Backend> Linear<B> for GptqLinear<B> {
92    fn in_features(&self) -> usize {
93        self.in_features
94    }
95
96    fn out_features(&self) -> usize {
97        self.out_features
98    }
99
100    fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
101        B::gemm_gptq(ctx, input, &self.store, out, m)
102            .unwrap_or_else(|e| panic!("GPTQ forward failed: {e}"));
103        if let Some(bias) = &self.bias {
104            B::add_bias(ctx, out, bias, m, self.out_features);
105        }
106    }
107}