Skip to main content

ferrum_quantization/
native_safetensors.rs

1//! Native safetensors `WeightLoader<B>` — mmap + `safetensors` crate, no
2//! candle dependency on the LLM hot path.
3//!
4//! What this owns:
5//!   - Discovering `model.safetensors` vs sharded `model.safetensors.index.json`.
6//!   - Mmapping each shard file.
7//!   - Per-tensor lookup: returns shape + dtype + a byte slice into the mmap.
8//!   - f32 materialisation for Dense weights (bf16 / f16 / f32 accepted).
9//!   - The Qwen3 / Llama fusion trick: `qkv_proj` / `gate_up_proj` synthesised
10//!     on the fly from split `q_proj`+`k_proj`+`v_proj` etc.
11//!
12//! What it deliberately doesn't do:
13//!   - GPTQ / AWQ / GGUF packed weights. Those need `B::from_slice_i32` /
14//!     `B::from_slice_f16` which aren't on the Backend trait yet. A dedicated
15//!     loader per quant format lands in Phase E.
16
17use std::collections::HashMap;
18use std::fs::File;
19use std::path::Path;
20
21use ferrum_kernels::backend::{Backend, SrcDtype};
22use ferrum_types::{FerrumError, Result};
23use half::{bf16, f16};
24use memmap2::Mmap;
25use safetensors::{Dtype, SafeTensors};
26
27/// Map a safetensors Dtype to ferrum's SrcDtype.
28fn map_src_dtype(dtype: Dtype) -> Result<SrcDtype> {
29    match dtype {
30        Dtype::F32 => Ok(SrcDtype::F32),
31        Dtype::F16 => Ok(SrcDtype::F16),
32        Dtype::BF16 => Ok(SrcDtype::BF16),
33        other => Err(FerrumError::model(format!(
34            "dtype {other:?} not supported; Dense path expects F32/F16/BF16"
35        ))),
36    }
37}
38
39use crate::config::{QuantConfig, QuantMethod};
40use crate::dense::DenseLinear;
41use crate::gptq::GptqLinear;
42use crate::loader::WeightLoader;
43use crate::traits::Linear;
44
45/// A single shard file: mmap + name→(shape, dtype, byte-offset-in-mmap).
46struct Shard {
47    mmap: Mmap,
48    /// Parsed entries. Safetensors' `SafeTensors` type borrows from the mmap,
49    /// so we can't store it directly — instead we pre-extract name → metadata
50    /// and rebuild a `SafeTensors` view on demand via `SafeTensors::deserialize`.
51    names: Vec<String>,
52}
53
54impl Shard {
55    fn open(path: &Path) -> Result<Self> {
56        let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
57        let mmap = unsafe {
58            Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
59        };
60        // Parse just to validate and extract names; the SafeTensors view is
61        // rebuilt on each read (cheap — it's a header reparse).
62        let st = SafeTensors::deserialize(&mmap)
63            .map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
64        let names = st.names().iter().map(|s| s.to_string()).collect();
65        Ok(Self { mmap, names })
66    }
67
68    fn get<'a>(&'a self, name: &str) -> Result<safetensors::tensor::TensorView<'a>> {
69        let st = SafeTensors::deserialize(&self.mmap)
70            .map_err(|e| FerrumError::model(format!("reparse: {e}")))?;
71        st.tensor(name)
72            .map_err(|e| FerrumError::model(format!("tensor '{name}': {e}")))
73    }
74}
75
76/// Native safetensors loader. Generic over `Backend` so every tensor is
77/// materialised directly into backend-native buffers.
78pub struct NativeSafetensorsLoader<B: Backend> {
79    /// All shards keyed by file; each tensor's name maps to its shard here.
80    shards: Vec<Shard>,
81    /// Name → shard index. Populated once at construction.
82    index: HashMap<String, usize>,
83    /// Optional `quantize_config.json` contents.
84    quant_config: Option<QuantConfig>,
85    _m: std::marker::PhantomData<B>,
86}
87
88impl<B: Backend> NativeSafetensorsLoader<B> {
89    /// Discover shards under `model_dir` and build the name → shard index.
90    pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
91        let dir = model_dir.as_ref();
92
93        let shard_paths = if dir.join("model.safetensors").exists() {
94            vec![dir.join("model.safetensors")]
95        } else if dir.join("model.safetensors.index.json").exists() {
96            Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
97                .into_iter()
98                .map(|name| dir.join(name))
99                .collect()
100        } else {
101            return Err(FerrumError::model(format!(
102                "no safetensors files in {dir:?}"
103            )));
104        };
105
106        let mut shards = Vec::with_capacity(shard_paths.len());
107        let mut index: HashMap<String, usize> = HashMap::new();
108        for (i, p) in shard_paths.iter().enumerate() {
109            let shard = Shard::open(p)?;
110            for name in &shard.names {
111                index.insert(name.clone(), i);
112            }
113            shards.push(shard);
114        }
115
116        let quant_config = load_quantize_config(dir)?;
117
118        Ok(Self {
119            shards,
120            index,
121            quant_config,
122            _m: std::marker::PhantomData,
123        })
124    }
125
126    fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
127        let data = std::fs::read_to_string(index_path)
128            .map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
129        let json: serde_json::Value = serde_json::from_str(&data)
130            .map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
131        let weight_map = json
132            .get("weight_map")
133            .and_then(|v| v.as_object())
134            .ok_or_else(|| FerrumError::model("index missing weight_map"))?;
135        let mut files: Vec<String> = weight_map
136            .values()
137            .filter_map(|v| v.as_str().map(|s| s.to_string()))
138            .collect();
139        files.sort();
140        files.dedup();
141        Ok(files)
142    }
143
144    /// Read a tensor as f32 (converting from bf16 / f16 / f32) + its shape.
145    fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
146        let shard_idx = *self
147            .index
148            .get(name)
149            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
150        let view = self.shards[shard_idx].get(name)?;
151        let shape = view.shape().to_vec();
152        let data = dtype_to_f32(view.dtype(), view.data())?;
153        Ok((data, shape))
154    }
155
156    /// Read the raw on-disk byte slice plus dtype and shape. Zero-copy into
157    /// the mmap — used to hand weights straight to `B::from_weight_bytes` so
158    /// a fp16-preferring backend can skip the transient f32 Vec.
159    fn read_bytes_typed(&self, name: &str) -> Result<(&[u8], SrcDtype, Vec<usize>)> {
160        let shard_idx = *self
161            .index
162            .get(name)
163            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
164        let view = self.shards[shard_idx].get(name)?;
165        let shape = view.shape().to_vec();
166        let dtype = map_src_dtype(view.dtype())?;
167        Ok((view.data(), dtype, shape))
168    }
169
170    /// Concatenate several tensors along dim 0 at the byte level. All parts
171    /// must share the same dtype and trailing-dim width. Returns the fused
172    /// raw bytes + common dtype + `(total_rows, cols)` shape.
173    fn cat_rows_bytes(&self, names: &[String]) -> Result<(Vec<u8>, SrcDtype, (usize, usize))> {
174        let mut total_rows = 0usize;
175        let mut cols = 0usize;
176        let mut dtype: Option<SrcDtype> = None;
177        let mut bytes: Vec<u8> = Vec::new();
178        for n in names {
179            let (raw, d, shape) = self.read_bytes_typed(n)?;
180            if shape.len() != 2 {
181                return Err(FerrumError::model(format!(
182                    "cat_rows_bytes: '{n}' is {shape:?}, need 2D"
183                )));
184            }
185            match dtype {
186                Some(prev) if prev != d => {
187                    return Err(FerrumError::model(format!(
188                        "cat_rows_bytes: dtype mismatch on '{n}'"
189                    )))
190                }
191                _ => dtype = Some(d),
192            }
193            if cols == 0 {
194                cols = shape[1];
195            } else if cols != shape[1] {
196                return Err(FerrumError::model(format!(
197                    "cat_rows_bytes: col mismatch {cols} vs {}",
198                    shape[1]
199                )));
200            }
201            total_rows += shape[0];
202            bytes.extend_from_slice(raw);
203        }
204        Ok((bytes, dtype.expect("at least one part"), (total_rows, cols)))
205    }
206
207    /// Read a tensor as i32 (for GPTQ qweight / qzeros / g_idx).
208    fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
209        let shard_idx = *self
210            .index
211            .get(name)
212            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
213        let view = self.shards[shard_idx].get(name)?;
214        let shape = view.shape().to_vec();
215        if view.dtype() != Dtype::I32 {
216            return Err(FerrumError::model(format!(
217                "'{name}': expected I32, got {:?}",
218                view.dtype()
219            )));
220        }
221        let bytes = view.data();
222        debug_assert_eq!(bytes.len() % 4, 0);
223        let mut out = vec![0i32; bytes.len() / 4];
224        out.as_mut_slice()
225            .iter_mut()
226            .zip(bytes.chunks_exact(4))
227            .for_each(|(d, chunk)| {
228                *d = i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
229            });
230        Ok((out, shape))
231    }
232
233    fn has(&self, name: &str) -> bool {
234        self.index.contains_key(name)
235    }
236}
237
238impl<B: Backend> WeightLoader<B> for NativeSafetensorsLoader<B> {
239    fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
240        // Route through `from_weight_bytes` so fp16-preferring backends can
241        // materialise big tensors (embed table) directly as half-precision
242        // without the transient f32 Vec. Tiny tensors (norm weights) still
243        // end up as f32 because backends size-threshold inside the override.
244        let (raw, src_dtype, _) = self.read_bytes_typed(name)?;
245        Ok(B::from_weight_bytes(raw, src_dtype))
246    }
247
248    fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
249        // GPTQ first: `<name>.qweight` + `<name>.scales` + `<name>.qzeros`.
250        let qw_key = format!("{name}.qweight");
251        if self.has(&qw_key) {
252            return self.load_gptq_linear(name);
253        }
254        // GPTQ fusion shims: synthesise qkv_proj / gate_up_proj from split
255        // components — same pattern as Dense but concatenating the GPTQ
256        // tensors (qweight/scales/qzeros) along the N dim.
257        if let Some(prefix) = name.strip_suffix("qkv_proj") {
258            let parts = [
259                format!("{prefix}q_proj"),
260                format!("{prefix}k_proj"),
261                format!("{prefix}v_proj"),
262            ];
263            if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
264                return self.load_gptq_linear_fused(&parts);
265            }
266        }
267        if let Some(prefix) = name.strip_suffix("gate_up_proj") {
268            let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
269            if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
270                return self.load_gptq_linear_fused(&parts);
271            }
272        }
273
274        // Direct fused `<name>.weight` next. Load straight from raw bytes
275        // so fp16-preferring backends can skip the f32 Vec intermediate.
276        let direct = format!("{name}.weight");
277        if self.has(&direct) {
278            let (raw, src_dtype, shape) = self.read_bytes_typed(&direct)?;
279            if shape.len() != 2 {
280                return Err(FerrumError::model(format!(
281                    "linear '{name}': expected 2D weight, got {shape:?}"
282                )));
283            }
284            let weight = B::from_weight_bytes(raw, src_dtype);
285            return Ok(Box::new(DenseLinear::<B>::from_buffer(
286                weight, shape[0], shape[1],
287            )));
288        }
289
290        // Llama-family fusion shims: synthesise qkv_proj / gate_up_proj from
291        // split q_proj+k_proj+v_proj / gate_proj+up_proj if present. The cat
292        // happens at the byte level so fused-weight memory is the same size
293        // as the per-part weights — no expansion to f32.
294        if let Some(prefix) = name.strip_suffix("qkv_proj") {
295            let parts = [
296                format!("{prefix}q_proj.weight"),
297                format!("{prefix}k_proj.weight"),
298                format!("{prefix}v_proj.weight"),
299            ];
300            if parts.iter().all(|p| self.has(p)) {
301                let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
302                let weight = B::from_weight_bytes(&bytes, dtype);
303                return Ok(Box::new(DenseLinear::<B>::from_buffer(weight, rows, cols)));
304            }
305        }
306        if let Some(prefix) = name.strip_suffix("gate_up_proj") {
307            let parts = [
308                format!("{prefix}gate_proj.weight"),
309                format!("{prefix}up_proj.weight"),
310            ];
311            if parts.iter().all(|p| self.has(p)) {
312                let (bytes, dtype, (rows, cols)) = self.cat_rows_bytes(&parts)?;
313                let weight = B::from_weight_bytes(&bytes, dtype);
314                return Ok(Box::new(DenseLinear::<B>::from_buffer(weight, rows, cols)));
315            }
316        }
317
318        Err(FerrumError::model(format!(
319            "could not load linear '{name}' — no direct `.weight`, no split components"
320        )))
321    }
322
323    fn has_tensor(&self, name: &str) -> bool {
324        self.has(name)
325    }
326
327    fn quant_config(&self) -> Option<&QuantConfig> {
328        self.quant_config.as_ref()
329    }
330}
331
332impl<B: Backend> NativeSafetensorsLoader<B> {
333    /// Load a GPTQ-packed linear projection: reads `<name>.qweight`,
334    /// `<name>.scales`, `<name>.qzeros`, optionally `<name>.g_idx`, and
335    /// hands the raw host-side tensors to `Backend::load_gptq` which
336    /// repacks + uploads per its own strategy.
337    fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
338        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
339            FerrumError::model(format!(
340                "'{name}.qweight' present but no quantize_config.json — \
341                 can't determine bits/group_size"
342            ))
343        })?;
344        if qcfg.method != QuantMethod::Gptq {
345            return Err(FerrumError::model(format!(
346                "'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
347                qcfg.method
348            )));
349        }
350
351        let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
352        let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
353        let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
354        let g_idx = if self.has(&format!("{name}.g_idx")) {
355            Some(self.read_i32(&format!("{name}.g_idx"))?.0)
356        } else {
357            None
358        };
359
360        // Shape inference: qweight is [K/8, N]; scales is [K/group, N].
361        // → K = qw_shape[0] * 8, N = qw_shape[1].
362        if qw_shape.len() != 2 {
363            return Err(FerrumError::model(format!(
364                "'{name}.qweight' expected 2D, got {qw_shape:?}"
365            )));
366        }
367        let in_features = qw_shape[0] * 8;
368        let out_features = qw_shape[1];
369        if sc_shape.len() != 2 || sc_shape[1] != out_features {
370            return Err(FerrumError::model(format!(
371                "'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
372            )));
373        }
374
375        let mut linear = GptqLinear::<B>::from_raw(
376            &qweight,
377            &scales_f32,
378            &qzeros,
379            g_idx.as_deref(),
380            qcfg.bits,
381            qcfg.group_size,
382            in_features,
383            out_features,
384        )?;
385
386        // Bias (Qwen2.5 attention projections, some Llama variants).
387        let bias_key = format!("{name}.bias");
388        if self.has(&bias_key) {
389            let (bias, bias_shape) = self.read_f32(&bias_key)?;
390            if bias_shape != [out_features] {
391                return Err(FerrumError::model(format!(
392                    "'{bias_key}' {bias_shape:?} != [{out_features}]"
393                )));
394            }
395            linear = linear.with_bias(&bias);
396        }
397        Ok(Box::new(linear))
398    }
399
400    /// Fuse multiple GPTQ projections by concatenating qweight/scales/qzeros
401    /// along the output (N) dim. Matches the Dense fusion shim used for
402    /// non-quantized models: q_proj + k_proj + v_proj → qkv_proj.
403    ///
404    /// All parts must share:
405    /// - in_features (K)
406    /// - bits, group_size
407    /// - qzeros N-packing (which the GPTQ format always honours: qzeros[-1]
408    ///   = N/8, concat along that axis works)
409    ///
410    /// g_idx: only present when desc_act=true. When present, all parts
411    /// share it (same K rows, same activation permutation).
412    fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
413        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
414            FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
415        })?;
416        if qcfg.method != QuantMethod::Gptq {
417            return Err(FerrumError::model(format!(
418                "GPTQ fusion but quant_method={:?}",
419                qcfg.method
420            )));
421        }
422
423        let mut qw_acc: Vec<i32> = Vec::new();
424        let mut sc_acc: Vec<f32> = Vec::new();
425        let mut qz_acc: Vec<i32> = Vec::new();
426        let mut qw_rows = 0usize;
427        let mut sc_rows = 0usize;
428        let mut qz_rows = 0usize;
429        let mut total_n = 0usize;
430        let mut total_n_scales = 0usize;
431        let mut total_n_zeros = 0usize;
432        let mut g_idx: Option<Vec<i32>> = None;
433        // Segments: (qw_slice, sc_slice, qz_slice) per part, needed for N-major layout concat
434        let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); // (data, rows, cols)
435        let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
436        let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
437
438        for p in parts {
439            let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
440            let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
441            let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
442            if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
443                return Err(FerrumError::model(format!(
444                    "GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
445                )));
446            }
447            if qw_rows == 0 {
448                qw_rows = qw_sh[0];
449                sc_rows = sc_sh[0];
450                qz_rows = qz_sh[0];
451            } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
452                return Err(FerrumError::model(format!(
453                    "GPTQ fusion row mismatch on '{p}'"
454                )));
455            }
456            total_n += qw_sh[1];
457            total_n_scales += sc_sh[1];
458            total_n_zeros += qz_sh[1];
459            qw_parts.push((qw, qw_sh[0], qw_sh[1]));
460            sc_parts.push((sc, sc_sh[0], sc_sh[1]));
461            qz_parts.push((qz, qz_sh[0], qz_sh[1]));
462
463            // g_idx optional; if first part has it, use that
464            if g_idx.is_none() && self.has(&format!("{p}.g_idx")) {
465                g_idx = Some(self.read_i32(&format!("{p}.g_idx"))?.0);
466            }
467        }
468
469        // Interleave row-major concatenation: for each row, write all parts' cols.
470        qw_acc.reserve(qw_rows * total_n);
471        for r in 0..qw_rows {
472            for (part, _rows, cols) in &qw_parts {
473                qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
474            }
475        }
476        sc_acc.reserve(sc_rows * total_n_scales);
477        for r in 0..sc_rows {
478            for (part, _rows, cols) in &sc_parts {
479                sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
480            }
481        }
482        qz_acc.reserve(qz_rows * total_n_zeros);
483        for r in 0..qz_rows {
484            for (part, _rows, cols) in &qz_parts {
485                qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
486            }
487        }
488
489        let in_features = qw_rows * 8;
490        let out_features = total_n;
491        let mut linear = GptqLinear::<B>::from_raw(
492            &qw_acc,
493            &sc_acc,
494            &qz_acc,
495            g_idx.as_deref(),
496            qcfg.bits,
497            qcfg.group_size,
498            in_features,
499            out_features,
500        )?;
501
502        // Biases: concatenate `<part>.bias` across parts in the same order as
503        // qweights. All-or-none; if any part has a bias, all must.
504        let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
505        let any = bias_keys.iter().any(|k| self.has(k));
506        let all = bias_keys.iter().all(|k| self.has(k));
507        if any && !all {
508            return Err(FerrumError::model(
509                "GPTQ fusion: inconsistent bias presence across parts".to_string(),
510            ));
511        }
512        if all {
513            let mut fused: Vec<f32> = Vec::with_capacity(out_features);
514            for k in &bias_keys {
515                let (b, _) = self.read_f32(k)?;
516                fused.extend_from_slice(&b);
517            }
518            if fused.len() != out_features {
519                return Err(FerrumError::model(format!(
520                    "GPTQ fusion bias length {} != out_features {out_features}",
521                    fused.len()
522                )));
523            }
524            linear = linear.with_bias(&fused);
525        }
526        Ok(Box::new(linear))
527    }
528
529    /// Read each name, assert shape width matches, concatenate along dim 0.
530    /// Kept for diagnostic / fallback paths; DenseLinear fusion prefers the
531    /// byte-level `cat_rows_bytes` above.
532    #[allow(dead_code)]
533    fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
534        let mut total_rows = 0usize;
535        let mut cols = 0usize;
536        let mut out: Vec<f32> = Vec::new();
537        for n in names {
538            let (data, shape) = self.read_f32(n)?;
539            if shape.len() != 2 {
540                return Err(FerrumError::model(format!(
541                    "cat_rows: '{n}' is {shape:?}, need 2D"
542                )));
543            }
544            if cols == 0 {
545                cols = shape[1];
546            } else if cols != shape[1] {
547                return Err(FerrumError::model(format!(
548                    "cat_rows: col mismatch {cols} vs {}",
549                    shape[1]
550                )));
551            }
552            total_rows += shape[0];
553            out.extend_from_slice(&data);
554        }
555        Ok((total_rows, cols, out))
556    }
557}
558
559fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
560    match dtype {
561        Dtype::F32 => {
562            debug_assert_eq!(raw.len() % 4, 0);
563            let n = raw.len() / 4;
564            let mut out = vec![0.0f32; n];
565            for i in 0..n {
566                let bytes = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
567                out[i] = f32::from_le_bytes(bytes);
568            }
569            Ok(out)
570        }
571        Dtype::F16 => {
572            debug_assert_eq!(raw.len() % 2, 0);
573            let n = raw.len() / 2;
574            let mut out = vec![0.0f32; n];
575            for i in 0..n {
576                let bytes = [raw[i * 2], raw[i * 2 + 1]];
577                out[i] = f16::from_le_bytes(bytes).to_f32();
578            }
579            Ok(out)
580        }
581        Dtype::BF16 => {
582            debug_assert_eq!(raw.len() % 2, 0);
583            let n = raw.len() / 2;
584            let mut out = vec![0.0f32; n];
585            for i in 0..n {
586                let bytes = [raw[i * 2], raw[i * 2 + 1]];
587                out[i] = bf16::from_le_bytes(bytes).to_f32();
588            }
589            Ok(out)
590        }
591        other => Err(FerrumError::model(format!(
592            "dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
593             use a format-specific loader (GPTQ / AWQ / GGUF)",
594        ))),
595    }
596}
597
598fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
599    // AutoGPTQ / gptq-for-llama format: separate quantize_config.json.
600    let p = dir.join("quantize_config.json");
601    if p.exists() {
602        let data =
603            std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
604        let qc: QuantConfig = serde_json::from_str(&data)
605            .map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
606        return Ok(Some(qc));
607    }
608    // Qwen GPTQ / transformers-style: embedded in config.json under
609    // "quantization_config": { "quant_method": "gptq", "bits": 4, ... }.
610    let cfg = dir.join("config.json");
611    if cfg.exists() {
612        let data = std::fs::read_to_string(&cfg)
613            .map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
614        let root: serde_json::Value = serde_json::from_str(&data)
615            .map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
616        if let Some(qc_val) = root.get("quantization_config") {
617            // The embedded block has "quant_method" (not "method"); remap.
618            let method = qc_val
619                .get("quant_method")
620                .and_then(|v| v.as_str())
621                .unwrap_or("none");
622            let method = match method.to_lowercase().as_str() {
623                "gptq" => QuantMethod::Gptq,
624                "awq" => QuantMethod::Awq,
625                "gguf" => QuantMethod::Gguf,
626                _ => QuantMethod::None,
627            };
628            let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
629            let group_size = qc_val
630                .get("group_size")
631                .and_then(|v| v.as_i64())
632                .unwrap_or(128)
633                .max(0) as usize;
634            let desc_act = qc_val
635                .get("desc_act")
636                .and_then(|v| v.as_bool())
637                .unwrap_or(false);
638            let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
639            if method != QuantMethod::None {
640                return Ok(Some(QuantConfig {
641                    method,
642                    bits,
643                    group_size,
644                    desc_act,
645                    sym,
646                }));
647            }
648        }
649    }
650    Ok(None)
651}