1use std::collections::HashMap;
27use std::fs;
28use std::path::Path;
29
30use memmap2::Mmap;
31use metal::MTLResourceOptions;
32use safetensors::SafeTensors;
33use serde::Deserialize;
34
35use crate::buffer::MlxBuffer;
36use crate::device::MlxDevice;
37use crate::dtypes::DType;
38use crate::error::{MlxError, Result};
39
40#[derive(Debug, Clone, Deserialize)]
49pub struct QuantizationConfig {
50 #[serde(default = "default_bits")]
52 pub bits: u8,
53
54 #[serde(default = "default_group_size")]
56 pub group_size: usize,
57
58 #[serde(default)]
61 pub per_tensor: HashMap<String, TensorQuantConfig>,
62}
63
64#[derive(Debug, Clone, Deserialize)]
66pub struct TensorQuantConfig {
67 pub bits: u8,
69 pub group_size: usize,
71}
72
73fn default_bits() -> u8 {
74 4
75}
76
77fn default_group_size() -> usize {
78 64
79}
80
81fn strip_tensor_suffix(name: &str) -> &str {
84 for suffix in &[".weight", ".scales", ".biases"] {
85 if let Some(stripped) = name.strip_suffix(suffix) {
86 return stripped;
87 }
88 }
89 name
90}
91
92impl QuantizationConfig {
93 pub fn from_file(path: &Path) -> Result<Self> {
100 let contents = fs::read_to_string(path).map_err(|e| {
101 MlxError::IoError(format!("Failed to read quantization config at {}: {}", path.display(), e))
102 })?;
103 Self::from_json(&contents)
104 }
105
106 pub fn from_json(json: &str) -> Result<Self> {
112 serde_json::from_str(json).map_err(|e| {
113 MlxError::QuantConfigError(format!("Failed to parse quantization config JSON: {e}"))
114 })
115 }
116
117 pub fn from_model_config_json(json: &str) -> Result<Self> {
136 let root: serde_json::Value = serde_json::from_str(json).map_err(|e| {
138 MlxError::QuantConfigError(format!("Failed to parse config.json: {e}"))
139 })?;
140
141 let quant_section = root.get("quantization").ok_or_else(|| {
142 MlxError::QuantConfigError("No \"quantization\" key in config.json".into())
143 })?;
144
145 let quant_obj = quant_section.as_object().ok_or_else(|| {
146 MlxError::QuantConfigError("\"quantization\" is not an object".into())
147 })?;
148
149 let bits = quant_obj
150 .get("bits")
151 .and_then(|v| v.as_u64())
152 .unwrap_or(4) as u8;
153
154 let group_size = quant_obj
155 .get("group_size")
156 .and_then(|v| v.as_u64())
157 .unwrap_or(64) as usize;
158
159 let mut per_tensor = HashMap::new();
161 for (key, value) in quant_obj {
162 if key == "bits" || key == "group_size" || key == "quant_method" {
163 continue;
164 }
165 if let Some(obj) = value.as_object() {
166 if let Some(tensor_bits) = obj.get("bits").and_then(|v| v.as_u64()) {
167 let tensor_gs = obj
168 .get("group_size")
169 .and_then(|v| v.as_u64())
170 .unwrap_or(group_size as u64) as usize;
171 per_tensor.insert(
172 key.clone(),
173 TensorQuantConfig {
174 bits: tensor_bits as u8,
175 group_size: tensor_gs,
176 },
177 );
178 }
179 }
180 }
181
182 Ok(Self {
183 bits,
184 group_size,
185 per_tensor,
186 })
187 }
188
189 pub fn from_model_config_file(path: &Path) -> Result<Self> {
191 let contents = fs::read_to_string(path).map_err(|e| {
192 MlxError::IoError(format!(
193 "Failed to read config.json at {}: {}",
194 path.display(),
195 e
196 ))
197 })?;
198 Self::from_model_config_json(&contents)
199 }
200
201 pub fn config_for_tensor(&self, tensor_name: &str) -> (u8, usize) {
211 if let Some(tc) = self.per_tensor.get(tensor_name) {
213 return (tc.bits, tc.group_size);
214 }
215
216 let base = strip_tensor_suffix(tensor_name);
218 if base != tensor_name {
219 if let Some(tc) = self.per_tensor.get(base) {
220 return (tc.bits, tc.group_size);
221 }
222 }
223
224 let lm_prefix = "language_model.";
226 if let Some(stripped) = tensor_name.strip_prefix(lm_prefix) {
227 if let Some(tc) = self.per_tensor.get(stripped) {
228 return (tc.bits, tc.group_size);
229 }
230 let stripped_base = strip_tensor_suffix(stripped);
231 if stripped_base != stripped {
232 if let Some(tc) = self.per_tensor.get(stripped_base) {
233 return (tc.bits, tc.group_size);
234 }
235 }
236 }
237
238 if !tensor_name.starts_with(lm_prefix) {
240 let with_prefix = format!("{lm_prefix}{tensor_name}");
241 if let Some(tc) = self.per_tensor.get(&with_prefix) {
242 return (tc.bits, tc.group_size);
243 }
244 let with_prefix_base = format!("{lm_prefix}{base}");
245 if base != tensor_name {
246 if let Some(tc) = self.per_tensor.get(&with_prefix_base) {
247 return (tc.bits, tc.group_size);
248 }
249 }
250 }
251
252 (self.bits, self.group_size)
253 }
254}
255
256pub struct QuantizedWeight {
273 tensor_name: String,
275 shape: Vec<usize>,
277 dtype: DType,
279 bits: u8,
281 group_size: usize,
283 scales: MlxBuffer,
285 biases: Option<MlxBuffer>,
287 packed_data: MlxBuffer,
289}
290
291impl QuantizedWeight {
292 pub fn new(
298 tensor_name: String,
299 shape: Vec<usize>,
300 dtype: DType,
301 bits: u8,
302 group_size: usize,
303 scales: MlxBuffer,
304 biases: Option<MlxBuffer>,
305 packed_data: MlxBuffer,
306 ) -> Self {
307 Self {
308 tensor_name,
309 shape,
310 dtype,
311 bits,
312 group_size,
313 scales,
314 biases,
315 packed_data,
316 }
317 }
318
319 #[inline]
321 pub fn tensor_name(&self) -> &str {
322 &self.tensor_name
323 }
324
325 #[inline]
327 pub fn shape(&self) -> &[usize] {
328 &self.shape
329 }
330
331 #[inline]
333 pub fn dtype(&self) -> DType {
334 self.dtype
335 }
336
337 #[inline]
339 pub fn bits(&self) -> u8 {
340 self.bits
341 }
342
343 #[inline]
345 pub fn group_size(&self) -> usize {
346 self.group_size
347 }
348
349 #[inline]
351 pub fn scales(&self) -> &MlxBuffer {
352 &self.scales
353 }
354
355 #[inline]
357 pub fn biases(&self) -> Option<&MlxBuffer> {
358 self.biases.as_ref()
359 }
360
361 #[inline]
363 pub fn packed_data(&self) -> &MlxBuffer {
364 &self.packed_data
365 }
366
367 pub fn element_count(&self) -> usize {
369 self.shape.iter().copied().product()
370 }
371
372 pub fn num_groups(&self) -> usize {
376 let last_dim = self.shape.last().copied().unwrap_or(0);
377 if self.group_size == 0 {
378 return 0;
379 }
380 (last_dim + self.group_size - 1) / self.group_size
381 }
382}
383
384impl std::fmt::Debug for QuantizedWeight {
385 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
386 f.debug_struct("QuantizedWeight")
387 .field("tensor_name", &self.tensor_name)
388 .field("shape", &self.shape)
389 .field("dtype", &self.dtype)
390 .field("bits", &self.bits)
391 .field("group_size", &self.group_size)
392 .field("packed_data_bytes", &self.packed_data.byte_len())
393 .field("scales_bytes", &self.scales.byte_len())
394 .field("has_biases", &self.biases.is_some())
395 .finish()
396 }
397}
398
399fn safetensors_dtype_to_dtype(st_dtype: safetensors::Dtype) -> Result<DType> {
407 match st_dtype {
408 safetensors::Dtype::F32 => Ok(DType::F32),
409 safetensors::Dtype::F16 => Ok(DType::F16),
410 safetensors::Dtype::BF16 => Ok(DType::BF16),
411 safetensors::Dtype::U8 => Ok(DType::U8),
412 safetensors::Dtype::U16 => Ok(DType::U16),
413 safetensors::Dtype::U32 => Ok(DType::U32),
414 safetensors::Dtype::I32 => Ok(DType::I32),
415 other => Err(MlxError::UnsupportedDtype(format!("{other:?}"))),
416 }
417}
418
419pub fn safetensors_to_metal_buffer(
443 device: &MlxDevice,
444 data: &[u8],
445 dtype: DType,
446 shape: Vec<usize>,
447) -> Result<MlxBuffer> {
448 if data.is_empty() {
449 return Err(MlxError::InvalidArgument(
450 "Cannot create Metal buffer from empty data".into(),
451 ));
452 }
453
454 let byte_len = data.len();
455 let metal_buf = device
456 .metal_device()
457 .new_buffer(byte_len as u64, MTLResourceOptions::StorageModeShared);
458
459 if metal_buf.contents().is_null() {
460 return Err(MlxError::BufferAllocationError { bytes: byte_len });
461 }
462
463 unsafe {
468 std::ptr::copy_nonoverlapping(data.as_ptr(), metal_buf.contents() as *mut u8, byte_len);
469 }
470
471 Ok(MlxBuffer::from_raw(metal_buf, dtype, shape))
472}
473
474pub struct SafetensorsFile {
484 #[allow(dead_code)]
486 mmap: Mmap,
487}
488
489impl std::fmt::Debug for SafetensorsFile {
490 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
491 f.debug_struct("SafetensorsFile")
492 .field("mmap_len", &self.mmap.len())
493 .finish()
494 }
495}
496
497impl SafetensorsFile {
498 pub fn open(path: &Path) -> Result<Self> {
507 let file = fs::File::open(path).map_err(|e| {
508 MlxError::IoError(format!("Failed to open safetensors file {}: {}", path.display(), e))
509 })?;
510
511 let mmap = unsafe {
515 Mmap::map(&file).map_err(|e| {
516 MlxError::IoError(format!("Failed to mmap safetensors file {}: {}", path.display(), e))
517 })?
518 };
519
520 Ok(Self { mmap })
521 }
522
523 fn parse(&self) -> Result<SafeTensors<'_>> {
527 SafeTensors::deserialize(&self.mmap).map_err(|e| {
528 MlxError::SafetensorsError(format!("Failed to parse safetensors header: {e}"))
529 })
530 }
531
532 pub fn tensor_names(&self) -> Result<Vec<String>> {
534 let st = self.parse()?;
535 Ok(st.names().into_iter().map(|s| s.to_string()).collect())
536 }
537
538 pub fn load_tensor(
548 &self,
549 name: &str,
550 device: &MlxDevice,
551 ) -> Result<(DType, Vec<usize>, MlxBuffer)> {
552 let st = self.parse()?;
553 let view = st.tensor(name).map_err(|e| {
554 MlxError::SafetensorsError(format!("Tensor '{}' not found: {}", name, e))
555 })?;
556
557 let dtype = safetensors_dtype_to_dtype(view.dtype())?;
558 let shape: Vec<usize> = view.shape().to_vec();
559 let data = view.data();
560
561 let buffer = safetensors_to_metal_buffer(device, data, dtype, shape.clone())?;
562 Ok((dtype, shape, buffer))
563 }
564
565 pub fn load_all_tensors(
573 &self,
574 device: &MlxDevice,
575 ) -> Result<HashMap<String, (DType, Vec<usize>, MlxBuffer)>> {
576 let st = self.parse()?;
577 let mut result = HashMap::new();
578
579 for (name, view) in st.tensors() {
580 let dtype = safetensors_dtype_to_dtype(view.dtype())?;
581 let shape: Vec<usize> = view.shape().to_vec();
582 let data = view.data();
583
584 let buffer = safetensors_to_metal_buffer(device, data, dtype, shape.clone())?;
585 result.insert(name, (dtype, shape, buffer));
586 }
587
588 Ok(result)
589 }
590}
591
592pub fn load_quantized_weights(
627 model_dir: &Path,
628 device: &MlxDevice,
629) -> Result<Vec<QuantizedWeight>> {
630 let config_path = model_dir.join("quantization_config.json");
632 let quant_config = QuantizationConfig::from_file(&config_path)?;
633
634 let safetensors_files = discover_safetensors_files(model_dir)?;
636 if safetensors_files.is_empty() {
637 return Err(MlxError::IoError(format!(
638 "No .safetensors files found in {}",
639 model_dir.display()
640 )));
641 }
642
643 let mut all_tensors: HashMap<String, (DType, Vec<usize>, MlxBuffer)> = HashMap::new();
645 for sf_path in &safetensors_files {
646 let sf = SafetensorsFile::open(sf_path)?;
647 let tensors = sf.load_all_tensors(device)?;
648 all_tensors.extend(tensors);
649 }
650
651 let mut weights = Vec::new();
661 let mut processed: std::collections::HashSet<String> = std::collections::HashSet::new();
662
663 let scale_suffix = ".scales";
665 let scale_bases: Vec<String> = all_tensors
666 .keys()
667 .filter(|k| k.ends_with(scale_suffix))
668 .map(|k| k[..k.len() - scale_suffix.len()].to_string())
669 .collect();
670
671 for base_name in &scale_bases {
672 let scales_key = format!("{base_name}.scales");
673 let biases_key = format!("{base_name}.biases");
674
675 let weight_key = if all_tensors.contains_key(&format!("{base_name}.weight")) {
677 format!("{base_name}.weight")
678 } else if all_tensors.contains_key(base_name) {
679 base_name.clone()
680 } else {
681 continue;
683 };
684
685 let (packed_dtype, packed_shape, packed_data) = match all_tensors.remove(&weight_key) {
687 Some(t) => t,
688 None => continue,
689 };
690
691 let (_scales_dtype, _scales_shape, scales_buf) = match all_tensors.remove(&scales_key) {
693 Some(t) => t,
694 None => continue,
695 };
696
697 let biases_buf = all_tensors.remove(&biases_key).map(|(_, _, buf)| buf);
699
700 let (bits, group_size) = quant_config.config_for_tensor(&weight_key);
702
703 weights.push(QuantizedWeight::new(
704 weight_key.clone(),
705 packed_shape,
706 packed_dtype,
707 bits,
708 group_size,
709 scales_buf,
710 biases_buf,
711 packed_data,
712 ));
713
714 processed.insert(weight_key);
715 processed.insert(scales_key);
716 processed.insert(biases_key);
717 }
718
719 Ok(weights)
720}
721
722fn discover_safetensors_files(dir: &Path) -> Result<Vec<std::path::PathBuf>> {
724 let entries = fs::read_dir(dir).map_err(|e| {
725 MlxError::IoError(format!("Failed to read directory {}: {}", dir.display(), e))
726 })?;
727
728 let mut files: Vec<std::path::PathBuf> = Vec::new();
729 for entry in entries {
730 let entry = entry.map_err(|e| {
731 MlxError::IoError(format!("Failed to read directory entry: {e}"))
732 })?;
733 let path = entry.path();
734 if path.extension().and_then(|e| e.to_str()) == Some("safetensors") {
735 files.push(path);
736 }
737 }
738
739 files.sort();
740 Ok(files)
741}
742
743#[cfg(test)]
748#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
749mod tests {
750 use super::*;
751 use safetensors::tensor::{Dtype as StDtype, TensorView};
752
753 #[test]
756 fn test_quantized_weight_construction() {
757 let device = MlxDevice::new().expect("device");
758
759 let packed = device.alloc_buffer(64, DType::U32, vec![4, 4]).expect("packed");
761 let scales = device.alloc_buffer(16, DType::F16, vec![4, 2]).expect("scales");
762 let biases = device.alloc_buffer(16, DType::F16, vec![4, 2]).expect("biases");
763
764 let qw = QuantizedWeight::new(
765 "model.layers.0.self_attn.q_proj.weight".to_string(),
766 vec![2816, 2816],
767 DType::F16,
768 4,
769 64,
770 scales,
771 Some(biases),
772 packed,
773 );
774
775 assert_eq!(qw.tensor_name(), "model.layers.0.self_attn.q_proj.weight");
776 assert_eq!(qw.shape(), &[2816, 2816]);
777 assert_eq!(qw.dtype(), DType::F16);
778 assert_eq!(qw.bits(), 4);
779 assert_eq!(qw.group_size(), 64);
780 assert!(qw.biases().is_some());
781 assert_eq!(qw.element_count(), 2816 * 2816);
782 assert_eq!(qw.num_groups(), (2816 + 64 - 1) / 64);
783 }
784
785 #[test]
786 fn test_quantized_weight_no_biases() {
787 let device = MlxDevice::new().expect("device");
788
789 let packed = device.alloc_buffer(32, DType::U32, vec![4, 2]).expect("packed");
790 let scales = device.alloc_buffer(8, DType::F16, vec![4, 1]).expect("scales");
791
792 let qw = QuantizedWeight::new(
793 "test.weight".to_string(),
794 vec![128, 128],
795 DType::BF16,
796 6,
797 32,
798 scales,
799 None,
800 packed,
801 );
802
803 assert!(qw.biases().is_none());
804 assert_eq!(qw.bits(), 6);
805 assert_eq!(qw.group_size(), 32);
806 assert_eq!(qw.num_groups(), (128 + 32 - 1) / 32);
807 }
808
809 #[test]
810 fn test_quantized_weight_debug() {
811 let device = MlxDevice::new().expect("device");
812 let packed = device.alloc_buffer(16, DType::U32, vec![4]).expect("packed");
813 let scales = device.alloc_buffer(4, DType::F16, vec![2]).expect("scales");
814
815 let qw = QuantizedWeight::new(
816 "test.w".to_string(),
817 vec![64],
818 DType::F32,
819 4,
820 64,
821 scales,
822 None,
823 packed,
824 );
825
826 let debug_str = format!("{:?}", qw);
827 assert!(debug_str.contains("QuantizedWeight"));
828 assert!(debug_str.contains("test.w"));
829 assert!(debug_str.contains("bits: 4"));
830 }
831
832 #[test]
835 fn test_quant_config_defaults() {
836 let json = r#"{}"#;
837 let config = QuantizationConfig::from_json(json).expect("parse");
838 assert_eq!(config.bits, 4);
839 assert_eq!(config.group_size, 64);
840 assert!(config.per_tensor.is_empty());
841 }
842
843 #[test]
844 fn test_quant_config_with_per_tensor() {
845 let json = r#"{
846 "bits": 4,
847 "group_size": 64,
848 "per_tensor": {
849 "model.layers.0.self_attn.v_proj.weight": {"bits": 6, "group_size": 128},
850 "model.embed_tokens.weight": {"bits": 8, "group_size": 32}
851 }
852 }"#;
853
854 let config = QuantizationConfig::from_json(json).expect("parse");
855 assert_eq!(config.bits, 4);
856 assert_eq!(config.group_size, 64);
857
858 let (bits, gs) = config.config_for_tensor("model.layers.0.self_attn.v_proj.weight");
860 assert_eq!(bits, 6);
861 assert_eq!(gs, 128);
862
863 let (bits, gs) = config.config_for_tensor("model.layers.5.mlp.gate_proj.weight");
865 assert_eq!(bits, 4);
866 assert_eq!(gs, 64);
867 }
868
869 #[test]
870 fn test_quant_config_invalid_json() {
871 let result = QuantizationConfig::from_json("not json at all {{{");
872 assert!(result.is_err());
873 match result {
874 Err(MlxError::QuantConfigError(msg)) => {
875 assert!(msg.contains("parse"), "msg: {msg}");
876 }
877 other => panic!("Expected QuantConfigError, got {:?}", other),
878 }
879 }
880
881 #[test]
884 fn test_config_for_tensor_strips_weight_suffix() {
885 let json = r#"{
886 "bits": 4,
887 "group_size": 64,
888 "per_tensor": {
889 "model.layers.0.mlp.down_proj": {"bits": 8, "group_size": 64}
890 }
891 }"#;
892 let config = QuantizationConfig::from_json(json).expect("parse");
893
894 let (bits, gs) = config.config_for_tensor("model.layers.0.mlp.down_proj.weight");
896 assert_eq!(bits, 8);
897 assert_eq!(gs, 64);
898
899 let (bits, _) = config.config_for_tensor("model.layers.0.mlp.down_proj.scales");
901 assert_eq!(bits, 8);
902
903 let (bits, _) = config.config_for_tensor("model.layers.0.mlp.down_proj.biases");
905 assert_eq!(bits, 8);
906 }
907
908 #[test]
909 fn test_config_for_tensor_adds_language_model_prefix() {
910 let json = r#"{
911 "bits": 4,
912 "group_size": 64,
913 "per_tensor": {
914 "language_model.model.layers.0.self_attn.v_proj": {"bits": 6, "group_size": 64}
915 }
916 }"#;
917 let config = QuantizationConfig::from_json(json).expect("parse");
918
919 let (bits, _) = config.config_for_tensor("model.layers.0.self_attn.v_proj.weight");
921 assert_eq!(bits, 6);
922 }
923
924 #[test]
925 fn test_config_for_tensor_strips_language_model_prefix() {
926 let json = r#"{
927 "bits": 4,
928 "group_size": 64,
929 "per_tensor": {
930 "model.layers.0.self_attn.v_proj": {"bits": 6, "group_size": 64}
931 }
932 }"#;
933 let config = QuantizationConfig::from_json(json).expect("parse");
934
935 let (bits, _) = config.config_for_tensor("language_model.model.layers.0.self_attn.v_proj.weight");
937 assert_eq!(bits, 6);
938 }
939
940 #[test]
943 fn test_from_model_config_json_basic() {
944 let json = r#"{
945 "model_type": "gemma4",
946 "quantization": {
947 "bits": 4,
948 "group_size": 64,
949 "language_model.model.layers.0.mlp.down_proj": {"bits": 8, "group_size": 64},
950 "language_model.model.layers.0.self_attn.v_proj": {"bits": 6, "group_size": 64}
951 }
952 }"#;
953
954 let config = QuantizationConfig::from_model_config_json(json).expect("parse");
955 assert_eq!(config.bits, 4);
956 assert_eq!(config.group_size, 64);
957 assert_eq!(config.per_tensor.len(), 2);
958
959 let (bits, _) = config.config_for_tensor("language_model.model.layers.0.mlp.down_proj.weight");
960 assert_eq!(bits, 8);
961
962 let (bits, _) = config.config_for_tensor("language_model.model.layers.0.self_attn.v_proj.weight");
963 assert_eq!(bits, 6);
964
965 let (bits, _) = config.config_for_tensor("language_model.model.layers.5.mlp.gate_proj.weight");
967 assert_eq!(bits, 4);
968 }
969
970 #[test]
971 fn test_from_model_config_json_no_quantization_key() {
972 let json = r#"{"model_type": "gemma4"}"#;
973 let result = QuantizationConfig::from_model_config_json(json);
974 assert!(result.is_err());
975 }
976
977 #[test]
980 fn test_safetensors_dtype_conversion() {
981 assert_eq!(safetensors_dtype_to_dtype(StDtype::F32).unwrap(), DType::F32);
982 assert_eq!(safetensors_dtype_to_dtype(StDtype::F16).unwrap(), DType::F16);
983 assert_eq!(safetensors_dtype_to_dtype(StDtype::BF16).unwrap(), DType::BF16);
984 assert_eq!(safetensors_dtype_to_dtype(StDtype::U8).unwrap(), DType::U8);
985 assert_eq!(safetensors_dtype_to_dtype(StDtype::U16).unwrap(), DType::U16);
986 assert_eq!(safetensors_dtype_to_dtype(StDtype::U32).unwrap(), DType::U32);
987 assert_eq!(safetensors_dtype_to_dtype(StDtype::I32).unwrap(), DType::I32);
988 }
989
990 #[test]
991 fn test_safetensors_dtype_unsupported() {
992 let result = safetensors_dtype_to_dtype(StDtype::BOOL);
993 assert!(result.is_err());
994 match result {
995 Err(MlxError::UnsupportedDtype(_)) => {}
996 other => panic!("Expected UnsupportedDtype, got {:?}", other),
997 }
998 }
999
1000 #[test]
1003 fn test_safetensors_to_metal_buffer_roundtrip() {
1004 let device = MlxDevice::new().expect("device");
1005
1006 let values: [f32; 4] = [1.0, 2.5, -3.0, 4.125];
1008 let bytes: &[u8] = bytemuck::cast_slice(&values);
1009
1010 let buf = safetensors_to_metal_buffer(&device, bytes, DType::F32, vec![4])
1011 .expect("to_metal_buffer");
1012
1013 assert_eq!(buf.byte_len(), 16);
1014 assert_eq!(buf.dtype(), DType::F32);
1015 assert_eq!(buf.shape(), &[4]);
1016
1017 let read_back: &[f32] = buf.as_slice().expect("as_slice");
1019 assert_eq!(read_back.len(), 4);
1020 assert_eq!(read_back[0], 1.0);
1021 assert_eq!(read_back[1], 2.5);
1022 assert_eq!(read_back[2], -3.0);
1023 assert_eq!(read_back[3], 4.125);
1024 }
1025
1026 #[test]
1027 fn test_safetensors_to_metal_buffer_empty_error() {
1028 let device = MlxDevice::new().expect("device");
1029 let result = safetensors_to_metal_buffer(&device, &[], DType::F32, vec![0]);
1030 assert!(result.is_err());
1031 match result {
1032 Err(MlxError::InvalidArgument(msg)) => {
1033 assert!(msg.contains("empty"), "msg: {msg}");
1034 }
1035 other => panic!("Expected InvalidArgument, got {:?}", other),
1036 }
1037 }
1038
1039 #[test]
1040 fn test_safetensors_to_metal_buffer_u8_data() {
1041 let device = MlxDevice::new().expect("device");
1042 let data: Vec<u8> = (0..128).collect();
1043
1044 let buf = safetensors_to_metal_buffer(&device, &data, DType::U8, vec![128])
1045 .expect("to_metal_buffer");
1046
1047 assert_eq!(buf.byte_len(), 128);
1048 let read_back: &[u8] = buf.as_slice().expect("as_slice");
1049 for (i, &val) in read_back.iter().enumerate() {
1050 assert_eq!(val, i as u8, "mismatch at index {i}");
1051 }
1052 }
1053
1054 fn create_test_safetensors(dir: &Path) -> std::path::PathBuf {
1058 let path = dir.join("test_model.safetensors");
1059
1060 let tensor_a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
1062 let tensor_a_bytes: &[u8] = bytemuck::cast_slice(&tensor_a_data);
1063 let tensor_b_data: Vec<f32> = vec![10.0, 20.0, 30.0];
1064 let tensor_b_bytes: &[u8] = bytemuck::cast_slice(&tensor_b_data);
1065
1066 let tensors = vec![
1067 (
1068 "layer.weight",
1069 TensorView::new(StDtype::F32, vec![2, 3], tensor_a_bytes).unwrap(),
1070 ),
1071 (
1072 "layer.bias",
1073 TensorView::new(StDtype::F32, vec![3], tensor_b_bytes).unwrap(),
1074 ),
1075 ];
1076
1077 let serialized = safetensors::tensor::serialize(tensors, None).unwrap();
1078 fs::write(&path, &serialized).unwrap();
1079
1080 path
1081 }
1082
1083 #[test]
1084 fn test_safetensors_file_open_and_list() {
1085 let tmp = tempdir();
1086 let st_path = create_test_safetensors(&tmp);
1087
1088 let sf = SafetensorsFile::open(&st_path).expect("open");
1089 let names = sf.tensor_names().expect("names");
1090
1091 assert_eq!(names.len(), 2);
1092 assert!(names.contains(&"layer.weight".to_string()));
1093 assert!(names.contains(&"layer.bias".to_string()));
1094 }
1095
1096 #[test]
1097 fn test_safetensors_file_load_tensor() {
1098 let device = MlxDevice::new().expect("device");
1099 let tmp = tempdir();
1100 let st_path = create_test_safetensors(&tmp);
1101
1102 let sf = SafetensorsFile::open(&st_path).expect("open");
1103 let (dtype, shape, buf) = sf.load_tensor("layer.weight", &device).expect("load");
1104
1105 assert_eq!(dtype, DType::F32);
1106 assert_eq!(shape, vec![2, 3]);
1107 assert_eq!(buf.byte_len(), 24); let data: &[f32] = buf.as_slice().expect("as_slice");
1110 assert_eq!(data, &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
1111 }
1112
1113 #[test]
1114 fn test_safetensors_file_load_all() {
1115 let device = MlxDevice::new().expect("device");
1116 let tmp = tempdir();
1117 let st_path = create_test_safetensors(&tmp);
1118
1119 let sf = SafetensorsFile::open(&st_path).expect("open");
1120 let all = sf.load_all_tensors(&device).expect("load_all");
1121
1122 assert_eq!(all.len(), 2);
1123
1124 let (dtype, shape, buf) = all.get("layer.bias").expect("bias");
1125 assert_eq!(*dtype, DType::F32);
1126 assert_eq!(*shape, vec![3]);
1127 let data: &[f32] = buf.as_slice().expect("as_slice");
1128 assert_eq!(data, &[10.0, 20.0, 30.0]);
1129 }
1130
1131 #[test]
1132 fn test_safetensors_file_tensor_not_found() {
1133 let tmp = tempdir();
1134 let st_path = create_test_safetensors(&tmp);
1135 let device = MlxDevice::new().expect("device");
1136
1137 let sf = SafetensorsFile::open(&st_path).expect("open");
1138 let result = sf.load_tensor("nonexistent", &device);
1139 assert!(result.is_err());
1140 match result {
1141 Err(MlxError::SafetensorsError(msg)) => {
1142 assert!(msg.contains("nonexistent"), "msg: {msg}");
1143 }
1144 other => panic!("Expected SafetensorsError, got {:?}", other),
1145 }
1146 }
1147
1148 #[test]
1149 fn test_safetensors_file_open_missing() {
1150 let result = SafetensorsFile::open(Path::new("/tmp/does_not_exist_8f3a2b1c.safetensors"));
1151 assert!(result.is_err());
1152 match result {
1153 Err(MlxError::IoError(_)) => {}
1154 other => panic!("Expected IoError, got {:?}", other),
1155 }
1156 }
1157
1158 fn create_test_quant_dir(dir: &Path) {
1162 let config_json = r#"{
1164 "bits": 4,
1165 "group_size": 64,
1166 "per_tensor": {
1167 "proj.weight": {"bits": 4, "group_size": 64}
1168 }
1169 }"#;
1170 fs::write(dir.join("quantization_config.json"), config_json).unwrap();
1171
1172 let weight_data: Vec<u32> = vec![0xAAAA_BBBB; 8]; let weight_bytes: &[u8] = bytemuck::cast_slice(&weight_data);
1179
1180 let scales_data: Vec<u16> = vec![0x3C00, 0x3C00]; let scales_bytes: &[u8] = bytemuck::cast_slice(&scales_data);
1183
1184 let biases_data: Vec<u16> = vec![0x0000, 0x0000]; let biases_bytes: &[u8] = bytemuck::cast_slice(&biases_data);
1187
1188 let tensors = vec![
1189 (
1190 "proj.weight",
1191 TensorView::new(StDtype::U32, vec![2, 4], weight_bytes).unwrap(),
1192 ),
1193 (
1194 "proj.scales",
1195 TensorView::new(StDtype::F16, vec![2, 1], scales_bytes).unwrap(),
1196 ),
1197 (
1198 "proj.biases",
1199 TensorView::new(StDtype::F16, vec![2, 1], biases_bytes).unwrap(),
1200 ),
1201 ];
1202
1203 let serialized = safetensors::tensor::serialize(tensors, None).unwrap();
1204 fs::write(dir.join("model.safetensors"), &serialized).unwrap();
1205 }
1206
1207 #[test]
1208 fn test_load_quantized_weights_integration() {
1209 let device = MlxDevice::new().expect("device");
1210 let tmp = tempdir();
1211 create_test_quant_dir(&tmp);
1212
1213 let weights = load_quantized_weights(&tmp, &device).expect("load");
1214
1215 assert_eq!(weights.len(), 1);
1216 let qw = &weights[0];
1217 assert_eq!(qw.tensor_name(), "proj.weight");
1218 assert_eq!(qw.bits(), 4);
1219 assert_eq!(qw.group_size(), 64);
1220 assert_eq!(qw.packed_data().byte_len(), 32); assert_eq!(qw.scales().byte_len(), 4); assert!(qw.biases().is_some());
1223 }
1224
1225 #[test]
1226 fn test_load_quantized_weights_no_safetensors() {
1227 let tmp = tempdir();
1228
1229 fs::write(tmp.join("quantization_config.json"), "{}").unwrap();
1231
1232 let device = MlxDevice::new().expect("device");
1233 let result = load_quantized_weights(&tmp, &device);
1234 assert!(result.is_err());
1235 match result {
1236 Err(MlxError::IoError(msg)) => {
1237 assert!(msg.contains("No .safetensors files"), "msg: {msg}");
1238 }
1239 other => panic!("Expected IoError, got {:?}", other),
1240 }
1241 }
1242
1243 #[test]
1244 fn test_load_quantized_weights_missing_config() {
1245 let tmp = tempdir();
1246 let data: Vec<u8> = vec![0; 16];
1248 let tensors = vec![(
1249 "dummy",
1250 TensorView::new(StDtype::U8, vec![16], &data).unwrap(),
1251 )];
1252 let serialized = safetensors::tensor::serialize(tensors, None).unwrap();
1253 fs::write(tmp.join("model.safetensors"), &serialized).unwrap();
1254
1255 let device = MlxDevice::new().expect("device");
1256 let result = load_quantized_weights(&tmp, &device);
1257 assert!(result.is_err());
1258 match result {
1259 Err(MlxError::IoError(msg)) => {
1260 assert!(msg.contains("quantization_config"), "msg: {msg}");
1261 }
1262 other => panic!("Expected IoError for missing config, got {:?}", other),
1263 }
1264 }
1265
1266 fn tempdir() -> std::path::PathBuf {
1269 let mut path = std::env::temp_dir();
1270 path.push(format!("mlx_native_test_{}", std::process::id()));
1271 path.push(format!("{}", std::time::SystemTime::now()
1272 .duration_since(std::time::UNIX_EPOCH)
1273 .unwrap_or_default()
1274 .as_nanos()));
1275 fs::create_dir_all(&path).expect("create temp dir");
1276 path
1277 }
1278}