Skip to main content

hanzo_quant/gptq/
gptq_cpu.rs

1use crate::{
2    has_missing_required_tensors, make_dummy_or_error, IsqType, QuantMethod, QuantMethodConfig,
3    QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, ShardedVarBuilder,
4};
5use hanzo_ml::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12    fn new(method: QuantMethodConfig) -> Result<Self>
13    where
14        Self: Sized,
15    {
16        match method {
17            QuantMethodConfig::GptqAwq { .. } => {
18                hanzo_ml::bail!("GPTQ is only supported on CUDA.")
19            }
20            QuantMethodConfig::Gguf { .. }
21            | QuantMethodConfig::Unquantized(_)
22            | QuantMethodConfig::Hqq { .. }
23            | QuantMethodConfig::Dummy
24            | QuantMethodConfig::FP8 { .. }
25            | QuantMethodConfig::Bnb { .. }
26            | QuantMethodConfig::BlockwiseFP8 { .. }
27            | QuantMethodConfig::PerTensorFP8 { .. }
28            | QuantMethodConfig::Afq { .. }
29            | QuantMethodConfig::MXFP4 { .. } => {
30                unreachable!()
31            }
32        }
33    }
34
35    fn dequantize_w(&self) -> Result<Tensor> {
36        todo!()
37    }
38
39    fn forward_raw(&self, _a: &Tensor) -> Result<Tensor> {
40        todo!()
41    }
42
43    fn quantized_act_type(&self) -> Option<DType> {
44        todo!()
45    }
46
47    fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
48        todo!()
49    }
50
51    fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
52        todo!()
53    }
54
55    fn apply_isq(
56        self: Arc<Self>,
57        _dtype: Option<IsqType>,
58        _device: Device,
59        _n_quantized: &AtomicUsize,
60        _imatrix_weight: Option<Vec<f32>>,
61        _guard: QuantizeOntoGuard,
62    ) -> Result<Arc<dyn QuantMethod>> {
63        todo!()
64    }
65}
66
67impl QuantizedSerde for GptqLayer {
68    fn name(&self) -> &'static str {
69        "gptq"
70    }
71}
72
73macro_rules! pack_factor {
74    ($bits:expr) => {
75        32 / $bits
76    };
77}
78
79pub fn gptq_linear(
80    in_dim: usize,
81    out_dim: usize,
82    config: &QuantizedConfig,
83    vb: ShardedVarBuilder,
84) -> Result<Arc<dyn QuantMethod>> {
85    let QuantizedConfig::GptqAwq {
86        bits,
87        group_size,
88        checkpoint_format: _,
89        is_awq,
90    } = config
91    else {
92        hanzo_ml::bail!("Unexpected quantization config.")
93    };
94
95    let is_awq = *is_awq;
96    // Handle the case where we actually have an unquantized
97    if vb.contains_tensor("weight") {
98        return crate::linear_b(in_dim, out_dim, false, &None, vb);
99    }
100
101    let mut required = vec!["qweight", "qzeros", "scales"];
102    if !is_awq {
103        required.push("g_idx");
104    }
105    if has_missing_required_tensors(&vb, &required) {
106        return make_dummy_or_error("gptq_awq_linear", &vb, &required);
107    }
108
109    let qw_shape = if !is_awq {
110        //quantized gptq (k/pack_factor, n) format
111        (in_dim / pack_factor!(bits), out_dim)
112    } else {
113        //quantized awq (k, n/pack_factor) format
114        (in_dim, out_dim / pack_factor!(bits))
115    };
116
117    let qweight = vb.get_with_hints_dtype(qw_shape, "qweight", Default::default(), DType::I32)?;
118    let scale_and_zero_size = in_dim / group_size;
119    let qzeros = vb.get_with_hints_dtype(
120        (scale_and_zero_size, out_dim / pack_factor!(bits)),
121        "qzeros",
122        Default::default(),
123        DType::I32,
124    )?;
125    let g_idx = if is_awq {
126        None
127    } else {
128        Some(vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?)
129    };
130    let scales = vb.get_with_hints_dtype(
131        (scale_and_zero_size, out_dim),
132        "scales",
133        Default::default(),
134        DType::F16,
135    )?;
136    let bias = if vb.contains_tensor("bias") {
137        Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
138    } else {
139        None
140    };
141
142    let config = QuantMethodConfig::GptqAwq {
143        bits: *bits as i32,
144        use_exllama: false,
145        q_weight: qweight,
146        qzeros: Some(qzeros),
147        scales,
148        g_idx,
149        bias,
150        workspace: None,
151        is_marlin: false,
152        is_awq,
153    };
154    Ok(Arc::new(GptqLayer::new(config)?))
155}