Skip to main content

mlx_native/
weight.rs

1//! Weight loading from safetensors files into Metal GPU buffers.
2//!
3//! This module provides utilities for loading quantized model weights from
4//! the [safetensors](https://huggingface.co/docs/safetensors) file format
5//! into Metal `StorageModeShared` buffers for GPU inference.
6//!
7//! # Architecture
8//!
9//! The loading pipeline is:
10//!
11//! 1. Memory-map the safetensors file(s) via `memmap2` (no full read into RAM).
12//! 2. Parse the header to discover tensor names, shapes, dtypes, and byte offsets.
13//! 3. For each tensor, create a Metal `StorageModeShared` buffer and copy the
14//!    raw bytes from the mmap region into it.
15//! 4. Attach quantization metadata (bits, group_size) from the
16//!    `quantization_config.json` file.
17//!
18//! # Zero-Copy Consideration
19//!
20//! On Apple Silicon, Metal shared-mode buffers reside in unified memory.  We
21//! *could* create a Metal buffer that wraps the mmap pointer directly, but this
22//! is unsafe because the mmap lifetime is tied to the file mapping.  Instead we
23//! copy the tensor bytes into a fresh Metal buffer, which is a single memcpy on
24//! unified memory and guarantees the buffer outlives the file mapping.
25
26use std::collections::HashMap;
27use std::fs;
28use std::path::Path;
29
30use memmap2::Mmap;
31use metal::MTLResourceOptions;
32use safetensors::SafeTensors;
33use serde::Deserialize;
34
35use crate::buffer::MlxBuffer;
36use crate::device::MlxDevice;
37use crate::dtypes::DType;
38use crate::error::{MlxError, Result};
39
40// ---------------------------------------------------------------------------
41// Quantization config parsing
42// ---------------------------------------------------------------------------
43
44/// Per-tensor quantization configuration from `quantization_config.json`.
45///
46/// This mirrors the JSON structure produced by hf2q's `--quant auto` mode,
47/// where each tensor may have a different bit-width and group size.
48#[derive(Debug, Clone, Deserialize)]
49pub struct QuantizationConfig {
50    /// Default bit-width applied when a tensor has no per-tensor override.
51    #[serde(default = "default_bits")]
52    pub bits: u8,
53
54    /// Default group size applied when a tensor has no per-tensor override.
55    #[serde(default = "default_group_size")]
56    pub group_size: usize,
57
58    /// Per-tensor overrides keyed by tensor name pattern.
59    /// Each entry maps a tensor name (or glob pattern) to its quant config.
60    #[serde(default)]
61    pub per_tensor: HashMap<String, TensorQuantConfig>,
62}
63
64/// Quantization parameters for an individual tensor.
65#[derive(Debug, Clone, Deserialize)]
66pub struct TensorQuantConfig {
67    /// Bit-width for this tensor (3, 4, 6, or 8).
68    pub bits: u8,
69    /// Number of consecutive values sharing one scale/bias pair.
70    pub group_size: usize,
71}
72
73fn default_bits() -> u8 {
74    4
75}
76
77fn default_group_size() -> usize {
78    64
79}
80
81/// Strip `.weight`, `.scales`, or `.biases` suffix from a tensor name to get
82/// the base linear layer name.  Returns the input unchanged if no suffix matches.
83fn strip_tensor_suffix(name: &str) -> &str {
84    for suffix in &[".weight", ".scales", ".biases"] {
85        if let Some(stripped) = name.strip_suffix(suffix) {
86            return stripped;
87        }
88    }
89    name
90}
91
92impl QuantizationConfig {
93    /// Load and parse a `quantization_config.json` file from disk.
94    ///
95    /// # Errors
96    ///
97    /// Returns `MlxError::IoError` if the file cannot be read, or
98    /// `MlxError::QuantConfigError` if the JSON is malformed.
99    pub fn from_file(path: &Path) -> Result<Self> {
100        let contents = fs::read_to_string(path).map_err(|e| {
101            MlxError::IoError(format!("Failed to read quantization config at {}: {}", path.display(), e))
102        })?;
103        Self::from_json(&contents)
104    }
105
106    /// Parse a `QuantizationConfig` from a JSON string.
107    ///
108    /// # Errors
109    ///
110    /// Returns `MlxError::QuantConfigError` if the JSON is malformed.
111    pub fn from_json(json: &str) -> Result<Self> {
112        serde_json::from_str(json).map_err(|e| {
113            MlxError::QuantConfigError(format!("Failed to parse quantization config JSON: {e}"))
114        })
115    }
116
117    /// Parse per-tensor quantization overrides from the `"quantization"` section
118    /// of an MLX model's `config.json`.
119    ///
120    /// In this format, the quantization section contains flat keys for tensor
121    /// names alongside the default `bits` and `group_size`:
122    ///
123    /// ```json
124    /// {
125    ///   "quantization": {
126    ///     "bits": 4,
127    ///     "group_size": 64,
128    ///     "model.layers.0.mlp.down_proj": {"bits": 8, "group_size": 64}
129    ///   }
130    /// }
131    /// ```
132    ///
133    /// This parses the entire `"quantization"` object, extracting `bits` and
134    /// `group_size` as defaults, and any nested objects as per-tensor overrides.
135    pub fn from_model_config_json(json: &str) -> Result<Self> {
136        // Parse the full config.json, extract the "quantization" section.
137        let root: serde_json::Value = serde_json::from_str(json).map_err(|e| {
138            MlxError::QuantConfigError(format!("Failed to parse config.json: {e}"))
139        })?;
140
141        let quant_section = root.get("quantization").ok_or_else(|| {
142            MlxError::QuantConfigError("No \"quantization\" key in config.json".into())
143        })?;
144
145        let quant_obj = quant_section.as_object().ok_or_else(|| {
146            MlxError::QuantConfigError("\"quantization\" is not an object".into())
147        })?;
148
149        let bits = quant_obj
150            .get("bits")
151            .and_then(|v| v.as_u64())
152            .unwrap_or(4) as u8;
153
154        let group_size = quant_obj
155            .get("group_size")
156            .and_then(|v| v.as_u64())
157            .unwrap_or(64) as usize;
158
159        // Any key whose value is an object with "bits" is a per-tensor override.
160        let mut per_tensor = HashMap::new();
161        for (key, value) in quant_obj {
162            if key == "bits" || key == "group_size" || key == "quant_method" {
163                continue;
164            }
165            if let Some(obj) = value.as_object() {
166                if let Some(tensor_bits) = obj.get("bits").and_then(|v| v.as_u64()) {
167                    let tensor_gs = obj
168                        .get("group_size")
169                        .and_then(|v| v.as_u64())
170                        .unwrap_or(group_size as u64) as usize;
171                    per_tensor.insert(
172                        key.clone(),
173                        TensorQuantConfig {
174                            bits: tensor_bits as u8,
175                            group_size: tensor_gs,
176                        },
177                    );
178                }
179            }
180        }
181
182        Ok(Self {
183            bits,
184            group_size,
185            per_tensor,
186        })
187    }
188
189    /// Parse per-tensor overrides from a `config.json` file on disk.
190    pub fn from_model_config_file(path: &Path) -> Result<Self> {
191        let contents = fs::read_to_string(path).map_err(|e| {
192            MlxError::IoError(format!(
193                "Failed to read config.json at {}: {}",
194                path.display(),
195                e
196            ))
197        })?;
198        Self::from_model_config_json(&contents)
199    }
200
201    /// Look up the quantization parameters for a specific tensor name.
202    ///
203    /// Matching strategy (in order):
204    /// 1. Exact match in `per_tensor`.
205    /// 2. Strip `.weight` / `.scales` / `.biases` suffix, then exact match.
206    /// 3. Strip `language_model.` prefix (with or without suffix), then match.
207    /// 4. Add `language_model.` prefix (with or without suffix), then match.
208    ///
209    /// If no override matches, returns the default bits and group_size.
210    pub fn config_for_tensor(&self, tensor_name: &str) -> (u8, usize) {
211        // 1. Exact match.
212        if let Some(tc) = self.per_tensor.get(tensor_name) {
213            return (tc.bits, tc.group_size);
214        }
215
216        // 2. Strip component suffix (.weight, .scales, .biases).
217        let base = strip_tensor_suffix(tensor_name);
218        if base != tensor_name {
219            if let Some(tc) = self.per_tensor.get(base) {
220                return (tc.bits, tc.group_size);
221            }
222        }
223
224        // 3. Strip `language_model.` prefix.
225        let lm_prefix = "language_model.";
226        if let Some(stripped) = tensor_name.strip_prefix(lm_prefix) {
227            if let Some(tc) = self.per_tensor.get(stripped) {
228                return (tc.bits, tc.group_size);
229            }
230            let stripped_base = strip_tensor_suffix(stripped);
231            if stripped_base != stripped {
232                if let Some(tc) = self.per_tensor.get(stripped_base) {
233                    return (tc.bits, tc.group_size);
234                }
235            }
236        }
237
238        // 4. Add `language_model.` prefix.
239        if !tensor_name.starts_with(lm_prefix) {
240            let with_prefix = format!("{lm_prefix}{tensor_name}");
241            if let Some(tc) = self.per_tensor.get(&with_prefix) {
242                return (tc.bits, tc.group_size);
243            }
244            let with_prefix_base = format!("{lm_prefix}{base}");
245            if base != tensor_name {
246                if let Some(tc) = self.per_tensor.get(&with_prefix_base) {
247                    return (tc.bits, tc.group_size);
248                }
249            }
250        }
251
252        (self.bits, self.group_size)
253    }
254}
255
256// ---------------------------------------------------------------------------
257// QuantizedWeight
258// ---------------------------------------------------------------------------
259
260/// A quantized weight tensor loaded into Metal GPU buffers.
261///
262/// Tracks the tensor name, logical shape, original dtype, quantization
263/// parameters, and the Metal buffers holding the packed data, scales, and
264/// optional biases.
265///
266/// # Layout
267///
268/// * `packed_data` — Packed quantized integers (e.g. 4-bit values packed
269///   8-per-uint32, or 6-bit values packed 4-per-uint32).
270/// * `scales` — Per-group scale factors as f16 values.
271/// * `biases` — Per-group biases as f16 values (present for affine quant).
272pub struct QuantizedWeight {
273    /// Full tensor path, e.g. `model.layers.0.self_attn.q_proj.weight`.
274    tensor_name: String,
275    /// Logical tensor dimensions before quantization.
276    shape: Vec<usize>,
277    /// Original dtype before quantization (e.g. `F16` or `BF16`).
278    dtype: DType,
279    /// Quantization bit-width (3, 4, 6, or 8).
280    bits: u8,
281    /// Number of consecutive values sharing one scale/bias pair.
282    group_size: usize,
283    /// Per-group scale factors (f16 Metal buffer).
284    scales: MlxBuffer,
285    /// Per-group biases (f16 Metal buffer), if asymmetric quantization.
286    biases: Option<MlxBuffer>,
287    /// Packed quantized weight data (Metal buffer).
288    packed_data: MlxBuffer,
289}
290
291impl QuantizedWeight {
292    /// Construct a new `QuantizedWeight` with all fields specified.
293    ///
294    /// This is the primary constructor used by [`load_quantized_weights`].
295    /// It does not validate buffer sizes — the caller is responsible for
296    /// ensuring the buffers match the declared shape, bits, and group_size.
297    pub fn new(
298        tensor_name: String,
299        shape: Vec<usize>,
300        dtype: DType,
301        bits: u8,
302        group_size: usize,
303        scales: MlxBuffer,
304        biases: Option<MlxBuffer>,
305        packed_data: MlxBuffer,
306    ) -> Self {
307        Self {
308            tensor_name,
309            shape,
310            dtype,
311            bits,
312            group_size,
313            scales,
314            biases,
315            packed_data,
316        }
317    }
318
319    /// Full tensor name path.
320    #[inline]
321    pub fn tensor_name(&self) -> &str {
322        &self.tensor_name
323    }
324
325    /// Logical tensor shape (dimensions before quantization).
326    #[inline]
327    pub fn shape(&self) -> &[usize] {
328        &self.shape
329    }
330
331    /// Original element dtype before quantization.
332    #[inline]
333    pub fn dtype(&self) -> DType {
334        self.dtype
335    }
336
337    /// Quantization bit-width.
338    #[inline]
339    pub fn bits(&self) -> u8 {
340        self.bits
341    }
342
343    /// Quantization group size.
344    #[inline]
345    pub fn group_size(&self) -> usize {
346        self.group_size
347    }
348
349    /// Borrow the per-group scales buffer.
350    #[inline]
351    pub fn scales(&self) -> &MlxBuffer {
352        &self.scales
353    }
354
355    /// Borrow the per-group biases buffer, if present.
356    #[inline]
357    pub fn biases(&self) -> Option<&MlxBuffer> {
358        self.biases.as_ref()
359    }
360
361    /// Borrow the packed quantized data buffer.
362    #[inline]
363    pub fn packed_data(&self) -> &MlxBuffer {
364        &self.packed_data
365    }
366
367    /// Number of logical elements in the weight tensor (product of shape dims).
368    pub fn element_count(&self) -> usize {
369        self.shape.iter().copied().product()
370    }
371
372    /// Number of quantization groups along the last dimension.
373    ///
374    /// This is `ceil(last_dim / group_size)`.
375    pub fn num_groups(&self) -> usize {
376        let last_dim = self.shape.last().copied().unwrap_or(0);
377        if self.group_size == 0 {
378            return 0;
379        }
380        (last_dim + self.group_size - 1) / self.group_size
381    }
382}
383
384impl std::fmt::Debug for QuantizedWeight {
385    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386        f.debug_struct("QuantizedWeight")
387            .field("tensor_name", &self.tensor_name)
388            .field("shape", &self.shape)
389            .field("dtype", &self.dtype)
390            .field("bits", &self.bits)
391            .field("group_size", &self.group_size)
392            .field("packed_data_bytes", &self.packed_data.byte_len())
393            .field("scales_bytes", &self.scales.byte_len())
394            .field("has_biases", &self.biases.is_some())
395            .finish()
396    }
397}
398
399// ---------------------------------------------------------------------------
400// DType conversion
401// ---------------------------------------------------------------------------
402
403/// Convert a safetensors `Dtype` to our `DType`.
404///
405/// Returns `Err(MlxError::UnsupportedDtype)` for types we don't handle.
406fn safetensors_dtype_to_dtype(st_dtype: safetensors::Dtype) -> Result<DType> {
407    match st_dtype {
408        safetensors::Dtype::F32 => Ok(DType::F32),
409        safetensors::Dtype::F16 => Ok(DType::F16),
410        safetensors::Dtype::BF16 => Ok(DType::BF16),
411        safetensors::Dtype::U8 => Ok(DType::U8),
412        safetensors::Dtype::U16 => Ok(DType::U16),
413        safetensors::Dtype::U32 => Ok(DType::U32),
414        safetensors::Dtype::I32 => Ok(DType::I32),
415        other => Err(MlxError::UnsupportedDtype(format!("{other:?}"))),
416    }
417}
418
419// ---------------------------------------------------------------------------
420// Buffer creation
421// ---------------------------------------------------------------------------
422
423/// Copy raw bytes from a safetensors tensor view into a new Metal
424/// `StorageModeShared` buffer.
425///
426/// This is the core data-transfer function.  It:
427/// 1. Allocates a Metal buffer of the exact byte length.
428/// 2. Copies the tensor data from the (mmap'd) safetensors region into the
429///    Metal buffer via a single `std::ptr::copy_nonoverlapping`.
430///
431/// # Arguments
432///
433/// * `device`   — The Metal device to allocate from.
434/// * `data`     — Raw tensor bytes (borrowed from the safetensors mmap).
435/// * `dtype`    — Element data type for metadata tracking.
436/// * `shape`    — Tensor dimensions for metadata tracking.
437///
438/// # Errors
439///
440/// * `MlxError::InvalidArgument` if `data` is empty.
441/// * `MlxError::BufferAllocationError` if Metal allocation fails.
442pub fn safetensors_to_metal_buffer(
443    device: &MlxDevice,
444    data: &[u8],
445    dtype: DType,
446    shape: Vec<usize>,
447) -> Result<MlxBuffer> {
448    if data.is_empty() {
449        return Err(MlxError::InvalidArgument(
450            "Cannot create Metal buffer from empty data".into(),
451        ));
452    }
453
454    let byte_len = data.len();
455    let metal_buf = device
456        .metal_device()
457        .new_buffer(byte_len as u64, MTLResourceOptions::StorageModeShared);
458
459    if metal_buf.contents().is_null() {
460        return Err(MlxError::BufferAllocationError { bytes: byte_len });
461    }
462
463    // Copy tensor bytes into the Metal buffer.
464    // SAFETY: Metal guarantees the buffer contents pointer is valid for
465    // `byte_len` bytes.  The source slice is valid for `byte_len` bytes.
466    // The regions do not overlap (one is mmap, the other is Metal allocation).
467    unsafe {
468        std::ptr::copy_nonoverlapping(data.as_ptr(), metal_buf.contents() as *mut u8, byte_len);
469    }
470
471    Ok(MlxBuffer::from_raw(metal_buf, dtype, shape))
472}
473
474// ---------------------------------------------------------------------------
475// Memory-mapped safetensors file handle
476// ---------------------------------------------------------------------------
477
478/// A memory-mapped safetensors file ready for tensor extraction.
479///
480/// This struct owns the mmap and parsed header.  Individual tensors can be
481/// loaded into Metal buffers on demand via [`load_tensor`](Self::load_tensor)
482/// or all at once via [`load_all_tensors`](Self::load_all_tensors).
483pub struct SafetensorsFile {
484    /// The memory-mapped file data.
485    #[allow(dead_code)]
486    mmap: Mmap,
487}
488
489impl std::fmt::Debug for SafetensorsFile {
490    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491        f.debug_struct("SafetensorsFile")
492            .field("mmap_len", &self.mmap.len())
493            .finish()
494    }
495}
496
497impl SafetensorsFile {
498    /// Open and memory-map a safetensors file.
499    ///
500    /// The file is mapped read-only.  No tensor data is copied until
501    /// `load_tensor` or `load_all_tensors` is called.
502    ///
503    /// # Errors
504    ///
505    /// * `MlxError::IoError` if the file cannot be opened or mapped.
506    pub fn open(path: &Path) -> Result<Self> {
507        let file = fs::File::open(path).map_err(|e| {
508            MlxError::IoError(format!("Failed to open safetensors file {}: {}", path.display(), e))
509        })?;
510
511        // SAFETY: We are mapping a regular file read-only.  The only hazard is
512        // if another process truncates the file while mapped, which is
513        // undefined behavior.  This is the standard use case for memmap2.
514        let mmap = unsafe {
515            Mmap::map(&file).map_err(|e| {
516                MlxError::IoError(format!("Failed to mmap safetensors file {}: {}", path.display(), e))
517            })?
518        };
519
520        Ok(Self { mmap })
521    }
522
523    /// Parse the safetensors header and return the deserialized view.
524    ///
525    /// This borrows from the mmap and is cheap (no tensor data is copied).
526    fn parse(&self) -> Result<SafeTensors<'_>> {
527        SafeTensors::deserialize(&self.mmap).map_err(|e| {
528            MlxError::SafetensorsError(format!("Failed to parse safetensors header: {e}"))
529        })
530    }
531
532    /// List all tensor names in the file.
533    pub fn tensor_names(&self) -> Result<Vec<String>> {
534        let st = self.parse()?;
535        Ok(st.names().into_iter().map(|s| s.to_string()).collect())
536    }
537
538    /// Load a single named tensor into a Metal buffer.
539    ///
540    /// Returns the dtype, shape, and a Metal buffer containing the raw bytes.
541    ///
542    /// # Errors
543    ///
544    /// * `MlxError::SafetensorsError` if the tensor name is not found.
545    /// * `MlxError::UnsupportedDtype` if the tensor's dtype is not supported.
546    /// * `MlxError::BufferAllocationError` if Metal allocation fails.
547    pub fn load_tensor(
548        &self,
549        name: &str,
550        device: &MlxDevice,
551    ) -> Result<(DType, Vec<usize>, MlxBuffer)> {
552        let st = self.parse()?;
553        let view = st.tensor(name).map_err(|e| {
554            MlxError::SafetensorsError(format!("Tensor '{}' not found: {}", name, e))
555        })?;
556
557        let dtype = safetensors_dtype_to_dtype(view.dtype())?;
558        let shape: Vec<usize> = view.shape().to_vec();
559        let data = view.data();
560
561        let buffer = safetensors_to_metal_buffer(device, data, dtype, shape.clone())?;
562        Ok((dtype, shape, buffer))
563    }
564
565    /// Load all tensors from the file into Metal buffers.
566    ///
567    /// Returns a map from tensor name to `(DType, shape, MlxBuffer)`.
568    ///
569    /// # Errors
570    ///
571    /// Returns the first error encountered during loading.
572    pub fn load_all_tensors(
573        &self,
574        device: &MlxDevice,
575    ) -> Result<HashMap<String, (DType, Vec<usize>, MlxBuffer)>> {
576        let st = self.parse()?;
577        let mut result = HashMap::new();
578
579        for (name, view) in st.tensors() {
580            let dtype = safetensors_dtype_to_dtype(view.dtype())?;
581            let shape: Vec<usize> = view.shape().to_vec();
582            let data = view.data();
583
584            let buffer = safetensors_to_metal_buffer(device, data, dtype, shape.clone())?;
585            result.insert(name, (dtype, shape, buffer));
586        }
587
588        Ok(result)
589    }
590}
591
592// ---------------------------------------------------------------------------
593// High-level quantized weight loading
594// ---------------------------------------------------------------------------
595
596/// Load quantized weights from a directory containing safetensors file(s) and
597/// a `quantization_config.json`.
598///
599/// This is the primary entry point for weight loading.  It:
600///
601/// 1. Reads `quantization_config.json` from the directory to determine
602///    per-tensor bit-widths and group sizes.
603/// 2. Discovers all `*.safetensors` files in the directory.
604/// 3. Memory-maps each file and loads tensors that look like quantized weight
605///    components (packed data, scales, biases) into Metal buffers.
606/// 4. Groups the components by base tensor name and constructs
607///    [`QuantizedWeight`] instances.
608///
609/// # Tensor Naming Convention
610///
611/// Quantized weights in safetensors use a naming convention:
612/// - `<base_name>.weight` — packed quantized data
613/// - `<base_name>.scales` — per-group scale factors
614/// - `<base_name>.biases` — per-group biases (optional, for affine quant)
615///
616/// # Arguments
617///
618/// * `model_dir` — Path to the directory containing safetensors files and config.
619/// * `device`    — The Metal device for buffer allocation.
620///
621/// # Errors
622///
623/// * `MlxError::IoError` if the directory or files cannot be read.
624/// * `MlxError::QuantConfigError` if the quantization config is invalid.
625/// * `MlxError::SafetensorsError` if a safetensors file is malformed.
626pub fn load_quantized_weights(
627    model_dir: &Path,
628    device: &MlxDevice,
629) -> Result<Vec<QuantizedWeight>> {
630    // 1. Load quantization config.
631    let config_path = model_dir.join("quantization_config.json");
632    let quant_config = QuantizationConfig::from_file(&config_path)?;
633
634    // 2. Discover safetensors files.
635    let safetensors_files = discover_safetensors_files(model_dir)?;
636    if safetensors_files.is_empty() {
637        return Err(MlxError::IoError(format!(
638            "No .safetensors files found in {}",
639            model_dir.display()
640        )));
641    }
642
643    // 3. Load all tensors from all files into a flat map.
644    let mut all_tensors: HashMap<String, (DType, Vec<usize>, MlxBuffer)> = HashMap::new();
645    for sf_path in &safetensors_files {
646        let sf = SafetensorsFile::open(sf_path)?;
647        let tensors = sf.load_all_tensors(device)?;
648        all_tensors.extend(tensors);
649    }
650
651    // 4. Group by base tensor name and construct QuantizedWeight instances.
652    //
653    // We look for groups of related tensors.  The convention is:
654    //   - `<base>.weight` or just `<base>` — packed quantized data
655    //   - `<base>.scales` — per-group scales (f16)
656    //   - `<base>.biases` — per-group biases (f16, optional)
657    //
658    // A tensor is considered quantized if it has a corresponding `.scales` entry.
659
660    let mut weights = Vec::new();
661    let mut processed: std::collections::HashSet<String> = std::collections::HashSet::new();
662
663    // Collect all base names that have .scales entries.
664    let scale_suffix = ".scales";
665    let scale_bases: Vec<String> = all_tensors
666        .keys()
667        .filter(|k| k.ends_with(scale_suffix))
668        .map(|k| k[..k.len() - scale_suffix.len()].to_string())
669        .collect();
670
671    for base_name in &scale_bases {
672        let scales_key = format!("{base_name}.scales");
673        let biases_key = format!("{base_name}.biases");
674
675        // The packed data might be at `<base>.weight` or just `<base>`.
676        let weight_key = if all_tensors.contains_key(&format!("{base_name}.weight")) {
677            format!("{base_name}.weight")
678        } else if all_tensors.contains_key(base_name) {
679            base_name.clone()
680        } else {
681            // Scales without a weight tensor — skip.
682            continue;
683        };
684
685        // Extract the packed data buffer.
686        let (packed_dtype, packed_shape, packed_data) = match all_tensors.remove(&weight_key) {
687            Some(t) => t,
688            None => continue,
689        };
690
691        // Extract scales buffer.
692        let (_scales_dtype, _scales_shape, scales_buf) = match all_tensors.remove(&scales_key) {
693            Some(t) => t,
694            None => continue,
695        };
696
697        // Extract biases buffer (optional).
698        let biases_buf = all_tensors.remove(&biases_key).map(|(_, _, buf)| buf);
699
700        // Look up quant config for this tensor.
701        let (bits, group_size) = quant_config.config_for_tensor(&weight_key);
702
703        weights.push(QuantizedWeight::new(
704            weight_key.clone(),
705            packed_shape,
706            packed_dtype,
707            bits,
708            group_size,
709            scales_buf,
710            biases_buf,
711            packed_data,
712        ));
713
714        processed.insert(weight_key);
715        processed.insert(scales_key);
716        processed.insert(biases_key);
717    }
718
719    Ok(weights)
720}
721
722/// Discover all `*.safetensors` files in a directory, sorted by name.
723fn discover_safetensors_files(dir: &Path) -> Result<Vec<std::path::PathBuf>> {
724    let entries = fs::read_dir(dir).map_err(|e| {
725        MlxError::IoError(format!("Failed to read directory {}: {}", dir.display(), e))
726    })?;
727
728    let mut files: Vec<std::path::PathBuf> = Vec::new();
729    for entry in entries {
730        let entry = entry.map_err(|e| {
731            MlxError::IoError(format!("Failed to read directory entry: {e}"))
732        })?;
733        let path = entry.path();
734        if path.extension().and_then(|e| e.to_str()) == Some("safetensors") {
735            files.push(path);
736        }
737    }
738
739    files.sort();
740    Ok(files)
741}
742
743// ---------------------------------------------------------------------------
744// Tests
745// ---------------------------------------------------------------------------
746
747#[cfg(test)]
748#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
749mod tests {
750    use super::*;
751    use safetensors::tensor::{Dtype as StDtype, TensorView};
752
753    // ---- QuantizedWeight construction and accessors ----
754
755    #[test]
756    fn test_quantized_weight_construction() {
757        let device = MlxDevice::new().expect("device");
758
759        // Create minimal buffers for testing.
760        let packed = device.alloc_buffer(64, DType::U32, vec![4, 4]).expect("packed");
761        let scales = device.alloc_buffer(16, DType::F16, vec![4, 2]).expect("scales");
762        let biases = device.alloc_buffer(16, DType::F16, vec![4, 2]).expect("biases");
763
764        let qw = QuantizedWeight::new(
765            "model.layers.0.self_attn.q_proj.weight".to_string(),
766            vec![2816, 2816],
767            DType::F16,
768            4,
769            64,
770            scales,
771            Some(biases),
772            packed,
773        );
774
775        assert_eq!(qw.tensor_name(), "model.layers.0.self_attn.q_proj.weight");
776        assert_eq!(qw.shape(), &[2816, 2816]);
777        assert_eq!(qw.dtype(), DType::F16);
778        assert_eq!(qw.bits(), 4);
779        assert_eq!(qw.group_size(), 64);
780        assert!(qw.biases().is_some());
781        assert_eq!(qw.element_count(), 2816 * 2816);
782        assert_eq!(qw.num_groups(), (2816 + 64 - 1) / 64);
783    }
784
785    #[test]
786    fn test_quantized_weight_no_biases() {
787        let device = MlxDevice::new().expect("device");
788
789        let packed = device.alloc_buffer(32, DType::U32, vec![4, 2]).expect("packed");
790        let scales = device.alloc_buffer(8, DType::F16, vec![4, 1]).expect("scales");
791
792        let qw = QuantizedWeight::new(
793            "test.weight".to_string(),
794            vec![128, 128],
795            DType::BF16,
796            6,
797            32,
798            scales,
799            None,
800            packed,
801        );
802
803        assert!(qw.biases().is_none());
804        assert_eq!(qw.bits(), 6);
805        assert_eq!(qw.group_size(), 32);
806        assert_eq!(qw.num_groups(), (128 + 32 - 1) / 32);
807    }
808
809    #[test]
810    fn test_quantized_weight_debug() {
811        let device = MlxDevice::new().expect("device");
812        let packed = device.alloc_buffer(16, DType::U32, vec![4]).expect("packed");
813        let scales = device.alloc_buffer(4, DType::F16, vec![2]).expect("scales");
814
815        let qw = QuantizedWeight::new(
816            "test.w".to_string(),
817            vec![64],
818            DType::F32,
819            4,
820            64,
821            scales,
822            None,
823            packed,
824        );
825
826        let debug_str = format!("{:?}", qw);
827        assert!(debug_str.contains("QuantizedWeight"));
828        assert!(debug_str.contains("test.w"));
829        assert!(debug_str.contains("bits: 4"));
830    }
831
832    // ---- QuantizationConfig parsing ----
833
834    #[test]
835    fn test_quant_config_defaults() {
836        let json = r#"{}"#;
837        let config = QuantizationConfig::from_json(json).expect("parse");
838        assert_eq!(config.bits, 4);
839        assert_eq!(config.group_size, 64);
840        assert!(config.per_tensor.is_empty());
841    }
842
843    #[test]
844    fn test_quant_config_with_per_tensor() {
845        let json = r#"{
846            "bits": 4,
847            "group_size": 64,
848            "per_tensor": {
849                "model.layers.0.self_attn.v_proj.weight": {"bits": 6, "group_size": 128},
850                "model.embed_tokens.weight": {"bits": 8, "group_size": 32}
851            }
852        }"#;
853
854        let config = QuantizationConfig::from_json(json).expect("parse");
855        assert_eq!(config.bits, 4);
856        assert_eq!(config.group_size, 64);
857
858        // Per-tensor override.
859        let (bits, gs) = config.config_for_tensor("model.layers.0.self_attn.v_proj.weight");
860        assert_eq!(bits, 6);
861        assert_eq!(gs, 128);
862
863        // Default for unknown tensor.
864        let (bits, gs) = config.config_for_tensor("model.layers.5.mlp.gate_proj.weight");
865        assert_eq!(bits, 4);
866        assert_eq!(gs, 64);
867    }
868
869    #[test]
870    fn test_quant_config_invalid_json() {
871        let result = QuantizationConfig::from_json("not json at all {{{");
872        assert!(result.is_err());
873        match result {
874            Err(MlxError::QuantConfigError(msg)) => {
875                assert!(msg.contains("parse"), "msg: {msg}");
876            }
877            other => panic!("Expected QuantConfigError, got {:?}", other),
878        }
879    }
880
881    // ---- config_for_tensor suffix/prefix stripping ----
882
883    #[test]
884    fn test_config_for_tensor_strips_weight_suffix() {
885        let json = r#"{
886            "bits": 4,
887            "group_size": 64,
888            "per_tensor": {
889                "model.layers.0.mlp.down_proj": {"bits": 8, "group_size": 64}
890            }
891        }"#;
892        let config = QuantizationConfig::from_json(json).expect("parse");
893
894        // Querying with .weight suffix should match the override without suffix.
895        let (bits, gs) = config.config_for_tensor("model.layers.0.mlp.down_proj.weight");
896        assert_eq!(bits, 8);
897        assert_eq!(gs, 64);
898
899        // Querying with .scales suffix should also match.
900        let (bits, _) = config.config_for_tensor("model.layers.0.mlp.down_proj.scales");
901        assert_eq!(bits, 8);
902
903        // Querying with .biases suffix should also match.
904        let (bits, _) = config.config_for_tensor("model.layers.0.mlp.down_proj.biases");
905        assert_eq!(bits, 8);
906    }
907
908    #[test]
909    fn test_config_for_tensor_adds_language_model_prefix() {
910        let json = r#"{
911            "bits": 4,
912            "group_size": 64,
913            "per_tensor": {
914                "language_model.model.layers.0.self_attn.v_proj": {"bits": 6, "group_size": 64}
915            }
916        }"#;
917        let config = QuantizationConfig::from_json(json).expect("parse");
918
919        // Query without the language_model. prefix — should still match.
920        let (bits, _) = config.config_for_tensor("model.layers.0.self_attn.v_proj.weight");
921        assert_eq!(bits, 6);
922    }
923
924    #[test]
925    fn test_config_for_tensor_strips_language_model_prefix() {
926        let json = r#"{
927            "bits": 4,
928            "group_size": 64,
929            "per_tensor": {
930                "model.layers.0.self_attn.v_proj": {"bits": 6, "group_size": 64}
931            }
932        }"#;
933        let config = QuantizationConfig::from_json(json).expect("parse");
934
935        // Query with the language_model. prefix — should match by stripping it.
936        let (bits, _) = config.config_for_tensor("language_model.model.layers.0.self_attn.v_proj.weight");
937        assert_eq!(bits, 6);
938    }
939
940    // ---- from_model_config_json ----
941
942    #[test]
943    fn test_from_model_config_json_basic() {
944        let json = r#"{
945            "model_type": "gemma4",
946            "quantization": {
947                "bits": 4,
948                "group_size": 64,
949                "language_model.model.layers.0.mlp.down_proj": {"bits": 8, "group_size": 64},
950                "language_model.model.layers.0.self_attn.v_proj": {"bits": 6, "group_size": 64}
951            }
952        }"#;
953
954        let config = QuantizationConfig::from_model_config_json(json).expect("parse");
955        assert_eq!(config.bits, 4);
956        assert_eq!(config.group_size, 64);
957        assert_eq!(config.per_tensor.len(), 2);
958
959        let (bits, _) = config.config_for_tensor("language_model.model.layers.0.mlp.down_proj.weight");
960        assert_eq!(bits, 8);
961
962        let (bits, _) = config.config_for_tensor("language_model.model.layers.0.self_attn.v_proj.weight");
963        assert_eq!(bits, 6);
964
965        // Unknown tensor gets default.
966        let (bits, _) = config.config_for_tensor("language_model.model.layers.5.mlp.gate_proj.weight");
967        assert_eq!(bits, 4);
968    }
969
970    #[test]
971    fn test_from_model_config_json_no_quantization_key() {
972        let json = r#"{"model_type": "gemma4"}"#;
973        let result = QuantizationConfig::from_model_config_json(json);
974        assert!(result.is_err());
975    }
976
977    // ---- DType conversion ----
978
979    #[test]
980    fn test_safetensors_dtype_conversion() {
981        assert_eq!(safetensors_dtype_to_dtype(StDtype::F32).unwrap(), DType::F32);
982        assert_eq!(safetensors_dtype_to_dtype(StDtype::F16).unwrap(), DType::F16);
983        assert_eq!(safetensors_dtype_to_dtype(StDtype::BF16).unwrap(), DType::BF16);
984        assert_eq!(safetensors_dtype_to_dtype(StDtype::U8).unwrap(), DType::U8);
985        assert_eq!(safetensors_dtype_to_dtype(StDtype::U16).unwrap(), DType::U16);
986        assert_eq!(safetensors_dtype_to_dtype(StDtype::U32).unwrap(), DType::U32);
987        assert_eq!(safetensors_dtype_to_dtype(StDtype::I32).unwrap(), DType::I32);
988    }
989
990    #[test]
991    fn test_safetensors_dtype_unsupported() {
992        let result = safetensors_dtype_to_dtype(StDtype::BOOL);
993        assert!(result.is_err());
994        match result {
995            Err(MlxError::UnsupportedDtype(_)) => {}
996            other => panic!("Expected UnsupportedDtype, got {:?}", other),
997        }
998    }
999
1000    // ---- safetensors_to_metal_buffer ----
1001
1002    #[test]
1003    fn test_safetensors_to_metal_buffer_roundtrip() {
1004        let device = MlxDevice::new().expect("device");
1005
1006        // Create test data: 4 f32 values.
1007        let values: [f32; 4] = [1.0, 2.5, -3.0, 4.125];
1008        let bytes: &[u8] = bytemuck::cast_slice(&values);
1009
1010        let buf = safetensors_to_metal_buffer(&device, bytes, DType::F32, vec![4])
1011            .expect("to_metal_buffer");
1012
1013        assert_eq!(buf.byte_len(), 16);
1014        assert_eq!(buf.dtype(), DType::F32);
1015        assert_eq!(buf.shape(), &[4]);
1016
1017        // Verify data integrity.
1018        let read_back: &[f32] = buf.as_slice().expect("as_slice");
1019        assert_eq!(read_back.len(), 4);
1020        assert_eq!(read_back[0], 1.0);
1021        assert_eq!(read_back[1], 2.5);
1022        assert_eq!(read_back[2], -3.0);
1023        assert_eq!(read_back[3], 4.125);
1024    }
1025
1026    #[test]
1027    fn test_safetensors_to_metal_buffer_empty_error() {
1028        let device = MlxDevice::new().expect("device");
1029        let result = safetensors_to_metal_buffer(&device, &[], DType::F32, vec![0]);
1030        assert!(result.is_err());
1031        match result {
1032            Err(MlxError::InvalidArgument(msg)) => {
1033                assert!(msg.contains("empty"), "msg: {msg}");
1034            }
1035            other => panic!("Expected InvalidArgument, got {:?}", other),
1036        }
1037    }
1038
1039    #[test]
1040    fn test_safetensors_to_metal_buffer_u8_data() {
1041        let device = MlxDevice::new().expect("device");
1042        let data: Vec<u8> = (0..128).collect();
1043
1044        let buf = safetensors_to_metal_buffer(&device, &data, DType::U8, vec![128])
1045            .expect("to_metal_buffer");
1046
1047        assert_eq!(buf.byte_len(), 128);
1048        let read_back: &[u8] = buf.as_slice().expect("as_slice");
1049        for (i, &val) in read_back.iter().enumerate() {
1050            assert_eq!(val, i as u8, "mismatch at index {i}");
1051        }
1052    }
1053
1054    // ---- SafetensorsFile with synthetic test file ----
1055
1056    /// Create a minimal safetensors file in a temp directory for testing.
1057    fn create_test_safetensors(dir: &Path) -> std::path::PathBuf {
1058        let path = dir.join("test_model.safetensors");
1059
1060        // Build tensors: two small f32 tensors.
1061        let tensor_a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1062        let tensor_a_bytes: &[u8] = bytemuck::cast_slice(&tensor_a_data);
1063        let tensor_b_data: Vec<f32> = vec![10.0, 20.0, 30.0];
1064        let tensor_b_bytes: &[u8] = bytemuck::cast_slice(&tensor_b_data);
1065
1066        let tensors = vec![
1067            (
1068                "layer.weight",
1069                TensorView::new(StDtype::F32, vec![2, 3], tensor_a_bytes).unwrap(),
1070            ),
1071            (
1072                "layer.bias",
1073                TensorView::new(StDtype::F32, vec![3], tensor_b_bytes).unwrap(),
1074            ),
1075        ];
1076
1077        let serialized = safetensors::tensor::serialize(tensors, None).unwrap();
1078        fs::write(&path, &serialized).unwrap();
1079
1080        path
1081    }
1082
1083    #[test]
1084    fn test_safetensors_file_open_and_list() {
1085        let tmp = tempdir();
1086        let st_path = create_test_safetensors(&tmp);
1087
1088        let sf = SafetensorsFile::open(&st_path).expect("open");
1089        let names = sf.tensor_names().expect("names");
1090
1091        assert_eq!(names.len(), 2);
1092        assert!(names.contains(&"layer.weight".to_string()));
1093        assert!(names.contains(&"layer.bias".to_string()));
1094    }
1095
1096    #[test]
1097    fn test_safetensors_file_load_tensor() {
1098        let device = MlxDevice::new().expect("device");
1099        let tmp = tempdir();
1100        let st_path = create_test_safetensors(&tmp);
1101
1102        let sf = SafetensorsFile::open(&st_path).expect("open");
1103        let (dtype, shape, buf) = sf.load_tensor("layer.weight", &device).expect("load");
1104
1105        assert_eq!(dtype, DType::F32);
1106        assert_eq!(shape, vec![2, 3]);
1107        assert_eq!(buf.byte_len(), 24); // 6 * 4 bytes
1108
1109        let data: &[f32] = buf.as_slice().expect("as_slice");
1110        assert_eq!(data, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1111    }
1112
1113    #[test]
1114    fn test_safetensors_file_load_all() {
1115        let device = MlxDevice::new().expect("device");
1116        let tmp = tempdir();
1117        let st_path = create_test_safetensors(&tmp);
1118
1119        let sf = SafetensorsFile::open(&st_path).expect("open");
1120        let all = sf.load_all_tensors(&device).expect("load_all");
1121
1122        assert_eq!(all.len(), 2);
1123
1124        let (dtype, shape, buf) = all.get("layer.bias").expect("bias");
1125        assert_eq!(*dtype, DType::F32);
1126        assert_eq!(*shape, vec![3]);
1127        let data: &[f32] = buf.as_slice().expect("as_slice");
1128        assert_eq!(data, &[10.0, 20.0, 30.0]);
1129    }
1130
1131    #[test]
1132    fn test_safetensors_file_tensor_not_found() {
1133        let tmp = tempdir();
1134        let st_path = create_test_safetensors(&tmp);
1135        let device = MlxDevice::new().expect("device");
1136
1137        let sf = SafetensorsFile::open(&st_path).expect("open");
1138        let result = sf.load_tensor("nonexistent", &device);
1139        assert!(result.is_err());
1140        match result {
1141            Err(MlxError::SafetensorsError(msg)) => {
1142                assert!(msg.contains("nonexistent"), "msg: {msg}");
1143            }
1144            other => panic!("Expected SafetensorsError, got {:?}", other),
1145        }
1146    }
1147
1148    #[test]
1149    fn test_safetensors_file_open_missing() {
1150        let result = SafetensorsFile::open(Path::new("/tmp/does_not_exist_8f3a2b1c.safetensors"));
1151        assert!(result.is_err());
1152        match result {
1153            Err(MlxError::IoError(_)) => {}
1154            other => panic!("Expected IoError, got {:?}", other),
1155        }
1156    }
1157
1158    // ---- load_quantized_weights with synthetic directory ----
1159
1160    /// Create a synthetic quantized model directory for integration testing.
1161    fn create_test_quant_dir(dir: &Path) {
1162        // Create quantization_config.json.
1163        let config_json = r#"{
1164            "bits": 4,
1165            "group_size": 64,
1166            "per_tensor": {
1167                "proj.weight": {"bits": 4, "group_size": 64}
1168            }
1169        }"#;
1170        fs::write(dir.join("quantization_config.json"), config_json).unwrap();
1171
1172        // Create a safetensors file with weight, scales, and biases tensors.
1173        //
1174        // proj.weight — packed quantized data (stored as U32)
1175        // proj.scales — per-group scale factors (stored as F16)
1176        // proj.biases — per-group biases (stored as F16)
1177        let weight_data: Vec<u32> = vec![0xAAAA_BBBB; 8]; // 8 uint32s = 32 bytes
1178        let weight_bytes: &[u8] = bytemuck::cast_slice(&weight_data);
1179
1180        // Scales: 2 f16 values (4 bytes).
1181        let scales_data: Vec<u16> = vec![0x3C00, 0x3C00]; // f16 = 1.0
1182        let scales_bytes: &[u8] = bytemuck::cast_slice(&scales_data);
1183
1184        // Biases: 2 f16 values (4 bytes).
1185        let biases_data: Vec<u16> = vec![0x0000, 0x0000]; // f16 = 0.0
1186        let biases_bytes: &[u8] = bytemuck::cast_slice(&biases_data);
1187
1188        let tensors = vec![
1189            (
1190                "proj.weight",
1191                TensorView::new(StDtype::U32, vec![2, 4], weight_bytes).unwrap(),
1192            ),
1193            (
1194                "proj.scales",
1195                TensorView::new(StDtype::F16, vec![2, 1], scales_bytes).unwrap(),
1196            ),
1197            (
1198                "proj.biases",
1199                TensorView::new(StDtype::F16, vec![2, 1], biases_bytes).unwrap(),
1200            ),
1201        ];
1202
1203        let serialized = safetensors::tensor::serialize(tensors, None).unwrap();
1204        fs::write(dir.join("model.safetensors"), &serialized).unwrap();
1205    }
1206
1207    #[test]
1208    fn test_load_quantized_weights_integration() {
1209        let device = MlxDevice::new().expect("device");
1210        let tmp = tempdir();
1211        create_test_quant_dir(&tmp);
1212
1213        let weights = load_quantized_weights(&tmp, &device).expect("load");
1214
1215        assert_eq!(weights.len(), 1);
1216        let qw = &weights[0];
1217        assert_eq!(qw.tensor_name(), "proj.weight");
1218        assert_eq!(qw.bits(), 4);
1219        assert_eq!(qw.group_size(), 64);
1220        assert_eq!(qw.packed_data().byte_len(), 32); // 8 * 4 bytes
1221        assert_eq!(qw.scales().byte_len(), 4); // 2 * 2 bytes
1222        assert!(qw.biases().is_some());
1223    }
1224
1225    #[test]
1226    fn test_load_quantized_weights_no_safetensors() {
1227        let tmp = tempdir();
1228
1229        // Create config but no safetensors files.
1230        fs::write(tmp.join("quantization_config.json"), "{}").unwrap();
1231
1232        let device = MlxDevice::new().expect("device");
1233        let result = load_quantized_weights(&tmp, &device);
1234        assert!(result.is_err());
1235        match result {
1236            Err(MlxError::IoError(msg)) => {
1237                assert!(msg.contains("No .safetensors files"), "msg: {msg}");
1238            }
1239            other => panic!("Expected IoError, got {:?}", other),
1240        }
1241    }
1242
1243    #[test]
1244    fn test_load_quantized_weights_missing_config() {
1245        let tmp = tempdir();
1246        // Create a dummy safetensors file but no config.
1247        let data: Vec<u8> = vec![0; 16];
1248        let tensors = vec![(
1249            "dummy",
1250            TensorView::new(StDtype::U8, vec![16], &data).unwrap(),
1251        )];
1252        let serialized = safetensors::tensor::serialize(tensors, None).unwrap();
1253        fs::write(tmp.join("model.safetensors"), &serialized).unwrap();
1254
1255        let device = MlxDevice::new().expect("device");
1256        let result = load_quantized_weights(&tmp, &device);
1257        assert!(result.is_err());
1258        match result {
1259            Err(MlxError::IoError(msg)) => {
1260                assert!(msg.contains("quantization_config"), "msg: {msg}");
1261            }
1262            other => panic!("Expected IoError for missing config, got {:?}", other),
1263        }
1264    }
1265
1266    // ---- Helper: create a temp directory and return its path ----
1267
1268    fn tempdir() -> std::path::PathBuf {
1269        let mut path = std::env::temp_dir();
1270        path.push(format!("mlx_native_test_{}", std::process::id()));
1271        path.push(format!("{}", std::time::SystemTime::now()
1272            .duration_since(std::time::UNIX_EPOCH)
1273            .unwrap_or_default()
1274            .as_nanos()));
1275        fs::create_dir_all(&path).expect("create temp dir");
1276        path
1277    }
1278}