Skip to main content

ferrum_models/loader/
gptq_loader.rs

1//! GPTQ quantized model loader.
2//!
3//! Loads GPTQ INT4 weights from HuggingFace safetensors and provides:
4//! - Dequantized FP16 tensors for candle prefill (via VarBuilder)
5//! - Packed INT4 weights for CUDA runner decode (via GpuQuantWeight)
6
7use ferrum_types::{FerrumError, Result};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11/// GPTQ quantization config (from quantize_config.json).
12#[derive(Debug, Clone, serde::Deserialize)]
13pub struct QuantizeConfig {
14    pub bits: usize,
15    pub group_size: i64,
16    #[serde(default)]
17    pub sym: bool,
18    #[serde(default)]
19    pub desc_act: bool,
20    #[serde(default)]
21    pub quant_method: String,
22}
23
24impl QuantizeConfig {
25    /// Try to load quantize_config.json from model directory.
26    /// Returns None if file doesn't exist (non-quantized model).
27    pub fn from_model_dir(model_dir: &Path) -> Result<Option<Self>> {
28        let path = model_dir.join("quantize_config.json");
29        if !path.exists() {
30            // Also check config.json for embedded quantization_config
31            let config_path = model_dir.join("config.json");
32            if config_path.exists() {
33                if let Ok(content) = std::fs::read_to_string(&config_path) {
34                    if let Ok(config) = serde_json::from_str::<serde_json::Value>(&content) {
35                        if let Some(qc) = config.get("quantization_config") {
36                            if let Ok(qconfig) =
37                                serde_json::from_value::<QuantizeConfig>(qc.clone())
38                            {
39                                tracing::info!("GPTQ config found in config.json: {:?}", qconfig);
40                                return Ok(Some(qconfig));
41                            }
42                        }
43                    }
44                }
45            }
46            return Ok(None);
47        }
48        let content = std::fs::read_to_string(&path)
49            .map_err(|e| FerrumError::model(format!("read quantize_config.json: {e}")))?;
50        let config: QuantizeConfig = serde_json::from_str(&content)
51            .map_err(|e| FerrumError::model(format!("parse quantize_config.json: {e}")))?;
52        tracing::info!("GPTQ config: {:?}", config);
53        Ok(Some(config))
54    }
55
56    pub fn effective_group_size(&self, k: usize) -> usize {
57        if self.group_size <= 0 {
58            k // per-channel
59        } else {
60            self.group_size as usize
61        }
62    }
63}
64
65/// Packed GPTQ weights for one linear layer (CPU-side, before GPU upload).
66#[derive(Debug)]
67pub struct GptqLayerWeights {
68    /// Packed INT4 weights: [K/8, N] as int32
69    pub qweight: Vec<i32>,
70    /// Per-group scales: [K/group_size, N] as f16 bytes
71    pub scales: Vec<half::f16>,
72    /// Per-group zero-points: [K/group_size, N/8] as int32 (None for symmetric)
73    pub qzeros: Option<Vec<i32>>,
74    pub k: usize,
75    pub n: usize,
76    pub group_size: usize,
77    pub symmetric: bool,
78}
79
80impl GptqLayerWeights {
81    /// Dequantize to FP16 on CPU. Returns [K, N] in row-major order.
82    pub fn dequantize_cpu(&self) -> Vec<half::f16> {
83        let mut output = vec![half::f16::ZERO; self.k * self.n];
84        let packed_rows = self.k / 8;
85
86        for packed_row in 0..packed_rows {
87            for col in 0..self.n {
88                let packed = self.qweight[packed_row * self.n + col];
89                let base_k = packed_row * 8;
90                let group = base_k / self.group_size;
91                let scale = self.scales[group * self.n + col].to_f32();
92
93                let zero = if self.symmetric {
94                    8
95                } else if let Some(ref qz) = self.qzeros {
96                    let zp_packed = qz[group * (self.n / 8) + col / 8];
97                    let zp_shift = (col % 8) * 4;
98                    (zp_packed >> zp_shift) & 0xF
99                } else {
100                    8
101                };
102
103                for i in 0..8 {
104                    let val = (packed >> (i * 4)) & 0xF;
105                    let dequantized = (val - zero) as f32 * scale;
106                    output[(base_k + i as usize) * self.n + col] = half::f16::from_f32(dequantized);
107                }
108            }
109        }
110        output
111    }
112}
113
114/// Load GPTQ packed weights from safetensors files.
115///
116/// Returns a map of layer_prefix → GptqLayerWeights.
117/// Layer prefixes are like "model.layers.0.self_attn.q_proj".
118pub fn load_gptq_weights(
119    model_dir: &Path,
120    qconfig: &QuantizeConfig,
121) -> Result<HashMap<String, GptqLayerWeights>> {
122    use safetensors::SafeTensors;
123
124    let safetensor_files = find_safetensor_files(model_dir)?;
125    if safetensor_files.is_empty() {
126        return Err(FerrumError::model("No safetensor files found"));
127    }
128
129    let mut result = HashMap::new();
130
131    // Collect all qweight tensor names to find layer prefixes
132    for path in &safetensor_files {
133        let data = std::fs::read(path)
134            .map_err(|e| FerrumError::model(format!("read {}: {e}", path.display())))?;
135        let st = SafeTensors::deserialize(&data)
136            .map_err(|e| FerrumError::model(format!("parse {}: {e}", path.display())))?;
137
138        for (name, _) in st.tensors() {
139            if !name.ends_with(".qweight") {
140                continue;
141            }
142            let prefix = name.strip_suffix(".qweight").unwrap().to_string();
143
144            // Load qweight
145            let qw_tensor = st
146                .tensor(&format!("{prefix}.qweight"))
147                .map_err(|e| FerrumError::model(format!("{prefix}.qweight: {e}")))?;
148            let qweight: Vec<i32> = bytemuck::cast_slice(qw_tensor.data()).to_vec();
149            let qw_shape = qw_tensor.shape();
150            let packed_k = qw_shape[0]; // K/8
151            let n = qw_shape[1];
152            let k = packed_k * 8;
153
154            // Load scales
155            let sc_tensor = st
156                .tensor(&format!("{prefix}.scales"))
157                .map_err(|e| FerrumError::model(format!("{prefix}.scales: {e}")))?;
158            let scales: Vec<half::f16> = bytemuck::cast_slice(sc_tensor.data()).to_vec();
159
160            // Load qzeros (optional for symmetric)
161            let qzeros = if !qconfig.sym {
162                let qz_tensor = st
163                    .tensor(&format!("{prefix}.qzeros"))
164                    .map_err(|e| FerrumError::model(format!("{prefix}.qzeros: {e}")))?;
165                Some(bytemuck::cast_slice(qz_tensor.data()).to_vec())
166            } else {
167                None
168            };
169
170            let gs = qconfig.effective_group_size(k);
171
172            tracing::debug!(
173                "GPTQ layer: {prefix} K={k} N={n} group_size={gs} sym={}",
174                qconfig.sym
175            );
176
177            result.insert(
178                prefix,
179                GptqLayerWeights {
180                    qweight,
181                    scales,
182                    qzeros,
183                    k,
184                    n,
185                    group_size: gs,
186                    symmetric: qconfig.sym,
187                },
188            );
189        }
190    }
191
192    tracing::info!("Loaded {} GPTQ quantized layers (raw)", result.len());
193
194    // Fuse separate q/k/v → qkv_proj and gate/up → gate_up_proj.
195    // GPTQ stores separate projections, but the CUDA runner expects fused weights.
196    fuse_qkv_and_gate_up(&mut result);
197
198    tracing::info!(
199        "After fusion: {} GPTQ layers (includes fused qkv_proj, gate_up_proj)",
200        result.len()
201    );
202    Ok(result)
203}
204
205/// Fuse separate q/k/v projections into qkv_proj, and gate/up into gate_up_proj.
206/// GPTQ packing is [K/8, N] — fusing along N dimension means concatenating columns.
207fn fuse_qkv_and_gate_up(weights: &mut HashMap<String, GptqLayerWeights>) {
208    let prefixes: Vec<String> = weights
209        .keys()
210        .filter(|k| k.ends_with(".self_attn.q_proj"))
211        .map(|k| k.strip_suffix(".self_attn.q_proj").unwrap().to_string())
212        .collect();
213
214    for layer_prefix in &prefixes {
215        // Fuse q + k + v → qkv_proj
216        let q_key = format!("{layer_prefix}.self_attn.q_proj");
217        let k_key = format!("{layer_prefix}.self_attn.k_proj");
218        let v_key = format!("{layer_prefix}.self_attn.v_proj");
219        if let (Some(q), Some(k), Some(v)) = (
220            weights.get(&q_key),
221            weights.get(&k_key),
222            weights.get(&v_key),
223        ) {
224            if q.k == k.k && q.k == v.k {
225                let fused = fuse_columns(&[q, k, v]);
226                let fused_key = format!("{layer_prefix}.self_attn.qkv_proj");
227                tracing::info!(
228                    "Fused {q_key}+{k_key}+{v_key} → {fused_key} K={} N={}",
229                    fused.k,
230                    fused.n
231                );
232                weights.insert(fused_key, fused);
233            }
234        }
235
236        // Fuse gate + up → gate_up_proj
237        let gate_key = format!("{layer_prefix}.mlp.gate_proj");
238        let up_key = format!("{layer_prefix}.mlp.up_proj");
239        if let (Some(gate), Some(up)) = (weights.get(&gate_key), weights.get(&up_key)) {
240            if gate.k == up.k {
241                let fused = fuse_columns(&[gate, up]);
242                let fused_key = format!("{layer_prefix}.mlp.gate_up_proj");
243                tracing::info!(
244                    "Fused {gate_key}+{up_key} → {fused_key} K={} N={}",
245                    fused.k,
246                    fused.n
247                );
248                weights.insert(fused_key, fused);
249            }
250        }
251    }
252}
253
254/// Fuse multiple GPTQ weights along the N (output) dimension.
255/// All weights must have the same K. Result has N = sum(w.n for w in weights).
256///
257/// qweight layout: [K/8, N] — concatenate columns.
258/// scales layout: [K/gs, N] — concatenate columns.
259/// qzeros layout: [K/gs, N/8] — concatenate columns (trickier due to packing).
260fn fuse_columns(parts: &[&GptqLayerWeights]) -> GptqLayerWeights {
261    let k = parts[0].k;
262    let gs = parts[0].group_size;
263    let sym = parts[0].symmetric;
264    let total_n: usize = parts.iter().map(|p| p.n).sum();
265    let packed_k = k / 8;
266    let num_groups = k / gs;
267
268    // Fuse qweight [K/8, N] — row by row, concatenate columns
269    let mut qweight = vec![0i32; packed_k * total_n];
270    let mut col_offset = 0;
271    for part in parts {
272        for row in 0..packed_k {
273            for col in 0..part.n {
274                qweight[row * total_n + col_offset + col] = part.qweight[row * part.n + col];
275            }
276        }
277        col_offset += part.n;
278    }
279
280    // Fuse scales [K/gs, N]
281    let mut scales = vec![half::f16::ZERO; num_groups * total_n];
282    col_offset = 0;
283    for part in parts {
284        for row in 0..num_groups {
285            for col in 0..part.n {
286                scales[row * total_n + col_offset + col] = part.scales[row * part.n + col];
287            }
288        }
289        col_offset += part.n;
290    }
291
292    // Fuse qzeros [K/gs, N/8] — need to unpack, concatenate, repack
293    let qzeros = if !sym {
294        let mut all_zeros = vec![0u8; num_groups * total_n];
295        let mut col_off = 0usize;
296        for part in parts {
297            if let Some(ref qz) = part.qzeros {
298                let part_n8 = part.n / 8;
299                for row in 0..num_groups {
300                    for col in 0..part.n {
301                        let packed = qz[row * part_n8 + col / 8];
302                        let val = ((packed >> ((col % 8) * 4)) & 0xF) as u8;
303                        all_zeros[row * total_n + col_off + col] = val;
304                    }
305                }
306            }
307            col_off += part.n;
308        }
309        // Repack
310        let total_n8 = total_n / 8;
311        let mut packed_zeros = vec![0i32; num_groups * total_n8];
312        for row in 0..num_groups {
313            for col in 0..total_n {
314                let val = all_zeros[row * total_n + col] as i32;
315                packed_zeros[row * total_n8 + col / 8] |= val << ((col % 8) * 4);
316            }
317        }
318        Some(packed_zeros)
319    } else {
320        None
321    };
322
323    GptqLayerWeights {
324        qweight,
325        scales,
326        qzeros,
327        k,
328        n: total_n,
329        group_size: gs,
330        symmetric: sym,
331    }
332}
333
334fn find_safetensor_files(model_dir: &Path) -> Result<Vec<PathBuf>> {
335    let mut files = Vec::new();
336
337    // Check single file
338    let single = model_dir.join("model.safetensors");
339    if single.exists() {
340        files.push(single);
341        return Ok(files);
342    }
343
344    // Check sharded index
345    let index_path = model_dir.join("model.safetensors.index.json");
346    if index_path.exists() {
347        let content = std::fs::read_to_string(&index_path)
348            .map_err(|e| FerrumError::model(format!("read index: {e}")))?;
349        let index: serde_json::Value = serde_json::from_str(&content)
350            .map_err(|e| FerrumError::model(format!("parse index: {e}")))?;
351        if let Some(weight_map) = index.get("weight_map").and_then(|v| v.as_object()) {
352            let mut seen = std::collections::HashSet::new();
353            for filename in weight_map.values().filter_map(|v| v.as_str()) {
354                if seen.insert(filename.to_string()) {
355                    let path = model_dir.join(filename);
356                    if path.exists() {
357                        files.push(path);
358                    }
359                }
360            }
361        }
362    }
363
364    Ok(files)
365}