Skip to main content

ferrum_quantization/
native_safetensors.rs

1//! Native safetensors `WeightLoader<B>` — mmap + `safetensors` crate, no
2//! candle dependency on the LLM hot path.
3//!
4//! What this owns:
5//!   - Discovering `model.safetensors` vs sharded `model.safetensors.index.json`.
6//!   - Mmapping each shard file.
7//!   - Per-tensor lookup: returns shape + dtype + a byte slice into the mmap.
8//!   - f32 materialisation for Dense weights (bf16 / f16 / f32 accepted).
9//!   - The Qwen3 / Llama fusion trick: `qkv_proj` / `gate_up_proj` synthesised
10//!     on the fly from split `q_proj`+`k_proj`+`v_proj` etc.
11//!
12//! What it deliberately doesn't do:
13//!   - GPTQ / AWQ / GGUF packed weights. Those need `B::from_slice_i32` /
14//!     `B::from_slice_f16` which aren't on the Backend trait yet. A dedicated
15//!     loader per quant format lands in Phase E.
16
17use std::collections::HashMap;
18use std::fs::File;
19use std::path::Path;
20
21use ferrum_kernels::backend::{Backend, BackendQuantMarlin, SrcDtype};
22use ferrum_types::{FerrumError, Result};
23use half::{bf16, f16};
24use memmap2::Mmap;
25use safetensors::{Dtype, SafeTensors};
26
27/// Map a safetensors Dtype to ferrum's SrcDtype.
28fn map_src_dtype(dtype: Dtype) -> Result<SrcDtype> {
29    match dtype {
30        Dtype::F32 => Ok(SrcDtype::F32),
31        Dtype::F16 => Ok(SrcDtype::F16),
32        Dtype::BF16 => Ok(SrcDtype::BF16),
33        other => Err(FerrumError::model(format!(
34            "dtype {other:?} not supported; Dense path expects F32/F16/BF16"
35        ))),
36    }
37}
38
39use crate::config::{QuantConfig, QuantMethod};
40use crate::dense::DenseLinear;
41use crate::gptq::GptqLinear;
42use crate::loader::WeightLoader;
43use crate::traits::Linear;
44
45/// Tensor metadata extracted ONCE from the safetensors header at open time.
46/// Avoids the per-tensor `SafeTensors::deserialize(&mmap)` re-parse that
47/// previously dominated cold-load on stacked-MoE (>= 18 000 calls per
48/// model, ~30 ms each on a 7 000-tensor header → 9+ minutes of header
49/// re-parse alone).
50struct TensorMeta {
51    dtype: Dtype,
52    shape: Vec<usize>,
53    /// Byte range in the shard's mmap that holds the raw tensor data.
54    data_start: usize,
55    data_end: usize,
56}
57
58/// A single shard file: mmap + name→TensorMeta cache.
59struct Shard {
60    mmap: Mmap,
61    names: Vec<String>,
62    /// Pre-extracted tensor metadata. Looked up by name; no header
63    /// re-parse on every read.
64    meta: HashMap<String, TensorMeta>,
65}
66
67impl Shard {
68    fn open(path: &Path) -> Result<Self> {
69        let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
70        let mmap = unsafe {
71            Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
72        };
73        // Parse the header ONCE and cache (offset, len, dtype, shape) for
74        // every tensor — re-deriving the data slice from the cache is a
75        // simple `&mmap[start..end]` rather than a full header re-parse.
76        let st = SafeTensors::deserialize(&mmap)
77            .map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
78        // SafeTensors stores: 8-byte little-endian header_len + header_json
79        // + data_blob. The TensorView's `data()` returns a slice into the
80        // data_blob region. We compute the data_blob base by reading the
81        // 8-byte header_len.
82        debug_assert!(mmap.len() >= 8, "safetensors smaller than 8 bytes");
83        let header_len = u64::from_le_bytes(
84            mmap[0..8]
85                .try_into()
86                .expect("8-byte header len read failed"),
87        ) as usize;
88        let data_base = 8 + header_len;
89        let names: Vec<String> = st.names().iter().map(|s| s.to_string()).collect();
90        let mut meta = HashMap::with_capacity(names.len());
91        for name in &names {
92            let view = st.tensor(name).map_err(|e| {
93                FerrumError::model(format!("tensor '{name}' missing during preindex: {e}"))
94            })?;
95            // TensorView::data() is &[u8] into the mmap; we recompute its
96            // [start, end) byte range relative to the mmap base via
97            // pointer arithmetic.
98            let view_data = view.data();
99            let start = view_data.as_ptr() as usize - mmap.as_ptr() as usize;
100            let end = start + view_data.len();
101            debug_assert!(start >= data_base);
102            meta.insert(
103                name.clone(),
104                TensorMeta {
105                    dtype: view.dtype(),
106                    shape: view.shape().to_vec(),
107                    data_start: start,
108                    data_end: end,
109                },
110            );
111        }
112        let _ = data_base;
113        Ok(Self { mmap, names, meta })
114    }
115
116    /// Returns (data_bytes, dtype, shape) for the named tensor without
117    /// re-parsing the safetensors header.
118    fn get_cached(&self, name: &str) -> Result<(&[u8], Dtype, &[usize])> {
119        let m = self
120            .meta
121            .get(name)
122            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in shard")))?;
123        Ok((&self.mmap[m.data_start..m.data_end], m.dtype, &m.shape))
124    }
125}
126
127/// Native safetensors loader. Generic over `Backend` so every tensor is
128/// materialised directly into backend-native buffers.
129pub struct NativeSafetensorsLoader<B: Backend + BackendQuantMarlin> {
130    /// All shards keyed by file; each tensor's name maps to its shard here.
131    shards: Vec<Shard>,
132    /// Name → shard index. Populated once at construction.
133    index: HashMap<String, usize>,
134    /// Optional `quantize_config.json` contents.
135    quant_config: Option<QuantConfig>,
136    _m: std::marker::PhantomData<B>,
137}
138
139impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
140    /// Discover shards under `model_dir` and build the name → shard index.
141    pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
142        let dir = model_dir.as_ref();
143
144        let shard_paths = if dir.join("model.safetensors").exists() {
145            vec![dir.join("model.safetensors")]
146        } else if dir.join("model.safetensors.index.json").exists() {
147            Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
148                .into_iter()
149                .map(|name| dir.join(name))
150                .collect()
151        } else {
152            return Err(FerrumError::model(format!(
153                "no safetensors files in {dir:?}"
154            )));
155        };
156
157        let mut shards = Vec::with_capacity(shard_paths.len());
158        let mut index: HashMap<String, usize> = HashMap::new();
159        for (i, p) in shard_paths.iter().enumerate() {
160            let shard = Shard::open(p)?;
161            for name in &shard.names {
162                index.insert(name.clone(), i);
163            }
164            shards.push(shard);
165        }
166
167        let quant_config = load_quantize_config(dir)?;
168
169        Ok(Self {
170            shards,
171            index,
172            quant_config,
173            _m: std::marker::PhantomData,
174        })
175    }
176
177    fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
178        let data = std::fs::read_to_string(index_path)
179            .map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
180        let json: serde_json::Value = serde_json::from_str(&data)
181            .map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
182        let weight_map = json
183            .get("weight_map")
184            .and_then(|v| v.as_object())
185            .ok_or_else(|| FerrumError::model("index missing weight_map"))?;
186        let mut files: Vec<String> = weight_map
187            .values()
188            .filter_map(|v| v.as_str().map(|s| s.to_string()))
189            .collect();
190        files.sort();
191        files.dedup();
192        Ok(files)
193    }
194
195    /// Read a tensor as f32 (converting from bf16 / f16 / f32) + its shape.
196    fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
197        let shard_idx = *self
198            .index
199            .get(name)
200            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
201        let (data_bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
202        let data = dtype_to_f32(dtype, data_bytes)?;
203        Ok((data, shape.to_vec()))
204    }
205
206    /// Read the raw on-disk byte slice plus dtype and shape. Zero-copy into
207    /// the mmap — used to hand weights straight to `B::from_weight_bytes` so
208    /// a fp16-preferring backend can skip the transient f32 Vec.
209    fn read_bytes_typed(&self, name: &str) -> Result<(&[u8], SrcDtype, Vec<usize>)> {
210        let shard_idx = *self
211            .index
212            .get(name)
213            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
214        let (data_bytes, st_dtype, shape) = self.shards[shard_idx].get_cached(name)?;
215        let dtype = map_src_dtype(st_dtype)?;
216        Ok((data_bytes, dtype, shape.to_vec()))
217    }
218
219    /// Concatenate several tensors along dim 0 at the byte level. All parts
220    /// must share the same dtype and trailing-dim width. Returns the fused
221    /// raw bytes + common dtype + `(total_rows, cols)` shape.
222    fn cat_rows_bytes(&self, names: &[String]) -> Result<(Vec<u8>, SrcDtype, (usize, usize))> {
223        let mut total_rows = 0usize;
224        let mut cols = 0usize;
225        let mut dtype: Option<SrcDtype> = None;
226        let mut bytes: Vec<u8> = Vec::new();
227        for n in names {
228            let (raw, d, shape) = self.read_bytes_typed(n)?;
229            if shape.len() != 2 {
230                return Err(FerrumError::model(format!(
231                    "cat_rows_bytes: '{n}' is {shape:?}, need 2D"
232                )));
233            }
234            match dtype {
235                Some(prev) if prev != d => {
236                    return Err(FerrumError::model(format!(
237                        "cat_rows_bytes: dtype mismatch on '{n}'"
238                    )))
239                }
240                _ => dtype = Some(d),
241            }
242            if cols == 0 {
243                cols = shape[1];
244            } else if cols != shape[1] {
245                return Err(FerrumError::model(format!(
246                    "cat_rows_bytes: col mismatch {cols} vs {}",
247                    shape[1]
248                )));
249            }
250            total_rows += shape[0];
251            bytes.extend_from_slice(raw);
252        }
253        Ok((bytes, dtype.expect("at least one part"), (total_rows, cols)))
254    }
255
256    /// Concatenate optional projection biases in the same order as split
257    /// weight fusion. Biases must be all-or-none; silently dropping one part's
258    /// bias corrupts Qwen2.5 attention logits.
259    fn cat_optional_biases(
260        &self,
261        weight_names: &[String],
262        out_features: usize,
263    ) -> Result<Option<Vec<f32>>> {
264        let bias_names: Vec<String> = weight_names
265            .iter()
266            .map(|name| {
267                name.strip_suffix(".weight")
268                    .map(|stem| format!("{stem}.bias"))
269                    .unwrap_or_else(|| format!("{name}.bias"))
270            })
271            .collect();
272        let any_bias = bias_names.iter().any(|name| self.has(name));
273        if !any_bias {
274            return Ok(None);
275        }
276        if let Some(missing) = bias_names.iter().find(|name| !self.has(name)) {
277            return Err(FerrumError::model(format!(
278                "dense fusion bias mix: '{missing}' missing while another fused part has bias"
279            )));
280        }
281        let mut fused = Vec::new();
282        for name in &bias_names {
283            let (bias, shape) = self.read_f32(name)?;
284            if shape.len() != 1 {
285                return Err(FerrumError::model(format!(
286                    "dense fusion bias '{name}': expected 1D, got {shape:?}"
287                )));
288            }
289            fused.extend_from_slice(&bias);
290        }
291        if fused.len() != out_features {
292            return Err(FerrumError::model(format!(
293                "dense fusion bias length {} != out_features {out_features}",
294                fused.len()
295            )));
296        }
297        Ok(Some(fused))
298    }
299
300    /// Read a tensor as i32 (for GPTQ qweight / qzeros / g_idx).
301    /// Bulk memcpy from the LE-stored bytes (safetensors guarantees LE)
302    /// — the previous per-element `from_le_bytes` was 4 ms for a single
303    /// 768 KB tensor and dominated stacked-MoE load.
304    fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
305        let shard_idx = *self
306            .index
307            .get(name)
308            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
309        let (bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
310        if dtype != Dtype::I32 {
311            return Err(FerrumError::model(format!(
312                "'{name}': expected I32, got {:?}",
313                dtype
314            )));
315        }
316        debug_assert_eq!(bytes.len() % 4, 0);
317        let count = bytes.len() / 4;
318        let mut out = Vec::<i32>::with_capacity(count);
319        // SAFETY: Vec<i32>'s buffer is 4-byte aligned by allocator
320        // contract. `bytes` is a raw u8 slice; copy_nonoverlapping
321        // doesn't require src alignment. We're on x86_64 LE, and
322        // safetensors stores LE i32 — bit pattern is identical.
323        unsafe {
324            std::ptr::copy_nonoverlapping(bytes.as_ptr(), out.as_mut_ptr() as *mut u8, bytes.len());
325            out.set_len(count);
326        }
327        Ok((out, shape.to_vec()))
328    }
329
330    fn has(&self, name: &str) -> bool {
331        self.index.contains_key(name)
332    }
333
334    /// Read the four raw GPTQ tensors for a named projection without
335    /// triggering a Backend repack. Used by MoE batch loading: callers
336    /// stack many experts host-side then issue a single `B::load_gptq`,
337    /// avoiding the 12 288× per-expert Marlin repack overhead.
338    ///
339    /// Returns `(qweight, scales, qzeros, g_idx, k, n)`.
340    /// `g_idx` is `None` when desc_act=false (no act-order perm needed).
341    pub fn read_gptq_raw(
342        &self,
343        name: &str,
344    ) -> Result<(Vec<i32>, Vec<f32>, Vec<i32>, Option<Vec<i32>>, usize, usize)> {
345        let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
346        let (scales, _) = self.read_f32(&format!("{name}.scales"))?;
347        let (qzeros, _) = self.read_i32(&format!("{name}.qzeros"))?;
348        let g_idx = if self.has(&format!("{name}.g_idx")) {
349            Some(self.read_i32(&format!("{name}.g_idx"))?.0)
350        } else {
351            None
352        };
353        if qw_shape.len() != 2 {
354            return Err(FerrumError::model(format!(
355                "'{name}.qweight' expected 2D, got {qw_shape:?}"
356            )));
357        }
358        let k = qw_shape[0] * 8;
359        let n = qw_shape[1];
360        Ok((qweight, scales, qzeros, g_idx, k, n))
361    }
362
363    pub fn quant_config_ref(&self) -> Option<&crate::config::QuantConfig> {
364        self.quant_config.as_ref()
365    }
366
367    /// Load a STACKED GPTQ tile that concatenates `num_experts` experts'
368    /// raw GPTQ tensors along the N (column) axis and runs ONE backend
369    /// repack — instead of `num_experts × proj_names.len()` repacks.
370    ///
371    /// Layout: per row `r`, the cols are emitted in expert-major order:
372    /// `expert_0[proj_0|proj_1|...] | expert_1[...] | ... | expert_{N-1}[...]`.
373    /// Caller can therefore index expert `e` at column offset
374    /// `e * n_per_expert`, where `n_per_expert = Σ n(proj)` across the
375    /// `proj_names` for one expert.
376    ///
377    /// `expert_prefix_fmt` should be a closure-style `&str` that contains
378    /// `"{e}"` placeholder (replaced by the expert index) and ends *just
379    /// before* the proj name — e.g. `"model.layers.5.mlp.experts.{e}."`.
380    /// The full tensor name probed is `{expert_prefix}{proj}`.
381    ///
382    /// Returns `(store, n_per_expert, k)` where `n_per_expert` is the
383    /// per-expert column width and `k = in_features` (shared by all).
384    pub fn load_stacked_gptq_experts(
385        &self,
386        expert_prefix_fmt: &str,
387        num_experts: usize,
388        proj_names: &[&str],
389    ) -> Result<(
390        std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>,
391        usize,
392        usize,
393    )> {
394        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
395            FerrumError::model(
396                "load_stacked_gptq_experts requires quantize_config.json".to_string(),
397            )
398        })?;
399        if qcfg.method != QuantMethod::Gptq {
400            return Err(FerrumError::model(format!(
401                "stacked GPTQ load but quant_method={:?}",
402                qcfg.method
403            )));
404        }
405
406        let mut qw_rows = 0usize;
407        let mut sc_rows = 0usize;
408        let mut qz_rows = 0usize;
409        let mut n_per_expert = 0usize;
410        let mut n_per_expert_scales = 0usize;
411        let mut n_per_expert_zeros = 0usize;
412        let mut k_shared = 0usize;
413        let mut g_idx_first: Option<Vec<i32>> = None;
414
415        // Per (expert, proj) raw slices — row-major (rows × cols).
416        let total_pairs = num_experts * proj_names.len();
417        let mut qw_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs); // (data, cols)
418        let mut sc_parts: Vec<(Vec<f32>, usize)> = Vec::with_capacity(total_pairs);
419        let mut qz_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs);
420
421        for e in 0..num_experts {
422            let prefix = expert_prefix_fmt.replace("{e}", &e.to_string());
423            let mut e_n = 0usize;
424            let mut e_n_scales = 0usize;
425            let mut e_n_zeros = 0usize;
426            for proj in proj_names {
427                let name = format!("{prefix}{proj}");
428                let (qw, qw_sh) = self.read_i32(&format!("{name}.qweight"))?;
429                let (sc, sc_sh) = self.read_f32(&format!("{name}.scales"))?;
430                let (qz, qz_sh) = self.read_i32(&format!("{name}.qzeros"))?;
431                if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
432                    return Err(FerrumError::model(format!(
433                        "stacked GPTQ '{name}': expected 2D, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
434                    )));
435                }
436                if qw_rows == 0 {
437                    qw_rows = qw_sh[0];
438                    sc_rows = sc_sh[0];
439                    qz_rows = qz_sh[0];
440                    k_shared = qw_sh[0] * 8;
441                } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
442                    return Err(FerrumError::model(format!(
443                        "stacked GPTQ '{name}': row mismatch qw {} sc {} qz {} vs ref {qw_rows}/{sc_rows}/{qz_rows}",
444                        qw_sh[0], sc_sh[0], qz_sh[0]
445                    )));
446                }
447                e_n += qw_sh[1];
448                e_n_scales += sc_sh[1];
449                e_n_zeros += qz_sh[1];
450                qw_parts.push((qw, qw_sh[1]));
451                sc_parts.push((sc, sc_sh[1]));
452                qz_parts.push((qz, qz_sh[1]));
453
454                // g_idx is a permutation over K — Marlin assumes ONE g_idx
455                // for the whole stacked tile. Validate all experts share
456                // identical g_idx if any has it (which they should, since
457                // K = hidden_size is the same across experts and GPTQ's
458                // act-order is computed on the input distribution).
459                let g_key = format!("{name}.g_idx");
460                if self.has(&g_key) {
461                    let (gx, _) = self.read_i32(&g_key)?;
462                    match &g_idx_first {
463                        None => g_idx_first = Some(gx),
464                        Some(prev) => {
465                            if prev.len() != gx.len() || prev.iter().zip(&gx).any(|(a, b)| a != b) {
466                                return Err(FerrumError::model(format!(
467                                    "stacked GPTQ '{name}': g_idx mismatch with first \
468                                     expert — Marlin requires identical act-order across \
469                                     experts in the same stacked tile"
470                                )));
471                            }
472                        }
473                    }
474                }
475            }
476            if e == 0 {
477                n_per_expert = e_n;
478                n_per_expert_scales = e_n_scales;
479                n_per_expert_zeros = e_n_zeros;
480            } else if e_n != n_per_expert
481                || e_n_scales != n_per_expert_scales
482                || e_n_zeros != n_per_expert_zeros
483            {
484                return Err(FerrumError::model(format!(
485                    "stacked GPTQ expert {e} N mismatch: qw {e_n} sc {e_n_scales} qz {e_n_zeros} vs expert 0 {n_per_expert}/{n_per_expert_scales}/{n_per_expert_zeros}"
486                )));
487            }
488        }
489
490        let proj_count = proj_names.len();
491        let pairs_per_expert = proj_count;
492        debug_assert_eq!(total_pairs, num_experts * pairs_per_expert);
493
494        // PER-EXPERT layout: build num_experts independent
495        // `[K/8, n_per_expert]` qweight tiles + scales + qzeros, each
496        // a row-major concat of the proj_names within that expert.
497        // Hand them to `B::load_gptq_stacked` which repacks PER-EXPERT
498        // and concats the resulting Marlin-format tiles into one
499        // contiguous buffer. Each expert's packed bytes are then
500        // contiguous, so the offset GEMM dispatches correctly via
501        // pointer arithmetic alone.
502        //
503        // Without per-expert repack, a single concat-then-repack of
504        // the stacked tile mangles per-expert tile boundaries (Marlin
505        // permutes in K-tile-major order across the whole tile).
506        let mut per_expert_qw: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
507        let mut per_expert_sc: Vec<Vec<f32>> = Vec::with_capacity(num_experts);
508        let mut per_expert_qz: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
509        for e in 0..num_experts {
510            let mut qw: Vec<i32> = Vec::with_capacity(qw_rows * n_per_expert);
511            let mut sc: Vec<f32> = Vec::with_capacity(sc_rows * n_per_expert_scales);
512            let mut qz: Vec<i32> = Vec::with_capacity(qz_rows * n_per_expert_zeros);
513            for r in 0..qw_rows {
514                for j in 0..pairs_per_expert {
515                    let pair_idx = e * pairs_per_expert + j;
516                    let (data, cols) = &qw_parts[pair_idx];
517                    qw.extend_from_slice(&data[r * cols..(r + 1) * cols]);
518                }
519            }
520            for r in 0..sc_rows {
521                for j in 0..pairs_per_expert {
522                    let pair_idx = e * pairs_per_expert + j;
523                    let (data, cols) = &sc_parts[pair_idx];
524                    sc.extend_from_slice(&data[r * cols..(r + 1) * cols]);
525                }
526            }
527            for r in 0..qz_rows {
528                for j in 0..pairs_per_expert {
529                    let pair_idx = e * pairs_per_expert + j;
530                    let (data, cols) = &qz_parts[pair_idx];
531                    qz.extend_from_slice(&data[r * cols..(r + 1) * cols]);
532                }
533            }
534            per_expert_qw.push(qw);
535            per_expert_sc.push(sc);
536            per_expert_qz.push(qz);
537        }
538
539        // Drop the original part buffers — we own copies in per_expert_*.
540        drop(qw_parts);
541        drop(sc_parts);
542        drop(qz_parts);
543
544        let qw_refs: Vec<&[i32]> = per_expert_qw.iter().map(|v| v.as_slice()).collect();
545        let sc_refs: Vec<&[f32]> = per_expert_sc.iter().map(|v| v.as_slice()).collect();
546        let qz_refs: Vec<&[i32]> = per_expert_qz.iter().map(|v| v.as_slice()).collect();
547
548        let store = B::load_gptq_stacked(
549            &qw_refs,
550            &sc_refs,
551            &qz_refs,
552            g_idx_first.as_deref(),
553            qcfg.bits,
554            qcfg.group_size,
555            k_shared,
556            n_per_expert,
557        )?;
558        Ok((store, n_per_expert, k_shared))
559    }
560}
561
562impl<B: Backend + BackendQuantMarlin> WeightLoader<B> for NativeSafetensorsLoader<B> {
563    fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
564        // Route through `from_weight_bytes` so fp16-preferring backends can
565        // materialise big tensors (embed table) directly as half-precision
566        // without the transient f32 Vec. Tiny tensors (norm weights) still
567        // end up as f32 because backends size-threshold inside the override.
568        let (raw, src_dtype, _) = self.read_bytes_typed(name)?;
569        Ok(B::from_weight_bytes(raw, src_dtype))
570    }
571
572    fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
573        // GPTQ first: `<name>.qweight` + `<name>.scales` + `<name>.qzeros`.
574        let qw_key = format!("{name}.qweight");
575        if self.has(&qw_key) {
576            return self.load_gptq_linear(name);
577        }
578        // GPTQ fusion shims: synthesise qkv_proj / gate_up_proj from split
579        // components — same pattern as Dense but concatenating the GPTQ
580        // tensors (qweight/scales/qzeros) along the N dim.
581        if let Some(prefix) = name.strip_suffix("qkv_proj") {
582            let parts = [
583                format!("{prefix}q_proj"),
584                format!("{prefix}k_proj"),
585                format!("{prefix}v_proj"),
586            ];
587            if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
588                return self.load_gptq_linear_fused(&parts);
589            }
590        }
591        if let Some(prefix) = name.strip_suffix("gate_up_proj") {
592            let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
593            if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
594                return self.load_gptq_linear_fused(&parts);
595            }
596        }
597
598        // Direct fused `<name>.weight` next. Load straight from raw bytes
599        // so fp16-preferring backends can skip the f32 Vec intermediate.
600        let direct = format!("{name}.weight");
601        if self.has(&direct) {
602            let (raw, src_dtype, shape) = self.read_bytes_typed(&direct)?;
603            if shape.len() != 2 {
604                return Err(FerrumError::model(format!(
605                    "linear '{name}': expected 2D weight, got {shape:?}"
606                )));
607            }
608            let weight = B::from_weight_bytes(raw, src_dtype);
609            return Ok(Box::new(DenseLinear::<B>::from_buffer(
610                weight, shape[0], shape[1],
611            )));
612        }
613
614        // Llama-family fusion shims: synthesise qkv_proj / gate_up_proj from
615        // split q_proj+k_proj+v_proj / gate_proj+up_proj if present. The cat
616        // happens at the byte level so fused-weight memory is the same size
617        // as the per-part weights — no expansion to f32.
618        if let Some(prefix) = name.strip_suffix("qkv_proj") {
619            let parts = [
620                format!("{prefix}q_proj.weight"),
621                format!("{prefix}k_proj.weight"),
622                format!("{prefix}v_proj.weight"),
623            ];
624            if parts.iter().all(|p| self.has(p)) {
625                let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
626                let weight = B::from_weight_bytes(&bytes, dtype);
627                let mut linear = DenseLinear::<B>::from_buffer(weight, rows, cols);
628                if let Some(bias) = self.cat_optional_biases(&parts, rows)? {
629                    linear = linear.with_bias(B::from_slice(&bias));
630                }
631                return Ok(Box::new(linear));
632            }
633        }
634        if let Some(prefix) = name.strip_suffix("gate_up_proj") {
635            let parts = [
636                format!("{prefix}gate_proj.weight"),
637                format!("{prefix}up_proj.weight"),
638            ];
639            if parts.iter().all(|p| self.has(p)) {
640                let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
641                let weight = B::from_weight_bytes(&bytes, dtype);
642                let mut linear = DenseLinear::<B>::from_buffer(weight, rows, cols);
643                if let Some(bias) = self.cat_optional_biases(&parts, rows)? {
644                    linear = linear.with_bias(B::from_slice(&bias));
645                }
646                return Ok(Box::new(linear));
647            }
648        }
649
650        Err(FerrumError::model(format!(
651            "could not load linear '{name}' — no direct `.weight`, no split components"
652        )))
653    }
654
655    fn has_tensor(&self, name: &str) -> bool {
656        self.has(name)
657    }
658
659    fn quant_config(&self) -> Option<&QuantConfig> {
660        self.quant_config.as_ref()
661    }
662}
663
664impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
665    /// Load a GPTQ-packed linear projection: reads `<name>.qweight`,
666    /// `<name>.scales`, `<name>.qzeros`, optionally `<name>.g_idx`, and
667    /// hands the raw host-side tensors to `Backend::load_gptq` which
668    /// repacks + uploads per its own strategy.
669    fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
670        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
671            FerrumError::model(format!(
672                "'{name}.qweight' present but no quantize_config.json — \
673                 can't determine bits/group_size"
674            ))
675        })?;
676        if qcfg.method != QuantMethod::Gptq {
677            return Err(FerrumError::model(format!(
678                "'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
679                qcfg.method
680            )));
681        }
682
683        let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
684        let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
685        let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
686        let g_idx = if self.has(&format!("{name}.g_idx")) {
687            Some(self.read_i32(&format!("{name}.g_idx"))?.0)
688        } else {
689            None
690        };
691
692        // Shape inference: qweight is [K/8, N]; scales is [K/group, N].
693        // → K = qw_shape[0] * 8, N = qw_shape[1].
694        if qw_shape.len() != 2 {
695            return Err(FerrumError::model(format!(
696                "'{name}.qweight' expected 2D, got {qw_shape:?}"
697            )));
698        }
699        let in_features = qw_shape[0] * 8;
700        let out_features = qw_shape[1];
701
702        let is_desc_act = validate_gptq_g_idx(name, qcfg, g_idx.as_deref(), in_features)?;
703
704        // Act-order GPTQ. CUDA backend has perm-aware Marlin (load_gptq
705        // builds perm = argsort(g_idx) + permutes qweight rows at load;
706        // gemm_gptq gathers input columns before the standard Marlin call).
707        // CPU/Metal still need the dequant→DenseLinear fallback.
708        #[cfg(not(feature = "cuda"))]
709        if is_desc_act {
710            let dequant_f32 = dequantize_gptq_with_g_idx(
711                &qweight,
712                &scales_f32,
713                &qzeros,
714                g_idx.as_ref().expect("desc_act=true requires g_idx"),
715                qcfg.group_size,
716                in_features,
717                out_features,
718            );
719            let mut linear =
720                crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
721            let bias_key = format!("{name}.bias");
722            if self.has(&bias_key) {
723                let (bias, _) = self.read_f32(&bias_key)?;
724                linear = linear.with_bias(B::from_slice(&bias));
725            }
726            tracing::info!(
727                "GPTQ load (desc_act dequant→DenseLinear, non-cuda): name={name} K={in_features} N={out_features}"
728            );
729            return Ok(Box::new(linear));
730        }
731        #[cfg(feature = "cuda")]
732        let _ = is_desc_act; // CUDA: g_idx threaded through to GptqLinear below
733        if sc_shape.len() != 2 || sc_shape[1] != out_features {
734            return Err(FerrumError::model(format!(
735                "'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
736            )));
737        }
738
739        // Read optional bias FIRST (Qwen2.5 attention projections, some
740        // Llama variants). Phase 3e/2: load_gptq takes the bias eagerly
741        // because the boxed Linear bakes it in.
742        let bias_key = format!("{name}.bias");
743        let bias_vec = if self.has(&bias_key) {
744            let (bias, bias_shape) = self.read_f32(&bias_key)?;
745            if bias_shape != [out_features] {
746                return Err(FerrumError::model(format!(
747                    "'{bias_key}' {bias_shape:?} != [{out_features}]"
748                )));
749            }
750            Some(bias)
751        } else {
752            None
753        };
754
755        let linear = GptqLinear::<B>::from_raw(
756            &qweight,
757            &scales_f32,
758            &qzeros,
759            g_idx.as_deref(),
760            bias_vec.as_deref(),
761            qcfg.bits,
762            qcfg.group_size,
763            in_features,
764            out_features,
765        )?;
766        Ok(Box::new(linear))
767    }
768
769    /// Fuse multiple GPTQ projections by concatenating qweight/scales/qzeros
770    /// along the output (N) dim. Matches the Dense fusion shim used for
771    /// non-quantized models: q_proj + k_proj + v_proj → qkv_proj.
772    ///
773    /// All parts must share:
774    /// - in_features (K)
775    /// - bits, group_size
776    /// - qzeros N-packing (which the GPTQ format always honours: qzeros[-1]
777    ///   = N/8, concat along that axis works)
778    ///
779    /// g_idx: only present when desc_act=true. When present, all parts
780    /// share it (same K rows, same activation permutation).
781    fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
782        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
783            FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
784        })?;
785        if qcfg.method != QuantMethod::Gptq {
786            return Err(FerrumError::model(format!(
787                "GPTQ fusion but quant_method={:?}",
788                qcfg.method
789            )));
790        }
791
792        let mut qw_acc: Vec<i32> = Vec::new();
793        let mut sc_acc: Vec<f32> = Vec::new();
794        let mut qz_acc: Vec<i32> = Vec::new();
795        let mut qw_rows = 0usize;
796        let mut sc_rows = 0usize;
797        let mut qz_rows = 0usize;
798        let mut total_n = 0usize;
799        let mut total_n_scales = 0usize;
800        let mut total_n_zeros = 0usize;
801        let mut g_idx: Option<Vec<i32>> = None;
802        let mut g_idx_presence: Vec<(String, bool)> = Vec::with_capacity(parts.len());
803        // Segments: (qw_slice, sc_slice, qz_slice) per part, needed for N-major layout concat
804        let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); // (data, rows, cols)
805        let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
806        let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
807
808        for p in parts {
809            let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
810            let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
811            let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
812            if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
813                return Err(FerrumError::model(format!(
814                    "GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
815                )));
816            }
817            if qw_rows == 0 {
818                qw_rows = qw_sh[0];
819                sc_rows = sc_sh[0];
820                qz_rows = qz_sh[0];
821            } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
822                return Err(FerrumError::model(format!(
823                    "GPTQ fusion row mismatch on '{p}'"
824                )));
825            }
826            total_n += qw_sh[1];
827            total_n_scales += sc_sh[1];
828            total_n_zeros += qz_sh[1];
829            qw_parts.push((qw, qw_sh[0], qw_sh[1]));
830            sc_parts.push((sc, sc_sh[0], sc_sh[1]));
831            qz_parts.push((qz, qz_sh[0], qz_sh[1]));
832
833            let g_key = format!("{p}.g_idx");
834            if self.has(&g_key) {
835                let (gx, gx_shape) = self.read_i32(&g_key)?;
836                if gx_shape != [qw_rows * 8] {
837                    return Err(FerrumError::model(format!(
838                        "GPTQ fusion '{p}': g_idx shape {gx_shape:?} incompatible with K={}",
839                        qw_rows * 8
840                    )));
841                }
842                match &g_idx {
843                    None => g_idx = Some(gx),
844                    Some(prev) => {
845                        if prev.len() != gx.len() || prev.iter().zip(&gx).any(|(a, b)| a != b) {
846                            return Err(FerrumError::model(format!(
847                                "GPTQ fusion '{p}': g_idx mismatch with first part; \
848                                 fused qkv/gate_up requires identical act-order across parts"
849                            )));
850                        }
851                    }
852                }
853                g_idx_presence.push((p.clone(), true));
854            } else {
855                g_idx_presence.push((p.clone(), false));
856            }
857        }
858
859        // Interleave row-major concatenation: for each row, write all parts' cols.
860        qw_acc.reserve(qw_rows * total_n);
861        for r in 0..qw_rows {
862            for (part, _rows, cols) in &qw_parts {
863                qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
864            }
865        }
866        sc_acc.reserve(sc_rows * total_n_scales);
867        for r in 0..sc_rows {
868            for (part, _rows, cols) in &sc_parts {
869                sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
870            }
871        }
872        qz_acc.reserve(qz_rows * total_n_zeros);
873        for r in 0..qz_rows {
874            for (part, _rows, cols) in &qz_parts {
875                qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
876            }
877        }
878
879        let in_features = qw_rows * 8;
880        let out_features = total_n;
881
882        if g_idx.is_some() {
883            let missing = g_idx_presence
884                .iter()
885                .filter_map(|(part, present)| (!present).then_some(part.as_str()))
886                .collect::<Vec<_>>();
887            if !missing.is_empty() {
888                return Err(FerrumError::model(format!(
889                    "GPTQ fusion requires all parts to carry g_idx when any part does; \
890                     missing g_idx for {missing:?}"
891                )));
892            }
893        }
894        let fused_name = format!("GPTQ fusion {}", parts.join("+"));
895        let is_desc_act = validate_gptq_g_idx(&fused_name, qcfg, g_idx.as_deref(), in_features)?;
896        // CUDA: perm-aware Marlin via load_gptq. CPU/Metal: dequant→Dense.
897        #[cfg(not(feature = "cuda"))]
898        if is_desc_act {
899            let dequant_f32 = dequantize_gptq_with_g_idx(
900                &qw_acc,
901                &sc_acc,
902                &qz_acc,
903                g_idx.as_ref().expect("desc_act=true requires g_idx"),
904                qcfg.group_size,
905                in_features,
906                out_features,
907            );
908            let mut linear =
909                crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
910            let mut bias_acc: Vec<f32> = Vec::new();
911            let mut any_bias = false;
912            for p in parts {
913                let bk = format!("{p}.bias");
914                if self.has(&bk) {
915                    any_bias = true;
916                    bias_acc.extend_from_slice(&self.read_f32(&bk)?.0);
917                } else if any_bias {
918                    return Err(FerrumError::model(format!(
919                        "GPTQ fusion bias mix: '{p}' has no bias but earlier part did"
920                    )));
921                }
922            }
923            if any_bias {
924                linear = linear.with_bias(B::from_slice(&bias_acc));
925            }
926            tracing::info!(
927                "GPTQ fused load (desc_act dequant→DenseLinear, non-cuda): K={in_features} N={out_features} parts={}",
928                parts.len()
929            );
930            return Ok(Box::new(linear));
931        }
932        #[cfg(feature = "cuda")]
933        let _ = is_desc_act;
934
935        // Biases: concatenate `<part>.bias` across parts in the same
936        // order as qweights. All-or-none; if any part has a bias, all
937        // must. Phase 3e/2: read first, pass into load_gptq.
938        let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
939        let any = bias_keys.iter().any(|k| self.has(k));
940        let all = bias_keys.iter().all(|k| self.has(k));
941        if any && !all {
942            return Err(FerrumError::model(
943                "GPTQ fusion: inconsistent bias presence across parts".to_string(),
944            ));
945        }
946        let fused_bias = if all {
947            let mut fused: Vec<f32> = Vec::with_capacity(out_features);
948            for k in &bias_keys {
949                let (b, _) = self.read_f32(k)?;
950                fused.extend_from_slice(&b);
951            }
952            if fused.len() != out_features {
953                return Err(FerrumError::model(format!(
954                    "GPTQ fusion bias length {} != out_features {out_features}",
955                    fused.len()
956                )));
957            }
958            Some(fused)
959        } else {
960            None
961        };
962
963        let linear = GptqLinear::<B>::from_raw(
964            &qw_acc,
965            &sc_acc,
966            &qz_acc,
967            g_idx.as_deref(),
968            fused_bias.as_deref(),
969            qcfg.bits,
970            qcfg.group_size,
971            in_features,
972            out_features,
973        )?;
974
975        Ok(Box::new(linear))
976    }
977
978    /// Read each name, assert shape width matches, concatenate along dim 0.
979    /// Kept for diagnostic / fallback paths; DenseLinear fusion prefers the
980    /// byte-level `cat_rows_bytes` above.
981    #[allow(dead_code)]
982    fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
983        let mut total_rows = 0usize;
984        let mut cols = 0usize;
985        let mut out: Vec<f32> = Vec::new();
986        for n in names {
987            let (data, shape) = self.read_f32(n)?;
988            if shape.len() != 2 {
989                return Err(FerrumError::model(format!(
990                    "cat_rows: '{n}' is {shape:?}, need 2D"
991                )));
992            }
993            if cols == 0 {
994                cols = shape[1];
995            } else if cols != shape[1] {
996                return Err(FerrumError::model(format!(
997                    "cat_rows: col mismatch {cols} vs {}",
998                    shape[1]
999                )));
1000            }
1001            total_rows += shape[0];
1002            out.extend_from_slice(&data);
1003        }
1004        Ok((total_rows, cols, out))
1005    }
1006}
1007
1008fn gptq_g_idx_is_desc_act(g_idx: &[i32], group_size: usize) -> bool {
1009    g_idx
1010        .iter()
1011        .enumerate()
1012        .any(|(i, &g)| g != (i as i32) / group_size as i32)
1013}
1014
1015fn validate_gptq_g_idx(
1016    name: &str,
1017    qcfg: &QuantConfig,
1018    g_idx: Option<&[i32]>,
1019    in_features: usize,
1020) -> Result<bool> {
1021    if qcfg.desc_act && g_idx.is_none() {
1022        return Err(FerrumError::model(format!(
1023            "{name}: quantize_config desc_act=true but no g_idx tensor was found"
1024        )));
1025    }
1026
1027    let Some(g_idx) = g_idx else {
1028        return Ok(false);
1029    };
1030    if qcfg.group_size == 0 {
1031        return Err(FerrumError::model(format!(
1032            "{name}: GPTQ g_idx present but group_size is 0"
1033        )));
1034    }
1035    if g_idx.len() != in_features {
1036        return Err(FerrumError::model(format!(
1037            "{name}: g_idx length {} must match K={in_features}",
1038            g_idx.len()
1039        )));
1040    }
1041    let expected_groups = in_features.div_ceil(qcfg.group_size);
1042    for (idx, &group) in g_idx.iter().enumerate() {
1043        if group < 0 || group as usize >= expected_groups {
1044            return Err(FerrumError::model(format!(
1045                "{name}: g_idx[{idx}]={group} outside expected group range 0..{}",
1046                expected_groups.saturating_sub(1)
1047            )));
1048        }
1049    }
1050    Ok(gptq_g_idx_is_desc_act(g_idx, qcfg.group_size))
1051}
1052
1053/// Dequantise GPTQ INT4 weights with desc_act=true (act-order) g_idx and
1054/// return original-order f32 weights laid out `[N, K]` row-major (matches
1055/// `DenseLinear::from_rows`).
1056///
1057/// Key insight: in AutoGPTQ desc_act format, qweight rows are NOT
1058/// permuted from original-K order. The act-order trick is encoded purely
1059/// in `g_idx[k]` — which records the QUANTISATION GROUP (not column
1060/// position) chosen for disk row k. Different rows that originally
1061/// belonged to far-apart positions can share a group via g_idx.
1062///
1063/// Verified against vLLM's exllama path (gptq.py:368): for desc_act it
1064/// runs `g_idx ← argsort(g_idx)` then `gptq_shuffle(qweight, g_idx)`,
1065/// which physically reorders qweight by argsort and gathers x by argsort
1066/// at GEMM. Net effect: y[n] = Σⱼ x[j] · dequant(qweight[j, n],
1067/// scales[g_idx_orig[j], n], qzeros[g_idx_orig[j], n]).
1068/// → disk_k IS original_k; only the (scale, zero) LOOKUP differs.
1069#[cfg(not(feature = "cuda"))]
1070fn dequantize_gptq_with_g_idx(
1071    qweight: &[i32], // [K/8, N] packed int4
1072    scales: &[f32],  // [num_groups, N]
1073    qzeros: &[i32],  // [num_groups, N/8] packed int4
1074    g_idx: &[i32],   // [K]
1075    _group_size: usize,
1076    k: usize,
1077    n: usize,
1078) -> Vec<f32> {
1079    debug_assert_eq!(g_idx.len(), k);
1080
1081    // Output: [N, K] row-major → out[col * k + k_idx] = value.
1082    let mut w = vec![0.0f32; n * k];
1083    let packed_rows = k / 8;
1084    for pr in 0..packed_rows {
1085        for col in 0..n {
1086            let packed = qweight[pr * n + col] as u32;
1087            for bi in 0..8 {
1088                let ki = pr * 8 + bi;
1089                let q = ((packed >> (bi * 4)) & 0xF) as i32;
1090                let g = g_idx[ki] as usize;
1091                let scale = scales[g * n + col];
1092                let z_packed = qzeros[g * (n / 8) + (col / 8)] as u32;
1093                let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
1094                w[col * k + ki] = (q - zero) as f32 * scale;
1095            }
1096        }
1097    }
1098    w
1099}
1100
1101fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
1102    match dtype {
1103        Dtype::F32 => {
1104            // Bulk memcpy from LE-stored bytes (safetensors is LE; we're
1105            // on x86_64 LE). Per-element from_le_bytes was the bottleneck
1106            // for stacked-MoE load (4-5 ms per call * 384 calls/layer *
1107            // 48 layers = ~80 sec just for f32 reads).
1108            debug_assert_eq!(raw.len() % 4, 0);
1109            let n = raw.len() / 4;
1110            let mut out = Vec::<f32>::with_capacity(n);
1111            unsafe {
1112                std::ptr::copy_nonoverlapping(raw.as_ptr(), out.as_mut_ptr() as *mut u8, raw.len());
1113                out.set_len(n);
1114            }
1115            Ok(out)
1116        }
1117        Dtype::F16 => {
1118            debug_assert_eq!(raw.len() % 2, 0);
1119            let n = raw.len() / 2;
1120            // Reinterpret raw bytes as f16, then convert. This avoids
1121            // the per-element from_le_bytes byte-array construction.
1122            let mut tmp = Vec::<f16>::with_capacity(n);
1123            unsafe {
1124                std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
1125                tmp.set_len(n);
1126            }
1127            let mut out = Vec::with_capacity(n);
1128            for h in &tmp {
1129                out.push(h.to_f32());
1130            }
1131            Ok(out)
1132        }
1133        Dtype::BF16 => {
1134            debug_assert_eq!(raw.len() % 2, 0);
1135            let n = raw.len() / 2;
1136            let mut tmp = Vec::<bf16>::with_capacity(n);
1137            unsafe {
1138                std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
1139                tmp.set_len(n);
1140            }
1141            let mut out = Vec::with_capacity(n);
1142            for h in &tmp {
1143                out.push(h.to_f32());
1144            }
1145            Ok(out)
1146        }
1147        other => Err(FerrumError::model(format!(
1148            "dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
1149             use a format-specific loader (GPTQ / AWQ / GGUF)",
1150        ))),
1151    }
1152}
1153
1154fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
1155    // AutoGPTQ / gptq-for-llama format: separate quantize_config.json.
1156    let p = dir.join("quantize_config.json");
1157    if p.exists() {
1158        let data =
1159            std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
1160        let qc: QuantConfig = serde_json::from_str(&data)
1161            .map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
1162        return Ok(Some(qc));
1163    }
1164    // Qwen GPTQ / transformers-style: embedded in config.json under
1165    // "quantization_config": { "quant_method": "gptq", "bits": 4, ... }.
1166    let cfg = dir.join("config.json");
1167    if cfg.exists() {
1168        let data = std::fs::read_to_string(&cfg)
1169            .map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
1170        let root: serde_json::Value = serde_json::from_str(&data)
1171            .map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
1172        if let Some(qc_val) = root.get("quantization_config") {
1173            // The embedded block has "quant_method" (not "method"); remap.
1174            let method = qc_val
1175                .get("quant_method")
1176                .and_then(|v| v.as_str())
1177                .unwrap_or("none");
1178            let method = match method.to_lowercase().as_str() {
1179                "gptq" => QuantMethod::Gptq,
1180                "awq" => QuantMethod::Awq,
1181                "gguf" => QuantMethod::Gguf,
1182                _ => QuantMethod::None,
1183            };
1184            let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
1185            let group_size = qc_val
1186                .get("group_size")
1187                .and_then(|v| v.as_i64())
1188                .unwrap_or(128)
1189                .max(0) as usize;
1190            let desc_act = qc_val
1191                .get("desc_act")
1192                .and_then(|v| v.as_bool())
1193                .unwrap_or(false);
1194            let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
1195            if method != QuantMethod::None {
1196                return Ok(Some(QuantConfig {
1197                    method,
1198                    bits,
1199                    group_size,
1200                    desc_act,
1201                    sym,
1202                }));
1203            }
1204        }
1205    }
1206    Ok(None)
1207}
1208
1209#[cfg(test)]
1210mod tests {
1211    use super::*;
1212
1213    fn gptq_config(desc_act: bool) -> QuantConfig {
1214        QuantConfig {
1215            method: QuantMethod::Gptq,
1216            bits: 4,
1217            group_size: 2,
1218            desc_act,
1219            sym: true,
1220        }
1221    }
1222
1223    #[test]
1224    fn validate_gptq_g_idx_requires_tensor_when_desc_act_configured() {
1225        let err = validate_gptq_g_idx("proj", &gptq_config(true), None, 4)
1226            .unwrap_err()
1227            .to_string();
1228
1229        assert!(err.contains("desc_act=true"));
1230        assert!(err.contains("no g_idx"));
1231    }
1232
1233    #[test]
1234    fn validate_gptq_g_idx_accepts_trivial_non_desc_act_order() {
1235        let is_desc_act =
1236            validate_gptq_g_idx("proj", &gptq_config(false), Some(&[0, 0, 1, 1]), 4).unwrap();
1237
1238        assert!(!is_desc_act);
1239    }
1240
1241    #[test]
1242    fn validate_gptq_g_idx_detects_nontrivial_act_order() {
1243        let is_desc_act =
1244            validate_gptq_g_idx("proj", &gptq_config(false), Some(&[1, 1, 0, 0]), 4).unwrap();
1245
1246        assert!(is_desc_act);
1247    }
1248
1249    #[test]
1250    fn validate_gptq_g_idx_rejects_invalid_shape_and_group() {
1251        let short = validate_gptq_g_idx("proj", &gptq_config(false), Some(&[0, 0, 1]), 4)
1252            .unwrap_err()
1253            .to_string();
1254        assert!(short.contains("must match K=4"));
1255
1256        let out_of_range = validate_gptq_g_idx("proj", &gptq_config(false), Some(&[0, 0, 2, 1]), 4)
1257            .unwrap_err()
1258            .to_string();
1259        assert!(out_of_range.contains("outside expected group range"));
1260    }
1261}