Skip to main content

ferrum_quantization/
native_safetensors.rs

1//! Native safetensors `WeightLoader<B>` — mmap + `safetensors` crate, no
2//! candle dependency on the LLM hot path.
3//!
4//! What this owns:
5//!   - Discovering `model.safetensors` vs sharded `model.safetensors.index.json`.
6//!   - Mmapping each shard file.
7//!   - Per-tensor lookup: returns shape + dtype + a byte slice into the mmap.
8//!   - f32 materialisation for Dense weights (bf16 / f16 / f32 accepted).
9//!   - The Qwen3 / Llama fusion trick: `qkv_proj` / `gate_up_proj` synthesised
10//!     on the fly from split `q_proj`+`k_proj`+`v_proj` etc.
11//!
12//! What it deliberately doesn't do:
13//!   - GPTQ / AWQ / GGUF packed weights. Those need `B::from_slice_i32` /
14//!     `B::from_slice_f16` which aren't on the Backend trait yet. A dedicated
15//!     loader per quant format lands in Phase E.
16
17use std::collections::HashMap;
18use std::fs::File;
19use std::path::Path;
20
21use ferrum_kernels::backend::{Backend, BackendQuantMarlin, SrcDtype};
22use ferrum_types::{FerrumError, Result};
23use half::{bf16, f16};
24use memmap2::Mmap;
25use safetensors::{Dtype, SafeTensors};
26
27/// Map a safetensors Dtype to ferrum's SrcDtype.
28fn map_src_dtype(dtype: Dtype) -> Result<SrcDtype> {
29    match dtype {
30        Dtype::F32 => Ok(SrcDtype::F32),
31        Dtype::F16 => Ok(SrcDtype::F16),
32        Dtype::BF16 => Ok(SrcDtype::BF16),
33        other => Err(FerrumError::model(format!(
34            "dtype {other:?} not supported; Dense path expects F32/F16/BF16"
35        ))),
36    }
37}
38
39use crate::config::{QuantConfig, QuantMethod};
40use crate::dense::DenseLinear;
41use crate::gptq::GptqLinear;
42use crate::loader::WeightLoader;
43use crate::traits::Linear;
44
45/// Tensor metadata extracted ONCE from the safetensors header at open time.
46/// Avoids the per-tensor `SafeTensors::deserialize(&mmap)` re-parse that
47/// previously dominated cold-load on stacked-MoE (>= 18 000 calls per
48/// model, ~30 ms each on a 7 000-tensor header → 9+ minutes of header
49/// re-parse alone).
50struct TensorMeta {
51    dtype: Dtype,
52    shape: Vec<usize>,
53    /// Byte range in the shard's mmap that holds the raw tensor data.
54    data_start: usize,
55    data_end: usize,
56}
57
58/// A single shard file: mmap + name→TensorMeta cache.
59struct Shard {
60    mmap: Mmap,
61    names: Vec<String>,
62    /// Pre-extracted tensor metadata. Looked up by name; no header
63    /// re-parse on every read.
64    meta: HashMap<String, TensorMeta>,
65}
66
67impl Shard {
68    fn open(path: &Path) -> Result<Self> {
69        let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
70        let mmap = unsafe {
71            Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
72        };
73        // Parse the header ONCE and cache (offset, len, dtype, shape) for
74        // every tensor — re-deriving the data slice from the cache is a
75        // simple `&mmap[start..end]` rather than a full header re-parse.
76        let st = SafeTensors::deserialize(&mmap)
77            .map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
78        // SafeTensors stores: 8-byte little-endian header_len + header_json
79        // + data_blob. The TensorView's `data()` returns a slice into the
80        // data_blob region. We compute the data_blob base by reading the
81        // 8-byte header_len.
82        debug_assert!(mmap.len() >= 8, "safetensors smaller than 8 bytes");
83        let header_len = u64::from_le_bytes(
84            mmap[0..8]
85                .try_into()
86                .expect("8-byte header len read failed"),
87        ) as usize;
88        let data_base = 8 + header_len;
89        let names: Vec<String> = st.names().iter().map(|s| s.to_string()).collect();
90        let mut meta = HashMap::with_capacity(names.len());
91        for name in &names {
92            let view = st.tensor(name).map_err(|e| {
93                FerrumError::model(format!("tensor '{name}' missing during preindex: {e}"))
94            })?;
95            // TensorView::data() is &[u8] into the mmap; we recompute its
96            // [start, end) byte range relative to the mmap base via
97            // pointer arithmetic.
98            let view_data = view.data();
99            let start = view_data.as_ptr() as usize - mmap.as_ptr() as usize;
100            let end = start + view_data.len();
101            debug_assert!(start >= data_base);
102            meta.insert(
103                name.clone(),
104                TensorMeta {
105                    dtype: view.dtype(),
106                    shape: view.shape().to_vec(),
107                    data_start: start,
108                    data_end: end,
109                },
110            );
111        }
112        let _ = data_base;
113        Ok(Self { mmap, names, meta })
114    }
115
116    /// Returns (data_bytes, dtype, shape) for the named tensor without
117    /// re-parsing the safetensors header.
118    fn get_cached(&self, name: &str) -> Result<(&[u8], Dtype, &[usize])> {
119        let m = self
120            .meta
121            .get(name)
122            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in shard")))?;
123        Ok((&self.mmap[m.data_start..m.data_end], m.dtype, &m.shape))
124    }
125}
126
127/// Native safetensors loader. Generic over `Backend` so every tensor is
128/// materialised directly into backend-native buffers.
129pub struct NativeSafetensorsLoader<B: Backend + BackendQuantMarlin> {
130    /// All shards keyed by file; each tensor's name maps to its shard here.
131    shards: Vec<Shard>,
132    /// Name → shard index. Populated once at construction.
133    index: HashMap<String, usize>,
134    /// Optional `quantize_config.json` contents.
135    quant_config: Option<QuantConfig>,
136    _m: std::marker::PhantomData<B>,
137}
138
139impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
140    /// Discover shards under `model_dir` and build the name → shard index.
141    pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
142        let dir = model_dir.as_ref();
143
144        let shard_paths = if dir.join("model.safetensors").exists() {
145            vec![dir.join("model.safetensors")]
146        } else if dir.join("model.safetensors.index.json").exists() {
147            Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
148                .into_iter()
149                .map(|name| dir.join(name))
150                .collect()
151        } else {
152            return Err(FerrumError::model(format!(
153                "no safetensors files in {dir:?}"
154            )));
155        };
156
157        let mut shards = Vec::with_capacity(shard_paths.len());
158        let mut index: HashMap<String, usize> = HashMap::new();
159        for (i, p) in shard_paths.iter().enumerate() {
160            let shard = Shard::open(p)?;
161            for name in &shard.names {
162                index.insert(name.clone(), i);
163            }
164            shards.push(shard);
165        }
166
167        let quant_config = load_quantize_config(dir)?;
168
169        Ok(Self {
170            shards,
171            index,
172            quant_config,
173            _m: std::marker::PhantomData,
174        })
175    }
176
177    fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
178        let data = std::fs::read_to_string(index_path)
179            .map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
180        let json: serde_json::Value = serde_json::from_str(&data)
181            .map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
182        let weight_map = json
183            .get("weight_map")
184            .and_then(|v| v.as_object())
185            .ok_or_else(|| FerrumError::model("index missing weight_map"))?;
186        let mut files: Vec<String> = weight_map
187            .values()
188            .filter_map(|v| v.as_str().map(|s| s.to_string()))
189            .collect();
190        files.sort();
191        files.dedup();
192        Ok(files)
193    }
194
195    /// Read a tensor as f32 (converting from bf16 / f16 / f32) + its shape.
196    fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
197        let shard_idx = *self
198            .index
199            .get(name)
200            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
201        let (data_bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
202        let data = dtype_to_f32(dtype, data_bytes)?;
203        Ok((data, shape.to_vec()))
204    }
205
206    /// Read the raw on-disk byte slice plus dtype and shape. Zero-copy into
207    /// the mmap — used to hand weights straight to `B::from_weight_bytes` so
208    /// a fp16-preferring backend can skip the transient f32 Vec.
209    fn read_bytes_typed(&self, name: &str) -> Result<(&[u8], SrcDtype, Vec<usize>)> {
210        let shard_idx = *self
211            .index
212            .get(name)
213            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
214        let (data_bytes, st_dtype, shape) = self.shards[shard_idx].get_cached(name)?;
215        let dtype = map_src_dtype(st_dtype)?;
216        Ok((data_bytes, dtype, shape.to_vec()))
217    }
218
219    /// Concatenate several tensors along dim 0 at the byte level. All parts
220    /// must share the same dtype and trailing-dim width. Returns the fused
221    /// raw bytes + common dtype + `(total_rows, cols)` shape.
222    fn cat_rows_bytes(&self, names: &[String]) -> Result<(Vec<u8>, SrcDtype, (usize, usize))> {
223        let mut total_rows = 0usize;
224        let mut cols = 0usize;
225        let mut dtype: Option<SrcDtype> = None;
226        let mut bytes: Vec<u8> = Vec::new();
227        for n in names {
228            let (raw, d, shape) = self.read_bytes_typed(n)?;
229            if shape.len() != 2 {
230                return Err(FerrumError::model(format!(
231                    "cat_rows_bytes: '{n}' is {shape:?}, need 2D"
232                )));
233            }
234            match dtype {
235                Some(prev) if prev != d => {
236                    return Err(FerrumError::model(format!(
237                        "cat_rows_bytes: dtype mismatch on '{n}'"
238                    )))
239                }
240                _ => dtype = Some(d),
241            }
242            if cols == 0 {
243                cols = shape[1];
244            } else if cols != shape[1] {
245                return Err(FerrumError::model(format!(
246                    "cat_rows_bytes: col mismatch {cols} vs {}",
247                    shape[1]
248                )));
249            }
250            total_rows += shape[0];
251            bytes.extend_from_slice(raw);
252        }
253        Ok((bytes, dtype.expect("at least one part"), (total_rows, cols)))
254    }
255
256    /// Read a tensor as i32 (for GPTQ qweight / qzeros / g_idx).
257    /// Bulk memcpy from the LE-stored bytes (safetensors guarantees LE)
258    /// — the previous per-element `from_le_bytes` was 4 ms for a single
259    /// 768 KB tensor and dominated stacked-MoE load.
260    fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
261        let shard_idx = *self
262            .index
263            .get(name)
264            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
265        let (bytes, dtype, shape) = self.shards[shard_idx].get_cached(name)?;
266        if dtype != Dtype::I32 {
267            return Err(FerrumError::model(format!(
268                "'{name}': expected I32, got {:?}",
269                dtype
270            )));
271        }
272        debug_assert_eq!(bytes.len() % 4, 0);
273        let count = bytes.len() / 4;
274        let mut out = Vec::<i32>::with_capacity(count);
275        // SAFETY: Vec<i32>'s buffer is 4-byte aligned by allocator
276        // contract. `bytes` is a raw u8 slice; copy_nonoverlapping
277        // doesn't require src alignment. We're on x86_64 LE, and
278        // safetensors stores LE i32 — bit pattern is identical.
279        unsafe {
280            std::ptr::copy_nonoverlapping(bytes.as_ptr(), out.as_mut_ptr() as *mut u8, bytes.len());
281            out.set_len(count);
282        }
283        Ok((out, shape.to_vec()))
284    }
285
286    fn has(&self, name: &str) -> bool {
287        self.index.contains_key(name)
288    }
289
290    /// Read the four raw GPTQ tensors for a named projection without
291    /// triggering a Backend repack. Used by MoE batch loading: callers
292    /// stack many experts host-side then issue a single `B::load_gptq`,
293    /// avoiding the 12 288× per-expert Marlin repack overhead.
294    ///
295    /// Returns `(qweight, scales, qzeros, g_idx, k, n)`.
296    /// `g_idx` is `None` when desc_act=false (no act-order perm needed).
297    pub fn read_gptq_raw(
298        &self,
299        name: &str,
300    ) -> Result<(Vec<i32>, Vec<f32>, Vec<i32>, Option<Vec<i32>>, usize, usize)> {
301        let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
302        let (scales, _) = self.read_f32(&format!("{name}.scales"))?;
303        let (qzeros, _) = self.read_i32(&format!("{name}.qzeros"))?;
304        let g_idx = if self.has(&format!("{name}.g_idx")) {
305            Some(self.read_i32(&format!("{name}.g_idx"))?.0)
306        } else {
307            None
308        };
309        if qw_shape.len() != 2 {
310            return Err(FerrumError::model(format!(
311                "'{name}.qweight' expected 2D, got {qw_shape:?}"
312            )));
313        }
314        let k = qw_shape[0] * 8;
315        let n = qw_shape[1];
316        Ok((qweight, scales, qzeros, g_idx, k, n))
317    }
318
319    pub fn quant_config_ref(&self) -> Option<&crate::config::QuantConfig> {
320        self.quant_config.as_ref()
321    }
322
323    /// Load a STACKED GPTQ tile that concatenates `num_experts` experts'
324    /// raw GPTQ tensors along the N (column) axis and runs ONE backend
325    /// repack — instead of `num_experts × proj_names.len()` repacks.
326    ///
327    /// Layout: per row `r`, the cols are emitted in expert-major order:
328    /// `expert_0[proj_0|proj_1|...] | expert_1[...] | ... | expert_{N-1}[...]`.
329    /// Caller can therefore index expert `e` at column offset
330    /// `e * n_per_expert`, where `n_per_expert = Σ n(proj)` across the
331    /// `proj_names` for one expert.
332    ///
333    /// `expert_prefix_fmt` should be a closure-style `&str` that contains
334    /// `"{e}"` placeholder (replaced by the expert index) and ends *just
335    /// before* the proj name — e.g. `"model.layers.5.mlp.experts.{e}."`.
336    /// The full tensor name probed is `{expert_prefix}{proj}`.
337    ///
338    /// Returns `(store, n_per_expert, k)` where `n_per_expert` is the
339    /// per-expert column width and `k = in_features` (shared by all).
340    pub fn load_stacked_gptq_experts(
341        &self,
342        expert_prefix_fmt: &str,
343        num_experts: usize,
344        proj_names: &[&str],
345    ) -> Result<(
346        std::sync::Arc<dyn ferrum_kernels::MarlinExpertStack<B>>,
347        usize,
348        usize,
349    )> {
350        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
351            FerrumError::model(
352                "load_stacked_gptq_experts requires quantize_config.json".to_string(),
353            )
354        })?;
355        if qcfg.method != QuantMethod::Gptq {
356            return Err(FerrumError::model(format!(
357                "stacked GPTQ load but quant_method={:?}",
358                qcfg.method
359            )));
360        }
361
362        let mut qw_rows = 0usize;
363        let mut sc_rows = 0usize;
364        let mut qz_rows = 0usize;
365        let mut n_per_expert = 0usize;
366        let mut n_per_expert_scales = 0usize;
367        let mut n_per_expert_zeros = 0usize;
368        let mut k_shared = 0usize;
369        let mut g_idx_first: Option<Vec<i32>> = None;
370
371        // Per (expert, proj) raw slices — row-major (rows × cols).
372        let total_pairs = num_experts * proj_names.len();
373        let mut qw_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs); // (data, cols)
374        let mut sc_parts: Vec<(Vec<f32>, usize)> = Vec::with_capacity(total_pairs);
375        let mut qz_parts: Vec<(Vec<i32>, usize)> = Vec::with_capacity(total_pairs);
376
377        for e in 0..num_experts {
378            let prefix = expert_prefix_fmt.replace("{e}", &e.to_string());
379            let mut e_n = 0usize;
380            let mut e_n_scales = 0usize;
381            let mut e_n_zeros = 0usize;
382            for proj in proj_names {
383                let name = format!("{prefix}{proj}");
384                let (qw, qw_sh) = self.read_i32(&format!("{name}.qweight"))?;
385                let (sc, sc_sh) = self.read_f32(&format!("{name}.scales"))?;
386                let (qz, qz_sh) = self.read_i32(&format!("{name}.qzeros"))?;
387                if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
388                    return Err(FerrumError::model(format!(
389                        "stacked GPTQ '{name}': expected 2D, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
390                    )));
391                }
392                if qw_rows == 0 {
393                    qw_rows = qw_sh[0];
394                    sc_rows = sc_sh[0];
395                    qz_rows = qz_sh[0];
396                    k_shared = qw_sh[0] * 8;
397                } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
398                    return Err(FerrumError::model(format!(
399                        "stacked GPTQ '{name}': row mismatch qw {} sc {} qz {} vs ref {qw_rows}/{sc_rows}/{qz_rows}",
400                        qw_sh[0], sc_sh[0], qz_sh[0]
401                    )));
402                }
403                e_n += qw_sh[1];
404                e_n_scales += sc_sh[1];
405                e_n_zeros += qz_sh[1];
406                qw_parts.push((qw, qw_sh[1]));
407                sc_parts.push((sc, sc_sh[1]));
408                qz_parts.push((qz, qz_sh[1]));
409
410                // g_idx is a permutation over K — Marlin assumes ONE g_idx
411                // for the whole stacked tile. Validate all experts share
412                // identical g_idx if any has it (which they should, since
413                // K = hidden_size is the same across experts and GPTQ's
414                // act-order is computed on the input distribution).
415                let g_key = format!("{name}.g_idx");
416                if self.has(&g_key) {
417                    let (gx, _) = self.read_i32(&g_key)?;
418                    match &g_idx_first {
419                        None => g_idx_first = Some(gx),
420                        Some(prev) => {
421                            if prev.len() != gx.len() || prev.iter().zip(&gx).any(|(a, b)| a != b) {
422                                return Err(FerrumError::model(format!(
423                                    "stacked GPTQ '{name}': g_idx mismatch with first \
424                                     expert — Marlin requires identical act-order across \
425                                     experts in the same stacked tile"
426                                )));
427                            }
428                        }
429                    }
430                }
431            }
432            if e == 0 {
433                n_per_expert = e_n;
434                n_per_expert_scales = e_n_scales;
435                n_per_expert_zeros = e_n_zeros;
436            } else if e_n != n_per_expert
437                || e_n_scales != n_per_expert_scales
438                || e_n_zeros != n_per_expert_zeros
439            {
440                return Err(FerrumError::model(format!(
441                    "stacked GPTQ expert {e} N mismatch: qw {e_n} sc {e_n_scales} qz {e_n_zeros} vs expert 0 {n_per_expert}/{n_per_expert_scales}/{n_per_expert_zeros}"
442                )));
443            }
444        }
445
446        let proj_count = proj_names.len();
447        let pairs_per_expert = proj_count;
448        debug_assert_eq!(total_pairs, num_experts * pairs_per_expert);
449
450        // PER-EXPERT layout: build num_experts independent
451        // `[K/8, n_per_expert]` qweight tiles + scales + qzeros, each
452        // a row-major concat of the proj_names within that expert.
453        // Hand them to `B::load_gptq_stacked` which repacks PER-EXPERT
454        // and concats the resulting Marlin-format tiles into one
455        // contiguous buffer. Each expert's packed bytes are then
456        // contiguous, so the offset GEMM dispatches correctly via
457        // pointer arithmetic alone.
458        //
459        // Without per-expert repack, a single concat-then-repack of
460        // the stacked tile mangles per-expert tile boundaries (Marlin
461        // permutes in K-tile-major order across the whole tile).
462        let mut per_expert_qw: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
463        let mut per_expert_sc: Vec<Vec<f32>> = Vec::with_capacity(num_experts);
464        let mut per_expert_qz: Vec<Vec<i32>> = Vec::with_capacity(num_experts);
465        for e in 0..num_experts {
466            let mut qw: Vec<i32> = Vec::with_capacity(qw_rows * n_per_expert);
467            let mut sc: Vec<f32> = Vec::with_capacity(sc_rows * n_per_expert_scales);
468            let mut qz: Vec<i32> = Vec::with_capacity(qz_rows * n_per_expert_zeros);
469            for r in 0..qw_rows {
470                for j in 0..pairs_per_expert {
471                    let pair_idx = e * pairs_per_expert + j;
472                    let (data, cols) = &qw_parts[pair_idx];
473                    qw.extend_from_slice(&data[r * cols..(r + 1) * cols]);
474                }
475            }
476            for r in 0..sc_rows {
477                for j in 0..pairs_per_expert {
478                    let pair_idx = e * pairs_per_expert + j;
479                    let (data, cols) = &sc_parts[pair_idx];
480                    sc.extend_from_slice(&data[r * cols..(r + 1) * cols]);
481                }
482            }
483            for r in 0..qz_rows {
484                for j in 0..pairs_per_expert {
485                    let pair_idx = e * pairs_per_expert + j;
486                    let (data, cols) = &qz_parts[pair_idx];
487                    qz.extend_from_slice(&data[r * cols..(r + 1) * cols]);
488                }
489            }
490            per_expert_qw.push(qw);
491            per_expert_sc.push(sc);
492            per_expert_qz.push(qz);
493        }
494
495        // Drop the original part buffers — we own copies in per_expert_*.
496        drop(qw_parts);
497        drop(sc_parts);
498        drop(qz_parts);
499
500        let qw_refs: Vec<&[i32]> = per_expert_qw.iter().map(|v| v.as_slice()).collect();
501        let sc_refs: Vec<&[f32]> = per_expert_sc.iter().map(|v| v.as_slice()).collect();
502        let qz_refs: Vec<&[i32]> = per_expert_qz.iter().map(|v| v.as_slice()).collect();
503
504        let store = B::load_gptq_stacked(
505            &qw_refs,
506            &sc_refs,
507            &qz_refs,
508            g_idx_first.as_deref(),
509            qcfg.bits,
510            qcfg.group_size,
511            k_shared,
512            n_per_expert,
513        )?;
514        Ok((store, n_per_expert, k_shared))
515    }
516}
517
518impl<B: Backend + BackendQuantMarlin> WeightLoader<B> for NativeSafetensorsLoader<B> {
519    fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
520        // Route through `from_weight_bytes` so fp16-preferring backends can
521        // materialise big tensors (embed table) directly as half-precision
522        // without the transient f32 Vec. Tiny tensors (norm weights) still
523        // end up as f32 because backends size-threshold inside the override.
524        let (raw, src_dtype, _) = self.read_bytes_typed(name)?;
525        Ok(B::from_weight_bytes(raw, src_dtype))
526    }
527
528    fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
529        // GPTQ first: `<name>.qweight` + `<name>.scales` + `<name>.qzeros`.
530        let qw_key = format!("{name}.qweight");
531        if self.has(&qw_key) {
532            return self.load_gptq_linear(name);
533        }
534        // GPTQ fusion shims: synthesise qkv_proj / gate_up_proj from split
535        // components — same pattern as Dense but concatenating the GPTQ
536        // tensors (qweight/scales/qzeros) along the N dim.
537        if let Some(prefix) = name.strip_suffix("qkv_proj") {
538            let parts = [
539                format!("{prefix}q_proj"),
540                format!("{prefix}k_proj"),
541                format!("{prefix}v_proj"),
542            ];
543            if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
544                return self.load_gptq_linear_fused(&parts);
545            }
546        }
547        if let Some(prefix) = name.strip_suffix("gate_up_proj") {
548            let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
549            if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
550                return self.load_gptq_linear_fused(&parts);
551            }
552        }
553
554        // Direct fused `<name>.weight` next. Load straight from raw bytes
555        // so fp16-preferring backends can skip the f32 Vec intermediate.
556        let direct = format!("{name}.weight");
557        if self.has(&direct) {
558            let (raw, src_dtype, shape) = self.read_bytes_typed(&direct)?;
559            if shape.len() != 2 {
560                return Err(FerrumError::model(format!(
561                    "linear '{name}': expected 2D weight, got {shape:?}"
562                )));
563            }
564            let weight = B::from_weight_bytes(raw, src_dtype);
565            return Ok(Box::new(DenseLinear::<B>::from_buffer(
566                weight, shape[0], shape[1],
567            )));
568        }
569
570        // Llama-family fusion shims: synthesise qkv_proj / gate_up_proj from
571        // split q_proj+k_proj+v_proj / gate_proj+up_proj if present. The cat
572        // happens at the byte level so fused-weight memory is the same size
573        // as the per-part weights — no expansion to f32.
574        if let Some(prefix) = name.strip_suffix("qkv_proj") {
575            let parts = [
576                format!("{prefix}q_proj.weight"),
577                format!("{prefix}k_proj.weight"),
578                format!("{prefix}v_proj.weight"),
579            ];
580            if parts.iter().all(|p| self.has(p)) {
581                let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
582                let weight = B::from_weight_bytes(&bytes, dtype);
583                return Ok(Box::new(DenseLinear::<B>::from_buffer(weight, rows, cols)));
584            }
585        }
586        if let Some(prefix) = name.strip_suffix("gate_up_proj") {
587            let parts = [
588                format!("{prefix}gate_proj.weight"),
589                format!("{prefix}up_proj.weight"),
590            ];
591            if parts.iter().all(|p| self.has(p)) {
592                let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
593                let weight = B::from_weight_bytes(&bytes, dtype);
594                return Ok(Box::new(DenseLinear::<B>::from_buffer(weight, rows, cols)));
595            }
596        }
597
598        Err(FerrumError::model(format!(
599            "could not load linear '{name}' — no direct `.weight`, no split components"
600        )))
601    }
602
603    fn has_tensor(&self, name: &str) -> bool {
604        self.has(name)
605    }
606
607    fn quant_config(&self) -> Option<&QuantConfig> {
608        self.quant_config.as_ref()
609    }
610}
611
612impl<B: Backend + BackendQuantMarlin> NativeSafetensorsLoader<B> {
613    /// Load a GPTQ-packed linear projection: reads `<name>.qweight`,
614    /// `<name>.scales`, `<name>.qzeros`, optionally `<name>.g_idx`, and
615    /// hands the raw host-side tensors to `Backend::load_gptq` which
616    /// repacks + uploads per its own strategy.
617    fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
618        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
619            FerrumError::model(format!(
620                "'{name}.qweight' present but no quantize_config.json — \
621                 can't determine bits/group_size"
622            ))
623        })?;
624        if qcfg.method != QuantMethod::Gptq {
625            return Err(FerrumError::model(format!(
626                "'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
627                qcfg.method
628            )));
629        }
630
631        let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
632        let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
633        let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
634        let g_idx = if self.has(&format!("{name}.g_idx")) {
635            Some(self.read_i32(&format!("{name}.g_idx"))?.0)
636        } else {
637            None
638        };
639
640        // Shape inference: qweight is [K/8, N]; scales is [K/group, N].
641        // → K = qw_shape[0] * 8, N = qw_shape[1].
642        if qw_shape.len() != 2 {
643            return Err(FerrumError::model(format!(
644                "'{name}.qweight' expected 2D, got {qw_shape:?}"
645            )));
646        }
647        let in_features = qw_shape[0] * 8;
648        let out_features = qw_shape[1];
649
650        // desc_act=true detection. AutoGPTQ writes g_idx[k] = k/group_size
651        // for desc_act=false (trivial). Non-monotonic values → act-order.
652        let is_desc_act = g_idx.as_ref().map_or(false, |gx| {
653            !gx.iter()
654                .enumerate()
655                .all(|(i, &g)| g == (i as i32) / qcfg.group_size as i32)
656        });
657
658        // Act-order GPTQ. CUDA backend has perm-aware Marlin (load_gptq
659        // builds perm = argsort(g_idx) + permutes qweight rows at load;
660        // gemm_gptq gathers input columns before the standard Marlin call).
661        // CPU/Metal still need the dequant→DenseLinear fallback.
662        #[cfg(not(feature = "cuda"))]
663        if is_desc_act {
664            let dequant_f32 = dequantize_gptq_with_g_idx(
665                &qweight,
666                &scales_f32,
667                &qzeros,
668                g_idx.as_ref().expect("desc_act=true requires g_idx"),
669                qcfg.group_size,
670                in_features,
671                out_features,
672            );
673            let mut linear =
674                crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
675            let bias_key = format!("{name}.bias");
676            if self.has(&bias_key) {
677                let (bias, _) = self.read_f32(&bias_key)?;
678                linear = linear.with_bias(B::from_slice(&bias));
679            }
680            tracing::info!(
681                "GPTQ load (desc_act dequant→DenseLinear, non-cuda): name={name} K={in_features} N={out_features}"
682            );
683            return Ok(Box::new(linear));
684        }
685        #[cfg(feature = "cuda")]
686        let _ = is_desc_act; // CUDA: g_idx threaded through to GptqLinear below
687        if sc_shape.len() != 2 || sc_shape[1] != out_features {
688            return Err(FerrumError::model(format!(
689                "'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
690            )));
691        }
692
693        // Read optional bias FIRST (Qwen2.5 attention projections, some
694        // Llama variants). Phase 3e/2: load_gptq takes the bias eagerly
695        // because the boxed Linear bakes it in.
696        let bias_key = format!("{name}.bias");
697        let bias_vec = if self.has(&bias_key) {
698            let (bias, bias_shape) = self.read_f32(&bias_key)?;
699            if bias_shape != [out_features] {
700                return Err(FerrumError::model(format!(
701                    "'{bias_key}' {bias_shape:?} != [{out_features}]"
702                )));
703            }
704            Some(bias)
705        } else {
706            None
707        };
708
709        let linear = GptqLinear::<B>::from_raw(
710            &qweight,
711            &scales_f32,
712            &qzeros,
713            g_idx.as_deref(),
714            bias_vec.as_deref(),
715            qcfg.bits,
716            qcfg.group_size,
717            in_features,
718            out_features,
719        )?;
720        Ok(Box::new(linear))
721    }
722
723    /// Fuse multiple GPTQ projections by concatenating qweight/scales/qzeros
724    /// along the output (N) dim. Matches the Dense fusion shim used for
725    /// non-quantized models: q_proj + k_proj + v_proj → qkv_proj.
726    ///
727    /// All parts must share:
728    /// - in_features (K)
729    /// - bits, group_size
730    /// - qzeros N-packing (which the GPTQ format always honours: qzeros[-1]
731    ///   = N/8, concat along that axis works)
732    ///
733    /// g_idx: only present when desc_act=true. When present, all parts
734    /// share it (same K rows, same activation permutation).
735    fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
736        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
737            FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
738        })?;
739        if qcfg.method != QuantMethod::Gptq {
740            return Err(FerrumError::model(format!(
741                "GPTQ fusion but quant_method={:?}",
742                qcfg.method
743            )));
744        }
745
746        let mut qw_acc: Vec<i32> = Vec::new();
747        let mut sc_acc: Vec<f32> = Vec::new();
748        let mut qz_acc: Vec<i32> = Vec::new();
749        let mut qw_rows = 0usize;
750        let mut sc_rows = 0usize;
751        let mut qz_rows = 0usize;
752        let mut total_n = 0usize;
753        let mut total_n_scales = 0usize;
754        let mut total_n_zeros = 0usize;
755        let mut g_idx: Option<Vec<i32>> = None;
756        // Segments: (qw_slice, sc_slice, qz_slice) per part, needed for N-major layout concat
757        let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); // (data, rows, cols)
758        let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
759        let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
760
761        for p in parts {
762            let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
763            let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
764            let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
765            if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
766                return Err(FerrumError::model(format!(
767                    "GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
768                )));
769            }
770            if qw_rows == 0 {
771                qw_rows = qw_sh[0];
772                sc_rows = sc_sh[0];
773                qz_rows = qz_sh[0];
774            } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
775                return Err(FerrumError::model(format!(
776                    "GPTQ fusion row mismatch on '{p}'"
777                )));
778            }
779            total_n += qw_sh[1];
780            total_n_scales += sc_sh[1];
781            total_n_zeros += qz_sh[1];
782            qw_parts.push((qw, qw_sh[0], qw_sh[1]));
783            sc_parts.push((sc, sc_sh[0], sc_sh[1]));
784            qz_parts.push((qz, qz_sh[0], qz_sh[1]));
785
786            // g_idx optional; if first part has it, use that
787            if g_idx.is_none() && self.has(&format!("{p}.g_idx")) {
788                g_idx = Some(self.read_i32(&format!("{p}.g_idx"))?.0);
789            }
790        }
791
792        // Interleave row-major concatenation: for each row, write all parts' cols.
793        qw_acc.reserve(qw_rows * total_n);
794        for r in 0..qw_rows {
795            for (part, _rows, cols) in &qw_parts {
796                qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
797            }
798        }
799        sc_acc.reserve(sc_rows * total_n_scales);
800        for r in 0..sc_rows {
801            for (part, _rows, cols) in &sc_parts {
802                sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
803            }
804        }
805        qz_acc.reserve(qz_rows * total_n_zeros);
806        for r in 0..qz_rows {
807            for (part, _rows, cols) in &qz_parts {
808                qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
809            }
810        }
811
812        let in_features = qw_rows * 8;
813        let out_features = total_n;
814
815        // desc_act detection — same as load_gptq_linear. Q/K/V (or
816        // gate/up) share K, so g_idx from first part covers all.
817        let is_desc_act = g_idx.as_ref().map_or(false, |gx| {
818            !gx.iter()
819                .enumerate()
820                .all(|(i, &g)| g == (i as i32) / qcfg.group_size as i32)
821        });
822        // CUDA: perm-aware Marlin via load_gptq. CPU/Metal: dequant→Dense.
823        #[cfg(not(feature = "cuda"))]
824        if is_desc_act {
825            let dequant_f32 = dequantize_gptq_with_g_idx(
826                &qw_acc,
827                &sc_acc,
828                &qz_acc,
829                g_idx.as_ref().expect("desc_act=true requires g_idx"),
830                qcfg.group_size,
831                in_features,
832                out_features,
833            );
834            let mut linear =
835                crate::dense::DenseLinear::<B>::from_rows(&dequant_f32, out_features, in_features);
836            let mut bias_acc: Vec<f32> = Vec::new();
837            let mut any_bias = false;
838            for p in parts {
839                let bk = format!("{p}.bias");
840                if self.has(&bk) {
841                    any_bias = true;
842                    bias_acc.extend_from_slice(&self.read_f32(&bk)?.0);
843                } else if any_bias {
844                    return Err(FerrumError::model(format!(
845                        "GPTQ fusion bias mix: '{p}' has no bias but earlier part did"
846                    )));
847                }
848            }
849            if any_bias {
850                linear = linear.with_bias(B::from_slice(&bias_acc));
851            }
852            tracing::info!(
853                "GPTQ fused load (desc_act dequant→DenseLinear, non-cuda): K={in_features} N={out_features} parts={}",
854                parts.len()
855            );
856            return Ok(Box::new(linear));
857        }
858        #[cfg(feature = "cuda")]
859        let _ = is_desc_act;
860
861        // Biases: concatenate `<part>.bias` across parts in the same
862        // order as qweights. All-or-none; if any part has a bias, all
863        // must. Phase 3e/2: read first, pass into load_gptq.
864        let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
865        let any = bias_keys.iter().any(|k| self.has(k));
866        let all = bias_keys.iter().all(|k| self.has(k));
867        if any && !all {
868            return Err(FerrumError::model(
869                "GPTQ fusion: inconsistent bias presence across parts".to_string(),
870            ));
871        }
872        let fused_bias = if all {
873            let mut fused: Vec<f32> = Vec::with_capacity(out_features);
874            for k in &bias_keys {
875                let (b, _) = self.read_f32(k)?;
876                fused.extend_from_slice(&b);
877            }
878            if fused.len() != out_features {
879                return Err(FerrumError::model(format!(
880                    "GPTQ fusion bias length {} != out_features {out_features}",
881                    fused.len()
882                )));
883            }
884            Some(fused)
885        } else {
886            None
887        };
888
889        let linear = GptqLinear::<B>::from_raw(
890            &qw_acc,
891            &sc_acc,
892            &qz_acc,
893            g_idx.as_deref(),
894            fused_bias.as_deref(),
895            qcfg.bits,
896            qcfg.group_size,
897            in_features,
898            out_features,
899        )?;
900
901        Ok(Box::new(linear))
902    }
903
904    /// Read each name, assert shape width matches, concatenate along dim 0.
905    /// Kept for diagnostic / fallback paths; DenseLinear fusion prefers the
906    /// byte-level `cat_rows_bytes` above.
907    #[allow(dead_code)]
908    fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
909        let mut total_rows = 0usize;
910        let mut cols = 0usize;
911        let mut out: Vec<f32> = Vec::new();
912        for n in names {
913            let (data, shape) = self.read_f32(n)?;
914            if shape.len() != 2 {
915                return Err(FerrumError::model(format!(
916                    "cat_rows: '{n}' is {shape:?}, need 2D"
917                )));
918            }
919            if cols == 0 {
920                cols = shape[1];
921            } else if cols != shape[1] {
922                return Err(FerrumError::model(format!(
923                    "cat_rows: col mismatch {cols} vs {}",
924                    shape[1]
925                )));
926            }
927            total_rows += shape[0];
928            out.extend_from_slice(&data);
929        }
930        Ok((total_rows, cols, out))
931    }
932}
933
934/// Dequantise GPTQ INT4 weights with desc_act=true (act-order) g_idx and
935/// return original-order f32 weights laid out `[N, K]` row-major (matches
936/// `DenseLinear::from_rows`).
937///
938/// Key insight: in AutoGPTQ desc_act format, qweight rows are NOT
939/// permuted from original-K order. The act-order trick is encoded purely
940/// in `g_idx[k]` — which records the QUANTISATION GROUP (not column
941/// position) chosen for disk row k. Different rows that originally
942/// belonged to far-apart positions can share a group via g_idx.
943///
944/// Verified against vLLM's exllama path (gptq.py:368): for desc_act it
945/// runs `g_idx ← argsort(g_idx)` then `gptq_shuffle(qweight, g_idx)`,
946/// which physically reorders qweight by argsort and gathers x by argsort
947/// at GEMM. Net effect: y[n] = Σⱼ x[j] · dequant(qweight[j, n],
948/// scales[g_idx_orig[j], n], qzeros[g_idx_orig[j], n]).
949/// → disk_k IS original_k; only the (scale, zero) LOOKUP differs.
950#[cfg(not(feature = "cuda"))]
951fn dequantize_gptq_with_g_idx(
952    qweight: &[i32], // [K/8, N] packed int4
953    scales: &[f32],  // [num_groups, N]
954    qzeros: &[i32],  // [num_groups, N/8] packed int4
955    g_idx: &[i32],   // [K]
956    _group_size: usize,
957    k: usize,
958    n: usize,
959) -> Vec<f32> {
960    debug_assert_eq!(g_idx.len(), k);
961
962    // Output: [N, K] row-major → out[col * k + k_idx] = value.
963    let mut w = vec![0.0f32; n * k];
964    let packed_rows = k / 8;
965    for pr in 0..packed_rows {
966        for col in 0..n {
967            let packed = qweight[pr * n + col] as u32;
968            for bi in 0..8 {
969                let ki = pr * 8 + bi;
970                let q = ((packed >> (bi * 4)) & 0xF) as i32;
971                let g = g_idx[ki] as usize;
972                let scale = scales[g * n + col];
973                let z_packed = qzeros[g * (n / 8) + (col / 8)] as u32;
974                let zero = (((z_packed >> ((col % 8) * 4)) & 0xF) as i32) + 1;
975                w[col * k + ki] = (q - zero) as f32 * scale;
976            }
977        }
978    }
979    w
980}
981
982fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
983    match dtype {
984        Dtype::F32 => {
985            // Bulk memcpy from LE-stored bytes (safetensors is LE; we're
986            // on x86_64 LE). Per-element from_le_bytes was the bottleneck
987            // for stacked-MoE load (4-5 ms per call * 384 calls/layer *
988            // 48 layers = ~80 sec just for f32 reads).
989            debug_assert_eq!(raw.len() % 4, 0);
990            let n = raw.len() / 4;
991            let mut out = Vec::<f32>::with_capacity(n);
992            unsafe {
993                std::ptr::copy_nonoverlapping(raw.as_ptr(), out.as_mut_ptr() as *mut u8, raw.len());
994                out.set_len(n);
995            }
996            Ok(out)
997        }
998        Dtype::F16 => {
999            debug_assert_eq!(raw.len() % 2, 0);
1000            let n = raw.len() / 2;
1001            // Reinterpret raw bytes as f16, then convert. This avoids
1002            // the per-element from_le_bytes byte-array construction.
1003            let mut tmp = Vec::<f16>::with_capacity(n);
1004            unsafe {
1005                std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
1006                tmp.set_len(n);
1007            }
1008            let mut out = Vec::with_capacity(n);
1009            for h in &tmp {
1010                out.push(h.to_f32());
1011            }
1012            Ok(out)
1013        }
1014        Dtype::BF16 => {
1015            debug_assert_eq!(raw.len() % 2, 0);
1016            let n = raw.len() / 2;
1017            let mut tmp = Vec::<bf16>::with_capacity(n);
1018            unsafe {
1019                std::ptr::copy_nonoverlapping(raw.as_ptr(), tmp.as_mut_ptr() as *mut u8, raw.len());
1020                tmp.set_len(n);
1021            }
1022            let mut out = Vec::with_capacity(n);
1023            for h in &tmp {
1024                out.push(h.to_f32());
1025            }
1026            Ok(out)
1027        }
1028        other => Err(FerrumError::model(format!(
1029            "dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
1030             use a format-specific loader (GPTQ / AWQ / GGUF)",
1031        ))),
1032    }
1033}
1034
1035fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
1036    // AutoGPTQ / gptq-for-llama format: separate quantize_config.json.
1037    let p = dir.join("quantize_config.json");
1038    if p.exists() {
1039        let data =
1040            std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
1041        let qc: QuantConfig = serde_json::from_str(&data)
1042            .map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
1043        return Ok(Some(qc));
1044    }
1045    // Qwen GPTQ / transformers-style: embedded in config.json under
1046    // "quantization_config": { "quant_method": "gptq", "bits": 4, ... }.
1047    let cfg = dir.join("config.json");
1048    if cfg.exists() {
1049        let data = std::fs::read_to_string(&cfg)
1050            .map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
1051        let root: serde_json::Value = serde_json::from_str(&data)
1052            .map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
1053        if let Some(qc_val) = root.get("quantization_config") {
1054            // The embedded block has "quant_method" (not "method"); remap.
1055            let method = qc_val
1056                .get("quant_method")
1057                .and_then(|v| v.as_str())
1058                .unwrap_or("none");
1059            let method = match method.to_lowercase().as_str() {
1060                "gptq" => QuantMethod::Gptq,
1061                "awq" => QuantMethod::Awq,
1062                "gguf" => QuantMethod::Gguf,
1063                _ => QuantMethod::None,
1064            };
1065            let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
1066            let group_size = qc_val
1067                .get("group_size")
1068                .and_then(|v| v.as_i64())
1069                .unwrap_or(128)
1070                .max(0) as usize;
1071            let desc_act = qc_val
1072                .get("desc_act")
1073                .and_then(|v| v.as_bool())
1074                .unwrap_or(false);
1075            let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
1076            if method != QuantMethod::None {
1077                return Ok(Some(QuantConfig {
1078                    method,
1079                    bits,
1080                    group_size,
1081                    desc_act,
1082                    sym,
1083                }));
1084            }
1085        }
1086    }
1087    Ok(None)
1088}