ferrum_models/loader/
gptq_loader.rs1use ferrum_types::{FerrumError, Result};
8use std::collections::HashMap;
9use std::path::{Path, PathBuf};
10
11#[derive(Debug, Clone, serde::Deserialize)]
13pub struct QuantizeConfig {
14 pub bits: usize,
15 pub group_size: i64,
16 #[serde(default)]
17 pub sym: bool,
18 #[serde(default)]
19 pub desc_act: bool,
20 #[serde(default)]
21 pub quant_method: String,
22}
23
24impl QuantizeConfig {
25 pub fn from_model_dir(model_dir: &Path) -> Result<Option<Self>> {
28 let path = model_dir.join("quantize_config.json");
29 if !path.exists() {
30 let config_path = model_dir.join("config.json");
32 if config_path.exists() {
33 if let Ok(content) = std::fs::read_to_string(&config_path) {
34 if let Ok(config) = serde_json::from_str::<serde_json::Value>(&content) {
35 if let Some(qc) = config.get("quantization_config") {
36 if let Ok(qconfig) =
37 serde_json::from_value::<QuantizeConfig>(qc.clone())
38 {
39 tracing::info!("GPTQ config found in config.json: {:?}", qconfig);
40 return Ok(Some(qconfig));
41 }
42 }
43 }
44 }
45 }
46 return Ok(None);
47 }
48 let content = std::fs::read_to_string(&path)
49 .map_err(|e| FerrumError::model(format!("read quantize_config.json: {e}")))?;
50 let config: QuantizeConfig = serde_json::from_str(&content)
51 .map_err(|e| FerrumError::model(format!("parse quantize_config.json: {e}")))?;
52 tracing::info!("GPTQ config: {:?}", config);
53 Ok(Some(config))
54 }
55
56 pub fn effective_group_size(&self, k: usize) -> usize {
57 if self.group_size <= 0 {
58 k } else {
60 self.group_size as usize
61 }
62 }
63}
64
65#[derive(Debug)]
67pub struct GptqLayerWeights {
68 pub qweight: Vec<i32>,
70 pub scales: Vec<half::f16>,
72 pub qzeros: Option<Vec<i32>>,
74 pub k: usize,
75 pub n: usize,
76 pub group_size: usize,
77 pub symmetric: bool,
78}
79
80impl GptqLayerWeights {
81 pub fn dequantize_cpu(&self) -> Vec<half::f16> {
83 let mut output = vec![half::f16::ZERO; self.k * self.n];
84 let packed_rows = self.k / 8;
85
86 for packed_row in 0..packed_rows {
87 for col in 0..self.n {
88 let packed = self.qweight[packed_row * self.n + col];
89 let base_k = packed_row * 8;
90 let group = base_k / self.group_size;
91 let scale = self.scales[group * self.n + col].to_f32();
92
93 let zero = if self.symmetric {
94 8
95 } else if let Some(ref qz) = self.qzeros {
96 let zp_packed = qz[group * (self.n / 8) + col / 8];
97 let zp_shift = (col % 8) * 4;
98 (zp_packed >> zp_shift) & 0xF
99 } else {
100 8
101 };
102
103 for i in 0..8 {
104 let val = (packed >> (i * 4)) & 0xF;
105 let dequantized = (val - zero) as f32 * scale;
106 output[(base_k + i as usize) * self.n + col] = half::f16::from_f32(dequantized);
107 }
108 }
109 }
110 output
111 }
112}
113
114pub fn load_gptq_weights(
119 model_dir: &Path,
120 qconfig: &QuantizeConfig,
121) -> Result<HashMap<String, GptqLayerWeights>> {
122 use safetensors::SafeTensors;
123
124 let safetensor_files = find_safetensor_files(model_dir)?;
125 if safetensor_files.is_empty() {
126 return Err(FerrumError::model("No safetensor files found"));
127 }
128
129 let mut result = HashMap::new();
130
131 for path in &safetensor_files {
133 let data = std::fs::read(path)
134 .map_err(|e| FerrumError::model(format!("read {}: {e}", path.display())))?;
135 let st = SafeTensors::deserialize(&data)
136 .map_err(|e| FerrumError::model(format!("parse {}: {e}", path.display())))?;
137
138 for (name, _) in st.tensors() {
139 if !name.ends_with(".qweight") {
140 continue;
141 }
142 let prefix = name.strip_suffix(".qweight").unwrap().to_string();
143
144 let qw_tensor = st
146 .tensor(&format!("{prefix}.qweight"))
147 .map_err(|e| FerrumError::model(format!("{prefix}.qweight: {e}")))?;
148 let qweight: Vec<i32> = bytemuck::cast_slice(qw_tensor.data()).to_vec();
149 let qw_shape = qw_tensor.shape();
150 let packed_k = qw_shape[0]; let n = qw_shape[1];
152 let k = packed_k * 8;
153
154 let sc_tensor = st
156 .tensor(&format!("{prefix}.scales"))
157 .map_err(|e| FerrumError::model(format!("{prefix}.scales: {e}")))?;
158 let scales: Vec<half::f16> = bytemuck::cast_slice(sc_tensor.data()).to_vec();
159
160 let qzeros = if !qconfig.sym {
162 let qz_tensor = st
163 .tensor(&format!("{prefix}.qzeros"))
164 .map_err(|e| FerrumError::model(format!("{prefix}.qzeros: {e}")))?;
165 Some(bytemuck::cast_slice(qz_tensor.data()).to_vec())
166 } else {
167 None
168 };
169
170 let gs = qconfig.effective_group_size(k);
171
172 tracing::debug!(
173 "GPTQ layer: {prefix} K={k} N={n} group_size={gs} sym={}",
174 qconfig.sym
175 );
176
177 result.insert(
178 prefix,
179 GptqLayerWeights {
180 qweight,
181 scales,
182 qzeros,
183 k,
184 n,
185 group_size: gs,
186 symmetric: qconfig.sym,
187 },
188 );
189 }
190 }
191
192 tracing::info!("Loaded {} GPTQ quantized layers (raw)", result.len());
193
194 fuse_qkv_and_gate_up(&mut result);
197
198 tracing::info!(
199 "After fusion: {} GPTQ layers (includes fused qkv_proj, gate_up_proj)",
200 result.len()
201 );
202 Ok(result)
203}
204
205fn fuse_qkv_and_gate_up(weights: &mut HashMap<String, GptqLayerWeights>) {
208 let prefixes: Vec<String> = weights
209 .keys()
210 .filter(|k| k.ends_with(".self_attn.q_proj"))
211 .map(|k| k.strip_suffix(".self_attn.q_proj").unwrap().to_string())
212 .collect();
213
214 for layer_prefix in &prefixes {
215 let q_key = format!("{layer_prefix}.self_attn.q_proj");
217 let k_key = format!("{layer_prefix}.self_attn.k_proj");
218 let v_key = format!("{layer_prefix}.self_attn.v_proj");
219 if let (Some(q), Some(k), Some(v)) = (
220 weights.get(&q_key),
221 weights.get(&k_key),
222 weights.get(&v_key),
223 ) {
224 if q.k == k.k && q.k == v.k {
225 let fused = fuse_columns(&[q, k, v]);
226 let fused_key = format!("{layer_prefix}.self_attn.qkv_proj");
227 tracing::info!(
228 "Fused {q_key}+{k_key}+{v_key} → {fused_key} K={} N={}",
229 fused.k,
230 fused.n
231 );
232 weights.insert(fused_key, fused);
233 }
234 }
235
236 let gate_key = format!("{layer_prefix}.mlp.gate_proj");
238 let up_key = format!("{layer_prefix}.mlp.up_proj");
239 if let (Some(gate), Some(up)) = (weights.get(&gate_key), weights.get(&up_key)) {
240 if gate.k == up.k {
241 let fused = fuse_columns(&[gate, up]);
242 let fused_key = format!("{layer_prefix}.mlp.gate_up_proj");
243 tracing::info!(
244 "Fused {gate_key}+{up_key} → {fused_key} K={} N={}",
245 fused.k,
246 fused.n
247 );
248 weights.insert(fused_key, fused);
249 }
250 }
251 }
252}
253
254fn fuse_columns(parts: &[&GptqLayerWeights]) -> GptqLayerWeights {
261 let k = parts[0].k;
262 let gs = parts[0].group_size;
263 let sym = parts[0].symmetric;
264 let total_n: usize = parts.iter().map(|p| p.n).sum();
265 let packed_k = k / 8;
266 let num_groups = k / gs;
267
268 let mut qweight = vec![0i32; packed_k * total_n];
270 let mut col_offset = 0;
271 for part in parts {
272 for row in 0..packed_k {
273 for col in 0..part.n {
274 qweight[row * total_n + col_offset + col] = part.qweight[row * part.n + col];
275 }
276 }
277 col_offset += part.n;
278 }
279
280 let mut scales = vec![half::f16::ZERO; num_groups * total_n];
282 col_offset = 0;
283 for part in parts {
284 for row in 0..num_groups {
285 for col in 0..part.n {
286 scales[row * total_n + col_offset + col] = part.scales[row * part.n + col];
287 }
288 }
289 col_offset += part.n;
290 }
291
292 let qzeros = if !sym {
294 let mut all_zeros = vec![0u8; num_groups * total_n];
295 let mut col_off = 0usize;
296 for part in parts {
297 if let Some(ref qz) = part.qzeros {
298 let part_n8 = part.n / 8;
299 for row in 0..num_groups {
300 for col in 0..part.n {
301 let packed = qz[row * part_n8 + col / 8];
302 let val = ((packed >> ((col % 8) * 4)) & 0xF) as u8;
303 all_zeros[row * total_n + col_off + col] = val;
304 }
305 }
306 }
307 col_off += part.n;
308 }
309 let total_n8 = total_n / 8;
311 let mut packed_zeros = vec![0i32; num_groups * total_n8];
312 for row in 0..num_groups {
313 for col in 0..total_n {
314 let val = all_zeros[row * total_n + col] as i32;
315 packed_zeros[row * total_n8 + col / 8] |= val << ((col % 8) * 4);
316 }
317 }
318 Some(packed_zeros)
319 } else {
320 None
321 };
322
323 GptqLayerWeights {
324 qweight,
325 scales,
326 qzeros,
327 k,
328 n: total_n,
329 group_size: gs,
330 symmetric: sym,
331 }
332}
333
334fn find_safetensor_files(model_dir: &Path) -> Result<Vec<PathBuf>> {
335 let mut files = Vec::new();
336
337 let single = model_dir.join("model.safetensors");
339 if single.exists() {
340 files.push(single);
341 return Ok(files);
342 }
343
344 let index_path = model_dir.join("model.safetensors.index.json");
346 if index_path.exists() {
347 let content = std::fs::read_to_string(&index_path)
348 .map_err(|e| FerrumError::model(format!("read index: {e}")))?;
349 let index: serde_json::Value = serde_json::from_str(&content)
350 .map_err(|e| FerrumError::model(format!("parse index: {e}")))?;
351 if let Some(weight_map) = index.get("weight_map").and_then(|v| v.as_object()) {
352 let mut seen = std::collections::HashSet::new();
353 for filename in weight_map.values().filter_map(|v| v.as_str()) {
354 if seen.insert(filename.to_string()) {
355 let path = model_dir.join(filename);
356 if path.exists() {
357 files.push(path);
358 }
359 }
360 }
361 }
362 }
363
364 Ok(files)
365}