Skip to main content

ferrum_quantization/
native_safetensors.rs

1//! Native safetensors `WeightLoader<B>` — mmap + `safetensors` crate, no
2//! candle dependency on the LLM hot path.
3//!
4//! What this owns:
5//!   - Discovering `model.safetensors` vs sharded `model.safetensors.index.json`.
6//!   - Mmapping each shard file.
7//!   - Per-tensor lookup: returns shape + dtype + a byte slice into the mmap.
8//!   - f32 materialisation for Dense weights (bf16 / f16 / f32 accepted).
9//!   - The Qwen3 / Llama fusion trick: `qkv_proj` / `gate_up_proj` synthesised
10//!     on the fly from split `q_proj`+`k_proj`+`v_proj` etc.
11//!
12//! What it deliberately doesn't do:
13//!   - GPTQ / AWQ / GGUF packed weights. Those need `B::from_slice_i32` /
14//!     `B::from_slice_f16` which aren't on the Backend trait yet. A dedicated
15//!     loader per quant format lands in Phase E.
16
17use std::collections::HashMap;
18use std::fs::File;
19use std::path::Path;
20
21use ferrum_kernels::backend::Backend;
22use ferrum_types::{FerrumError, Result};
23use half::{bf16, f16};
24use memmap2::Mmap;
25use safetensors::{Dtype, SafeTensors};
26
27use crate::config::{QuantConfig, QuantMethod};
28use crate::dense::DenseLinear;
29use crate::gptq::GptqLinear;
30use crate::loader::WeightLoader;
31use crate::traits::Linear;
32
33/// A single shard file: mmap + name→(shape, dtype, byte-offset-in-mmap).
34struct Shard {
35    mmap: Mmap,
36    /// Parsed entries. Safetensors' `SafeTensors` type borrows from the mmap,
37    /// so we can't store it directly — instead we pre-extract name → metadata
38    /// and rebuild a `SafeTensors` view on demand via `SafeTensors::deserialize`.
39    names: Vec<String>,
40}
41
42impl Shard {
43    fn open(path: &Path) -> Result<Self> {
44        let file = File::open(path).map_err(|e| FerrumError::io(format!("open {path:?}: {e}")))?;
45        let mmap = unsafe {
46            Mmap::map(&file).map_err(|e| FerrumError::io(format!("mmap {path:?}: {e}")))?
47        };
48        // Parse just to validate and extract names; the SafeTensors view is
49        // rebuilt on each read (cheap — it's a header reparse).
50        let st = SafeTensors::deserialize(&mmap)
51            .map_err(|e| FerrumError::model(format!("parse {path:?}: {e}")))?;
52        let names = st.names().iter().map(|s| s.to_string()).collect();
53        Ok(Self { mmap, names })
54    }
55
56    fn get<'a>(&'a self, name: &str) -> Result<safetensors::tensor::TensorView<'a>> {
57        let st = SafeTensors::deserialize(&self.mmap)
58            .map_err(|e| FerrumError::model(format!("reparse: {e}")))?;
59        st.tensor(name)
60            .map_err(|e| FerrumError::model(format!("tensor '{name}': {e}")))
61    }
62}
63
64/// Native safetensors loader. Generic over `Backend` so every tensor is
65/// materialised directly into backend-native buffers.
66pub struct NativeSafetensorsLoader<B: Backend> {
67    /// All shards keyed by file; each tensor's name maps to its shard here.
68    shards: Vec<Shard>,
69    /// Name → shard index. Populated once at construction.
70    index: HashMap<String, usize>,
71    /// Optional `quantize_config.json` contents.
72    quant_config: Option<QuantConfig>,
73    _m: std::marker::PhantomData<B>,
74}
75
76impl<B: Backend> NativeSafetensorsLoader<B> {
77    /// Discover shards under `model_dir` and build the name → shard index.
78    pub fn open(model_dir: impl AsRef<Path>) -> Result<Self> {
79        let dir = model_dir.as_ref();
80
81        let shard_paths = if dir.join("model.safetensors").exists() {
82            vec![dir.join("model.safetensors")]
83        } else if dir.join("model.safetensors.index.json").exists() {
84            Self::parse_sharded_index(&dir.join("model.safetensors.index.json"))?
85                .into_iter()
86                .map(|name| dir.join(name))
87                .collect()
88        } else {
89            return Err(FerrumError::model(format!(
90                "no safetensors files in {dir:?}"
91            )));
92        };
93
94        let mut shards = Vec::with_capacity(shard_paths.len());
95        let mut index: HashMap<String, usize> = HashMap::new();
96        for (i, p) in shard_paths.iter().enumerate() {
97            let shard = Shard::open(p)?;
98            for name in &shard.names {
99                index.insert(name.clone(), i);
100            }
101            shards.push(shard);
102        }
103
104        let quant_config = load_quantize_config(dir)?;
105
106        Ok(Self {
107            shards,
108            index,
109            quant_config,
110            _m: std::marker::PhantomData,
111        })
112    }
113
114    fn parse_sharded_index(index_path: &Path) -> Result<Vec<String>> {
115        let data = std::fs::read_to_string(index_path)
116            .map_err(|e| FerrumError::io(format!("read {index_path:?}: {e}")))?;
117        let json: serde_json::Value = serde_json::from_str(&data)
118            .map_err(|e| FerrumError::serialization(format!("index json: {e}")))?;
119        let weight_map = json
120            .get("weight_map")
121            .and_then(|v| v.as_object())
122            .ok_or_else(|| FerrumError::model("index missing weight_map"))?;
123        let mut files: Vec<String> = weight_map
124            .values()
125            .filter_map(|v| v.as_str().map(|s| s.to_string()))
126            .collect();
127        files.sort();
128        files.dedup();
129        Ok(files)
130    }
131
132    /// Read a tensor as f32 (converting from bf16 / f16 / f32) + its shape.
133    fn read_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
134        let shard_idx = *self
135            .index
136            .get(name)
137            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
138        let view = self.shards[shard_idx].get(name)?;
139        let shape = view.shape().to_vec();
140        let data = dtype_to_f32(view.dtype(), view.data())?;
141        Ok((data, shape))
142    }
143
144    /// Read a tensor as i32 (for GPTQ qweight / qzeros / g_idx).
145    fn read_i32(&self, name: &str) -> Result<(Vec<i32>, Vec<usize>)> {
146        let shard_idx = *self
147            .index
148            .get(name)
149            .ok_or_else(|| FerrumError::model(format!("tensor '{name}' not in index")))?;
150        let view = self.shards[shard_idx].get(name)?;
151        let shape = view.shape().to_vec();
152        if view.dtype() != Dtype::I32 {
153            return Err(FerrumError::model(format!(
154                "'{name}': expected I32, got {:?}",
155                view.dtype()
156            )));
157        }
158        let bytes = view.data();
159        debug_assert_eq!(bytes.len() % 4, 0);
160        let mut out = vec![0i32; bytes.len() / 4];
161        out.as_mut_slice()
162            .iter_mut()
163            .zip(bytes.chunks_exact(4))
164            .for_each(|(d, chunk)| {
165                *d = i32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
166            });
167        Ok((out, shape))
168    }
169
170    fn has(&self, name: &str) -> bool {
171        self.index.contains_key(name)
172    }
173}
174
175impl<B: Backend> WeightLoader<B> for NativeSafetensorsLoader<B> {
176    fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
177        let (data, _) = self.read_f32(name)?;
178        Ok(B::from_slice(&data))
179    }
180
181    fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
182        // GPTQ first: `<name>.qweight` + `<name>.scales` + `<name>.qzeros`.
183        let qw_key = format!("{name}.qweight");
184        if self.has(&qw_key) {
185            return self.load_gptq_linear(name);
186        }
187        // GPTQ fusion shims: synthesise qkv_proj / gate_up_proj from split
188        // components — same pattern as Dense but concatenating the GPTQ
189        // tensors (qweight/scales/qzeros) along the N dim.
190        if name.ends_with("qkv_proj") {
191            let prefix = &name[..name.len() - "qkv_proj".len()];
192            let parts = [
193                format!("{prefix}q_proj"),
194                format!("{prefix}k_proj"),
195                format!("{prefix}v_proj"),
196            ];
197            if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
198                return self.load_gptq_linear_fused(&parts);
199            }
200        }
201        if name.ends_with("gate_up_proj") {
202            let prefix = &name[..name.len() - "gate_up_proj".len()];
203            let parts = [format!("{prefix}gate_proj"), format!("{prefix}up_proj")];
204            if parts.iter().all(|p| self.has(&format!("{p}.qweight"))) {
205                return self.load_gptq_linear_fused(&parts);
206            }
207        }
208
209        // Direct fused `<name>.weight` next.
210        let direct = format!("{name}.weight");
211        if self.has(&direct) {
212            let (data, shape) = self.read_f32(&direct)?;
213            if shape.len() != 2 {
214                return Err(FerrumError::model(format!(
215                    "linear '{name}': expected 2D weight, got {shape:?}"
216                )));
217            }
218            return Ok(Box::new(DenseLinear::<B>::from_rows(
219                &data, shape[0], shape[1],
220            )));
221        }
222
223        // Llama-family fusion shims: synthesise qkv_proj / gate_up_proj from
224        // split q_proj+k_proj+v_proj / gate_proj+up_proj if present.
225        if name.ends_with("qkv_proj") {
226            let prefix = &name[..name.len() - "qkv_proj".len()];
227            let parts = [
228                format!("{prefix}q_proj.weight"),
229                format!("{prefix}k_proj.weight"),
230                format!("{prefix}v_proj.weight"),
231            ];
232            if parts.iter().all(|p| self.has(p)) {
233                let (rows, cols, data) = self.cat_rows(&parts)?;
234                return Ok(Box::new(DenseLinear::<B>::from_rows(&data, rows, cols)));
235            }
236        }
237        if name.ends_with("gate_up_proj") {
238            let prefix = &name[..name.len() - "gate_up_proj".len()];
239            let parts = [
240                format!("{prefix}gate_proj.weight"),
241                format!("{prefix}up_proj.weight"),
242            ];
243            if parts.iter().all(|p| self.has(p)) {
244                let (rows, cols, data) = self.cat_rows(&parts)?;
245                return Ok(Box::new(DenseLinear::<B>::from_rows(&data, rows, cols)));
246            }
247        }
248
249        Err(FerrumError::model(format!(
250            "could not load linear '{name}' — no direct `.weight`, no split components"
251        )))
252    }
253
254    fn has_tensor(&self, name: &str) -> bool {
255        self.has(name)
256    }
257
258    fn quant_config(&self) -> Option<&QuantConfig> {
259        self.quant_config.as_ref()
260    }
261}
262
263impl<B: Backend> NativeSafetensorsLoader<B> {
264    /// Load a GPTQ-packed linear projection: reads `<name>.qweight`,
265    /// `<name>.scales`, `<name>.qzeros`, optionally `<name>.g_idx`, and
266    /// hands the raw host-side tensors to `Backend::load_gptq` which
267    /// repacks + uploads per its own strategy.
268    fn load_gptq_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
269        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
270            FerrumError::model(format!(
271                "'{name}.qweight' present but no quantize_config.json — \
272                 can't determine bits/group_size"
273            ))
274        })?;
275        if qcfg.method != QuantMethod::Gptq {
276            return Err(FerrumError::model(format!(
277                "'{name}.qweight' present but quant_method={:?} (expected GPTQ)",
278                qcfg.method
279            )));
280        }
281
282        let (qweight, qw_shape) = self.read_i32(&format!("{name}.qweight"))?;
283        let (scales_f32, sc_shape) = self.read_f32(&format!("{name}.scales"))?;
284        let (qzeros, _qz_shape) = self.read_i32(&format!("{name}.qzeros"))?;
285        let g_idx = if self.has(&format!("{name}.g_idx")) {
286            Some(self.read_i32(&format!("{name}.g_idx"))?.0)
287        } else {
288            None
289        };
290
291        // Shape inference: qweight is [K/8, N]; scales is [K/group, N].
292        // → K = qw_shape[0] * 8, N = qw_shape[1].
293        if qw_shape.len() != 2 {
294            return Err(FerrumError::model(format!(
295                "'{name}.qweight' expected 2D, got {qw_shape:?}"
296            )));
297        }
298        let in_features = qw_shape[0] * 8;
299        let out_features = qw_shape[1];
300        if sc_shape.len() != 2 || sc_shape[1] != out_features {
301            return Err(FerrumError::model(format!(
302                "'{name}.scales' {sc_shape:?} incompatible with qweight {qw_shape:?}"
303            )));
304        }
305
306        let mut linear = GptqLinear::<B>::from_raw(
307            &qweight,
308            &scales_f32,
309            &qzeros,
310            g_idx.as_deref(),
311            qcfg.bits,
312            qcfg.group_size,
313            in_features,
314            out_features,
315        )?;
316
317        // Bias (Qwen2.5 attention projections, some Llama variants).
318        let bias_key = format!("{name}.bias");
319        if self.has(&bias_key) {
320            let (bias, bias_shape) = self.read_f32(&bias_key)?;
321            if bias_shape != [out_features] {
322                return Err(FerrumError::model(format!(
323                    "'{bias_key}' {bias_shape:?} != [{out_features}]"
324                )));
325            }
326            linear = linear.with_bias(&bias);
327        }
328        Ok(Box::new(linear))
329    }
330
331    /// Fuse multiple GPTQ projections by concatenating qweight/scales/qzeros
332    /// along the output (N) dim. Matches the Dense fusion shim used for
333    /// non-quantized models: q_proj + k_proj + v_proj → qkv_proj.
334    ///
335    /// All parts must share:
336    /// - in_features (K)
337    /// - bits, group_size
338    /// - qzeros N-packing (which the GPTQ format always honours: qzeros[-1]
339    ///   = N/8, concat along that axis works)
340    ///
341    /// g_idx: only present when desc_act=true. When present, all parts
342    /// share it (same K rows, same activation permutation).
343    fn load_gptq_linear_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
344        let qcfg = self.quant_config.as_ref().ok_or_else(|| {
345            FerrumError::model("GPTQ fusion requires quantize_config.json".to_string())
346        })?;
347        if qcfg.method != QuantMethod::Gptq {
348            return Err(FerrumError::model(format!(
349                "GPTQ fusion but quant_method={:?}",
350                qcfg.method
351            )));
352        }
353
354        let mut qw_acc: Vec<i32> = Vec::new();
355        let mut sc_acc: Vec<f32> = Vec::new();
356        let mut qz_acc: Vec<i32> = Vec::new();
357        let mut qw_rows = 0usize;
358        let mut sc_rows = 0usize;
359        let mut qz_rows = 0usize;
360        let mut total_n = 0usize;
361        let mut total_n_scales = 0usize;
362        let mut total_n_zeros = 0usize;
363        let mut g_idx: Option<Vec<i32>> = None;
364        // Segments: (qw_slice, sc_slice, qz_slice) per part, needed for N-major layout concat
365        let mut qw_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new(); // (data, rows, cols)
366        let mut sc_parts: Vec<(Vec<f32>, usize, usize)> = Vec::new();
367        let mut qz_parts: Vec<(Vec<i32>, usize, usize)> = Vec::new();
368
369        for p in parts {
370            let (qw, qw_sh) = self.read_i32(&format!("{p}.qweight"))?;
371            let (sc, sc_sh) = self.read_f32(&format!("{p}.scales"))?;
372            let (qz, qz_sh) = self.read_i32(&format!("{p}.qzeros"))?;
373            if qw_sh.len() != 2 || sc_sh.len() != 2 || qz_sh.len() != 2 {
374                return Err(FerrumError::model(format!(
375                    "GPTQ fusion '{p}': expected 2D tensors, got qw {qw_sh:?} sc {sc_sh:?} qz {qz_sh:?}"
376                )));
377            }
378            if qw_rows == 0 {
379                qw_rows = qw_sh[0];
380                sc_rows = sc_sh[0];
381                qz_rows = qz_sh[0];
382            } else if qw_sh[0] != qw_rows || sc_sh[0] != sc_rows || qz_sh[0] != qz_rows {
383                return Err(FerrumError::model(format!(
384                    "GPTQ fusion row mismatch on '{p}'"
385                )));
386            }
387            total_n += qw_sh[1];
388            total_n_scales += sc_sh[1];
389            total_n_zeros += qz_sh[1];
390            qw_parts.push((qw, qw_sh[0], qw_sh[1]));
391            sc_parts.push((sc, sc_sh[0], sc_sh[1]));
392            qz_parts.push((qz, qz_sh[0], qz_sh[1]));
393
394            // g_idx optional; if first part has it, use that
395            if g_idx.is_none() {
396                if self.has(&format!("{p}.g_idx")) {
397                    g_idx = Some(self.read_i32(&format!("{p}.g_idx"))?.0);
398                }
399            }
400        }
401
402        // Interleave row-major concatenation: for each row, write all parts' cols.
403        qw_acc.reserve(qw_rows * total_n);
404        for r in 0..qw_rows {
405            for (part, _rows, cols) in &qw_parts {
406                qw_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
407            }
408        }
409        sc_acc.reserve(sc_rows * total_n_scales);
410        for r in 0..sc_rows {
411            for (part, _rows, cols) in &sc_parts {
412                sc_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
413            }
414        }
415        qz_acc.reserve(qz_rows * total_n_zeros);
416        for r in 0..qz_rows {
417            for (part, _rows, cols) in &qz_parts {
418                qz_acc.extend_from_slice(&part[r * cols..r * cols + cols]);
419            }
420        }
421
422        let in_features = qw_rows * 8;
423        let out_features = total_n;
424        let mut linear = GptqLinear::<B>::from_raw(
425            &qw_acc,
426            &sc_acc,
427            &qz_acc,
428            g_idx.as_deref(),
429            qcfg.bits,
430            qcfg.group_size,
431            in_features,
432            out_features,
433        )?;
434
435        // Biases: concatenate `<part>.bias` across parts in the same order as
436        // qweights. All-or-none; if any part has a bias, all must.
437        let bias_keys: Vec<String> = parts.iter().map(|p| format!("{p}.bias")).collect();
438        let any = bias_keys.iter().any(|k| self.has(k));
439        let all = bias_keys.iter().all(|k| self.has(k));
440        if any && !all {
441            return Err(FerrumError::model(
442                "GPTQ fusion: inconsistent bias presence across parts".to_string(),
443            ));
444        }
445        if all {
446            let mut fused: Vec<f32> = Vec::with_capacity(out_features);
447            for k in &bias_keys {
448                let (b, _) = self.read_f32(k)?;
449                fused.extend_from_slice(&b);
450            }
451            if fused.len() != out_features {
452                return Err(FerrumError::model(format!(
453                    "GPTQ fusion bias length {} != out_features {out_features}",
454                    fused.len()
455                )));
456            }
457            linear = linear.with_bias(&fused);
458        }
459        Ok(Box::new(linear))
460    }
461
462    /// Read each name, assert shape width matches, concatenate along dim 0.
463    fn cat_rows(&self, names: &[String]) -> Result<(usize, usize, Vec<f32>)> {
464        let mut total_rows = 0usize;
465        let mut cols = 0usize;
466        let mut out: Vec<f32> = Vec::new();
467        for n in names {
468            let (data, shape) = self.read_f32(n)?;
469            if shape.len() != 2 {
470                return Err(FerrumError::model(format!(
471                    "cat_rows: '{n}' is {shape:?}, need 2D"
472                )));
473            }
474            if cols == 0 {
475                cols = shape[1];
476            } else if cols != shape[1] {
477                return Err(FerrumError::model(format!(
478                    "cat_rows: col mismatch {cols} vs {}",
479                    shape[1]
480                )));
481            }
482            total_rows += shape[0];
483            out.extend_from_slice(&data);
484        }
485        Ok((total_rows, cols, out))
486    }
487}
488
489fn dtype_to_f32(dtype: Dtype, raw: &[u8]) -> Result<Vec<f32>> {
490    match dtype {
491        Dtype::F32 => {
492            debug_assert_eq!(raw.len() % 4, 0);
493            let n = raw.len() / 4;
494            let mut out = vec![0.0f32; n];
495            for i in 0..n {
496                let bytes = [raw[i * 4], raw[i * 4 + 1], raw[i * 4 + 2], raw[i * 4 + 3]];
497                out[i] = f32::from_le_bytes(bytes);
498            }
499            Ok(out)
500        }
501        Dtype::F16 => {
502            debug_assert_eq!(raw.len() % 2, 0);
503            let n = raw.len() / 2;
504            let mut out = vec![0.0f32; n];
505            for i in 0..n {
506                let bytes = [raw[i * 2], raw[i * 2 + 1]];
507                out[i] = f16::from_le_bytes(bytes).to_f32();
508            }
509            Ok(out)
510        }
511        Dtype::BF16 => {
512            debug_assert_eq!(raw.len() % 2, 0);
513            let n = raw.len() / 2;
514            let mut out = vec![0.0f32; n];
515            for i in 0..n {
516                let bytes = [raw[i * 2], raw[i * 2 + 1]];
517                out[i] = bf16::from_le_bytes(bytes).to_f32();
518            }
519            Ok(out)
520        }
521        other => Err(FerrumError::model(format!(
522            "dtype {other:?} not supported by NativeSafetensorsLoader's f32 path; \
523             use a format-specific loader (GPTQ / AWQ / GGUF)",
524        ))),
525    }
526}
527
528fn load_quantize_config(dir: &Path) -> Result<Option<QuantConfig>> {
529    // AutoGPTQ / gptq-for-llama format: separate quantize_config.json.
530    let p = dir.join("quantize_config.json");
531    if p.exists() {
532        let data =
533            std::fs::read_to_string(&p).map_err(|e| FerrumError::io(format!("read {p:?}: {e}")))?;
534        let qc: QuantConfig = serde_json::from_str(&data)
535            .map_err(|e| FerrumError::serialization(format!("parse quantize_config.json: {e}")))?;
536        return Ok(Some(qc));
537    }
538    // Qwen GPTQ / transformers-style: embedded in config.json under
539    // "quantization_config": { "quant_method": "gptq", "bits": 4, ... }.
540    let cfg = dir.join("config.json");
541    if cfg.exists() {
542        let data = std::fs::read_to_string(&cfg)
543            .map_err(|e| FerrumError::io(format!("read {cfg:?}: {e}")))?;
544        let root: serde_json::Value = serde_json::from_str(&data)
545            .map_err(|e| FerrumError::serialization(format!("parse config.json: {e}")))?;
546        if let Some(qc_val) = root.get("quantization_config") {
547            // The embedded block has "quant_method" (not "method"); remap.
548            let method = qc_val
549                .get("quant_method")
550                .and_then(|v| v.as_str())
551                .unwrap_or("none");
552            let method = match method.to_lowercase().as_str() {
553                "gptq" => QuantMethod::Gptq,
554                "awq" => QuantMethod::Awq,
555                "gguf" => QuantMethod::Gguf,
556                _ => QuantMethod::None,
557            };
558            let bits = qc_val.get("bits").and_then(|v| v.as_u64()).unwrap_or(0) as u32;
559            let group_size = qc_val
560                .get("group_size")
561                .and_then(|v| v.as_i64())
562                .unwrap_or(128)
563                .max(0) as usize;
564            let desc_act = qc_val
565                .get("desc_act")
566                .and_then(|v| v.as_bool())
567                .unwrap_or(false);
568            let sym = qc_val.get("sym").and_then(|v| v.as_bool()).unwrap_or(false);
569            if method != QuantMethod::None {
570                return Ok(Some(QuantConfig {
571                    method,
572                    bits,
573                    group_size,
574                    desc_act,
575                    sym,
576                }));
577            }
578        }
579    }
580    Ok(None)
581}