1use std::collections::HashMap;
20use std::io::{BufReader, Read, Seek, SeekFrom};
21use std::path::Path;
22use std::sync::Mutex;
23
24use half::f16;
25
26use crate::ops::quantized_matmul_ggml::GgmlType;
27use crate::{DType, MlxBuffer, MlxBufferPool, MlxDevice, MlxError, Result};
28
29const GGUF_MAGIC: u32 = 0x4655_4747;
35
36const GGUF_VERSION: u32 = 3;
38
39const GGUF_DEFAULT_ALIGNMENT: u64 = 32;
41
42const GGUF_ALIGNMENT_KEY: &str = "general.alignment";
44
45const GGUF_TYPE_UINT8: u32 = 0;
50const GGUF_TYPE_INT8: u32 = 1;
51const GGUF_TYPE_UINT16: u32 = 2;
52const GGUF_TYPE_INT16: u32 = 3;
53const GGUF_TYPE_UINT32: u32 = 4;
54const GGUF_TYPE_INT32: u32 = 5;
55const GGUF_TYPE_FLOAT32: u32 = 6;
56const GGUF_TYPE_BOOL: u32 = 7;
57const GGUF_TYPE_STRING: u32 = 8;
58const GGUF_TYPE_ARRAY: u32 = 9;
59const GGUF_TYPE_UINT64: u32 = 10;
60const GGUF_TYPE_INT64: u32 = 11;
61const GGUF_TYPE_FLOAT64: u32 = 12;
62
63const GGML_TYPE_F32: u32 = 0;
68const GGML_TYPE_F16: u32 = 1;
69const GGML_TYPE_Q4_0: u32 = 2;
70const GGML_TYPE_Q5_1: u32 = 7;
71const GGML_TYPE_Q8_0: u32 = 8;
72const GGML_TYPE_Q4_K: u32 = 12;
73const GGML_TYPE_Q5_K: u32 = 13;
74const GGML_TYPE_Q6_K: u32 = 14;
75const GGML_TYPE_I16: u32 = 17;
76const GGML_TYPE_IQ4_NL: u32 = 20;
77
78const KVALUES_IQ4_NL: [i8; 16] = [
82 -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
83];
84
85#[derive(Debug, Clone)]
91pub enum MetadataValue {
92 Uint8(u8),
93 Int8(i8),
94 Uint16(u16),
95 Int16(i16),
96 Uint32(u32),
97 Int32(i32),
98 Float32(f32),
99 Bool(bool),
100 String(String),
101 Array(Vec<MetadataValue>),
102 Uint64(u64),
103 Int64(i64),
104 Float64(f64),
105}
106
107impl MetadataValue {
108 pub fn as_str(&self) -> Option<&str> {
110 match self {
111 MetadataValue::String(s) => Some(s.as_str()),
112 _ => None,
113 }
114 }
115
116 pub fn as_u32(&self) -> Option<u32> {
118 match self {
119 MetadataValue::Uint32(v) => Some(*v),
120 MetadataValue::Uint8(v) => Some(*v as u32),
121 MetadataValue::Uint16(v) => Some(*v as u32),
122 MetadataValue::Int32(v) if *v >= 0 => Some(*v as u32),
123 _ => None,
124 }
125 }
126
127 pub fn as_f32(&self) -> Option<f32> {
129 match self {
130 MetadataValue::Float32(v) => Some(*v),
131 MetadataValue::Float64(v) => Some(*v as f32),
132 _ => None,
133 }
134 }
135}
136
137#[derive(Debug, Clone)]
139pub struct TensorInfo {
140 pub name: String,
142 pub shape: Vec<usize>,
144 pub ggml_type: GgmlType,
146 pub offset: u64,
148 pub byte_len: usize,
150}
151
152pub struct GgufFile {
158 metadata: HashMap<String, MetadataValue>,
159 tensors: HashMap<String, TensorInfo>,
160 tensor_data_offset: u64,
162 reader: Mutex<BufReader<std::fs::File>>,
163}
164
165fn read_u8<R: Read>(r: &mut R) -> Result<u8> {
171 let mut buf = [0u8; 1];
172 r.read_exact(&mut buf)
173 .map_err(|e| MlxError::GgufParseError(format!("read u8: {e}")))?;
174 Ok(buf[0])
175}
176
177fn read_i8<R: Read>(r: &mut R) -> Result<i8> {
179 Ok(read_u8(r)? as i8)
180}
181
182fn read_u16<R: Read>(r: &mut R) -> Result<u16> {
184 let mut buf = [0u8; 2];
185 r.read_exact(&mut buf)
186 .map_err(|e| MlxError::GgufParseError(format!("read u16: {e}")))?;
187 Ok(u16::from_le_bytes(buf))
188}
189
190fn read_i16<R: Read>(r: &mut R) -> Result<i16> {
192 let mut buf = [0u8; 2];
193 r.read_exact(&mut buf)
194 .map_err(|e| MlxError::GgufParseError(format!("read i16: {e}")))?;
195 Ok(i16::from_le_bytes(buf))
196}
197
198fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
200 let mut buf = [0u8; 4];
201 r.read_exact(&mut buf)
202 .map_err(|e| MlxError::GgufParseError(format!("read u32: {e}")))?;
203 Ok(u32::from_le_bytes(buf))
204}
205
206fn read_i32<R: Read>(r: &mut R) -> Result<i32> {
208 let mut buf = [0u8; 4];
209 r.read_exact(&mut buf)
210 .map_err(|e| MlxError::GgufParseError(format!("read i32: {e}")))?;
211 Ok(i32::from_le_bytes(buf))
212}
213
214fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
216 let mut buf = [0u8; 8];
217 r.read_exact(&mut buf)
218 .map_err(|e| MlxError::GgufParseError(format!("read u64: {e}")))?;
219 Ok(u64::from_le_bytes(buf))
220}
221
222fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
224 let mut buf = [0u8; 8];
225 r.read_exact(&mut buf)
226 .map_err(|e| MlxError::GgufParseError(format!("read i64: {e}")))?;
227 Ok(i64::from_le_bytes(buf))
228}
229
230fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
232 let mut buf = [0u8; 4];
233 r.read_exact(&mut buf)
234 .map_err(|e| MlxError::GgufParseError(format!("read f32: {e}")))?;
235 Ok(f32::from_le_bytes(buf))
236}
237
238fn read_f64<R: Read>(r: &mut R) -> Result<f64> {
240 let mut buf = [0u8; 8];
241 r.read_exact(&mut buf)
242 .map_err(|e| MlxError::GgufParseError(format!("read f64: {e}")))?;
243 Ok(f64::from_le_bytes(buf))
244}
245
246fn read_gguf_string<R: Read>(r: &mut R) -> Result<String> {
249 let len = read_u64(r)? as usize;
250 if len > 256 * 1024 * 1024 {
251 return Err(MlxError::GgufParseError(format!(
252 "string length {len} exceeds 256 MiB safety limit"
253 )));
254 }
255 let mut buf = vec![0u8; len];
256 r.read_exact(&mut buf)
257 .map_err(|e| MlxError::GgufParseError(format!("read string bytes: {e}")))?;
258 String::from_utf8(buf)
259 .map_err(|e| MlxError::GgufParseError(format!("invalid UTF-8 in string: {e}")))
260}
261
262fn read_metadata_value<R: Read>(r: &mut R, value_type: u32) -> Result<MetadataValue> {
268 match value_type {
269 GGUF_TYPE_UINT8 => Ok(MetadataValue::Uint8(read_u8(r)?)),
270 GGUF_TYPE_INT8 => Ok(MetadataValue::Int8(read_i8(r)?)),
271 GGUF_TYPE_UINT16 => Ok(MetadataValue::Uint16(read_u16(r)?)),
272 GGUF_TYPE_INT16 => Ok(MetadataValue::Int16(read_i16(r)?)),
273 GGUF_TYPE_UINT32 => Ok(MetadataValue::Uint32(read_u32(r)?)),
274 GGUF_TYPE_INT32 => Ok(MetadataValue::Int32(read_i32(r)?)),
275 GGUF_TYPE_FLOAT32 => Ok(MetadataValue::Float32(read_f32(r)?)),
276 GGUF_TYPE_BOOL => {
277 let byte = read_u8(r)?;
278 Ok(MetadataValue::Bool(byte != 0))
279 }
280 GGUF_TYPE_STRING => Ok(MetadataValue::String(read_gguf_string(r)?)),
281 GGUF_TYPE_ARRAY => {
282 let elem_type = read_u32(r)?;
283 let count = read_u64(r)? as usize;
284 if count > 64 * 1024 * 1024 {
285 return Err(MlxError::GgufParseError(format!(
286 "array count {count} exceeds 64M element safety limit"
287 )));
288 }
289 let mut elems = Vec::with_capacity(count);
290 for _ in 0..count {
291 elems.push(read_metadata_value(r, elem_type)?);
292 }
293 Ok(MetadataValue::Array(elems))
294 }
295 GGUF_TYPE_UINT64 => Ok(MetadataValue::Uint64(read_u64(r)?)),
296 GGUF_TYPE_INT64 => Ok(MetadataValue::Int64(read_i64(r)?)),
297 GGUF_TYPE_FLOAT64 => Ok(MetadataValue::Float64(read_f64(r)?)),
298 other => Err(MlxError::GgufParseError(format!(
299 "unknown metadata value type {other}"
300 ))),
301 }
302}
303
304fn ggml_type_from_u32(id: u32) -> Result<GgmlType> {
310 match id {
311 GGML_TYPE_F32 => Ok(GgmlType::F32),
312 GGML_TYPE_F16 => Ok(GgmlType::F16),
313 GGML_TYPE_Q4_0 => Ok(GgmlType::Q4_0),
314 GGML_TYPE_Q5_1 => Ok(GgmlType::Q5_1),
315 GGML_TYPE_Q8_0 => Ok(GgmlType::Q8_0),
316 GGML_TYPE_Q4_K => Ok(GgmlType::Q4_K),
317 GGML_TYPE_Q5_K => Ok(GgmlType::Q5_K),
318 GGML_TYPE_Q6_K => Ok(GgmlType::Q6_K),
319 GGML_TYPE_I16 => Ok(GgmlType::I16),
320 GGML_TYPE_IQ4_NL => Ok(GgmlType::IQ4_NL),
321 other => Err(MlxError::GgufParseError(format!(
322 "unsupported GGML type ID {other}"
323 ))),
324 }
325}
326
327fn compute_byte_len(shape: &[usize], ggml_type: GgmlType) -> Result<usize> {
332 let total_elements: usize = shape.iter().product();
333 if total_elements == 0 {
334 return Ok(0);
335 }
336
337 let elems_per_block = ggml_type.block_values() as usize;
338 let bytes_per_block = ggml_type.block_bytes() as usize;
339
340 if total_elements % elems_per_block != 0 {
341 return Err(MlxError::GgufParseError(format!(
342 "total elements {total_elements} not divisible by block size {elems_per_block} \
343 for type {:?}",
344 ggml_type
345 )));
346 }
347
348 Ok((total_elements / elems_per_block) * bytes_per_block)
349}
350
351#[inline]
357fn f16_from_le_bytes(bytes: [u8; 2]) -> f32 {
358 f16::from_le_bytes(bytes).to_f32()
359}
360
361fn dequantize_q4_0(data: &[u8], output: &mut [f32]) -> Result<()> {
367 const BLOCK_BYTES: usize = 18;
368 const BLOCK_ELEMS: usize = 32;
369
370 if data.len() % BLOCK_BYTES != 0 {
371 return Err(MlxError::GgufParseError(format!(
372 "Q4_0 data length {} not divisible by block size {BLOCK_BYTES}",
373 data.len()
374 )));
375 }
376
377 let num_blocks = data.len() / BLOCK_BYTES;
378 if output.len() < num_blocks * BLOCK_ELEMS {
379 return Err(MlxError::GgufParseError(
380 "Q4_0 output buffer too small".into(),
381 ));
382 }
383
384 for i in 0..num_blocks {
385 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
386 let d = f16_from_le_bytes([block[0], block[1]]);
387 let qs = &block[2..18]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
390
391 for j in 0..16 {
392 let x0 = (qs[j] & 0x0F) as i16 - 8;
393 let x1 = (qs[j] >> 4) as i16 - 8;
394 out[j] = x0 as f32 * d;
395 out[j + 16] = x1 as f32 * d;
396 }
397 }
398 Ok(())
399}
400
401fn dequantize_q8_0(data: &[u8], output: &mut [f32]) -> Result<()> {
407 const BLOCK_BYTES: usize = 34;
408 const BLOCK_ELEMS: usize = 32;
409
410 if data.len() % BLOCK_BYTES != 0 {
411 return Err(MlxError::GgufParseError(format!(
412 "Q8_0 data length {} not divisible by block size {BLOCK_BYTES}",
413 data.len()
414 )));
415 }
416
417 let num_blocks = data.len() / BLOCK_BYTES;
418 if output.len() < num_blocks * BLOCK_ELEMS {
419 return Err(MlxError::GgufParseError(
420 "Q8_0 output buffer too small".into(),
421 ));
422 }
423
424 for i in 0..num_blocks {
425 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
426 let d = f16_from_le_bytes([block[0], block[1]]);
427 let qs = &block[2..34]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
430
431 for j in 0..32 {
432 out[j] = (qs[j] as i8) as f32 * d;
433 }
434 }
435 Ok(())
436}
437
438#[inline]
451fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
452 if j < 4 {
453 let sc = scales[j] & 63;
454 let m = scales[j + 4] & 63;
455 (sc, m)
456 } else {
457 let sc = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4);
458 let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
459 (sc, m)
460 }
461}
462
463fn dequantize_q5_k(data: &[u8], output: &mut [f32]) -> Result<()> {
481 const BLOCK_BYTES: usize = 176;
482 const BLOCK_ELEMS: usize = 256;
483
484 if data.len() % BLOCK_BYTES != 0 {
485 return Err(MlxError::GgufParseError(format!(
486 "Q5_K data length {} not divisible by block size {BLOCK_BYTES}",
487 data.len()
488 )));
489 }
490
491 let num_blocks = data.len() / BLOCK_BYTES;
492 if output.len() < num_blocks * BLOCK_ELEMS {
493 return Err(MlxError::GgufParseError(
494 "Q5_K output buffer too small".into(),
495 ));
496 }
497
498 for i in 0..num_blocks {
499 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
500
501 let d = f16_from_le_bytes([block[0], block[1]]);
502 let dmin = f16_from_le_bytes([block[2], block[3]]);
503 let scales = &block[4..16]; let qh = &block[16..48]; let qs = &block[48..176]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
508
509 let mut is = 0usize;
513 let mut u1: u8 = 1;
514 let mut u2: u8 = 2;
515 let mut ys_index = 0usize;
516 let mut ql_off = 0usize;
517
518 while ql_off < 128 {
519 let ql = &qs[ql_off..ql_off + 32];
520
521 let (sc1, m1) = get_scale_min_k4(is, scales);
522 let d1 = d * sc1 as f32;
523 let m1 = dmin * m1 as f32;
524 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
525 let d2 = d * sc2 as f32;
526 let m2 = dmin * m2 as f32;
527
528 for l in 0..32 {
530 let low = (ql[l] & 0x0F) as u32;
531 let high = if (qh[l] & u1) != 0 { 16 } else { 0 };
532 let q = low + high;
533 out[ys_index] = d1 * q as f32 - m1;
534 ys_index += 1;
535 }
536 for l in 0..32 {
538 let low = (ql[l] >> 4) as u32;
539 let high = if (qh[l] & u2) != 0 { 16 } else { 0 };
540 let q = low + high;
541 out[ys_index] = d2 * q as f32 - m2;
542 ys_index += 1;
543 }
544
545 is += 2;
546 ql_off += 32;
547 u1 <<= 2;
548 u2 <<= 2;
549 }
550 }
551 Ok(())
552}
553
554fn dequantize_i16(data: &[u8], output: &mut [f32]) -> Result<()> {
563 if data.len() % 2 != 0 {
564 return Err(MlxError::GgufParseError(format!(
565 "I16 data length {} not even",
566 data.len()
567 )));
568 }
569 let num_elements = data.len() / 2;
570 if output.len() < num_elements {
571 return Err(MlxError::GgufParseError(
572 "I16 output buffer too small".into(),
573 ));
574 }
575 for i in 0..num_elements {
576 let v = i16::from_le_bytes([data[2 * i], data[2 * i + 1]]);
577 output[i] = v as f32;
578 }
579 Ok(())
580}
581
582fn dequantize_q4_k(data: &[u8], output: &mut [f32]) -> Result<()> {
594 const BLOCK_BYTES: usize = 144;
595 const BLOCK_ELEMS: usize = 256;
596
597 if data.len() % BLOCK_BYTES != 0 {
598 return Err(MlxError::GgufParseError(format!(
599 "Q4_K data length {} not divisible by block size {BLOCK_BYTES}",
600 data.len()
601 )));
602 }
603
604 let num_blocks = data.len() / BLOCK_BYTES;
605 if output.len() < num_blocks * BLOCK_ELEMS {
606 return Err(MlxError::GgufParseError(
607 "Q4_K output buffer too small".into(),
608 ));
609 }
610
611 for i in 0..num_blocks {
612 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
613
614 let d = f16_from_le_bytes([block[0], block[1]]);
615 let dmin = f16_from_le_bytes([block[2], block[3]]);
616 let scales = &block[4..16]; let qs = &block[16..144]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
620
621 let mut is = 0usize;
625 let mut ys_index = 0usize;
626
627 let mut j = 0usize;
630 while j < 128 {
631 let q = &qs[j..j + 32];
632 let (sc1, m1) = get_scale_min_k4(is, scales);
633 let d1 = d * sc1 as f32;
634 let min1 = dmin * m1 as f32;
635 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
636 let d2 = d * sc2 as f32;
637 let min2 = dmin * m2 as f32;
638
639 for byte in q.iter() {
641 out[ys_index] = d1 * (*byte & 0xF) as f32 - min1;
642 ys_index += 1;
643 }
644 for byte in q.iter() {
646 out[ys_index] = d2 * (*byte >> 4) as f32 - min2;
647 ys_index += 1;
648 }
649
650 is += 2;
651 j += 32;
652 }
653 }
654 Ok(())
655}
656
657fn dequantize_q6_k(data: &[u8], output: &mut [f32]) -> Result<()> {
668 const BLOCK_BYTES: usize = 210;
669 const BLOCK_ELEMS: usize = 256;
670
671 if data.len() % BLOCK_BYTES != 0 {
672 return Err(MlxError::GgufParseError(format!(
673 "Q6_K data length {} not divisible by block size {BLOCK_BYTES}",
674 data.len()
675 )));
676 }
677
678 let num_blocks = data.len() / BLOCK_BYTES;
679 if output.len() < num_blocks * BLOCK_ELEMS {
680 return Err(MlxError::GgufParseError(
681 "Q6_K output buffer too small".into(),
682 ));
683 }
684
685 for i in 0..num_blocks {
686 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
687
688 let ql = &block[0..128];
689 let qh = &block[128..192];
690 let sc = &block[192..208]; let d = f16_from_le_bytes([block[208], block[209]]);
692
693 let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
694
695 for idx in 0..2 {
697 let ql_base = &ql[64 * idx..];
698 let qh_base = &qh[32 * idx..];
699 let sc_base = &sc[8 * idx..];
700 let out_base = &mut out[128 * idx..];
701
702 for l in 0..32 {
703 let is = l / 16; let q1 = ((ql_base[l] & 0xF) | ((qh_base[l] & 3) << 4)) as i8 - 32_i8;
706 let q2 = ((ql_base[l + 32] & 0xF) | (((qh_base[l] >> 2) & 3) << 4)) as i8
707 - 32_i8;
708 let q3 = ((ql_base[l] >> 4) | (((qh_base[l] >> 4) & 3) << 4)) as i8 - 32_i8;
709 let q4 = ((ql_base[l + 32] >> 4) | (((qh_base[l] >> 6) & 3) << 4)) as i8
710 - 32_i8;
711
712 out_base[l] = d * sc_base[is] as i8 as f32 * q1 as f32;
713 out_base[l + 32] = d * sc_base[is + 2] as i8 as f32 * q2 as f32;
714 out_base[l + 64] = d * sc_base[is + 4] as i8 as f32 * q3 as f32;
715 out_base[l + 96] = d * sc_base[is + 6] as i8 as f32 * q4 as f32;
716 }
717 }
718 }
719 Ok(())
720}
721
722fn dequantize_f16(data: &[u8], output: &mut [f32]) -> Result<()> {
724 if data.len() % 2 != 0 {
725 return Err(MlxError::GgufParseError(
726 "F16 data length not even".into(),
727 ));
728 }
729 let count = data.len() / 2;
730 if output.len() < count {
731 return Err(MlxError::GgufParseError(
732 "F16 output buffer too small".into(),
733 ));
734 }
735 for i in 0..count {
736 output[i] = f16_from_le_bytes([data[2 * i], data[2 * i + 1]]);
737 }
738 Ok(())
739}
740
741fn copy_f32(data: &[u8], output: &mut [f32]) -> Result<()> {
743 if data.len() % 4 != 0 {
744 return Err(MlxError::GgufParseError(
745 "F32 data length not multiple of 4".into(),
746 ));
747 }
748 let count = data.len() / 4;
749 if output.len() < count {
750 return Err(MlxError::GgufParseError(
751 "F32 output buffer too small".into(),
752 ));
753 }
754 for i in 0..count {
755 output[i] = f32::from_le_bytes([
756 data[4 * i],
757 data[4 * i + 1],
758 data[4 * i + 2],
759 data[4 * i + 3],
760 ]);
761 }
762 Ok(())
763}
764
765fn dequantize_q5_1(data: &[u8], output: &mut [f32]) -> Result<()> {
780 const BLOCK_BYTES: usize = 24;
781 const BLOCK_ELEMS: usize = 32;
782
783 if data.len() % BLOCK_BYTES != 0 {
784 return Err(MlxError::GgufParseError(format!(
785 "Q5_1 data length {} not divisible by block size {BLOCK_BYTES}",
786 data.len()
787 )));
788 }
789
790 let num_blocks = data.len() / BLOCK_BYTES;
791 if output.len() < num_blocks * BLOCK_ELEMS {
792 return Err(MlxError::GgufParseError(
793 "Q5_1 output buffer too small".into(),
794 ));
795 }
796
797 for i in 0..num_blocks {
798 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
799
800 let d = f16_from_le_bytes([block[0], block[1]]);
801 let m = f16_from_le_bytes([block[2], block[3]]);
802 let qh = u32::from_le_bytes([block[4], block[5], block[6], block[7]]);
803 let qs = &block[8..24]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
806
807 for j in 0..(BLOCK_ELEMS / 2) {
808 let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
812 let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
813 let x0 = ((qs[j] & 0x0F) | xh_0) as i32;
814 let x1 = ((qs[j] >> 4) | xh_1) as i32;
815 out[j] = (x0 as f32) * d + m;
816 out[j + BLOCK_ELEMS / 2] = (x1 as f32) * d + m;
817 }
818 }
819 Ok(())
820}
821
822fn dequantize_iq4_nl(data: &[u8], output: &mut [f32]) -> Result<()> {
834 const BLOCK_BYTES: usize = 18;
835 const BLOCK_ELEMS: usize = 32;
836
837 if data.len() % BLOCK_BYTES != 0 {
838 return Err(MlxError::GgufParseError(format!(
839 "IQ4_NL data length {} not divisible by block size {BLOCK_BYTES}",
840 data.len()
841 )));
842 }
843
844 let num_blocks = data.len() / BLOCK_BYTES;
845 if output.len() < num_blocks * BLOCK_ELEMS {
846 return Err(MlxError::GgufParseError(
847 "IQ4_NL output buffer too small".into(),
848 ));
849 }
850
851 for i in 0..num_blocks {
852 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
853
854 let d = f16_from_le_bytes([block[0], block[1]]);
855 let qs = &block[2..18];
856
857 let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
858
859 for j in 0..(BLOCK_ELEMS / 2) {
860 let lo = (qs[j] & 0x0F) as usize;
861 let hi = (qs[j] >> 4) as usize;
862 out[j] = d * KVALUES_IQ4_NL[lo] as f32;
863 out[j + BLOCK_ELEMS / 2] = d * KVALUES_IQ4_NL[hi] as f32;
864 }
865 }
866 Ok(())
867}
868
869#[doc(hidden)]
874pub fn test_only_dequantize_q5_1(data: &[u8], output: &mut [f32]) -> Result<()> {
875 dequantize_q5_1(data, output)
876}
877
878#[doc(hidden)]
880pub fn test_only_dequantize_iq4_nl(data: &[u8], output: &mut [f32]) -> Result<()> {
881 dequantize_iq4_nl(data, output)
882}
883
884#[doc(hidden)]
887pub fn test_only_kvalues_iq4_nl() -> [i8; 16] {
888 KVALUES_IQ4_NL
889}
890
891#[doc(hidden)]
895pub fn test_only_dequantize(data: &[u8], ggml_type: GgmlType, output: &mut [f32]) -> Result<()> {
896 dequantize_to_f32(data, ggml_type, output)
897}
898
899fn dequantize_to_f32(data: &[u8], ggml_type: GgmlType, output: &mut [f32]) -> Result<()> {
901 match ggml_type {
902 GgmlType::F32 => copy_f32(data, output),
903 GgmlType::F16 => dequantize_f16(data, output),
904 GgmlType::Q4_0 => dequantize_q4_0(data, output),
905 GgmlType::Q8_0 => dequantize_q8_0(data, output),
906 GgmlType::Q4_K => dequantize_q4_k(data, output),
907 GgmlType::Q6_K => dequantize_q6_k(data, output),
908 GgmlType::Q5_K => dequantize_q5_k(data, output),
909 GgmlType::I16 => dequantize_i16(data, output),
910 GgmlType::Q5_1 => dequantize_q5_1(data, output),
911 GgmlType::IQ4_NL => dequantize_iq4_nl(data, output),
912 }
913}
914
915impl GgufFile {
920 pub fn open(path: &Path) -> Result<Self> {
932 let file = std::fs::File::open(path).map_err(|e| {
933 MlxError::IoError(format!("cannot open GGUF file '{}': {e}", path.display()))
934 })?;
935 let mut reader = BufReader::new(file);
936
937 let magic = read_u32(&mut reader)?;
939 if magic != GGUF_MAGIC {
940 return Err(MlxError::GgufParseError(format!(
941 "bad magic: expected 0x{GGUF_MAGIC:08X}, got 0x{magic:08X}"
942 )));
943 }
944
945 let version = read_u32(&mut reader)?;
946 if version != GGUF_VERSION {
947 return Err(MlxError::GgufParseError(format!(
948 "unsupported GGUF version {version} (only v3 is supported)"
949 )));
950 }
951
952 let tensor_count = read_u64(&mut reader)? as usize;
953 let metadata_kv_count = read_u64(&mut reader)? as usize;
954
955 if tensor_count > 100_000 {
957 return Err(MlxError::GgufParseError(format!(
958 "tensor_count {tensor_count} exceeds 100k safety limit"
959 )));
960 }
961 if metadata_kv_count > 1_000_000 {
962 return Err(MlxError::GgufParseError(format!(
963 "metadata_kv_count {metadata_kv_count} exceeds 1M safety limit"
964 )));
965 }
966
967 let mut metadata = HashMap::with_capacity(metadata_kv_count);
969 for _ in 0..metadata_kv_count {
970 let key = read_gguf_string(&mut reader)?;
971 let value_type = read_u32(&mut reader)?;
972 let value = read_metadata_value(&mut reader, value_type)?;
973 metadata.insert(key, value);
974 }
975
976 let alignment = metadata
978 .get(GGUF_ALIGNMENT_KEY)
979 .and_then(|v| v.as_u32())
980 .map(|v| v as u64)
981 .unwrap_or(GGUF_DEFAULT_ALIGNMENT);
982
983 if alignment == 0 || (alignment & (alignment - 1)) != 0 {
984 return Err(MlxError::GgufParseError(format!(
985 "alignment {alignment} is not a power of two"
986 )));
987 }
988
989 let mut tensors = HashMap::with_capacity(tensor_count);
991 for _ in 0..tensor_count {
992 let name = read_gguf_string(&mut reader)?;
993 let n_dims = read_u32(&mut reader)? as usize;
994
995 if n_dims > 8 {
996 return Err(MlxError::GgufParseError(format!(
997 "tensor '{name}' has {n_dims} dimensions (max 8)"
998 )));
999 }
1000
1001 let mut shape = Vec::with_capacity(n_dims);
1002 for _ in 0..n_dims {
1003 shape.push(read_u64(&mut reader)? as usize);
1004 }
1005 shape.reverse();
1009
1010 let ggml_type_id = read_u32(&mut reader)?;
1011 let ggml_type = ggml_type_from_u32(ggml_type_id).map_err(|e| {
1012 MlxError::GgufParseError(format!("tensor '{name}': {e}"))
1013 })?;
1014
1015 let offset = read_u64(&mut reader)?;
1016 let byte_len = compute_byte_len(&shape, ggml_type).map_err(|e| {
1017 MlxError::GgufParseError(format!("tensor '{name}': {e}"))
1018 })?;
1019
1020 tensors.insert(
1021 name.clone(),
1022 TensorInfo {
1023 name,
1024 shape,
1025 ggml_type,
1026 offset,
1027 byte_len,
1028 },
1029 );
1030 }
1031
1032 let pos = reader
1036 .stream_position()
1037 .map_err(|e| MlxError::GgufParseError(format!("stream_position: {e}")))?;
1038 let tensor_data_offset = align_offset(pos, alignment);
1039
1040 Ok(GgufFile {
1041 metadata,
1042 tensors,
1043 tensor_data_offset,
1044 reader: Mutex::new(reader),
1045 })
1046 }
1047
1048 pub fn metadata(&self, key: &str) -> Option<&MetadataValue> {
1054 self.metadata.get(key)
1055 }
1056
1057 pub fn metadata_string(&self, key: &str) -> Option<&str> {
1059 self.metadata.get(key).and_then(|v| v.as_str())
1060 }
1061
1062 pub fn metadata_u32(&self, key: &str) -> Option<u32> {
1064 self.metadata.get(key).and_then(|v| v.as_u32())
1065 }
1066
1067 pub fn metadata_f32(&self, key: &str) -> Option<f32> {
1069 self.metadata.get(key).and_then(|v| v.as_f32())
1070 }
1071
1072 pub fn tensor_names(&self) -> Vec<&str> {
1078 self.tensors.keys().map(|s| s.as_str()).collect()
1079 }
1080
1081 pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
1083 self.tensors.get(name)
1084 }
1085
1086 pub fn tensor_count(&self) -> usize {
1088 self.tensors.len()
1089 }
1090
1091 pub fn metadata_count(&self) -> usize {
1093 self.metadata.len()
1094 }
1095
1096 fn read_tensor_bytes(&self, info: &TensorInfo) -> Result<Vec<u8>> {
1105 let abs_offset = self.tensor_data_offset + info.offset;
1106 let mut reader = self
1107 .reader
1108 .lock()
1109 .map_err(|_| MlxError::GgufParseError("reader mutex poisoned".into()))?;
1110
1111 reader
1112 .seek(SeekFrom::Start(abs_offset))
1113 .map_err(|e| MlxError::IoError(format!("seek to tensor '{}': {e}", info.name)))?;
1114
1115 let mut buf = vec![0u8; info.byte_len];
1116 reader.read_exact(&mut buf).map_err(|e| {
1117 MlxError::IoError(format!(
1118 "read tensor '{}' ({} bytes at offset {}): {e}",
1119 info.name, info.byte_len, abs_offset
1120 ))
1121 })?;
1122
1123 Ok(buf)
1124 }
1125
1126 pub fn load_tensor(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
1138 let info = self.tensors.get(name).ok_or_else(|| {
1139 MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
1140 })?;
1141
1142 let data = self.read_tensor_bytes(info)?;
1143
1144 match info.ggml_type {
1145 GgmlType::F32 => {
1146 let mut buf =
1147 device.alloc_buffer(info.byte_len, DType::F32, info.shape.clone())?;
1148 {
1149 let slice: &mut [u8] = buf.as_mut_slice()?;
1150 slice.copy_from_slice(&data);
1151 }
1152 Ok(buf)
1153 }
1154 GgmlType::F16 => {
1155 let mut buf =
1156 device.alloc_buffer(info.byte_len, DType::F16, info.shape.clone())?;
1157 {
1158 let slice: &mut [u8] = buf.as_mut_slice()?;
1159 slice.copy_from_slice(&data);
1160 }
1161 Ok(buf)
1162 }
1163 GgmlType::Q4_0
1164 | GgmlType::Q8_0
1165 | GgmlType::Q4_K
1166 | GgmlType::Q5_K
1167 | GgmlType::Q6_K
1168 | GgmlType::I16
1169 | GgmlType::Q5_1
1170 | GgmlType::IQ4_NL => {
1171 let mut buf =
1184 device.alloc_buffer(info.byte_len, DType::U8, info.shape.clone())?;
1185 {
1186 let slice: &mut [u8] = buf.as_mut_slice()?;
1187 slice.copy_from_slice(&data);
1188 }
1189 Ok(buf)
1190 }
1191 }
1192 }
1193
1194 pub fn load_tensor_f32(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
1205 let info = self.tensors.get(name).ok_or_else(|| {
1206 MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
1207 })?;
1208
1209 let data = self.read_tensor_bytes(info)?;
1210 let total_elements: usize = info.shape.iter().product();
1211
1212 if total_elements == 0 {
1213 return Err(MlxError::GgufParseError(format!(
1214 "tensor '{name}' has zero elements"
1215 )));
1216 }
1217
1218 let f32_byte_len = total_elements * 4;
1219 let mut buf =
1220 device.alloc_buffer(f32_byte_len, DType::F32, info.shape.clone())?;
1221
1222 {
1223 let out_slice: &mut [f32] = buf.as_mut_slice()?;
1224 dequantize_to_f32(&data, info.ggml_type, out_slice)?;
1225 }
1226
1227 Ok(buf)
1228 }
1229
1230 pub fn load_tensor_into_pool(
1266 &self,
1267 name: &str,
1268 device: &MlxDevice,
1269 pool: &mut MlxBufferPool,
1270 ) -> Result<MlxBuffer> {
1271 let buf = self.load_tensor(name, device)?;
1272 pool.register_existing(device, &buf)?;
1273 Ok(buf)
1274 }
1275}
1276
1277fn align_offset(offset: u64, alignment: u64) -> u64 {
1283 let mask = alignment - 1;
1284 (offset + mask) & !mask
1285}