ferrum_quantization/gptq.rs
1//! GPTQ linear projection.
2//!
3//! GPTQ packs f16 weights as int4 groups, each group sharing a scale +
4//! zero_point. On-disk layout from AutoGPTQ / gptq-for-llama:
5//!
6//! qweight: `[in_features / 8, out_features]` i32 — 8 int4s per int32
7//! qzeros: `[in_features / group_size, out_features / 8]` i32
8//! scales: `[in_features / group_size, out_features]` f16
9//! g_idx: `[in_features]` i32 — per-row scale-group map (desc_act only)
10//!
11//! `GptqLinear<B>` stores a backend-specific `B::GptqStore` produced by
12//! `Backend::load_gptq`. The store holds whatever format the backend
13//! needs (CUDA: Marlin-repacked tiles; CPU: dequantized f32 weights;
14//! Metal: unsupported).
15
16use ferrum_kernels::backend::Backend;
17use ferrum_kernels::Linear;
18use ferrum_types::Result;
19
20pub struct GptqLinear<B: Backend> {
21 store: B::GptqStore,
22 bias: Option<B::Buffer>,
23 in_features: usize,
24 out_features: usize,
25}
26
27impl<B: Backend> GptqLinear<B> {
28 /// Build from raw host-side GPTQ tensors. The Backend repacks into
29 /// its preferred format once; inference uses the repacked store.
30 ///
31 /// `qweight`: `[k/8, n]` i32 (packed int4)
32 /// `scales`: `[k/group_size, n]` f32 (converted from f16 by caller)
33 /// `qzeros`: `[k/group_size, n/8]` i32
34 /// `g_idx`: `[k]` i32 — optional, only used for desc_act=true
35 #[allow(clippy::too_many_arguments)]
36 pub fn from_raw(
37 qweight: &[i32],
38 scales: &[f32],
39 qzeros: &[i32],
40 g_idx: Option<&[i32]>,
41 bits: u32,
42 group_size: usize,
43 in_features: usize,
44 out_features: usize,
45 ) -> Result<Self> {
46 let store = B::load_gptq(
47 qweight,
48 scales,
49 qzeros,
50 g_idx,
51 bits,
52 group_size,
53 in_features,
54 out_features,
55 )?;
56 Ok(Self {
57 store,
58 bias: None,
59 in_features,
60 out_features,
61 })
62 }
63
64 /// Construct directly from a pre-built backend store (e.g. tests).
65 pub fn from_store(store: B::GptqStore, in_features: usize, out_features: usize) -> Self {
66 Self {
67 store,
68 bias: None,
69 in_features,
70 out_features,
71 }
72 }
73
74 /// Attach a bias vector (`[out_features]` f32 on host, uploaded to backend).
75 /// Qwen2.5 / Llama-with-bias variants require this.
76 pub fn with_bias(mut self, bias: &[f32]) -> Self {
77 debug_assert_eq!(
78 bias.len(),
79 self.out_features,
80 "GptqLinear bias length mismatch"
81 );
82 self.bias = Some(B::from_slice(bias));
83 self
84 }
85
86 pub fn store(&self) -> &B::GptqStore {
87 &self.store
88 }
89}
90
91impl<B: Backend> Linear<B> for GptqLinear<B> {
92 fn in_features(&self) -> usize {
93 self.in_features
94 }
95
96 fn out_features(&self) -> usize {
97 self.out_features
98 }
99
100 fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
101 B::gemm_gptq(ctx, input, &self.store, out, m)
102 .unwrap_or_else(|e| panic!("GPTQ forward failed: {e}"));
103 if let Some(bias) = &self.bias {
104 B::add_bias(ctx, out, bias, m, self.out_features);
105 }
106 }
107}