Skip to main content

ferrum_quantization/
gptq.rs

1//! GPTQ linear projection — thin factory wrapper.
2//!
3//! Phase 3e/2: the actual kernel dispatch lives inside the boxed
4//! `Linear<B>` returned by `B::load_gptq` (`CudaMarlinLinear` on
5//! CUDA, `CpuGptqLinear` on CPU). This module just re-exposes the
6//! historical constructor names so callers don't have to switch.
7
8use ferrum_kernels::backend::{Backend, BackendQuantMarlin};
9use ferrum_kernels::Linear;
10use ferrum_types::Result;
11use std::sync::Arc;
12
13/// GPTQ-format Linear projection, polymorphic over backend.
14///
15/// Holds a boxed backend-specific `Linear<B>` produced by `B::load_gptq`.
16/// `forward()` delegates straight through.
17pub struct GptqLinear<B: Backend + BackendQuantMarlin> {
18    inner: Box<dyn Linear<B> + Send + Sync>,
19}
20
21impl<B: Backend + BackendQuantMarlin> GptqLinear<B> {
22    /// Build from raw host-side GPTQ tensors. The Backend repacks into
23    /// its preferred format once (Marlin tiles on CUDA, dequant on CPU)
24    /// and returns a boxed Linear; inference uses the boxed forward.
25    ///
26    /// `qweight`: `[k/8, n]` i32 (packed int4)
27    /// `scales`:  `[k/group_size, n]` f32 (converted from f16 by caller)
28    /// `qzeros`:  `[k/group_size, n/8]` i32
29    /// `g_idx`:   `[k]` i32 — optional, only used for desc_act=true
30    /// `bias`:    `[n]` f32 — optional fused bias (Qwen2.5 attention)
31    #[allow(clippy::too_many_arguments)]
32    pub fn from_raw(
33        qweight: &[i32],
34        scales: &[f32],
35        qzeros: &[i32],
36        g_idx: Option<&[i32]>,
37        bias: Option<&[f32]>,
38        bits: u32,
39        group_size: usize,
40        in_features: usize,
41        out_features: usize,
42    ) -> Result<Self> {
43        let inner = B::load_gptq(
44            qweight,
45            scales,
46            qzeros,
47            g_idx,
48            bias,
49            bits,
50            group_size,
51            in_features,
52            out_features,
53        )?;
54        Ok(Self { inner })
55    }
56}
57
58impl<B: Backend + BackendQuantMarlin> Linear<B> for GptqLinear<B> {
59    fn in_features(&self) -> usize {
60        self.inner.in_features()
61    }
62
63    fn out_features(&self) -> usize {
64        self.inner.out_features()
65    }
66
67    fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
68        self.inner.forward(ctx, input, out, m);
69    }
70}
71
72/// View into a single column-slice of a shared stacked GPTQ store.
73///
74/// Phase 3e/2: backed by a `Box<dyn Linear<B>>` produced by
75/// `B::make_stacked_expert_linear` (CUDA: `CudaMarlinStackedExpertLinear`;
76/// CPU: `CpuGptqLinear` over a sliced row range). The store itself is
77/// `Arc<B::GptqStore>` so cloning a view is cheap; dropping all views
78/// drops the underlying store.
79pub struct StackedExpertLinear<B: Backend + BackendQuantMarlin> {
80    inner: Box<dyn Linear<B> + Send + Sync>,
81    /// Kept for in_features() reporting.
82    k: usize,
83    /// Kept for out_features() reporting.
84    expert_n: usize,
85}
86
87impl<B: Backend + BackendQuantMarlin> StackedExpertLinear<B> {
88    /// Phase C step 4b: takes the trait-object MarlinExpertStack
89    /// directly (was `Arc<B::GptqStore>` + `B::make_stacked_expert_linear`).
90    pub fn new(
91        stack: Arc<dyn ferrum_kernels::MarlinExpertStack<B>>,
92        expert_offset: usize,
93        expert_n: usize,
94    ) -> Result<Self> {
95        let k = stack.k();
96        let inner = stack.make_expert_linear(expert_offset, expert_n, None)?;
97        Ok(Self { inner, k, expert_n })
98    }
99
100    pub fn new_with_bias(
101        stack: Arc<dyn ferrum_kernels::MarlinExpertStack<B>>,
102        expert_offset: usize,
103        expert_n: usize,
104        bias: &[f32],
105    ) -> Result<Self> {
106        let k = stack.k();
107        let inner = stack.make_expert_linear(expert_offset, expert_n, Some(bias))?;
108        Ok(Self { inner, k, expert_n })
109    }
110}
111
112impl<B: Backend + BackendQuantMarlin> Linear<B> for StackedExpertLinear<B> {
113    fn in_features(&self) -> usize {
114        self.k
115    }
116
117    fn out_features(&self) -> usize {
118        self.expert_n
119    }
120
121    fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
122        self.inner.forward(ctx, input, out, m);
123    }
124}