1use std::collections::HashMap;
20use std::io::{BufReader, Read, Seek, SeekFrom};
21use std::path::Path;
22use std::sync::Mutex;
23
24use half::f16;
25
26use crate::ops::quantized_matmul_ggml::GgmlType;
27use crate::{DType, MlxBuffer, MlxBufferPool, MlxDevice, MlxError, Result};
28
29const GGUF_MAGIC: u32 = 0x4655_4747;
35
36const GGUF_VERSION: u32 = 3;
38
39const GGUF_DEFAULT_ALIGNMENT: u64 = 32;
41
42const GGUF_ALIGNMENT_KEY: &str = "general.alignment";
44
45const GGUF_TYPE_UINT8: u32 = 0;
50const GGUF_TYPE_INT8: u32 = 1;
51const GGUF_TYPE_UINT16: u32 = 2;
52const GGUF_TYPE_INT16: u32 = 3;
53const GGUF_TYPE_UINT32: u32 = 4;
54const GGUF_TYPE_INT32: u32 = 5;
55const GGUF_TYPE_FLOAT32: u32 = 6;
56const GGUF_TYPE_BOOL: u32 = 7;
57const GGUF_TYPE_STRING: u32 = 8;
58const GGUF_TYPE_ARRAY: u32 = 9;
59const GGUF_TYPE_UINT64: u32 = 10;
60const GGUF_TYPE_INT64: u32 = 11;
61const GGUF_TYPE_FLOAT64: u32 = 12;
62
63const GGML_TYPE_F32: u32 = 0;
68const GGML_TYPE_F16: u32 = 1;
69const GGML_TYPE_Q4_0: u32 = 2;
70const GGML_TYPE_Q8_0: u32 = 8;
71const GGML_TYPE_Q4_K: u32 = 12;
72const GGML_TYPE_Q5_K: u32 = 13;
73const GGML_TYPE_Q6_K: u32 = 14;
74const GGML_TYPE_I16: u32 = 17;
75
76#[derive(Debug, Clone)]
82pub enum MetadataValue {
83 Uint8(u8),
84 Int8(i8),
85 Uint16(u16),
86 Int16(i16),
87 Uint32(u32),
88 Int32(i32),
89 Float32(f32),
90 Bool(bool),
91 String(String),
92 Array(Vec<MetadataValue>),
93 Uint64(u64),
94 Int64(i64),
95 Float64(f64),
96}
97
98impl MetadataValue {
99 pub fn as_str(&self) -> Option<&str> {
101 match self {
102 MetadataValue::String(s) => Some(s.as_str()),
103 _ => None,
104 }
105 }
106
107 pub fn as_u32(&self) -> Option<u32> {
109 match self {
110 MetadataValue::Uint32(v) => Some(*v),
111 MetadataValue::Uint8(v) => Some(*v as u32),
112 MetadataValue::Uint16(v) => Some(*v as u32),
113 MetadataValue::Int32(v) if *v >= 0 => Some(*v as u32),
114 _ => None,
115 }
116 }
117
118 pub fn as_f32(&self) -> Option<f32> {
120 match self {
121 MetadataValue::Float32(v) => Some(*v),
122 MetadataValue::Float64(v) => Some(*v as f32),
123 _ => None,
124 }
125 }
126}
127
128#[derive(Debug, Clone)]
130pub struct TensorInfo {
131 pub name: String,
133 pub shape: Vec<usize>,
135 pub ggml_type: GgmlType,
137 pub offset: u64,
139 pub byte_len: usize,
141}
142
143pub struct GgufFile {
149 metadata: HashMap<String, MetadataValue>,
150 tensors: HashMap<String, TensorInfo>,
151 tensor_data_offset: u64,
153 reader: Mutex<BufReader<std::fs::File>>,
154}
155
156fn read_u8<R: Read>(r: &mut R) -> Result<u8> {
162 let mut buf = [0u8; 1];
163 r.read_exact(&mut buf)
164 .map_err(|e| MlxError::GgufParseError(format!("read u8: {e}")))?;
165 Ok(buf[0])
166}
167
168fn read_i8<R: Read>(r: &mut R) -> Result<i8> {
170 Ok(read_u8(r)? as i8)
171}
172
173fn read_u16<R: Read>(r: &mut R) -> Result<u16> {
175 let mut buf = [0u8; 2];
176 r.read_exact(&mut buf)
177 .map_err(|e| MlxError::GgufParseError(format!("read u16: {e}")))?;
178 Ok(u16::from_le_bytes(buf))
179}
180
181fn read_i16<R: Read>(r: &mut R) -> Result<i16> {
183 let mut buf = [0u8; 2];
184 r.read_exact(&mut buf)
185 .map_err(|e| MlxError::GgufParseError(format!("read i16: {e}")))?;
186 Ok(i16::from_le_bytes(buf))
187}
188
189fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
191 let mut buf = [0u8; 4];
192 r.read_exact(&mut buf)
193 .map_err(|e| MlxError::GgufParseError(format!("read u32: {e}")))?;
194 Ok(u32::from_le_bytes(buf))
195}
196
197fn read_i32<R: Read>(r: &mut R) -> Result<i32> {
199 let mut buf = [0u8; 4];
200 r.read_exact(&mut buf)
201 .map_err(|e| MlxError::GgufParseError(format!("read i32: {e}")))?;
202 Ok(i32::from_le_bytes(buf))
203}
204
205fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
207 let mut buf = [0u8; 8];
208 r.read_exact(&mut buf)
209 .map_err(|e| MlxError::GgufParseError(format!("read u64: {e}")))?;
210 Ok(u64::from_le_bytes(buf))
211}
212
213fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
215 let mut buf = [0u8; 8];
216 r.read_exact(&mut buf)
217 .map_err(|e| MlxError::GgufParseError(format!("read i64: {e}")))?;
218 Ok(i64::from_le_bytes(buf))
219}
220
221fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
223 let mut buf = [0u8; 4];
224 r.read_exact(&mut buf)
225 .map_err(|e| MlxError::GgufParseError(format!("read f32: {e}")))?;
226 Ok(f32::from_le_bytes(buf))
227}
228
229fn read_f64<R: Read>(r: &mut R) -> Result<f64> {
231 let mut buf = [0u8; 8];
232 r.read_exact(&mut buf)
233 .map_err(|e| MlxError::GgufParseError(format!("read f64: {e}")))?;
234 Ok(f64::from_le_bytes(buf))
235}
236
237fn read_gguf_string<R: Read>(r: &mut R) -> Result<String> {
240 let len = read_u64(r)? as usize;
241 if len > 256 * 1024 * 1024 {
242 return Err(MlxError::GgufParseError(format!(
243 "string length {len} exceeds 256 MiB safety limit"
244 )));
245 }
246 let mut buf = vec![0u8; len];
247 r.read_exact(&mut buf)
248 .map_err(|e| MlxError::GgufParseError(format!("read string bytes: {e}")))?;
249 String::from_utf8(buf)
250 .map_err(|e| MlxError::GgufParseError(format!("invalid UTF-8 in string: {e}")))
251}
252
253fn read_metadata_value<R: Read>(r: &mut R, value_type: u32) -> Result<MetadataValue> {
259 match value_type {
260 GGUF_TYPE_UINT8 => Ok(MetadataValue::Uint8(read_u8(r)?)),
261 GGUF_TYPE_INT8 => Ok(MetadataValue::Int8(read_i8(r)?)),
262 GGUF_TYPE_UINT16 => Ok(MetadataValue::Uint16(read_u16(r)?)),
263 GGUF_TYPE_INT16 => Ok(MetadataValue::Int16(read_i16(r)?)),
264 GGUF_TYPE_UINT32 => Ok(MetadataValue::Uint32(read_u32(r)?)),
265 GGUF_TYPE_INT32 => Ok(MetadataValue::Int32(read_i32(r)?)),
266 GGUF_TYPE_FLOAT32 => Ok(MetadataValue::Float32(read_f32(r)?)),
267 GGUF_TYPE_BOOL => {
268 let byte = read_u8(r)?;
269 Ok(MetadataValue::Bool(byte != 0))
270 }
271 GGUF_TYPE_STRING => Ok(MetadataValue::String(read_gguf_string(r)?)),
272 GGUF_TYPE_ARRAY => {
273 let elem_type = read_u32(r)?;
274 let count = read_u64(r)? as usize;
275 if count > 64 * 1024 * 1024 {
276 return Err(MlxError::GgufParseError(format!(
277 "array count {count} exceeds 64M element safety limit"
278 )));
279 }
280 let mut elems = Vec::with_capacity(count);
281 for _ in 0..count {
282 elems.push(read_metadata_value(r, elem_type)?);
283 }
284 Ok(MetadataValue::Array(elems))
285 }
286 GGUF_TYPE_UINT64 => Ok(MetadataValue::Uint64(read_u64(r)?)),
287 GGUF_TYPE_INT64 => Ok(MetadataValue::Int64(read_i64(r)?)),
288 GGUF_TYPE_FLOAT64 => Ok(MetadataValue::Float64(read_f64(r)?)),
289 other => Err(MlxError::GgufParseError(format!(
290 "unknown metadata value type {other}"
291 ))),
292 }
293}
294
295fn ggml_type_from_u32(id: u32) -> Result<GgmlType> {
301 match id {
302 GGML_TYPE_F32 => Ok(GgmlType::F32),
303 GGML_TYPE_F16 => Ok(GgmlType::F16),
304 GGML_TYPE_Q4_0 => Ok(GgmlType::Q4_0),
305 GGML_TYPE_Q8_0 => Ok(GgmlType::Q8_0),
306 GGML_TYPE_Q4_K => Ok(GgmlType::Q4_K),
307 GGML_TYPE_Q5_K => Ok(GgmlType::Q5_K),
308 GGML_TYPE_Q6_K => Ok(GgmlType::Q6_K),
309 GGML_TYPE_I16 => Ok(GgmlType::I16),
310 other => Err(MlxError::GgufParseError(format!(
311 "unsupported GGML type ID {other}"
312 ))),
313 }
314}
315
316fn compute_byte_len(shape: &[usize], ggml_type: GgmlType) -> Result<usize> {
321 let total_elements: usize = shape.iter().product();
322 if total_elements == 0 {
323 return Ok(0);
324 }
325
326 let elems_per_block = ggml_type.block_values() as usize;
327 let bytes_per_block = ggml_type.block_bytes() as usize;
328
329 if total_elements % elems_per_block != 0 {
330 return Err(MlxError::GgufParseError(format!(
331 "total elements {total_elements} not divisible by block size {elems_per_block} \
332 for type {:?}",
333 ggml_type
334 )));
335 }
336
337 Ok((total_elements / elems_per_block) * bytes_per_block)
338}
339
340#[inline]
346fn f16_from_le_bytes(bytes: [u8; 2]) -> f32 {
347 f16::from_le_bytes(bytes).to_f32()
348}
349
350fn dequantize_q4_0(data: &[u8], output: &mut [f32]) -> Result<()> {
356 const BLOCK_BYTES: usize = 18;
357 const BLOCK_ELEMS: usize = 32;
358
359 if data.len() % BLOCK_BYTES != 0 {
360 return Err(MlxError::GgufParseError(format!(
361 "Q4_0 data length {} not divisible by block size {BLOCK_BYTES}",
362 data.len()
363 )));
364 }
365
366 let num_blocks = data.len() / BLOCK_BYTES;
367 if output.len() < num_blocks * BLOCK_ELEMS {
368 return Err(MlxError::GgufParseError(
369 "Q4_0 output buffer too small".into(),
370 ));
371 }
372
373 for i in 0..num_blocks {
374 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
375 let d = f16_from_le_bytes([block[0], block[1]]);
376 let qs = &block[2..18]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
379
380 for j in 0..16 {
381 let x0 = (qs[j] & 0x0F) as i16 - 8;
382 let x1 = (qs[j] >> 4) as i16 - 8;
383 out[j] = x0 as f32 * d;
384 out[j + 16] = x1 as f32 * d;
385 }
386 }
387 Ok(())
388}
389
390fn dequantize_q8_0(data: &[u8], output: &mut [f32]) -> Result<()> {
396 const BLOCK_BYTES: usize = 34;
397 const BLOCK_ELEMS: usize = 32;
398
399 if data.len() % BLOCK_BYTES != 0 {
400 return Err(MlxError::GgufParseError(format!(
401 "Q8_0 data length {} not divisible by block size {BLOCK_BYTES}",
402 data.len()
403 )));
404 }
405
406 let num_blocks = data.len() / BLOCK_BYTES;
407 if output.len() < num_blocks * BLOCK_ELEMS {
408 return Err(MlxError::GgufParseError(
409 "Q8_0 output buffer too small".into(),
410 ));
411 }
412
413 for i in 0..num_blocks {
414 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
415 let d = f16_from_le_bytes([block[0], block[1]]);
416 let qs = &block[2..34]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
419
420 for j in 0..32 {
421 out[j] = (qs[j] as i8) as f32 * d;
422 }
423 }
424 Ok(())
425}
426
427#[inline]
440fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
441 if j < 4 {
442 let sc = scales[j] & 63;
443 let m = scales[j + 4] & 63;
444 (sc, m)
445 } else {
446 let sc = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4);
447 let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
448 (sc, m)
449 }
450}
451
452fn dequantize_q5_k(data: &[u8], output: &mut [f32]) -> Result<()> {
470 const BLOCK_BYTES: usize = 176;
471 const BLOCK_ELEMS: usize = 256;
472
473 if data.len() % BLOCK_BYTES != 0 {
474 return Err(MlxError::GgufParseError(format!(
475 "Q5_K data length {} not divisible by block size {BLOCK_BYTES}",
476 data.len()
477 )));
478 }
479
480 let num_blocks = data.len() / BLOCK_BYTES;
481 if output.len() < num_blocks * BLOCK_ELEMS {
482 return Err(MlxError::GgufParseError(
483 "Q5_K output buffer too small".into(),
484 ));
485 }
486
487 for i in 0..num_blocks {
488 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
489
490 let d = f16_from_le_bytes([block[0], block[1]]);
491 let dmin = f16_from_le_bytes([block[2], block[3]]);
492 let scales = &block[4..16]; let qh = &block[16..48]; let qs = &block[48..176]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
497
498 let mut is = 0usize;
502 let mut u1: u8 = 1;
503 let mut u2: u8 = 2;
504 let mut ys_index = 0usize;
505 let mut ql_off = 0usize;
506
507 while ql_off < 128 {
508 let ql = &qs[ql_off..ql_off + 32];
509
510 let (sc1, m1) = get_scale_min_k4(is, scales);
511 let d1 = d * sc1 as f32;
512 let m1 = dmin * m1 as f32;
513 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
514 let d2 = d * sc2 as f32;
515 let m2 = dmin * m2 as f32;
516
517 for l in 0..32 {
519 let low = (ql[l] & 0x0F) as u32;
520 let high = if (qh[l] & u1) != 0 { 16 } else { 0 };
521 let q = low + high;
522 out[ys_index] = d1 * q as f32 - m1;
523 ys_index += 1;
524 }
525 for l in 0..32 {
527 let low = (ql[l] >> 4) as u32;
528 let high = if (qh[l] & u2) != 0 { 16 } else { 0 };
529 let q = low + high;
530 out[ys_index] = d2 * q as f32 - m2;
531 ys_index += 1;
532 }
533
534 is += 2;
535 ql_off += 32;
536 u1 <<= 2;
537 u2 <<= 2;
538 }
539 }
540 Ok(())
541}
542
543fn dequantize_i16(data: &[u8], output: &mut [f32]) -> Result<()> {
552 if data.len() % 2 != 0 {
553 return Err(MlxError::GgufParseError(format!(
554 "I16 data length {} not even",
555 data.len()
556 )));
557 }
558 let num_elements = data.len() / 2;
559 if output.len() < num_elements {
560 return Err(MlxError::GgufParseError(
561 "I16 output buffer too small".into(),
562 ));
563 }
564 for i in 0..num_elements {
565 let v = i16::from_le_bytes([data[2 * i], data[2 * i + 1]]);
566 output[i] = v as f32;
567 }
568 Ok(())
569}
570
571fn dequantize_q4_k(data: &[u8], output: &mut [f32]) -> Result<()> {
583 const BLOCK_BYTES: usize = 144;
584 const BLOCK_ELEMS: usize = 256;
585
586 if data.len() % BLOCK_BYTES != 0 {
587 return Err(MlxError::GgufParseError(format!(
588 "Q4_K data length {} not divisible by block size {BLOCK_BYTES}",
589 data.len()
590 )));
591 }
592
593 let num_blocks = data.len() / BLOCK_BYTES;
594 if output.len() < num_blocks * BLOCK_ELEMS {
595 return Err(MlxError::GgufParseError(
596 "Q4_K output buffer too small".into(),
597 ));
598 }
599
600 for i in 0..num_blocks {
601 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
602
603 let d = f16_from_le_bytes([block[0], block[1]]);
604 let dmin = f16_from_le_bytes([block[2], block[3]]);
605 let scales = &block[4..16]; let qs = &block[16..144]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
609
610 let mut is = 0usize;
614 let mut ys_index = 0usize;
615
616 let mut j = 0usize;
619 while j < 128 {
620 let q = &qs[j..j + 32];
621 let (sc1, m1) = get_scale_min_k4(is, scales);
622 let d1 = d * sc1 as f32;
623 let min1 = dmin * m1 as f32;
624 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
625 let d2 = d * sc2 as f32;
626 let min2 = dmin * m2 as f32;
627
628 for byte in q.iter() {
630 out[ys_index] = d1 * (*byte & 0xF) as f32 - min1;
631 ys_index += 1;
632 }
633 for byte in q.iter() {
635 out[ys_index] = d2 * (*byte >> 4) as f32 - min2;
636 ys_index += 1;
637 }
638
639 is += 2;
640 j += 32;
641 }
642 }
643 Ok(())
644}
645
646fn dequantize_q6_k(data: &[u8], output: &mut [f32]) -> Result<()> {
657 const BLOCK_BYTES: usize = 210;
658 const BLOCK_ELEMS: usize = 256;
659
660 if data.len() % BLOCK_BYTES != 0 {
661 return Err(MlxError::GgufParseError(format!(
662 "Q6_K data length {} not divisible by block size {BLOCK_BYTES}",
663 data.len()
664 )));
665 }
666
667 let num_blocks = data.len() / BLOCK_BYTES;
668 if output.len() < num_blocks * BLOCK_ELEMS {
669 return Err(MlxError::GgufParseError(
670 "Q6_K output buffer too small".into(),
671 ));
672 }
673
674 for i in 0..num_blocks {
675 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
676
677 let ql = &block[0..128];
678 let qh = &block[128..192];
679 let sc = &block[192..208]; let d = f16_from_le_bytes([block[208], block[209]]);
681
682 let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
683
684 for idx in 0..2 {
686 let ql_base = &ql[64 * idx..];
687 let qh_base = &qh[32 * idx..];
688 let sc_base = &sc[8 * idx..];
689 let out_base = &mut out[128 * idx..];
690
691 for l in 0..32 {
692 let is = l / 16; let q1 = ((ql_base[l] & 0xF) | ((qh_base[l] & 3) << 4)) as i8 - 32_i8;
695 let q2 = ((ql_base[l + 32] & 0xF) | (((qh_base[l] >> 2) & 3) << 4)) as i8
696 - 32_i8;
697 let q3 = ((ql_base[l] >> 4) | (((qh_base[l] >> 4) & 3) << 4)) as i8 - 32_i8;
698 let q4 = ((ql_base[l + 32] >> 4) | (((qh_base[l] >> 6) & 3) << 4)) as i8
699 - 32_i8;
700
701 out_base[l] = d * sc_base[is] as i8 as f32 * q1 as f32;
702 out_base[l + 32] = d * sc_base[is + 2] as i8 as f32 * q2 as f32;
703 out_base[l + 64] = d * sc_base[is + 4] as i8 as f32 * q3 as f32;
704 out_base[l + 96] = d * sc_base[is + 6] as i8 as f32 * q4 as f32;
705 }
706 }
707 }
708 Ok(())
709}
710
711fn dequantize_f16(data: &[u8], output: &mut [f32]) -> Result<()> {
713 if data.len() % 2 != 0 {
714 return Err(MlxError::GgufParseError(
715 "F16 data length not even".into(),
716 ));
717 }
718 let count = data.len() / 2;
719 if output.len() < count {
720 return Err(MlxError::GgufParseError(
721 "F16 output buffer too small".into(),
722 ));
723 }
724 for i in 0..count {
725 output[i] = f16_from_le_bytes([data[2 * i], data[2 * i + 1]]);
726 }
727 Ok(())
728}
729
730fn copy_f32(data: &[u8], output: &mut [f32]) -> Result<()> {
732 if data.len() % 4 != 0 {
733 return Err(MlxError::GgufParseError(
734 "F32 data length not multiple of 4".into(),
735 ));
736 }
737 let count = data.len() / 4;
738 if output.len() < count {
739 return Err(MlxError::GgufParseError(
740 "F32 output buffer too small".into(),
741 ));
742 }
743 for i in 0..count {
744 output[i] = f32::from_le_bytes([
745 data[4 * i],
746 data[4 * i + 1],
747 data[4 * i + 2],
748 data[4 * i + 3],
749 ]);
750 }
751 Ok(())
752}
753
754fn dequantize_to_f32(data: &[u8], ggml_type: GgmlType, output: &mut [f32]) -> Result<()> {
756 match ggml_type {
757 GgmlType::F32 => copy_f32(data, output),
758 GgmlType::F16 => dequantize_f16(data, output),
759 GgmlType::Q4_0 => dequantize_q4_0(data, output),
760 GgmlType::Q8_0 => dequantize_q8_0(data, output),
761 GgmlType::Q4_K => dequantize_q4_k(data, output),
762 GgmlType::Q6_K => dequantize_q6_k(data, output),
763 GgmlType::Q5_K => dequantize_q5_k(data, output),
764 GgmlType::I16 => dequantize_i16(data, output),
765 }
766}
767
768impl GgufFile {
773 pub fn open(path: &Path) -> Result<Self> {
785 let file = std::fs::File::open(path).map_err(|e| {
786 MlxError::IoError(format!("cannot open GGUF file '{}': {e}", path.display()))
787 })?;
788 let mut reader = BufReader::new(file);
789
790 let magic = read_u32(&mut reader)?;
792 if magic != GGUF_MAGIC {
793 return Err(MlxError::GgufParseError(format!(
794 "bad magic: expected 0x{GGUF_MAGIC:08X}, got 0x{magic:08X}"
795 )));
796 }
797
798 let version = read_u32(&mut reader)?;
799 if version != GGUF_VERSION {
800 return Err(MlxError::GgufParseError(format!(
801 "unsupported GGUF version {version} (only v3 is supported)"
802 )));
803 }
804
805 let tensor_count = read_u64(&mut reader)? as usize;
806 let metadata_kv_count = read_u64(&mut reader)? as usize;
807
808 if tensor_count > 100_000 {
810 return Err(MlxError::GgufParseError(format!(
811 "tensor_count {tensor_count} exceeds 100k safety limit"
812 )));
813 }
814 if metadata_kv_count > 1_000_000 {
815 return Err(MlxError::GgufParseError(format!(
816 "metadata_kv_count {metadata_kv_count} exceeds 1M safety limit"
817 )));
818 }
819
820 let mut metadata = HashMap::with_capacity(metadata_kv_count);
822 for _ in 0..metadata_kv_count {
823 let key = read_gguf_string(&mut reader)?;
824 let value_type = read_u32(&mut reader)?;
825 let value = read_metadata_value(&mut reader, value_type)?;
826 metadata.insert(key, value);
827 }
828
829 let alignment = metadata
831 .get(GGUF_ALIGNMENT_KEY)
832 .and_then(|v| v.as_u32())
833 .map(|v| v as u64)
834 .unwrap_or(GGUF_DEFAULT_ALIGNMENT);
835
836 if alignment == 0 || (alignment & (alignment - 1)) != 0 {
837 return Err(MlxError::GgufParseError(format!(
838 "alignment {alignment} is not a power of two"
839 )));
840 }
841
842 let mut tensors = HashMap::with_capacity(tensor_count);
844 for _ in 0..tensor_count {
845 let name = read_gguf_string(&mut reader)?;
846 let n_dims = read_u32(&mut reader)? as usize;
847
848 if n_dims > 8 {
849 return Err(MlxError::GgufParseError(format!(
850 "tensor '{name}' has {n_dims} dimensions (max 8)"
851 )));
852 }
853
854 let mut shape = Vec::with_capacity(n_dims);
855 for _ in 0..n_dims {
856 shape.push(read_u64(&mut reader)? as usize);
857 }
858 shape.reverse();
862
863 let ggml_type_id = read_u32(&mut reader)?;
864 let ggml_type = ggml_type_from_u32(ggml_type_id).map_err(|e| {
865 MlxError::GgufParseError(format!("tensor '{name}': {e}"))
866 })?;
867
868 let offset = read_u64(&mut reader)?;
869 let byte_len = compute_byte_len(&shape, ggml_type).map_err(|e| {
870 MlxError::GgufParseError(format!("tensor '{name}': {e}"))
871 })?;
872
873 tensors.insert(
874 name.clone(),
875 TensorInfo {
876 name,
877 shape,
878 ggml_type,
879 offset,
880 byte_len,
881 },
882 );
883 }
884
885 let pos = reader
889 .stream_position()
890 .map_err(|e| MlxError::GgufParseError(format!("stream_position: {e}")))?;
891 let tensor_data_offset = align_offset(pos, alignment);
892
893 Ok(GgufFile {
894 metadata,
895 tensors,
896 tensor_data_offset,
897 reader: Mutex::new(reader),
898 })
899 }
900
901 pub fn metadata(&self, key: &str) -> Option<&MetadataValue> {
907 self.metadata.get(key)
908 }
909
910 pub fn metadata_string(&self, key: &str) -> Option<&str> {
912 self.metadata.get(key).and_then(|v| v.as_str())
913 }
914
915 pub fn metadata_u32(&self, key: &str) -> Option<u32> {
917 self.metadata.get(key).and_then(|v| v.as_u32())
918 }
919
920 pub fn metadata_f32(&self, key: &str) -> Option<f32> {
922 self.metadata.get(key).and_then(|v| v.as_f32())
923 }
924
925 pub fn tensor_names(&self) -> Vec<&str> {
931 self.tensors.keys().map(|s| s.as_str()).collect()
932 }
933
934 pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
936 self.tensors.get(name)
937 }
938
939 pub fn tensor_count(&self) -> usize {
941 self.tensors.len()
942 }
943
944 pub fn metadata_count(&self) -> usize {
946 self.metadata.len()
947 }
948
949 fn read_tensor_bytes(&self, info: &TensorInfo) -> Result<Vec<u8>> {
958 let abs_offset = self.tensor_data_offset + info.offset;
959 let mut reader = self
960 .reader
961 .lock()
962 .map_err(|_| MlxError::GgufParseError("reader mutex poisoned".into()))?;
963
964 reader
965 .seek(SeekFrom::Start(abs_offset))
966 .map_err(|e| MlxError::IoError(format!("seek to tensor '{}': {e}", info.name)))?;
967
968 let mut buf = vec![0u8; info.byte_len];
969 reader.read_exact(&mut buf).map_err(|e| {
970 MlxError::IoError(format!(
971 "read tensor '{}' ({} bytes at offset {}): {e}",
972 info.name, info.byte_len, abs_offset
973 ))
974 })?;
975
976 Ok(buf)
977 }
978
979 pub fn load_tensor(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
991 let info = self.tensors.get(name).ok_or_else(|| {
992 MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
993 })?;
994
995 let data = self.read_tensor_bytes(info)?;
996
997 match info.ggml_type {
998 GgmlType::F32 => {
999 let mut buf =
1000 device.alloc_buffer(info.byte_len, DType::F32, info.shape.clone())?;
1001 {
1002 let slice: &mut [u8] = buf.as_mut_slice()?;
1003 slice.copy_from_slice(&data);
1004 }
1005 Ok(buf)
1006 }
1007 GgmlType::F16 => {
1008 let mut buf =
1009 device.alloc_buffer(info.byte_len, DType::F16, info.shape.clone())?;
1010 {
1011 let slice: &mut [u8] = buf.as_mut_slice()?;
1012 slice.copy_from_slice(&data);
1013 }
1014 Ok(buf)
1015 }
1016 GgmlType::Q4_0
1017 | GgmlType::Q8_0
1018 | GgmlType::Q4_K
1019 | GgmlType::Q5_K
1020 | GgmlType::Q6_K
1021 | GgmlType::I16 => {
1022 let mut buf =
1046 device.alloc_buffer(info.byte_len, DType::U8, info.shape.clone())?;
1047 {
1048 let slice: &mut [u8] = buf.as_mut_slice()?;
1049 slice.copy_from_slice(&data);
1050 }
1051 Ok(buf)
1052 }
1053 }
1054 }
1055
1056 pub fn load_tensor_f32(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
1067 let info = self.tensors.get(name).ok_or_else(|| {
1068 MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
1069 })?;
1070
1071 let data = self.read_tensor_bytes(info)?;
1072 let total_elements: usize = info.shape.iter().product();
1073
1074 if total_elements == 0 {
1075 return Err(MlxError::GgufParseError(format!(
1076 "tensor '{name}' has zero elements"
1077 )));
1078 }
1079
1080 let f32_byte_len = total_elements * 4;
1081 let mut buf =
1082 device.alloc_buffer(f32_byte_len, DType::F32, info.shape.clone())?;
1083
1084 {
1085 let out_slice: &mut [f32] = buf.as_mut_slice()?;
1086 dequantize_to_f32(&data, info.ggml_type, out_slice)?;
1087 }
1088
1089 Ok(buf)
1090 }
1091
1092 pub fn load_tensor_into_pool(
1128 &self,
1129 name: &str,
1130 device: &MlxDevice,
1131 pool: &mut MlxBufferPool,
1132 ) -> Result<MlxBuffer> {
1133 let buf = self.load_tensor(name, device)?;
1134 pool.register_existing(device, &buf)?;
1135 Ok(buf)
1136 }
1137}
1138
1139fn align_offset(offset: u64, alignment: u64) -> u64 {
1145 let mask = alignment - 1;
1146 (offset + mask) & !mask
1147}