1use std::collections::HashMap;
18use std::fs::File;
19use std::path::Path;
20
21use ferrum_kernels::backend::{Backend, BackendQuantMarlin, SrcDtype};
22use ferrum_types::{FerrumError, Result};
23use half::{bf16, f16};
24use memmap2::Mmap;
25use safetensors::{Dtype, SafeTensors};
26
27fn map_src_dtype(dtype: Dtype) -> Result<SrcDtype> {
29 match dtype {
30 Dtype::F32 => Ok(SrcDtype::F32),
31 Dtype::F16 => Ok(SrcDtype::F16),
32 Dtype::BF16 => Ok(SrcDtype::BF16),
33 other => Err(FerrumError::model(format!(
34 "dtype {other:?} not supported; Dense path expects F32/F16/BF16"
35 ))),
36 }
37}
38
39use crate::config::{QuantConfig, QuantMethod};
40use crate::dense::DenseLinear;
41use crate::gptq::GptqLinear;
42use crate::loader::WeightLoader;
43use crate::traits::Linear;
44
45struct TensorMeta {
51 dtype: Dtype,
52 shape: Vec<usize>,
53 data_start: usize,
55 data_end: usize,
56}
57
58struct Shard {
60 mmap: Mmap,
61 names: Vec<String>,
62 meta: HashMap<String, TensorMeta>,
65}
66
67impl Shard {
68 fn open(path: &Path) -> Result<Self> {
69 let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
70 let mmap = unsafe {
71 Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
72 };
73 let st = SafeTensors::deserialize(&mmap)
77 .map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
78 debug_assert!(mmap.len() >= 8, "safetensors smaller than 8 bytes");
83 let header_len = u64::from_le_bytes(
84 mmap[0..8]
85 .try_into()
86 .expect("8-byte header len read failed"),
87 ) as usize;
88 let data_base = 8 + header_len;
89 let names: Vec<String> = st.names().iter().map(|s| s.to_string()).collect();
90 let mut meta = HashMap::with_capacity(names.len());
91 for name in &names {
92 let view = st.tensor(name).map_err(|e| {
93 FerrumError::model(format!("tensor '{name}' missing during preindex: {e}"))
94 })?;
95 let view_data = view.data();
99 let start = view_data.as_ptr() as usize - mmap.as_ptr() as usize;
100 let end = start + view_data.len();
101 debug_assert!(start >= data_base);
102 meta.insert(
103 name.clone(),
104 TensorMeta {
105 dtype: view.dtype(),
106 shape: view.shape().to_vec(),
107 data_start: start,
108 data_end: end,
109 },
110 );
111 }
112 let _ = data_base;
113 Ok(Self { mmap, names, meta })
114 }
115
116 fn get_cached(&self, name: &str) -> Result<(&[u8], Dtype, &[usize])> {
119 let m = self
120 .meta
121 .get(name)
122 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in shard")))?;
123 Ok((&self.mmap[m.data_start..m.data_end], m.dtype, &m.shape))
124 }
125}
126
127pub struct NativeSafetensorsLoader<B: Backend + BackendQuantMarlin> {
130 shards: Vec<Shard>,
132 index: HashMap<String, usize>,
134 quant_config: Option<QuantConfig>,
136 _m: std::marker::PhantomData<B>,
137}
138
139impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
140 pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
142 let dir = model_dir.as_ref();
143
144 let shard_paths = if dir.join("model.safetensors").exists() {
145 vec![dir.join("model.safetensors")]
146 } else if dir.join("model.safetensors.index.json").exists() {
147 Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
148 .into_iter()
149 .map(|name| dir.join(name))
150 .collect()
151 } else {
152 return Err(FerrumError::model(format!(
153 "no safetensors files in {dir:?}"
154 )));
155 };
156
157 let mut shards = Vec::with_capacity(shard_paths.len());
158 let mut index: HashMap<String, usize> = HashMap::new();
159 for (i, p) in shard_paths.iter().enumerate() {
160 let shard = Shard::open(p)?;
161 for name in &shard.names {
162 index.insert(name.clone(), i);
163 }
164 shards.push(shard);
165 }
166
167 let quant_config = load_quantize_config(dir)?;
168
169 Ok(Self {
170 shards,
171 index,
172 quant_config,
173 _m: std::marker::PhantomData,
174 })
175 }
176
177 fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
178 let data = std::fs::read_to_string(index_path)
179 .map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
180 let json: serde_json::Value = serde_json::from_str(&data)
181 .map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
182 let weight_map = json
183 .get("weight_map")
184 .and_then(|v| v.as_object())
185 .ok_or_else(|| FerrumError::model("index missing weight_map"))?;
186 let mut files: Vec<String> = weight_map
187 .values()
188 .filter_map(|v| v.as_str().map(|s| s.to_string()))
189 .collect();
190 files.sort();
191 files.dedup();
192 Ok(files)
193 }
194
195 fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
197 let shard_idx = *self
198 .index
199 .get(name)
200 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
201 let (data_bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
202 let data = dtype_to_f32(dtype, data_bytes)?;
203 Ok((data, shape.to_vec()))
204 }
205
206 fn read_bytes_typed(&self, name: &str) -> Result<(&[u8], SrcDtype, Vec<usize>)> {
210 let shard_idx = *self
211 .index
212 .get(name)
213 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
214 let (data_bytes, st_dtype, shape) = self.shards[shard_idx].get_cached(name)?;
215 let dtype = map_src_dtype(st_dtype)?;
216 Ok((data_bytes, dtype, shape.to_vec()))
217 }
218
219 fn cat_rows_bytes(&self, names: &[String]) -> Result<(Vec<u8>, SrcDtype, (usize, usize))> {
223 let mut total_rows = 0usize;
224 let mut cols = 0usize;
225 let mut dtype: Option<SrcDtype> = None;
226 let mut bytes: Vec<u8> = Vec::new();
227 for n in names {
228 let (raw, d, shape) = self.read_bytes_typed(n)?;
229 if shape.len() != 2 {
230 return Err(FerrumError::model(format!(
231 "cat_rows_bytes: '{n}' is {shape:?}, need 2D"
232 )));
233 }
234 match dtype {
235 Some(prev) if prev != d => {
236 return Err(FerrumError::model(format!(
237 "cat_rows_bytes: dtype mismatch on '{n}'"
238 )))
239 }
240 _ => dtype = Some(d),
241 }
242 if cols == 0 {
243 cols = shape[1];
244 } else if cols != shape[1] {
245 return Err(FerrumError::model(format!(
246 "cat_rows_bytes: col mismatch {cols} vs {}",
247 shape[1]
248 )));
249 }
250 total_rows += shape[0];
251 bytes.extend_from_slice(raw);
252 }
253 Ok((bytes, dtype.expect("at least one part"), (total_rows, cols)))
254 }
255
256 fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
261 let shard_idx = *self
262 .index
263 .get(name)
264 .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
265 let (bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
266 if dtype != Dtype::I32 {
267 return Err(FerrumError::model(format!(
268 "'{name}': expected I32, got {:?}",
269 dtype
270 )));
271 }
272 debug_assert_eq!(bytes.len() % 4, 0);
273 let count = bytes.len() / 4;
274 let mut out = Vec::<i32>::with_capacity(count);
275 unsafe {
280 std::ptr::copy_nonoverlapping(bytes.as_ptr(), out.as_mut_ptr() as *mut u8, bytes.len());
281 out.set_len(count);
282 }
283 Ok((out, shape.to_vec()))
284 }
285
286 fn has(&self, name: &str) -> bool {
287 self.index.contains_key(name)
288 }
289
290 pub fn read_gptq_raw(
298 &self,
299 name: &str,
300 ) -> Result<(Vec<i32>, Vec<f32>, Vec<i32>, Option<Vec<i32>>, usize, usize)> {
301 let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
302 let (scales, _) = self.read_f32(&format!("{name}.scales"))?;
303 let (qzeros, _) = self.read_i32(&format!("{name}.qzeros"))?;
304 let g_idx = if self.has(&format!("{name}.g_idx")) {
305 Some(self.read_i32(&format!("{name}.g_idx"))?.0)
306 } else {
307 None
308 };
309 if qw_shape.len() != 2 {
310 return Err(FerrumError::model(format!(
311 "'{name}.qweight' expected 2D, got {qw_shape:?}"
312 )));
313 }
314 let k = qw_shape[0] * 8;
315 let n = qw_shape[1];
316 Ok((qweight, scales, qzeros, g_idx, k, n))
317 }
318
319 pub fn quant_config_ref(&self) -> Option<&crate::config::QuantConfig> {
320 self.quant_config.as_ref()
321 }
322
323 pub fn load_stacked_gptq_experts(
341 &self,
342 expert_prefix_fmt: &str,
343 num_experts: usize,
344 proj_names: &[&str],
345 ) -> Result<(
346 std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>,
347 usize,
348 usize,
349 )> {
350 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
351 FerrumError::model(
352 "load_stacked_gptq_experts requires quantize_config.json".to_string(),
353 )
354 })?;
355 if qcfg.method != QuantMethod::Gptq {
356 return Err(FerrumError::model(format!(
357 "stacked GPTQ load but quant_method={:?}",
358 qcfg.method
359 )));
360 }
361
362 let mut qw_rows = 0usize;
363 let mut sc_rows = 0usize;
364 let mut qz_rows = 0usize;
365 let mut n_per_expert = 0usize;
366 let mut n_per_expert_scales = 0usize;
367 let mut n_per_expert_zeros = 0usize;
368 let mut k_shared = 0usize;
369 let mut g_idx_first: Option<Vec<i32>> = None;
370
371 let total_pairs = num_experts * proj_names.len();
373 let mut qw_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs); let mut sc_parts: Vec<(Vec<f32>, usize)> = Vec::with_capacity(total_pairs);
375 let mut qz_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs);
376
377 for e in 0..num_experts {
378 let prefix = expert_prefix_fmt.replace("{e}", &e.to_string());
379 let mut e_n = 0usize;
380 let mut e_n_scales = 0usize;
381 let mut e_n_zeros = 0usize;
382 for proj in proj_names {
383 let name = format!("{prefix}{proj}");
384 let (qw, qw_sh) = self.read_i32(&format!("{name}.qweight"))?;
385 let (sc, sc_sh) = self.read_f32(&format!("{name}.scales"))?;
386 let (qz, qz_sh) = self.read_i32(&format!("{name}.qzeros"))?;
387 if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
388 return Err(FerrumError::model(format!(
389 "stacked GPTQ '{name}': expected 2D, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
390 )));
391 }
392 if qw_rows == 0 {
393 qw_rows = qw_sh[0];
394 sc_rows = sc_sh[0];
395 qz_rows = qz_sh[0];
396 k_shared = qw_sh[0] * 8;
397 } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
398 return Err(FerrumError::model(format!(
399 "stacked GPTQ '{name}': row mismatch qw {} sc {} qz {} vs ref {qw_rows}/{sc_rows}/{qz_rows}",
400 qw_sh[0], sc_sh[0], qz_sh[0]
401 )));
402 }
403 e_n += qw_sh[1];
404 e_n_scales += sc_sh[1];
405 e_n_zeros += qz_sh[1];
406 qw_parts.push((qw, qw_sh[1]));
407 sc_parts.push((sc, sc_sh[1]));
408 qz_parts.push((qz, qz_sh[1]));
409
410 let g_key = format!("{name}.g_idx");
416 if self.has(&g_key) {
417 let (gx, _) = self.read_i32(&g_key)?;
418 match &g_idx_first {
419 None => g_idx_first = Some(gx),
420 Some(prev) => {
421 if prev.len() != gx.len() || prev.iter().zip(&gx).any(|(a, b)| a != b) {
422 return Err(FerrumError::model(format!(
423 "stacked GPTQ '{name}': g_idx mismatch with first \
424 expert — Marlin requires identical act-order across \
425 experts in the same stacked tile"
426 )));
427 }
428 }
429 }
430 }
431 }
432 if e == 0 {
433 n_per_expert = e_n;
434 n_per_expert_scales = e_n_scales;
435 n_per_expert_zeros = e_n_zeros;
436 } else if e_n != n_per_expert
437 || e_n_scales != n_per_expert_scales
438 || e_n_zeros != n_per_expert_zeros
439 {
440 return Err(FerrumError::model(format!(
441 "stacked GPTQ expert {e} N mismatch: qw {e_n} sc {e_n_scales} qz {e_n_zeros} vs expert 0 {n_per_expert}/{n_per_expert_scales}/{n_per_expert_zeros}"
442 )));
443 }
444 }
445
446 let proj_count = proj_names.len();
447 let pairs_per_expert = proj_count;
448 debug_assert_eq!(total_pairs, num_experts * pairs_per_expert);
449
450 let mut per_expert_qw: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
463 let mut per_expert_sc: Vec<Vec<f32>> = Vec::with_capacity(num_experts);
464 let mut per_expert_qz: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
465 for e in 0..num_experts {
466 let mut qw: Vec<i32> = Vec::with_capacity(qw_rows * n_per_expert);
467 let mut sc: Vec<f32> = Vec::with_capacity(sc_rows * n_per_expert_scales);
468 let mut qz: Vec<i32> = Vec::with_capacity(qz_rows * n_per_expert_zeros);
469 for r in 0..qw_rows {
470 for j in 0..pairs_per_expert {
471 let pair_idx = e * pairs_per_expert + j;
472 let (data, cols) = &qw_parts[pair_idx];
473 qw.extend_from_slice(&data[r * cols..(r + 1) * cols]);
474 }
475 }
476 for r in 0..sc_rows {
477 for j in 0..pairs_per_expert {
478 let pair_idx = e * pairs_per_expert + j;
479 let (data, cols) = &sc_parts[pair_idx];
480 sc.extend_from_slice(&data[r * cols..(r + 1) * cols]);
481 }
482 }
483 for r in 0..qz_rows {
484 for j in 0..pairs_per_expert {
485 let pair_idx = e * pairs_per_expert + j;
486 let (data, cols) = &qz_parts[pair_idx];
487 qz.extend_from_slice(&data[r * cols..(r + 1) * cols]);
488 }
489 }
490 per_expert_qw.push(qw);
491 per_expert_sc.push(sc);
492 per_expert_qz.push(qz);
493 }
494
495 drop(qw_parts);
497 drop(sc_parts);
498 drop(qz_parts);
499
500 let qw_refs: Vec<&[i32]> = per_expert_qw.iter().map(|v| v.as_slice()).collect();
501 let sc_refs: Vec<&[f32]> = per_expert_sc.iter().map(|v| v.as_slice()).collect();
502 let qz_refs: Vec<&[i32]> = per_expert_qz.iter().map(|v| v.as_slice()).collect();
503
504 let store = B::load_gptq_stacked(
505 &qw_refs,
506 &sc_refs,
507 &qz_refs,
508 g_idx_first.as_deref(),
509 qcfg.bits,
510 qcfg.group_size,
511 k_shared,
512 n_per_expert,
513 )?;
514 Ok((store, n_per_expert, k_shared))
515 }
516}
517
518impl<B: Backend + BackendQuantMarlin> WeightLoader<B> for NativeSafetensorsLoader<B> {
519 fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
520 let (raw, src_dtype, _) = self.read_bytes_typed(name)?;
525 Ok(B::from_weight_bytes(raw, src_dtype))
526 }
527
528 fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
529 let qw_key = format!("{name}.qweight");
531 if self.has(&qw_key) {
532 return self.load_gptq_linear(name);
533 }
534 if let Some(prefix) = name.strip_suffix("qkv_proj") {
538 let parts = [
539 format!("{prefix}q_proj"),
540 format!("{prefix}k_proj"),
541 format!("{prefix}v_proj"),
542 ];
543 if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
544 return self.load_gptq_linear_fused(&parts);
545 }
546 }
547 if let Some(prefix) = name.strip_suffix("gate_up_proj") {
548 let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
549 if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
550 return self.load_gptq_linear_fused(&parts);
551 }
552 }
553
554 let direct = format!("{name}.weight");
557 if self.has(&direct) {
558 let (raw, src_dtype, shape) = self.read_bytes_typed(&direct)?;
559 if shape.len() != 2 {
560 return Err(FerrumError::model(format!(
561 "linear '{name}': expected 2D weight, got {shape:?}"
562 )));
563 }
564 let weight = B::from_weight_bytes(raw, src_dtype);
565 return Ok(Box::new(DenseLinear::<B>::from_buffer(
566 weight, shape[0], shape[1],
567 )));
568 }
569
570 if let Some(prefix) = name.strip_suffix("qkv_proj") {
575 let parts = [
576 format!("{prefix}q_proj.weight"),
577 format!("{prefix}k_proj.weight"),
578 format!("{prefix}v_proj.weight"),
579 ];
580 if parts.iter().all(|p| self.has(p)) {
581 let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
582 let weight = B::from_weight_bytes(&bytes, dtype);
583 return Ok(Box::new(DenseLinear::<B>::from_buffer(weight, rows, cols)));
584 }
585 }
586 if let Some(prefix) = name.strip_suffix("gate_up_proj") {
587 let parts = [
588 format!("{prefix}gate_proj.weight"),
589 format!("{prefix}up_proj.weight"),
590 ];
591 if parts.iter().all(|p| self.has(p)) {
592 let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
593 let weight = B::from_weight_bytes(&bytes, dtype);
594 return Ok(Box::new(DenseLinear::<B>::from_buffer(weight, rows, cols)));
595 }
596 }
597
598 Err(FerrumError::model(format!(
599 "could not load linear '{name}' — no direct `.weight`, no split components"
600 )))
601 }
602
603 fn has_tensor(&self, name: &str) -> bool {
604 self.has(name)
605 }
606
607 fn quant_config(&self) -> Option<&QuantConfig> {
608 self.quant_config.as_ref()
609 }
610}
611
612impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
613 fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
618 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
619 FerrumError::model(format!(
620 "'{name}.qweight' present but no quantize_config.json — \
621 can't determine bits/group_size"
622 ))
623 })?;
624 if qcfg.method != QuantMethod::Gptq {
625 return Err(FerrumError::model(format!(
626 "'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
627 qcfg.method
628 )));
629 }
630
631 let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
632 let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
633 let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
634 let g_idx = if self.has(&format!("{name}.g_idx")) {
635 Some(self.read_i32(&format!("{name}.g_idx"))?.0)
636 } else {
637 None
638 };
639
640 if qw_shape.len() != 2 {
643 return Err(FerrumError::model(format!(
644 "'{name}.qweight' expected 2D, got {qw_shape:?}"
645 )));
646 }
647 let in_features = qw_shape[0] * 8;
648 let out_features = qw_shape[1];
649
650 let is_desc_act = g_idx.as_ref().map_or(false, |gx| {
653 !gx.iter()
654 .enumerate()
655 .all(|(i, &g)| g == (i as i32) / qcfg.group_size as i32)
656 });
657
658 #[cfg(not(feature = "cuda"))]
663 if is_desc_act {
664 let dequant_f32 = dequantize_gptq_with_g_idx(
665 &qweight,
666 &scales_f32,
667 &qzeros,
668 g_idx.as_ref().expect("desc_act=true requires g_idx"),
669 qcfg.group_size,
670 in_features,
671 out_features,
672 );
673 let mut linear =
674 crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
675 let bias_key = format!("{name}.bias");
676 if self.has(&bias_key) {
677 let (bias, _) = self.read_f32(&bias_key)?;
678 linear = linear.with_bias(B::from_slice(&bias));
679 }
680 tracing::info!(
681 "GPTQ load (desc_act dequant→DenseLinear, non-cuda): name={name} K={in_features} N={out_features}"
682 );
683 return Ok(Box::new(linear));
684 }
685 #[cfg(feature = "cuda")]
686 let _ = is_desc_act; if sc_shape.len() != 2 || sc_shape[1] != out_features {
688 return Err(FerrumError::model(format!(
689 "'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
690 )));
691 }
692
693 let bias_key = format!("{name}.bias");
697 let bias_vec = if self.has(&bias_key) {
698 let (bias, bias_shape) = self.read_f32(&bias_key)?;
699 if bias_shape != [out_features] {
700 return Err(FerrumError::model(format!(
701 "'{bias_key}' {bias_shape:?} != [{out_features}]"
702 )));
703 }
704 Some(bias)
705 } else {
706 None
707 };
708
709 let linear = GptqLinear::<B>::from_raw(
710 &qweight,
711 &scales_f32,
712 &qzeros,
713 g_idx.as_deref(),
714 bias_vec.as_deref(),
715 qcfg.bits,
716 qcfg.group_size,
717 in_features,
718 out_features,
719 )?;
720 Ok(Box::new(linear))
721 }
722
723 fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
736 let qcfg = self.quant_config.as_ref().ok_or_else(|| {
737 FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
738 })?;
739 if qcfg.method != QuantMethod::Gptq {
740 return Err(FerrumError::model(format!(
741 "GPTQ fusion but quant_method={:?}",
742 qcfg.method
743 )));
744 }
745
746 let mut qw_acc: Vec<i32> = Vec::new();
747 let mut sc_acc: Vec<f32> = Vec::new();
748 let mut qz_acc: Vec<i32> = Vec::new();
749 let mut qw_rows = 0usize;
750 let mut sc_rows = 0usize;
751 let mut qz_rows = 0usize;
752 let mut total_n = 0usize;
753 let mut total_n_scales = 0usize;
754 let mut total_n_zeros = 0usize;
755 let mut g_idx: Option<Vec<i32>> = None;
756 let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
759 let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
760
761 for p in parts {
762 let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
763 let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
764 let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
765 if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
766 return Err(FerrumError::model(format!(
767 "GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
768 )));
769 }
770 if qw_rows == 0 {
771 qw_rows = qw_sh[0];
772 sc_rows = sc_sh[0];
773 qz_rows = qz_sh[0];
774 } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
775 return Err(FerrumError::model(format!(
776 "GPTQ fusion row mismatch on '{p}'"
777 )));
778 }
779 total_n += qw_sh[1];
780 total_n_scales += sc_sh[1];
781 total_n_zeros += qz_sh[1];
782 qw_parts.push((qw, qw_sh[0], qw_sh[1]));
783 sc_parts.push((sc, sc_sh[0], sc_sh[1]));
784 qz_parts.push((qz, qz_sh[0], qz_sh[1]));
785
786 if g_idx.is_none() && self.has(&format!("{p}.g_idx")) {
788 g_idx = Some(self.read_i32(&format!("{p}.g_idx"))?.0);
789 }
790 }
791
792 qw_acc.reserve(qw_rows * total_n);
794 for r in 0..qw_rows {
795 for (part, _rows, cols) in &qw_parts {
796 qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
797 }
798 }
799 sc_acc.reserve(sc_rows * total_n_scales);
800 for r in 0..sc_rows {
801 for (part, _rows, cols) in &sc_parts {
802 sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
803 }
804 }
805 qz_acc.reserve(qz_rows * total_n_zeros);
806 for r in 0..qz_rows {
807 for (part, _rows, cols) in &qz_parts {
808 qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
809 }
810 }
811
812 let in_features = qw_rows * 8;
813 let out_features = total_n;
814
815 let is_desc_act = g_idx.as_ref().map_or(false, |gx| {
818 !gx.iter()
819 .enumerate()
820 .all(|(i, &g)| g == (i as i32) / qcfg.group_size as i32)
821 });
822 #[cfg(not(feature = "cuda"))]
824 if is_desc_act {
825 let dequant_f32 = dequantize_gptq_with_g_idx(
826 &qw_acc,
827 &sc_acc,
828 &qz_acc,
829 g_idx.as_ref().expect("desc_act=true requires g_idx"),
830 qcfg.group_size,
831 in_features,
832 out_features,
833 );
834 let mut linear =
835 crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
836 let mut bias_acc: Vec<f32> = Vec::new();
837 let mut any_bias = false;
838 for p in parts {
839 let bk = format!("{p}.bias");
840 if self.has(&bk) {
841 any_bias = true;
842 bias_acc.extend_from_slice(&self.read_f32(&bk)?.0);
843 } else if any_bias {
844 return Err(FerrumError::model(format!(
845 "GPTQ fusion bias mix: '{p}' has no bias but earlier part did"
846 )));
847 }
848 }
849 if any_bias {
850 linear = linear.with_bias(B::from_slice(&bias_acc));
851 }
852 tracing::info!(
853 "GPTQ fused load (desc_act dequant→DenseLinear, non-cuda): K={in_features} N={out_features} parts={}",
854 parts.len()
855 );
856 return Ok(Box::new(linear));
857 }
858 #[cfg(feature = "cuda")]
859 let _ = is_desc_act;
860
861 let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
865 let any = bias_keys.iter().any(|k| self.has(k));
866 let all = bias_keys.iter().all(|k| self.has(k));
867 if any && !all {
868 return Err(FerrumError::model(
869 "GPTQ fusion: inconsistent bias presence across parts".to_string(),
870 ));
871 }
872 let fused_bias = if all {
873 let mut fused: Vec<f32> = Vec::with_capacity(out_features);
874 for k in &bias_keys {
875 let (b, _) = self.read_f32(k)?;
876 fused.extend_from_slice(&b);
877 }
878 if fused.len() != out_features {
879 return Err(FerrumError::model(format!(
880 "GPTQ fusion bias length {} != out_features {out_features}",
881 fused.len()
882 )));
883 }
884 Some(fused)
885 } else {
886 None
887 };
888
889 let linear = GptqLinear::<B>::from_raw(
890 &qw_acc,
891 &sc_acc,
892 &qz_acc,
893 g_idx.as_deref(),
894 fused_bias.as_deref(),
895 qcfg.bits,
896 qcfg.group_size,
897 in_features,
898 out_features,
899 )?;
900
901 Ok(Box::new(linear))
902 }
903
904 #[allow(dead_code)]
908 fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
909 let mut total_rows = 0usize;
910 let mut cols = 0usize;
911 let mut out: Vec<f32> = Vec::new();
912 for n in names {
913 let (data, shape) = self.read_f32(n)?;
914 if shape.len() != 2 {
915 return Err(FerrumError::model(format!(
916 "cat_rows: '{n}' is {shape:?}, need 2D"
917 )));
918 }
919 if cols == 0 {
920 cols = shape[1];
921 } else if cols != shape[1] {
922 return Err(FerrumError::model(format!(
923 "cat_rows: col mismatch {cols} vs {}",
924 shape[1]
925 )));
926 }
927 total_rows += shape[0];
928 out.extend_from_slice(&data);
929 }
930 Ok((total_rows, cols, out))
931 }
932}
933
934#[cfg(not(feature = "cuda"))]
951fn dequantize_gptq_with_g_idx(
952 qweight: &[i32], scales: &[f32], qzeros: &[i32], g_idx: &[i32], _group_size: usize,
957 k: usize,
958 n: usize,
959) -> Vec<f32> {
960 debug_assert_eq!(g_idx.len(), k);
961
962 let mut w = vec![0.0f32; n * k];
964 let packed_rows = k / 8;
965 for pr in 0..packed_rows {
966 for col in 0..n {
967 let packed = qweight[pr * n + col] as u32;
968 for bi in 0..8 {
969 let ki = pr * 8 + bi;
970 let q = ((packed >> (bi * 4)) & 0xF) as i32;
971 let g = g_idx[ki] as usize;
972 let scale = scales[g * n + col];
973 let z_packed = qzeros[g * (n / 8) + (col / 8)] as u32;
974 let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
975 w[col * k + ki] = (q - zero) as f32 * scale;
976 }
977 }
978 }
979 w
980}
981
982fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
983 match dtype {
984 Dtype::F32 => {
985 debug_assert_eq!(raw.len() % 4, 0);
990 let n = raw.len() / 4;
991 let mut out = Vec::<f32>::with_capacity(n);
992 unsafe {
993 std::ptr::copy_nonoverlapping(raw.as_ptr(), out.as_mut_ptr() as *mut u8, raw.len());
994 out.set_len(n);
995 }
996 Ok(out)
997 }
998 Dtype::F16 => {
999 debug_assert_eq!(raw.len() % 2, 0);
1000 let n = raw.len() / 2;
1001 let mut tmp = Vec::<f16>::with_capacity(n);
1004 unsafe {
1005 std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
1006 tmp.set_len(n);
1007 }
1008 let mut out = Vec::with_capacity(n);
1009 for h in &tmp {
1010 out.push(h.to_f32());
1011 }
1012 Ok(out)
1013 }
1014 Dtype::BF16 => {
1015 debug_assert_eq!(raw.len() % 2, 0);
1016 let n = raw.len() / 2;
1017 let mut tmp = Vec::<bf16>::with_capacity(n);
1018 unsafe {
1019 std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
1020 tmp.set_len(n);
1021 }
1022 let mut out = Vec::with_capacity(n);
1023 for h in &tmp {
1024 out.push(h.to_f32());
1025 }
1026 Ok(out)
1027 }
1028 other => Err(FerrumError::model(format!(
1029 "dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
1030 use a format-specific loader (GPTQ / AWQ / GGUF)",
1031 ))),
1032 }
1033}
1034
1035fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
1036 let p = dir.join("quantize_config.json");
1038 if p.exists() {
1039 let data =
1040 std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
1041 let qc: QuantConfig = serde_json::from_str(&data)
1042 .map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
1043 return Ok(Some(qc));
1044 }
1045 let cfg = dir.join("config.json");
1048 if cfg.exists() {
1049 let data = std::fs::read_to_string(&cfg)
1050 .map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
1051 let root: serde_json::Value = serde_json::from_str(&data)
1052 .map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
1053 if let Some(qc_val) = root.get("quantization_config") {
1054 let method = qc_val
1056 .get("quant_method")
1057 .and_then(|v| v.as_str())
1058 .unwrap_or("none");
1059 let method = match method.to_lowercase().as_str() {
1060 "gptq" => QuantMethod::Gptq,
1061 "awq" => QuantMethod::Awq,
1062 "gguf" => QuantMethod::Gguf,
1063 _ => QuantMethod::None,
1064 };
1065 let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
1066 let group_size = qc_val
1067 .get("group_size")
1068 .and_then(|v| v.as_i64())
1069 .unwrap_or(128)
1070 .max(0) as usize;
1071 let desc_act = qc_val
1072 .get("desc_act")
1073 .and_then(|v| v.as_bool())
1074 .unwrap_or(false);
1075 let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
1076 if method != QuantMethod::None {
1077 return Ok(Some(QuantConfig {
1078 method,
1079 bits,
1080 group_size,
1081 desc_act,
1082 sym,
1083 }));
1084 }
1085 }
1086 }
1087 Ok(None)
1088}