1use std::collections::HashMap;
18use std::fs::File;
19use std::path::Path;
20
21use ferrum_kernels::backend::{Backend, BackendQuantMarlin, SrcDtype};
22use ferrum_types::{FerrumError, Result};
23use half::{bf16, f16};
24use memmap2::Mmap;
25use safetensors::{Dtype, SafeTensors};
26
27fn map_src_dtype(dtype: Dtype) -> Result<SrcDtype> {
29 match dtype {
30 Dtype::F32 => Ok(SrcDtype::F32),
31 Dtype::F16 => Ok(SrcDtype::F16),
32 Dtype::BF16 => Ok(SrcDtype::BF16),
33 other => Err(FerrumError::model(format!(
34 "dtype {other:?} not supported; Dense path expects F32/F16/BF16"
35 ))),
36 }
37}
38
39use crate::config::{QuantConfig, QuantMethod};
40use crate::dense::DenseLinear;
41use crate::gptq::GptqLinear;
42use crate::loader::WeightLoader;
43use crate::traits::Linear;
44
45struct TensorMeta {
51 dtype: Dtype,
52 shape: Vec<usize>,
53 data_start: usize,
55 data_end: usize,
56}
57
58struct Shard {
60 mmap: Mmap,
61 names: Vec<String>,
62 meta: HashMap<String, TensorMeta>,
65}
66
67impl Shard {
68 fn open(path: &Path) -> Result<Self> {
69 let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
70 let mmap = unsafe {
71 Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
72 };
73 let st = SafeTensors::deserialize(&mmap)
77 .map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
78 debug_assert!(mmap.len() >= 8, "safetensors smaller than 8 bytes");
83 let header_len = u64::from_le_bytes(
84 mmap[0..8]
85 .try_into()
86 .expect("8-byte header len read failed"),
87 ) as usize;
88 let data_base = 8 + header_len;
89 let names: Vec<String> = st.names().iter().map(|s| s.to_string()).collect();
90 let mut meta = HashMap::with_capacity(names.len());
91 for name in &names {
92 let view = st.tensor(name).map_err(|e| {
93 FerrumError::model(format!("tensor '{name}' missing during preindex: {e}"))
94 })?;
95 let view_data = view.data();
99 let start = view_data.as_ptr() as usize - mmap.as_ptr() as usize;
100 let end = start + view_data.len();
101 debug_assert!(start >= data_base);
102 meta.insert(
103 name.clone(),
104 TensorMeta {
105 dtype: view.dtype(),
106 shape: view.shape().to_vec(),
107 data_start: start,
108 data_end: end,
109 },
110 );
111 }
112 let _ = data_base;
113 Ok(Self { mmap, names, meta })
114 }
115
116 fn get_cached(&self, name: &str) -> Result<(&[u8], Dtype, &[usize])> {
119 let m = self
120 .meta
121 .get(name)
122 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in shard")))?;
123 Ok((&self.mmap[m.data_start..m.data_end], m.dtype, &m.shape))
124 }
125}
126
127pub struct NativeSafetensorsLoader<B: Backend + BackendQuantMarlin> {
130 shards: Vec<Shard>,
132 index: HashMap<String, usize>,
134 quant_config: Option<QuantConfig>,
136 _m: std::marker::PhantomData<B>,
137}
138
139impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
140 pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
142 let dir = model_dir.as_ref();
143
144 let shard_paths = if dir.join("model.safetensors").exists() {
145 vec![dir.join("model.safetensors")]
146 } else if dir.join("model.safetensors.index.json").exists() {
147 Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
148 .into_iter()
149 .map(|name| dir.join(name))
150 .collect()
151 } else {
152 return Err(FerrumError::model(format!(
153 "no safetensors files in {dir:?}"
154 )));
155 };
156
157 let mut shards = Vec::with_capacity(shard_paths.len());
158 let mut index: HashMap<String, usize> = HashMap::new();
159 for (i, p) in shard_paths.iter().enumerate() {
160 let shard = Shard::open(p)?;
161 for name in &shard.names {
162 index.insert(name.clone(), i);
163 }
164 shards.push(shard);
165 }
166
167 let quant_config = load_quantize_config(dir)?;
168
169 Ok(Self {
170 shards,
171 index,
172 quant_config,
173 _m: std::marker::PhantomData,
174 })
175 }
176
177 fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
178 let data = std::fs::read_to_string(index_path)
179 .map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
180 let json: serde_json::Value = serde_json::from_str(&data)
181 .map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
182 let weight_map = json
183 .get("weight_map")
184 .and_then(|v| v.as_object())
185 .ok_or_else(|| FerrumError::model("index missing weight_map"))?;
186 let mut files: Vec<String> = weight_map
187 .values()
188 .filter_map(|v| v.as_str().map(|s| s.to_string()))
189 .collect();
190 files.sort();
191 files.dedup();
192 Ok(files)
193 }
194
195 fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
197 let shard_idx = *self
198 .index
199 .get(name)
200 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
201 let (data_bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
202 let data = dtype_to_f32(dtype, data_bytes)?;
203 Ok((data, shape.to_vec()))
204 }
205
206 fn read_bytes_typed(&self, name: &str) -> Result<(&[u8], SrcDtype, Vec<usize>)> {
210 let shard_idx = *self
211 .index
212 .get(name)
213 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
214 let (data_bytes, st_dtype, shape) = self.shards[shard_idx].get_cached(name)?;
215 let dtype = map_src_dtype(st_dtype)?;
216 Ok((data_bytes, dtype, shape.to_vec()))
217 }
218
219 fn cat_rows_bytes(&self, names: &[String]) -> Result<(Vec<u8>, SrcDtype, (usize, usize))> {
223 let mut total_rows = 0usize;
224 let mut cols = 0usize;
225 let mut dtype: Option<SrcDtype> = None;
226 let mut bytes: Vec<u8> = Vec::new();
227 for n in names {
228 let (raw, d, shape) = self.read_bytes_typed(n)?;
229 if shape.len() != 2 {
230 return Err(FerrumError::model(format!(
231 "cat_rows_bytes: '{n}' is {shape:?}, need 2D"
232 )));
233 }
234 match dtype {
235 Some(prev) if prev != d => {
236 return Err(FerrumError::model(format!(
237 "cat_rows_bytes: dtype mismatch on '{n}'"
238 )))
239 }
240 _ => dtype = Some(d),
241 }
242 if cols == 0 {
243 cols = shape[1];
244 } else if cols != shape[1] {
245 return Err(FerrumError::model(format!(
246 "cat_rows_bytes: col mismatch {cols} vs {}",
247 shape[1]
248 )));
249 }
250 total_rows += shape[0];
251 bytes.extend_from_slice(raw);
252 }
253 Ok((bytes, dtype.expect("at least one part"), (total_rows, cols)))
254 }
255
256 fn cat_optional_biases(
260 &self,
261 weight_names: &[String],
262 out_features: usize,
263 ) -> Result<Option<Vec<f32>>> {
264 let bias_names: Vec<String> = weight_names
265 .iter()
266 .map(|name| {
267 name.strip_suffix(".weight")
268 .map(|stem| format!("{stem}.bias"))
269 .unwrap_or_else(|| format!("{name}.bias"))
270 })
271 .collect();
272 let any_bias = bias_names.iter().any(|name| self.has(name));
273 if !any_bias {
274 return Ok(None);
275 }
276 if let Some(missing) = bias_names.iter().find(|name| !self.has(name)) {
277 return Err(FerrumError::model(format!(
278 "dense fusion bias mix: '{missing}' missing while another fused part has bias"
279 )));
280 }
281 let mut fused = Vec::new();
282 for name in &bias_names {
283 let (bias, shape) = self.read_f32(name)?;
284 if shape.len() != 1 {
285 return Err(FerrumError::model(format!(
286 "dense fusion bias '{name}': expected 1D, got {shape:?}"
287 )));
288 }
289 fused.extend_from_slice(&bias);
290 }
291 if fused.len() != out_features {
292 return Err(FerrumError::model(format!(
293 "dense fusion bias length {} != out_features {out_features}",
294 fused.len()
295 )));
296 }
297 Ok(Some(fused))
298 }
299
300 fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
305 let shard_idx = *self
306 .index
307 .get(name)
308 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
309 let (bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
310 if dtype != Dtype::I32 {
311 return Err(FerrumError::model(format!(
312 "'{name}': expected I32, got {:?}",
313 dtype
314 )));
315 }
316 debug_assert_eq!(bytes.len() % 4, 0);
317 let count = bytes.len() / 4;
318 let mut out = Vec::<i32>::with_capacity(count);
319 unsafe {
324 std::ptr::copy_nonoverlapping(bytes.as_ptr(), out.as_mut_ptr() as *mut u8, bytes.len());
325 out.set_len(count);
326 }
327 Ok((out, shape.to_vec()))
328 }
329
330 fn has(&self, name: &str) -> bool {
331 self.index.contains_key(name)
332 }
333
334 pub fn read_gptq_raw(
342 &self,
343 name: &str,
344 ) -> Result<(Vec<i32>, Vec<f32>, Vec<i32>, Option<Vec<i32>>, usize, usize)> {
345 let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
346 let (scales, _) = self.read_f32(&format!("{name}.scales"))?;
347 let (qzeros, _) = self.read_i32(&format!("{name}.qzeros"))?;
348 let g_idx = if self.has(&format!("{name}.g_idx")) {
349 Some(self.read_i32(&format!("{name}.g_idx"))?.0)
350 } else {
351 None
352 };
353 if qw_shape.len() != 2 {
354 return Err(FerrumError::model(format!(
355 "'{name}.qweight' expected 2D, got {qw_shape:?}"
356 )));
357 }
358 let k = qw_shape[0] * 8;
359 let n = qw_shape[1];
360 Ok((qweight, scales, qzeros, g_idx, k, n))
361 }
362
363 pub fn quant_config_ref(&self) -> Option<&crate::config::QuantConfig> {
364 self.quant_config.as_ref()
365 }
366
367 pub fn load_stacked_gptq_experts(
385 &self,
386 expert_prefix_fmt: &str,
387 num_experts: usize,
388 proj_names: &[&str],
389 ) -> Result<(
390 std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>,
391 usize,
392 usize,
393 )> {
394 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
395 FerrumError::model(
396 "load_stacked_gptq_experts requires quantize_config.json".to_string(),
397 )
398 })?;
399 if qcfg.method != QuantMethod::Gptq {
400 return Err(FerrumError::model(format!(
401 "stacked GPTQ load but quant_method={:?}",
402 qcfg.method
403 )));
404 }
405
406 let mut qw_rows = 0usize;
407 let mut sc_rows = 0usize;
408 let mut qz_rows = 0usize;
409 let mut n_per_expert = 0usize;
410 let mut n_per_expert_scales = 0usize;
411 let mut n_per_expert_zeros = 0usize;
412 let mut k_shared = 0usize;
413 let mut g_idx_first: Option<Vec<i32>> = None;
414
415 let total_pairs = num_experts * proj_names.len();
417 let mut qw_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs); let mut sc_parts: Vec<(Vec<f32>, usize)> = Vec::with_capacity(total_pairs);
419 let mut qz_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs);
420
421 for e in 0..num_experts {
422 let prefix = expert_prefix_fmt.replace("{e}", &e.to_string());
423 let mut e_n = 0usize;
424 let mut e_n_scales = 0usize;
425 let mut e_n_zeros = 0usize;
426 for proj in proj_names {
427 let name = format!("{prefix}{proj}");
428 let (qw, qw_sh) = self.read_i32(&format!("{name}.qweight"))?;
429 let (sc, sc_sh) = self.read_f32(&format!("{name}.scales"))?;
430 let (qz, qz_sh) = self.read_i32(&format!("{name}.qzeros"))?;
431 if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
432 return Err(FerrumError::model(format!(
433 "stacked GPTQ '{name}': expected 2D, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
434 )));
435 }
436 if qw_rows == 0 {
437 qw_rows = qw_sh[0];
438 sc_rows = sc_sh[0];
439 qz_rows = qz_sh[0];
440 k_shared = qw_sh[0] * 8;
441 } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
442 return Err(FerrumError::model(format!(
443 "stacked GPTQ '{name}': row mismatch qw {} sc {} qz {} vs ref {qw_rows}/{sc_rows}/{qz_rows}",
444 qw_sh[0], sc_sh[0], qz_sh[0]
445 )));
446 }
447 e_n += qw_sh[1];
448 e_n_scales += sc_sh[1];
449 e_n_zeros += qz_sh[1];
450 qw_parts.push((qw, qw_sh[1]));
451 sc_parts.push((sc, sc_sh[1]));
452 qz_parts.push((qz, qz_sh[1]));
453
454 let g_key = format!("{name}.g_idx");
460 if self.has(&g_key) {
461 let (gx, _) = self.read_i32(&g_key)?;
462 match &g_idx_first {
463 None => g_idx_first = Some(gx),
464 Some(prev) => {
465 if prev.len() != gx.len() || prev.iter().zip(&gx).any(|(a, b)| a != b) {
466 return Err(FerrumError::model(format!(
467 "stacked GPTQ '{name}': g_idx mismatch with first \
468 expert — Marlin requires identical act-order across \
469 experts in the same stacked tile"
470 )));
471 }
472 }
473 }
474 }
475 }
476 if e == 0 {
477 n_per_expert = e_n;
478 n_per_expert_scales = e_n_scales;
479 n_per_expert_zeros = e_n_zeros;
480 } else if e_n != n_per_expert
481 || e_n_scales != n_per_expert_scales
482 || e_n_zeros != n_per_expert_zeros
483 {
484 return Err(FerrumError::model(format!(
485 "stacked GPTQ expert {e} N mismatch: qw {e_n} sc {e_n_scales} qz {e_n_zeros} vs expert 0 {n_per_expert}/{n_per_expert_scales}/{n_per_expert_zeros}"
486 )));
487 }
488 }
489
490 let proj_count = proj_names.len();
491 let pairs_per_expert = proj_count;
492 debug_assert_eq!(total_pairs, num_experts * pairs_per_expert);
493
494 let mut per_expert_qw: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
507 let mut per_expert_sc: Vec<Vec<f32>> = Vec::with_capacity(num_experts);
508 let mut per_expert_qz: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
509 for e in 0..num_experts {
510 let mut qw: Vec<i32> = Vec::with_capacity(qw_rows * n_per_expert);
511 let mut sc: Vec<f32> = Vec::with_capacity(sc_rows * n_per_expert_scales);
512 let mut qz: Vec<i32> = Vec::with_capacity(qz_rows * n_per_expert_zeros);
513 for r in 0..qw_rows {
514 for j in 0..pairs_per_expert {
515 let pair_idx = e * pairs_per_expert + j;
516 let (data, cols) = &qw_parts[pair_idx];
517 qw.extend_from_slice(&data[r * cols..(r + 1) * cols]);
518 }
519 }
520 for r in 0..sc_rows {
521 for j in 0..pairs_per_expert {
522 let pair_idx = e * pairs_per_expert + j;
523 let (data, cols) = &sc_parts[pair_idx];
524 sc.extend_from_slice(&data[r * cols..(r + 1) * cols]);
525 }
526 }
527 for r in 0..qz_rows {
528 for j in 0..pairs_per_expert {
529 let pair_idx = e * pairs_per_expert + j;
530 let (data, cols) = &qz_parts[pair_idx];
531 qz.extend_from_slice(&data[r * cols..(r + 1) * cols]);
532 }
533 }
534 per_expert_qw.push(qw);
535 per_expert_sc.push(sc);
536 per_expert_qz.push(qz);
537 }
538
539 drop(qw_parts);
541 drop(sc_parts);
542 drop(qz_parts);
543
544 let qw_refs: Vec<&[i32]> = per_expert_qw.iter().map(|v| v.as_slice()).collect();
545 let sc_refs: Vec<&[f32]> = per_expert_sc.iter().map(|v| v.as_slice()).collect();
546 let qz_refs: Vec<&[i32]> = per_expert_qz.iter().map(|v| v.as_slice()).collect();
547
548 let store = B::load_gptq_stacked(
549 &qw_refs,
550 &sc_refs,
551 &qz_refs,
552 g_idx_first.as_deref(),
553 qcfg.bits,
554 qcfg.group_size,
555 k_shared,
556 n_per_expert,
557 )?;
558 Ok((store, n_per_expert, k_shared))
559 }
560}
561
562impl<B: Backend + BackendQuantMarlin> WeightLoader<B> for NativeSafetensorsLoader<B> {
563 fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
564 let (raw, src_dtype, _) = self.read_bytes_typed(name)?;
569 Ok(B::from_weight_bytes(raw, src_dtype))
570 }
571
572 fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
573 let qw_key = format!("{name}.qweight");
575 if self.has(&qw_key) {
576 return self.load_gptq_linear(name);
577 }
578 if let Some(prefix) = name.strip_suffix("qkv_proj") {
582 let parts = [
583 format!("{prefix}q_proj"),
584 format!("{prefix}k_proj"),
585 format!("{prefix}v_proj"),
586 ];
587 if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
588 return self.load_gptq_linear_fused(&parts);
589 }
590 }
591 if let Some(prefix) = name.strip_suffix("gate_up_proj") {
592 let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
593 if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
594 return self.load_gptq_linear_fused(&parts);
595 }
596 }
597
598 let direct = format!("{name}.weight");
601 if self.has(&direct) {
602 let (raw, src_dtype, shape) = self.read_bytes_typed(&direct)?;
603 if shape.len() != 2 {
604 return Err(FerrumError::model(format!(
605 "linear '{name}': expected 2D weight, got {shape:?}"
606 )));
607 }
608 let weight = B::from_weight_bytes(raw, src_dtype);
609 return Ok(Box::new(DenseLinear::<B>::from_buffer(
610 weight, shape[0], shape[1],
611 )));
612 }
613
614 if let Some(prefix) = name.strip_suffix("qkv_proj") {
619 let parts = [
620 format!("{prefix}q_proj.weight"),
621 format!("{prefix}k_proj.weight"),
622 format!("{prefix}v_proj.weight"),
623 ];
624 if parts.iter().all(|p| self.has(p)) {
625 let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
626 let weight = B::from_weight_bytes(&bytes, dtype);
627 let mut linear = DenseLinear::<B>::from_buffer(weight, rows, cols);
628 if let Some(bias) = self.cat_optional_biases(&parts, rows)? {
629 linear = linear.with_bias(B::from_slice(&bias));
630 }
631 return Ok(Box::new(linear));
632 }
633 }
634 if let Some(prefix) = name.strip_suffix("gate_up_proj") {
635 let parts = [
636 format!("{prefix}gate_proj.weight"),
637 format!("{prefix}up_proj.weight"),
638 ];
639 if parts.iter().all(|p| self.has(p)) {
640 let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
641 let weight = B::from_weight_bytes(&bytes, dtype);
642 let mut linear = DenseLinear::<B>::from_buffer(weight, rows, cols);
643 if let Some(bias) = self.cat_optional_biases(&parts, rows)? {
644 linear = linear.with_bias(B::from_slice(&bias));
645 }
646 return Ok(Box::new(linear));
647 }
648 }
649
650 Err(FerrumError::model(format!(
651 "could not load linear '{name}' — no direct `.weight`, no split components"
652 )))
653 }
654
655 fn has_tensor(&self, name: &str) -> bool {
656 self.has(name)
657 }
658
659 fn quant_config(&self) -> Option<&QuantConfig> {
660 self.quant_config.as_ref()
661 }
662}
663
664impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
665 fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
670 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
671 FerrumError::model(format!(
672 "'{name}.qweight' present but no quantize_config.json — \
673 can't determine bits/group_size"
674 ))
675 })?;
676 if qcfg.method != QuantMethod::Gptq {
677 return Err(FerrumError::model(format!(
678 "'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
679 qcfg.method
680 )));
681 }
682
683 let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
684 let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
685 let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
686 let g_idx = if self.has(&format!("{name}.g_idx")) {
687 Some(self.read_i32(&format!("{name}.g_idx"))?.0)
688 } else {
689 None
690 };
691
692 if qw_shape.len() != 2 {
695 return Err(FerrumError::model(format!(
696 "'{name}.qweight' expected 2D, got {qw_shape:?}"
697 )));
698 }
699 let in_features = qw_shape[0] * 8;
700 let out_features = qw_shape[1];
701
702 let is_desc_act = validate_gptq_g_idx(name, qcfg, g_idx.as_deref(), in_features)?;
703
704 #[cfg(not(feature = "cuda"))]
709 if is_desc_act {
710 let dequant_f32 = dequantize_gptq_with_g_idx(
711 &qweight,
712 &scales_f32,
713 &qzeros,
714 g_idx.as_ref().expect("desc_act=true requires g_idx"),
715 qcfg.group_size,
716 in_features,
717 out_features,
718 );
719 let mut linear =
720 crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
721 let bias_key = format!("{name}.bias");
722 if self.has(&bias_key) {
723 let (bias, _) = self.read_f32(&bias_key)?;
724 linear = linear.with_bias(B::from_slice(&bias));
725 }
726 tracing::info!(
727 "GPTQ load (desc_act dequant→DenseLinear, non-cuda): name={name} K={in_features} N={out_features}"
728 );
729 return Ok(Box::new(linear));
730 }
731 #[cfg(feature = "cuda")]
732 let _ = is_desc_act; if sc_shape.len() != 2 || sc_shape[1] != out_features {
734 return Err(FerrumError::model(format!(
735 "'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
736 )));
737 }
738
739 let bias_key = format!("{name}.bias");
743 let bias_vec = if self.has(&bias_key) {
744 let (bias, bias_shape) = self.read_f32(&bias_key)?;
745 if bias_shape != [out_features] {
746 return Err(FerrumError::model(format!(
747 "'{bias_key}' {bias_shape:?} != [{out_features}]"
748 )));
749 }
750 Some(bias)
751 } else {
752 None
753 };
754
755 let linear = GptqLinear::<B>::from_raw(
756 &qweight,
757 &scales_f32,
758 &qzeros,
759 g_idx.as_deref(),
760 bias_vec.as_deref(),
761 qcfg.bits,
762 qcfg.group_size,
763 in_features,
764 out_features,
765 )?;
766 Ok(Box::new(linear))
767 }
768
769 fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
782 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
783 FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
784 })?;
785 if qcfg.method != QuantMethod::Gptq {
786 return Err(FerrumError::model(format!(
787 "GPTQ fusion but quant_method={:?}",
788 qcfg.method
789 )));
790 }
791
792 let mut qw_acc: Vec<i32> = Vec::new();
793 let mut sc_acc: Vec<f32> = Vec::new();
794 let mut qz_acc: Vec<i32> = Vec::new();
795 let mut qw_rows = 0usize;
796 let mut sc_rows = 0usize;
797 let mut qz_rows = 0usize;
798 let mut total_n = 0usize;
799 let mut total_n_scales = 0usize;
800 let mut total_n_zeros = 0usize;
801 let mut g_idx: Option<Vec<i32>> = None;
802 let mut g_idx_presence: Vec<(String, bool)> = Vec::with_capacity(parts.len());
803 let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
806 let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
807
808 for p in parts {
809 let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
810 let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
811 let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
812 if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
813 return Err(FerrumError::model(format!(
814 "GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
815 )));
816 }
817 if qw_rows == 0 {
818 qw_rows = qw_sh[0];
819 sc_rows = sc_sh[0];
820 qz_rows = qz_sh[0];
821 } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
822 return Err(FerrumError::model(format!(
823 "GPTQ fusion row mismatch on '{p}'"
824 )));
825 }
826 total_n += qw_sh[1];
827 total_n_scales += sc_sh[1];
828 total_n_zeros += qz_sh[1];
829 qw_parts.push((qw, qw_sh[0], qw_sh[1]));
830 sc_parts.push((sc, sc_sh[0], sc_sh[1]));
831 qz_parts.push((qz, qz_sh[0], qz_sh[1]));
832
833 let g_key = format!("{p}.g_idx");
834 if self.has(&g_key) {
835 let (gx, gx_shape) = self.read_i32(&g_key)?;
836 if gx_shape != [qw_rows * 8] {
837 return Err(FerrumError::model(format!(
838 "GPTQ fusion '{p}': g_idx shape {gx_shape:?} incompatible with K={}",
839 qw_rows * 8
840 )));
841 }
842 match &g_idx {
843 None => g_idx = Some(gx),
844 Some(prev) => {
845 if prev.len() != gx.len() || prev.iter().zip(&gx).any(|(a, b)| a != b) {
846 return Err(FerrumError::model(format!(
847 "GPTQ fusion '{p}': g_idx mismatch with first part; \
848 fused qkv/gate_up requires identical act-order across parts"
849 )));
850 }
851 }
852 }
853 g_idx_presence.push((p.clone(), true));
854 } else {
855 g_idx_presence.push((p.clone(), false));
856 }
857 }
858
859 qw_acc.reserve(qw_rows * total_n);
861 for r in 0..qw_rows {
862 for (part, _rows, cols) in &qw_parts {
863 qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
864 }
865 }
866 sc_acc.reserve(sc_rows * total_n_scales);
867 for r in 0..sc_rows {
868 for (part, _rows, cols) in &sc_parts {
869 sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
870 }
871 }
872 qz_acc.reserve(qz_rows * total_n_zeros);
873 for r in 0..qz_rows {
874 for (part, _rows, cols) in &qz_parts {
875 qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
876 }
877 }
878
879 let in_features = qw_rows * 8;
880 let out_features = total_n;
881
882 if g_idx.is_some() {
883 let missing = g_idx_presence
884 .iter()
885 .filter_map(|(part, present)| (!present).then_some(part.as_str()))
886 .collect::<Vec<_>>();
887 if !missing.is_empty() {
888 return Err(FerrumError::model(format!(
889 "GPTQ fusion requires all parts to carry g_idx when any part does; \
890 missing g_idx for {missing:?}"
891 )));
892 }
893 }
894 let fused_name = format!("GPTQ fusion {}", parts.join("+"));
895 let is_desc_act = validate_gptq_g_idx(&fused_name, qcfg, g_idx.as_deref(), in_features)?;
896 #[cfg(not(feature = "cuda"))]
898 if is_desc_act {
899 let dequant_f32 = dequantize_gptq_with_g_idx(
900 &qw_acc,
901 &sc_acc,
902 &qz_acc,
903 g_idx.as_ref().expect("desc_act=true requires g_idx"),
904 qcfg.group_size,
905 in_features,
906 out_features,
907 );
908 let mut linear =
909 crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
910 let mut bias_acc: Vec<f32> = Vec::new();
911 let mut any_bias = false;
912 for p in parts {
913 let bk = format!("{p}.bias");
914 if self.has(&bk) {
915 any_bias = true;
916 bias_acc.extend_from_slice(&self.read_f32(&bk)?.0);
917 } else if any_bias {
918 return Err(FerrumError::model(format!(
919 "GPTQ fusion bias mix: '{p}' has no bias but earlier part did"
920 )));
921 }
922 }
923 if any_bias {
924 linear = linear.with_bias(B::from_slice(&bias_acc));
925 }
926 tracing::info!(
927 "GPTQ fused load (desc_act dequant→DenseLinear, non-cuda): K={in_features} N={out_features} parts={}",
928 parts.len()
929 );
930 return Ok(Box::new(linear));
931 }
932 #[cfg(feature = "cuda")]
933 let _ = is_desc_act;
934
935 let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
939 let any = bias_keys.iter().any(|k| self.has(k));
940 let all = bias_keys.iter().all(|k| self.has(k));
941 if any && !all {
942 return Err(FerrumError::model(
943 "GPTQ fusion: inconsistent bias presence across parts".to_string(),
944 ));
945 }
946 let fused_bias = if all {
947 let mut fused: Vec<f32> = Vec::with_capacity(out_features);
948 for k in &bias_keys {
949 let (b, _) = self.read_f32(k)?;
950 fused.extend_from_slice(&b);
951 }
952 if fused.len() != out_features {
953 return Err(FerrumError::model(format!(
954 "GPTQ fusion bias length {} != out_features {out_features}",
955 fused.len()
956 )));
957 }
958 Some(fused)
959 } else {
960 None
961 };
962
963 let linear = GptqLinear::<B>::from_raw(
964 &qw_acc,
965 &sc_acc,
966 &qz_acc,
967 g_idx.as_deref(),
968 fused_bias.as_deref(),
969 qcfg.bits,
970 qcfg.group_size,
971 in_features,
972 out_features,
973 )?;
974
975 Ok(Box::new(linear))
976 }
977
978 #[allow(dead_code)]
982 fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
983 let mut total_rows = 0usize;
984 let mut cols = 0usize;
985 let mut out: Vec<f32> = Vec::new();
986 for n in names {
987 let (data, shape) = self.read_f32(n)?;
988 if shape.len() != 2 {
989 return Err(FerrumError::model(format!(
990 "cat_rows: '{n}' is {shape:?}, need 2D"
991 )));
992 }
993 if cols == 0 {
994 cols = shape[1];
995 } else if cols != shape[1] {
996 return Err(FerrumError::model(format!(
997 "cat_rows: col mismatch {cols} vs {}",
998 shape[1]
999 )));
1000 }
1001 total_rows += shape[0];
1002 out.extend_from_slice(&data);
1003 }
1004 Ok((total_rows, cols, out))
1005 }
1006}
1007
1008fn gptq_g_idx_is_desc_act(g_idx: &[i32], group_size: usize) -> bool {
1009 g_idx
1010 .iter()
1011 .enumerate()
1012 .any(|(i, &g)| g != (i as i32) / group_size as i32)
1013}
1014
1015fn validate_gptq_g_idx(
1016 name: &str,
1017 qcfg: &QuantConfig,
1018 g_idx: Option<&[i32]>,
1019 in_features: usize,
1020) -> Result<bool> {
1021 if qcfg.desc_act && g_idx.is_none() {
1022 return Err(FerrumError::model(format!(
1023 "{name}: quantize_config desc_act=true but no g_idx tensor was found"
1024 )));
1025 }
1026
1027 let Some(g_idx) = g_idx else {
1028 return Ok(false);
1029 };
1030 if qcfg.group_size == 0 {
1031 return Err(FerrumError::model(format!(
1032 "{name}: GPTQ g_idx present but group_size is 0"
1033 )));
1034 }
1035 if g_idx.len() != in_features {
1036 return Err(FerrumError::model(format!(
1037 "{name}: g_idx length {} must match K={in_features}",
1038 g_idx.len()
1039 )));
1040 }
1041 let expected_groups = in_features.div_ceil(qcfg.group_size);
1042 for (idx, &group) in g_idx.iter().enumerate() {
1043 if group < 0 || group as usize >= expected_groups {
1044 return Err(FerrumError::model(format!(
1045 "{name}: g_idx[{idx}]={group} outside expected group range 0..{}",
1046 expected_groups.saturating_sub(1)
1047 )));
1048 }
1049 }
1050 Ok(gptq_g_idx_is_desc_act(g_idx, qcfg.group_size))
1051}
1052
1053#[cfg(not(feature = "cuda"))]
1070fn dequantize_gptq_with_g_idx(
1071 qweight: &[i32], scales: &[f32], qzeros: &[i32], g_idx: &[i32], _group_size: usize,
1076 k: usize,
1077 n: usize,
1078) -> Vec<f32> {
1079 debug_assert_eq!(g_idx.len(), k);
1080
1081 let mut w = vec![0.0f32; n * k];
1083 let packed_rows = k / 8;
1084 for pr in 0..packed_rows {
1085 for col in 0..n {
1086 let packed = qweight[pr * n + col] as u32;
1087 for bi in 0..8 {
1088 let ki = pr * 8 + bi;
1089 let q = ((packed >> (bi * 4)) & 0xF) as i32;
1090 let g = g_idx[ki] as usize;
1091 let scale = scales[g * n + col];
1092 let z_packed = qzeros[g * (n / 8) + (col / 8)] as u32;
1093 let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
1094 w[col * k + ki] = (q - zero) as f32 * scale;
1095 }
1096 }
1097 }
1098 w
1099}
1100
1101fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
1102 match dtype {
1103 Dtype::F32 => {
1104 debug_assert_eq!(raw.len() % 4, 0);
1109 let n = raw.len() / 4;
1110 let mut out = Vec::<f32>::with_capacity(n);
1111 unsafe {
1112 std::ptr::copy_nonoverlapping(raw.as_ptr(), out.as_mut_ptr() as *mut u8, raw.len());
1113 out.set_len(n);
1114 }
1115 Ok(out)
1116 }
1117 Dtype::F16 => {
1118 debug_assert_eq!(raw.len() % 2, 0);
1119 let n = raw.len() / 2;
1120 let mut tmp = Vec::<f16>::with_capacity(n);
1123 unsafe {
1124 std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
1125 tmp.set_len(n);
1126 }
1127 let mut out = Vec::with_capacity(n);
1128 for h in &tmp {
1129 out.push(h.to_f32());
1130 }
1131 Ok(out)
1132 }
1133 Dtype::BF16 => {
1134 debug_assert_eq!(raw.len() % 2, 0);
1135 let n = raw.len() / 2;
1136 let mut tmp = Vec::<bf16>::with_capacity(n);
1137 unsafe {
1138 std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
1139 tmp.set_len(n);
1140 }
1141 let mut out = Vec::with_capacity(n);
1142 for h in &tmp {
1143 out.push(h.to_f32());
1144 }
1145 Ok(out)
1146 }
1147 other => Err(FerrumError::model(format!(
1148 "dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
1149 use a format-specific loader (GPTQ / AWQ / GGUF)",
1150 ))),
1151 }
1152}
1153
1154fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
1155 let p = dir.join("quantize_config.json");
1157 if p.exists() {
1158 let data =
1159 std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
1160 let qc: QuantConfig = serde_json::from_str(&data)
1161 .map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
1162 return Ok(Some(qc));
1163 }
1164 let cfg = dir.join("config.json");
1167 if cfg.exists() {
1168 let data = std::fs::read_to_string(&cfg)
1169 .map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
1170 let root: serde_json::Value = serde_json::from_str(&data)
1171 .map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
1172 if let Some(qc_val) = root.get("quantization_config") {
1173 let method = qc_val
1175 .get("quant_method")
1176 .and_then(|v| v.as_str())
1177 .unwrap_or("none");
1178 let method = match method.to_lowercase().as_str() {
1179 "gptq" => QuantMethod::Gptq,
1180 "awq" => QuantMethod::Awq,
1181 "gguf" => QuantMethod::Gguf,
1182 _ => QuantMethod::None,
1183 };
1184 let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
1185 let group_size = qc_val
1186 .get("group_size")
1187 .and_then(|v| v.as_i64())
1188 .unwrap_or(128)
1189 .max(0) as usize;
1190 let desc_act = qc_val
1191 .get("desc_act")
1192 .and_then(|v| v.as_bool())
1193 .unwrap_or(false);
1194 let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
1195 if method != QuantMethod::None {
1196 return Ok(Some(QuantConfig {
1197 method,
1198 bits,
1199 group_size,
1200 desc_act,
1201 sym,
1202 }));
1203 }
1204 }
1205 }
1206 Ok(None)
1207}
1208
1209#[cfg(test)]
1210mod tests {
1211 use super::*;
1212
1213 fn gptq_config(desc_act: bool) -> QuantConfig {
1214 QuantConfig {
1215 method: QuantMethod::Gptq,
1216 bits: 4,
1217 group_size: 2,
1218 desc_act,
1219 sym: true,
1220 }
1221 }
1222
1223 #[test]
1224 fn validate_gptq_g_idx_requires_tensor_when_desc_act_configured() {
1225 let err = validate_gptq_g_idx("proj", &gptq_config(true), None, 4)
1226 .unwrap_err()
1227 .to_string();
1228
1229 assert!(err.contains("desc_act=true"));
1230 assert!(err.contains("no g_idx"));
1231 }
1232
1233 #[test]
1234 fn validate_gptq_g_idx_accepts_trivial_non_desc_act_order() {
1235 let is_desc_act =
1236 validate_gptq_g_idx("proj", &gptq_config(false), Some(&[0, 0, 1, 1]), 4).unwrap();
1237
1238 assert!(!is_desc_act);
1239 }
1240
1241 #[test]
1242 fn validate_gptq_g_idx_detects_nontrivial_act_order() {
1243 let is_desc_act =
1244 validate_gptq_g_idx("proj", &gptq_config(false), Some(&[1, 1, 0, 0]), 4).unwrap();
1245
1246 assert!(is_desc_act);
1247 }
1248
1249 #[test]
1250 fn validate_gptq_g_idx_rejects_invalid_shape_and_group() {
1251 let short = validate_gptq_g_idx("proj", &gptq_config(false), Some(&[0, 0, 1]), 4)
1252 .unwrap_err()
1253 .to_string();
1254 assert!(short.contains("must match K=4"));
1255
1256 let out_of_range = validate_gptq_g_idx("proj", &gptq_config(false), Some(&[0, 0, 2, 1]), 4)
1257 .unwrap_err()
1258 .to_string();
1259 assert!(out_of_range.contains("outside expected group range"));
1260 }
1261}