hanzo_quant/gptq/
gptq_cpu.rs1use crate::{
2 has_missing_required_tensors, make_dummy_or_error, IsqType, QuantMethod, QuantMethodConfig,
3 QuantizeOntoGuard, QuantizedConfig, QuantizedSerde, ShardedVarBuilder,
4};
5use hanzo_ml::{DType, Device, Result, Tensor};
6use std::sync::{atomic::AtomicUsize, Arc};
7
8#[derive(Debug)]
9pub struct GptqLayer;
10
11impl QuantMethod for GptqLayer {
12 fn new(method: QuantMethodConfig) -> Result<Self>
13 where
14 Self: Sized,
15 {
16 match method {
17 QuantMethodConfig::GptqAwq { .. } => {
18 hanzo_ml::bail!("GPTQ is only supported on CUDA.")
19 }
20 QuantMethodConfig::Gguf { .. }
21 | QuantMethodConfig::Unquantized(_)
22 | QuantMethodConfig::Hqq { .. }
23 | QuantMethodConfig::Dummy
24 | QuantMethodConfig::FP8 { .. }
25 | QuantMethodConfig::Bnb { .. }
26 | QuantMethodConfig::BlockwiseFP8 { .. }
27 | QuantMethodConfig::PerTensorFP8 { .. }
28 | QuantMethodConfig::Afq { .. }
29 | QuantMethodConfig::MXFP4 { .. } => {
30 unreachable!()
31 }
32 }
33 }
34
35 fn dequantize_w(&self) -> Result<Tensor> {
36 todo!()
37 }
38
39 fn forward_raw(&self, _a: &Tensor) -> Result<Tensor> {
40 todo!()
41 }
42
43 fn quantized_act_type(&self) -> Option<DType> {
44 todo!()
45 }
46
47 fn add_delta_w(&self, _delta: &Tensor) -> Result<Arc<dyn QuantMethod>> {
48 todo!()
49 }
50
51 fn dtype_and_device(&self) -> (DType, hanzo_ml::Device) {
52 todo!()
53 }
54
55 fn apply_isq(
56 self: Arc<Self>,
57 _dtype: Option<IsqType>,
58 _device: Device,
59 _n_quantized: &AtomicUsize,
60 _imatrix_weight: Option<Vec<f32>>,
61 _guard: QuantizeOntoGuard,
62 ) -> Result<Arc<dyn QuantMethod>> {
63 todo!()
64 }
65}
66
67impl QuantizedSerde for GptqLayer {
68 fn name(&self) -> &'static str {
69 "gptq"
70 }
71}
72
73macro_rules! pack_factor {
74 ($bits:expr) => {
75 32 / $bits
76 };
77}
78
79pub fn gptq_linear(
80 in_dim: usize,
81 out_dim: usize,
82 config: &QuantizedConfig,
83 vb: ShardedVarBuilder,
84) -> Result<Arc<dyn QuantMethod>> {
85 let QuantizedConfig::GptqAwq {
86 bits,
87 group_size,
88 checkpoint_format: _,
89 is_awq,
90 } = config
91 else {
92 hanzo_ml::bail!("Unexpected quantization config.")
93 };
94
95 let is_awq = *is_awq;
96 if vb.contains_tensor("weight") {
98 return crate::linear_b(in_dim, out_dim, false, &None, vb);
99 }
100
101 let mut required = vec!["qweight", "qzeros", "scales"];
102 if !is_awq {
103 required.push("g_idx");
104 }
105 if has_missing_required_tensors(&vb, &required) {
106 return make_dummy_or_error("gptq_awq_linear", &vb, &required);
107 }
108
109 let qw_shape = if !is_awq {
110 (in_dim / pack_factor!(bits), out_dim)
112 } else {
113 (in_dim, out_dim / pack_factor!(bits))
115 };
116
117 let qweight = vb.get_with_hints_dtype(qw_shape, "qweight", Default::default(), DType::I32)?;
118 let scale_and_zero_size = in_dim / group_size;
119 let qzeros = vb.get_with_hints_dtype(
120 (scale_and_zero_size, out_dim / pack_factor!(bits)),
121 "qzeros",
122 Default::default(),
123 DType::I32,
124 )?;
125 let g_idx = if is_awq {
126 None
127 } else {
128 Some(vb.get_with_hints_dtype((in_dim,), "g_idx", Default::default(), DType::I32)?)
129 };
130 let scales = vb.get_with_hints_dtype(
131 (scale_and_zero_size, out_dim),
132 "scales",
133 Default::default(),
134 DType::F16,
135 )?;
136 let bias = if vb.contains_tensor("bias") {
137 Some(vb.get_with_hints_dtype((out_dim,), "bias", Default::default(), DType::F16)?)
138 } else {
139 None
140 };
141
142 let config = QuantMethodConfig::GptqAwq {
143 bits: *bits as i32,
144 use_exllama: false,
145 q_weight: qweight,
146 qzeros: Some(qzeros),
147 scales,
148 g_idx,
149 bias,
150 workspace: None,
151 is_marlin: false,
152 is_awq,
153 };
154 Ok(Arc::new(GptqLayer::new(config)?))
155}