use std::collections::HashMap;
use std::fs;
use std::path::Path;
use memmap2::Mmap;
use metal::MTLResourceOptions;
use safetensors::SafeTensors;
use serde::Deserialize;
use crate::buffer::MlxBuffer;
use crate::device::MlxDevice;
use crate::dtypes::DType;
use crate::error::{MlxError, Result};
#[derive(Debug, Clone, Deserialize)]
pub struct QuantizationConfig {
#[serde(default = "default_bits")]
pub bits: u8,
#[serde(default = "default_group_size")]
pub group_size: usize,
#[serde(default)]
pub per_tensor: HashMap<String, TensorQuantConfig>,
}
#[derive(Debug, Clone, Deserialize)]
pub struct TensorQuantConfig {
pub bits: u8,
pub group_size: usize,
}
fn default_bits() -> u8 {
4
}
fn default_group_size() -> usize {
64
}
fn strip_tensor_suffix(name: &str) -> &str {
for suffix in &[".weight", ".scales", ".biases"] {
if let Some(stripped) = name.strip_suffix(suffix) {
return stripped;
}
}
name
}
impl QuantizationConfig {
pub fn from_file(path: &Path) -> Result<Self> {
let contents = fs::read_to_string(path).map_err(|e| {
MlxError::IoError(format!("Failed to read quantization config at {}: {}", path.display(), e))
})?;
Self::from_json(&contents)
}
pub fn from_json(json: &str) -> Result<Self> {
serde_json::from_str(json).map_err(|e| {
MlxError::QuantConfigError(format!("Failed to parse quantization config JSON: {e}"))
})
}
pub fn from_model_config_json(json: &str) -> Result<Self> {
let root: serde_json::Value = serde_json::from_str(json).map_err(|e| {
MlxError::QuantConfigError(format!("Failed to parse config.json: {e}"))
})?;
let quant_section = root.get("quantization").ok_or_else(|| {
MlxError::QuantConfigError("No \"quantization\" key in config.json".into())
})?;
let quant_obj = quant_section.as_object().ok_or_else(|| {
MlxError::QuantConfigError("\"quantization\" is not an object".into())
})?;
let bits = quant_obj
.get("bits")
.and_then(|v| v.as_u64())
.unwrap_or(4) as u8;
let group_size = quant_obj
.get("group_size")
.and_then(|v| v.as_u64())
.unwrap_or(64) as usize;
let mut per_tensor = HashMap::new();
for (key, value) in quant_obj {
if key == "bits" || key == "group_size" || key == "quant_method" {
continue;
}
if let Some(obj) = value.as_object() {
if let Some(tensor_bits) = obj.get("bits").and_then(|v| v.as_u64()) {
let tensor_gs = obj
.get("group_size")
.and_then(|v| v.as_u64())
.unwrap_or(group_size as u64) as usize;
per_tensor.insert(
key.clone(),
TensorQuantConfig {
bits: tensor_bits as u8,
group_size: tensor_gs,
},
);
}
}
}
Ok(Self {
bits,
group_size,
per_tensor,
})
}
pub fn from_model_config_file(path: &Path) -> Result<Self> {
let contents = fs::read_to_string(path).map_err(|e| {
MlxError::IoError(format!(
"Failed to read config.json at {}: {}",
path.display(),
e
))
})?;
Self::from_model_config_json(&contents)
}
pub fn config_for_tensor(&self, tensor_name: &str) -> (u8, usize) {
if let Some(tc) = self.per_tensor.get(tensor_name) {
return (tc.bits, tc.group_size);
}
let base = strip_tensor_suffix(tensor_name);
if base != tensor_name {
if let Some(tc) = self.per_tensor.get(base) {
return (tc.bits, tc.group_size);
}
}
let lm_prefix = "language_model.";
if let Some(stripped) = tensor_name.strip_prefix(lm_prefix) {
if let Some(tc) = self.per_tensor.get(stripped) {
return (tc.bits, tc.group_size);
}
let stripped_base = strip_tensor_suffix(stripped);
if stripped_base != stripped {
if let Some(tc) = self.per_tensor.get(stripped_base) {
return (tc.bits, tc.group_size);
}
}
}
if !tensor_name.starts_with(lm_prefix) {
let with_prefix = format!("{lm_prefix}{tensor_name}");
if let Some(tc) = self.per_tensor.get(&with_prefix) {
return (tc.bits, tc.group_size);
}
let with_prefix_base = format!("{lm_prefix}{base}");
if base != tensor_name {
if let Some(tc) = self.per_tensor.get(&with_prefix_base) {
return (tc.bits, tc.group_size);
}
}
}
(self.bits, self.group_size)
}
}
pub struct QuantizedWeight {
tensor_name: String,
shape: Vec<usize>,
dtype: DType,
bits: u8,
group_size: usize,
scales: MlxBuffer,
biases: Option<MlxBuffer>,
packed_data: MlxBuffer,
}
impl QuantizedWeight {
pub fn new(
tensor_name: String,
shape: Vec<usize>,
dtype: DType,
bits: u8,
group_size: usize,
scales: MlxBuffer,
biases: Option<MlxBuffer>,
packed_data: MlxBuffer,
) -> Self {
Self {
tensor_name,
shape,
dtype,
bits,
group_size,
scales,
biases,
packed_data,
}
}
#[inline]
pub fn tensor_name(&self) -> &str {
&self.tensor_name
}
#[inline]
pub fn shape(&self) -> &[usize] {
&self.shape
}
#[inline]
pub fn dtype(&self) -> DType {
self.dtype
}
#[inline]
pub fn bits(&self) -> u8 {
self.bits
}
#[inline]
pub fn group_size(&self) -> usize {
self.group_size
}
#[inline]
pub fn scales(&self) -> &MlxBuffer {
&self.scales
}
#[inline]
pub fn biases(&self) -> Option<&MlxBuffer> {
self.biases.as_ref()
}
#[inline]
pub fn packed_data(&self) -> &MlxBuffer {
&self.packed_data
}
pub fn element_count(&self) -> usize {
self.shape.iter().copied().product()
}
pub fn num_groups(&self) -> usize {
let last_dim = self.shape.last().copied().unwrap_or(0);
if self.group_size == 0 {
return 0;
}
(last_dim + self.group_size - 1) / self.group_size
}
}
impl std::fmt::Debug for QuantizedWeight {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("QuantizedWeight")
.field("tensor_name", &self.tensor_name)
.field("shape", &self.shape)
.field("dtype", &self.dtype)
.field("bits", &self.bits)
.field("group_size", &self.group_size)
.field("packed_data_bytes", &self.packed_data.byte_len())
.field("scales_bytes", &self.scales.byte_len())
.field("has_biases", &self.biases.is_some())
.finish()
}
}
fn safetensors_dtype_to_dtype(st_dtype: safetensors::Dtype) -> Result<DType> {
match st_dtype {
safetensors::Dtype::F32 => Ok(DType::F32),
safetensors::Dtype::F16 => Ok(DType::F16),
safetensors::Dtype::BF16 => Ok(DType::BF16),
safetensors::Dtype::U8 => Ok(DType::U8),
safetensors::Dtype::U16 => Ok(DType::U16),
safetensors::Dtype::U32 => Ok(DType::U32),
safetensors::Dtype::I32 => Ok(DType::I32),
other => Err(MlxError::UnsupportedDtype(format!("{other:?}"))),
}
}
pub fn safetensors_to_metal_buffer(
device: &MlxDevice,
data: &[u8],
dtype: DType,
shape: Vec<usize>,
) -> Result<MlxBuffer> {
if data.is_empty() {
return Err(MlxError::InvalidArgument(
"Cannot create Metal buffer from empty data".into(),
));
}
let byte_len = data.len();
let metal_buf = device
.metal_device()
.new_buffer(byte_len as u64, MTLResourceOptions::StorageModeShared);
if metal_buf.contents().is_null() {
return Err(MlxError::BufferAllocationError { bytes: byte_len });
}
unsafe {
std::ptr::copy_nonoverlapping(data.as_ptr(), metal_buf.contents() as *mut u8, byte_len);
}
Ok(MlxBuffer::from_raw(metal_buf, dtype, shape))
}
pub struct SafetensorsFile {
#[allow(dead_code)]
mmap: Mmap,
}
impl std::fmt::Debug for SafetensorsFile {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SafetensorsFile")
.field("mmap_len", &self.mmap.len())
.finish()
}
}
impl SafetensorsFile {
pub fn open(path: &Path) -> Result<Self> {
let file = fs::File::open(path).map_err(|e| {
MlxError::IoError(format!("Failed to open safetensors file {}: {}", path.display(), e))
})?;
let mmap = unsafe {
Mmap::map(&file).map_err(|e| {
MlxError::IoError(format!("Failed to mmap safetensors file {}: {}", path.display(), e))
})?
};
Ok(Self { mmap })
}
fn parse(&self) -> Result<SafeTensors<'_>> {
SafeTensors::deserialize(&self.mmap).map_err(|e| {
MlxError::SafetensorsError(format!("Failed to parse safetensors header: {e}"))
})
}
pub fn tensor_names(&self) -> Result<Vec<String>> {
let st = self.parse()?;
Ok(st.names().into_iter().map(|s| s.to_string()).collect())
}
pub fn load_tensor(
&self,
name: &str,
device: &MlxDevice,
) -> Result<(DType, Vec<usize>, MlxBuffer)> {
let st = self.parse()?;
let view = st.tensor(name).map_err(|e| {
MlxError::SafetensorsError(format!("Tensor '{}' not found: {}", name, e))
})?;
let dtype = safetensors_dtype_to_dtype(view.dtype())?;
let shape: Vec<usize> = view.shape().to_vec();
let data = view.data();
let buffer = safetensors_to_metal_buffer(device, data, dtype, shape.clone())?;
Ok((dtype, shape, buffer))
}
pub fn load_all_tensors(
&self,
device: &MlxDevice,
) -> Result<HashMap<String, (DType, Vec<usize>, MlxBuffer)>> {
let st = self.parse()?;
let mut result = HashMap::new();
for (name, view) in st.tensors() {
let dtype = safetensors_dtype_to_dtype(view.dtype())?;
let shape: Vec<usize> = view.shape().to_vec();
let data = view.data();
let buffer = safetensors_to_metal_buffer(device, data, dtype, shape.clone())?;
result.insert(name, (dtype, shape, buffer));
}
Ok(result)
}
}
pub fn load_quantized_weights(
model_dir: &Path,
device: &MlxDevice,
) -> Result<Vec<QuantizedWeight>> {
let config_path = model_dir.join("quantization_config.json");
let quant_config = QuantizationConfig::from_file(&config_path)?;
let safetensors_files = discover_safetensors_files(model_dir)?;
if safetensors_files.is_empty() {
return Err(MlxError::IoError(format!(
"No .safetensors files found in {}",
model_dir.display()
)));
}
let mut all_tensors: HashMap<String, (DType, Vec<usize>, MlxBuffer)> = HashMap::new();
for sf_path in &safetensors_files {
let sf = SafetensorsFile::open(sf_path)?;
let tensors = sf.load_all_tensors(device)?;
all_tensors.extend(tensors);
}
let mut weights = Vec::new();
let mut processed: std::collections::HashSet<String> = std::collections::HashSet::new();
let scale_suffix = ".scales";
let scale_bases: Vec<String> = all_tensors
.keys()
.filter(|k| k.ends_with(scale_suffix))
.map(|k| k[..k.len() - scale_suffix.len()].to_string())
.collect();
for base_name in &scale_bases {
let scales_key = format!("{base_name}.scales");
let biases_key = format!("{base_name}.biases");
let weight_key = if all_tensors.contains_key(&format!("{base_name}.weight")) {
format!("{base_name}.weight")
} else if all_tensors.contains_key(base_name) {
base_name.clone()
} else {
continue;
};
let (packed_dtype, packed_shape, packed_data) = match all_tensors.remove(&weight_key) {
Some(t) => t,
None => continue,
};
let (_scales_dtype, _scales_shape, scales_buf) = match all_tensors.remove(&scales_key) {
Some(t) => t,
None => continue,
};
let biases_buf = all_tensors.remove(&biases_key).map(|(_, _, buf)| buf);
let (bits, group_size) = quant_config.config_for_tensor(&weight_key);
weights.push(QuantizedWeight::new(
weight_key.clone(),
packed_shape,
packed_dtype,
bits,
group_size,
scales_buf,
biases_buf,
packed_data,
));
processed.insert(weight_key);
processed.insert(scales_key);
processed.insert(biases_key);
}
Ok(weights)
}
fn discover_safetensors_files(dir: &Path) -> Result<Vec<std::path::PathBuf>> {
let entries = fs::read_dir(dir).map_err(|e| {
MlxError::IoError(format!("Failed to read directory {}: {}", dir.display(), e))
})?;
let mut files: Vec<std::path::PathBuf> = Vec::new();
for entry in entries {
let entry = entry.map_err(|e| {
MlxError::IoError(format!("Failed to read directory entry: {e}"))
})?;
let path = entry.path();
if path.extension().and_then(|e| e.to_str()) == Some("safetensors") {
files.push(path);
}
}
files.sort();
Ok(files)
}
#[cfg(test)]
#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
mod tests {
use super::*;
use safetensors::tensor::{Dtype as StDtype, TensorView};
#[test]
fn test_quantized_weight_construction() {
let device = MlxDevice::new().expect("device");
let packed = device.alloc_buffer(64, DType::U32, vec![4, 4]).expect("packed");
let scales = device.alloc_buffer(16, DType::F16, vec![4, 2]).expect("scales");
let biases = device.alloc_buffer(16, DType::F16, vec![4, 2]).expect("biases");
let qw = QuantizedWeight::new(
"model.layers.0.self_attn.q_proj.weight".to_string(),
vec![2816, 2816],
DType::F16,
4,
64,
scales,
Some(biases),
packed,
);
assert_eq!(qw.tensor_name(), "model.layers.0.self_attn.q_proj.weight");
assert_eq!(qw.shape(), &[2816, 2816]);
assert_eq!(qw.dtype(), DType::F16);
assert_eq!(qw.bits(), 4);
assert_eq!(qw.group_size(), 64);
assert!(qw.biases().is_some());
assert_eq!(qw.element_count(), 2816 * 2816);
assert_eq!(qw.num_groups(), (2816 + 64 - 1) / 64);
}
#[test]
fn test_quantized_weight_no_biases() {
let device = MlxDevice::new().expect("device");
let packed = device.alloc_buffer(32, DType::U32, vec![4, 2]).expect("packed");
let scales = device.alloc_buffer(8, DType::F16, vec![4, 1]).expect("scales");
let qw = QuantizedWeight::new(
"test.weight".to_string(),
vec![128, 128],
DType::BF16,
6,
32,
scales,
None,
packed,
);
assert!(qw.biases().is_none());
assert_eq!(qw.bits(), 6);
assert_eq!(qw.group_size(), 32);
assert_eq!(qw.num_groups(), (128 + 32 - 1) / 32);
}
#[test]
fn test_quantized_weight_debug() {
let device = MlxDevice::new().expect("device");
let packed = device.alloc_buffer(16, DType::U32, vec![4]).expect("packed");
let scales = device.alloc_buffer(4, DType::F16, vec![2]).expect("scales");
let qw = QuantizedWeight::new(
"test.w".to_string(),
vec![64],
DType::F32,
4,
64,
scales,
None,
packed,
);
let debug_str = format!("{:?}", qw);
assert!(debug_str.contains("QuantizedWeight"));
assert!(debug_str.contains("test.w"));
assert!(debug_str.contains("bits: 4"));
}
#[test]
fn test_quant_config_defaults() {
let json = r#"{}"#;
let config = QuantizationConfig::from_json(json).expect("parse");
assert_eq!(config.bits, 4);
assert_eq!(config.group_size, 64);
assert!(config.per_tensor.is_empty());
}
#[test]
fn test_quant_config_with_per_tensor() {
let json = r#"{
"bits": 4,
"group_size": 64,
"per_tensor": {
"model.layers.0.self_attn.v_proj.weight": {"bits": 6, "group_size": 128},
"model.embed_tokens.weight": {"bits": 8, "group_size": 32}
}
}"#;
let config = QuantizationConfig::from_json(json).expect("parse");
assert_eq!(config.bits, 4);
assert_eq!(config.group_size, 64);
let (bits, gs) = config.config_for_tensor("model.layers.0.self_attn.v_proj.weight");
assert_eq!(bits, 6);
assert_eq!(gs, 128);
let (bits, gs) = config.config_for_tensor("model.layers.5.mlp.gate_proj.weight");
assert_eq!(bits, 4);
assert_eq!(gs, 64);
}
#[test]
fn test_quant_config_invalid_json() {
let result = QuantizationConfig::from_json("not json at all {{{");
assert!(result.is_err());
match result {
Err(MlxError::QuantConfigError(msg)) => {
assert!(msg.contains("parse"), "msg: {msg}");
}
other => panic!("Expected QuantConfigError, got {:?}", other),
}
}
#[test]
fn test_config_for_tensor_strips_weight_suffix() {
let json = r#"{
"bits": 4,
"group_size": 64,
"per_tensor": {
"model.layers.0.mlp.down_proj": {"bits": 8, "group_size": 64}
}
}"#;
let config = QuantizationConfig::from_json(json).expect("parse");
let (bits, gs) = config.config_for_tensor("model.layers.0.mlp.down_proj.weight");
assert_eq!(bits, 8);
assert_eq!(gs, 64);
let (bits, _) = config.config_for_tensor("model.layers.0.mlp.down_proj.scales");
assert_eq!(bits, 8);
let (bits, _) = config.config_for_tensor("model.layers.0.mlp.down_proj.biases");
assert_eq!(bits, 8);
}
#[test]
fn test_config_for_tensor_adds_language_model_prefix() {
let json = r#"{
"bits": 4,
"group_size": 64,
"per_tensor": {
"language_model.model.layers.0.self_attn.v_proj": {"bits": 6, "group_size": 64}
}
}"#;
let config = QuantizationConfig::from_json(json).expect("parse");
let (bits, _) = config.config_for_tensor("model.layers.0.self_attn.v_proj.weight");
assert_eq!(bits, 6);
}
#[test]
fn test_config_for_tensor_strips_language_model_prefix() {
let json = r#"{
"bits": 4,
"group_size": 64,
"per_tensor": {
"model.layers.0.self_attn.v_proj": {"bits": 6, "group_size": 64}
}
}"#;
let config = QuantizationConfig::from_json(json).expect("parse");
let (bits, _) = config.config_for_tensor("language_model.model.layers.0.self_attn.v_proj.weight");
assert_eq!(bits, 6);
}
#[test]
fn test_from_model_config_json_basic() {
let json = r#"{
"model_type": "gemma4",
"quantization": {
"bits": 4,
"group_size": 64,
"language_model.model.layers.0.mlp.down_proj": {"bits": 8, "group_size": 64},
"language_model.model.layers.0.self_attn.v_proj": {"bits": 6, "group_size": 64}
}
}"#;
let config = QuantizationConfig::from_model_config_json(json).expect("parse");
assert_eq!(config.bits, 4);
assert_eq!(config.group_size, 64);
assert_eq!(config.per_tensor.len(), 2);
let (bits, _) = config.config_for_tensor("language_model.model.layers.0.mlp.down_proj.weight");
assert_eq!(bits, 8);
let (bits, _) = config.config_for_tensor("language_model.model.layers.0.self_attn.v_proj.weight");
assert_eq!(bits, 6);
let (bits, _) = config.config_for_tensor("language_model.model.layers.5.mlp.gate_proj.weight");
assert_eq!(bits, 4);
}
#[test]
fn test_from_model_config_json_no_quantization_key() {
let json = r#"{"model_type": "gemma4"}"#;
let result = QuantizationConfig::from_model_config_json(json);
assert!(result.is_err());
}
#[test]
fn test_safetensors_dtype_conversion() {
assert_eq!(safetensors_dtype_to_dtype(StDtype::F32).unwrap(), DType::F32);
assert_eq!(safetensors_dtype_to_dtype(StDtype::F16).unwrap(), DType::F16);
assert_eq!(safetensors_dtype_to_dtype(StDtype::BF16).unwrap(), DType::BF16);
assert_eq!(safetensors_dtype_to_dtype(StDtype::U8).unwrap(), DType::U8);
assert_eq!(safetensors_dtype_to_dtype(StDtype::U16).unwrap(), DType::U16);
assert_eq!(safetensors_dtype_to_dtype(StDtype::U32).unwrap(), DType::U32);
assert_eq!(safetensors_dtype_to_dtype(StDtype::I32).unwrap(), DType::I32);
}
#[test]
fn test_safetensors_dtype_unsupported() {
let result = safetensors_dtype_to_dtype(StDtype::BOOL);
assert!(result.is_err());
match result {
Err(MlxError::UnsupportedDtype(_)) => {}
other => panic!("Expected UnsupportedDtype, got {:?}", other),
}
}
#[test]
fn test_safetensors_to_metal_buffer_roundtrip() {
let device = MlxDevice::new().expect("device");
let values: [f32; 4] = [1.0, 2.5, -3.0, 4.125];
let bytes: &[u8] = bytemuck::cast_slice(&values);
let buf = safetensors_to_metal_buffer(&device, bytes, DType::F32, vec![4])
.expect("to_metal_buffer");
assert_eq!(buf.byte_len(), 16);
assert_eq!(buf.dtype(), DType::F32);
assert_eq!(buf.shape(), &[4]);
let read_back: &[f32] = buf.as_slice().expect("as_slice");
assert_eq!(read_back.len(), 4);
assert_eq!(read_back[0], 1.0);
assert_eq!(read_back[1], 2.5);
assert_eq!(read_back[2], -3.0);
assert_eq!(read_back[3], 4.125);
}
#[test]
fn test_safetensors_to_metal_buffer_empty_error() {
let device = MlxDevice::new().expect("device");
let result = safetensors_to_metal_buffer(&device, &[], DType::F32, vec![0]);
assert!(result.is_err());
match result {
Err(MlxError::InvalidArgument(msg)) => {
assert!(msg.contains("empty"), "msg: {msg}");
}
other => panic!("Expected InvalidArgument, got {:?}", other),
}
}
#[test]
fn test_safetensors_to_metal_buffer_u8_data() {
let device = MlxDevice::new().expect("device");
let data: Vec<u8> = (0..128).collect();
let buf = safetensors_to_metal_buffer(&device, &data, DType::U8, vec![128])
.expect("to_metal_buffer");
assert_eq!(buf.byte_len(), 128);
let read_back: &[u8] = buf.as_slice().expect("as_slice");
for (i, &val) in read_back.iter().enumerate() {
assert_eq!(val, i as u8, "mismatch at index {i}");
}
}
fn create_test_safetensors(dir: &Path) -> std::path::PathBuf {
let path = dir.join("test_model.safetensors");
let tensor_a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let tensor_a_bytes: &[u8] = bytemuck::cast_slice(&tensor_a_data);
let tensor_b_data: Vec<f32> = vec![10.0, 20.0, 30.0];
let tensor_b_bytes: &[u8] = bytemuck::cast_slice(&tensor_b_data);
let tensors = vec![
(
"layer.weight",
TensorView::new(StDtype::F32, vec![2, 3], tensor_a_bytes).unwrap(),
),
(
"layer.bias",
TensorView::new(StDtype::F32, vec![3], tensor_b_bytes).unwrap(),
),
];
let serialized = safetensors::tensor::serialize(tensors, None).unwrap();
fs::write(&path, &serialized).unwrap();
path
}
#[test]
fn test_safetensors_file_open_and_list() {
let tmp = tempdir();
let st_path = create_test_safetensors(&tmp);
let sf = SafetensorsFile::open(&st_path).expect("open");
let names = sf.tensor_names().expect("names");
assert_eq!(names.len(), 2);
assert!(names.contains(&"layer.weight".to_string()));
assert!(names.contains(&"layer.bias".to_string()));
}
#[test]
fn test_safetensors_file_load_tensor() {
let device = MlxDevice::new().expect("device");
let tmp = tempdir();
let st_path = create_test_safetensors(&tmp);
let sf = SafetensorsFile::open(&st_path).expect("open");
let (dtype, shape, buf) = sf.load_tensor("layer.weight", &device).expect("load");
assert_eq!(dtype, DType::F32);
assert_eq!(shape, vec![2, 3]);
assert_eq!(buf.byte_len(), 24);
let data: &[f32] = buf.as_slice().expect("as_slice");
assert_eq!(data, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_safetensors_file_load_all() {
let device = MlxDevice::new().expect("device");
let tmp = tempdir();
let st_path = create_test_safetensors(&tmp);
let sf = SafetensorsFile::open(&st_path).expect("open");
let all = sf.load_all_tensors(&device).expect("load_all");
assert_eq!(all.len(), 2);
let (dtype, shape, buf) = all.get("layer.bias").expect("bias");
assert_eq!(*dtype, DType::F32);
assert_eq!(*shape, vec![3]);
let data: &[f32] = buf.as_slice().expect("as_slice");
assert_eq!(data, &[10.0, 20.0, 30.0]);
}
#[test]
fn test_safetensors_file_tensor_not_found() {
let tmp = tempdir();
let st_path = create_test_safetensors(&tmp);
let device = MlxDevice::new().expect("device");
let sf = SafetensorsFile::open(&st_path).expect("open");
let result = sf.load_tensor("nonexistent", &device);
assert!(result.is_err());
match result {
Err(MlxError::SafetensorsError(msg)) => {
assert!(msg.contains("nonexistent"), "msg: {msg}");
}
other => panic!("Expected SafetensorsError, got {:?}", other),
}
}
#[test]
fn test_safetensors_file_open_missing() {
let result = SafetensorsFile::open(Path::new("/tmp/does_not_exist_8f3a2b1c.safetensors"));
assert!(result.is_err());
match result {
Err(MlxError::IoError(_)) => {}
other => panic!("Expected IoError, got {:?}", other),
}
}
fn create_test_quant_dir(dir: &Path) {
let config_json = r#"{
"bits": 4,
"group_size": 64,
"per_tensor": {
"proj.weight": {"bits": 4, "group_size": 64}
}
}"#;
fs::write(dir.join("quantization_config.json"), config_json).unwrap();
let weight_data: Vec<u32> = vec![0xAAAA_BBBB; 8]; let weight_bytes: &[u8] = bytemuck::cast_slice(&weight_data);
let scales_data: Vec<u16> = vec![0x3C00, 0x3C00]; let scales_bytes: &[u8] = bytemuck::cast_slice(&scales_data);
let biases_data: Vec<u16> = vec![0x0000, 0x0000]; let biases_bytes: &[u8] = bytemuck::cast_slice(&biases_data);
let tensors = vec![
(
"proj.weight",
TensorView::new(StDtype::U32, vec![2, 4], weight_bytes).unwrap(),
),
(
"proj.scales",
TensorView::new(StDtype::F16, vec![2, 1], scales_bytes).unwrap(),
),
(
"proj.biases",
TensorView::new(StDtype::F16, vec![2, 1], biases_bytes).unwrap(),
),
];
let serialized = safetensors::tensor::serialize(tensors, None).unwrap();
fs::write(dir.join("model.safetensors"), &serialized).unwrap();
}
#[test]
fn test_load_quantized_weights_integration() {
let device = MlxDevice::new().expect("device");
let tmp = tempdir();
create_test_quant_dir(&tmp);
let weights = load_quantized_weights(&tmp, &device).expect("load");
assert_eq!(weights.len(), 1);
let qw = &weights[0];
assert_eq!(qw.tensor_name(), "proj.weight");
assert_eq!(qw.bits(), 4);
assert_eq!(qw.group_size(), 64);
assert_eq!(qw.packed_data().byte_len(), 32); assert_eq!(qw.scales().byte_len(), 4); assert!(qw.biases().is_some());
}
#[test]
fn test_load_quantized_weights_no_safetensors() {
let tmp = tempdir();
fs::write(tmp.join("quantization_config.json"), "{}").unwrap();
let device = MlxDevice::new().expect("device");
let result = load_quantized_weights(&tmp, &device);
assert!(result.is_err());
match result {
Err(MlxError::IoError(msg)) => {
assert!(msg.contains("No .safetensors files"), "msg: {msg}");
}
other => panic!("Expected IoError, got {:?}", other),
}
}
#[test]
fn test_load_quantized_weights_missing_config() {
let tmp = tempdir();
let data: Vec<u8> = vec![0; 16];
let tensors = vec![(
"dummy",
TensorView::new(StDtype::U8, vec![16], &data).unwrap(),
)];
let serialized = safetensors::tensor::serialize(tensors, None).unwrap();
fs::write(tmp.join("model.safetensors"), &serialized).unwrap();
let device = MlxDevice::new().expect("device");
let result = load_quantized_weights(&tmp, &device);
assert!(result.is_err());
match result {
Err(MlxError::IoError(msg)) => {
assert!(msg.contains("quantization_config"), "msg: {msg}");
}
other => panic!("Expected IoError for missing config, got {:?}", other),
}
}
fn tempdir() -> std::path::PathBuf {
let mut path = std::env::temp_dir();
path.push(format!("mlx_native_test_{}", std::process::id()));
path.push(format!("{}", std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()));
fs::create_dir_all(&path).expect("create temp dir");
path
}
}