1use std::collections::HashMap;
20use std::io::{BufReader, Read, Seek, SeekFrom};
21use std::path::Path;
22use std::sync::Mutex;
23
24use half::f16;
25
26use crate::ops::quantized_matmul_ggml::GgmlType;
27use crate::{DType, MlxBuffer, MlxDevice, MlxError, Result};
28
29const GGUF_MAGIC: u32 = 0x4655_4747;
35
36const GGUF_VERSION: u32 = 3;
38
39const GGUF_DEFAULT_ALIGNMENT: u64 = 32;
41
42const GGUF_ALIGNMENT_KEY: &str = "general.alignment";
44
45const GGUF_TYPE_UINT8: u32 = 0;
50const GGUF_TYPE_INT8: u32 = 1;
51const GGUF_TYPE_UINT16: u32 = 2;
52const GGUF_TYPE_INT16: u32 = 3;
53const GGUF_TYPE_UINT32: u32 = 4;
54const GGUF_TYPE_INT32: u32 = 5;
55const GGUF_TYPE_FLOAT32: u32 = 6;
56const GGUF_TYPE_BOOL: u32 = 7;
57const GGUF_TYPE_STRING: u32 = 8;
58const GGUF_TYPE_ARRAY: u32 = 9;
59const GGUF_TYPE_UINT64: u32 = 10;
60const GGUF_TYPE_INT64: u32 = 11;
61const GGUF_TYPE_FLOAT64: u32 = 12;
62
63const GGML_TYPE_F32: u32 = 0;
68const GGML_TYPE_F16: u32 = 1;
69const GGML_TYPE_Q4_0: u32 = 2;
70const GGML_TYPE_Q8_0: u32 = 8;
71const GGML_TYPE_Q4_K: u32 = 12;
72const GGML_TYPE_Q6_K: u32 = 14;
73
74#[derive(Debug, Clone)]
80pub enum MetadataValue {
81 Uint8(u8),
82 Int8(i8),
83 Uint16(u16),
84 Int16(i16),
85 Uint32(u32),
86 Int32(i32),
87 Float32(f32),
88 Bool(bool),
89 String(String),
90 Array(Vec<MetadataValue>),
91 Uint64(u64),
92 Int64(i64),
93 Float64(f64),
94}
95
96impl MetadataValue {
97 pub fn as_str(&self) -> Option<&str> {
99 match self {
100 MetadataValue::String(s) => Some(s.as_str()),
101 _ => None,
102 }
103 }
104
105 pub fn as_u32(&self) -> Option<u32> {
107 match self {
108 MetadataValue::Uint32(v) => Some(*v),
109 MetadataValue::Uint8(v) => Some(*v as u32),
110 MetadataValue::Uint16(v) => Some(*v as u32),
111 MetadataValue::Int32(v) if *v >= 0 => Some(*v as u32),
112 _ => None,
113 }
114 }
115
116 pub fn as_f32(&self) -> Option<f32> {
118 match self {
119 MetadataValue::Float32(v) => Some(*v),
120 MetadataValue::Float64(v) => Some(*v as f32),
121 _ => None,
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct TensorInfo {
129 pub name: String,
131 pub shape: Vec<usize>,
133 pub ggml_type: GgmlType,
135 pub offset: u64,
137 pub byte_len: usize,
139}
140
141pub struct GgufFile {
147 metadata: HashMap<String, MetadataValue>,
148 tensors: HashMap<String, TensorInfo>,
149 tensor_data_offset: u64,
151 reader: Mutex<BufReader<std::fs::File>>,
152}
153
154fn read_u8<R: Read>(r: &mut R) -> Result<u8> {
160 let mut buf = [0u8; 1];
161 r.read_exact(&mut buf)
162 .map_err(|e| MlxError::GgufParseError(format!("read u8: {e}")))?;
163 Ok(buf[0])
164}
165
166fn read_i8<R: Read>(r: &mut R) -> Result<i8> {
168 Ok(read_u8(r)? as i8)
169}
170
171fn read_u16<R: Read>(r: &mut R) -> Result<u16> {
173 let mut buf = [0u8; 2];
174 r.read_exact(&mut buf)
175 .map_err(|e| MlxError::GgufParseError(format!("read u16: {e}")))?;
176 Ok(u16::from_le_bytes(buf))
177}
178
179fn read_i16<R: Read>(r: &mut R) -> Result<i16> {
181 let mut buf = [0u8; 2];
182 r.read_exact(&mut buf)
183 .map_err(|e| MlxError::GgufParseError(format!("read i16: {e}")))?;
184 Ok(i16::from_le_bytes(buf))
185}
186
187fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
189 let mut buf = [0u8; 4];
190 r.read_exact(&mut buf)
191 .map_err(|e| MlxError::GgufParseError(format!("read u32: {e}")))?;
192 Ok(u32::from_le_bytes(buf))
193}
194
195fn read_i32<R: Read>(r: &mut R) -> Result<i32> {
197 let mut buf = [0u8; 4];
198 r.read_exact(&mut buf)
199 .map_err(|e| MlxError::GgufParseError(format!("read i32: {e}")))?;
200 Ok(i32::from_le_bytes(buf))
201}
202
203fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
205 let mut buf = [0u8; 8];
206 r.read_exact(&mut buf)
207 .map_err(|e| MlxError::GgufParseError(format!("read u64: {e}")))?;
208 Ok(u64::from_le_bytes(buf))
209}
210
211fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
213 let mut buf = [0u8; 8];
214 r.read_exact(&mut buf)
215 .map_err(|e| MlxError::GgufParseError(format!("read i64: {e}")))?;
216 Ok(i64::from_le_bytes(buf))
217}
218
219fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
221 let mut buf = [0u8; 4];
222 r.read_exact(&mut buf)
223 .map_err(|e| MlxError::GgufParseError(format!("read f32: {e}")))?;
224 Ok(f32::from_le_bytes(buf))
225}
226
227fn read_f64<R: Read>(r: &mut R) -> Result<f64> {
229 let mut buf = [0u8; 8];
230 r.read_exact(&mut buf)
231 .map_err(|e| MlxError::GgufParseError(format!("read f64: {e}")))?;
232 Ok(f64::from_le_bytes(buf))
233}
234
235fn read_gguf_string<R: Read>(r: &mut R) -> Result<String> {
238 let len = read_u64(r)? as usize;
239 if len > 256 * 1024 * 1024 {
240 return Err(MlxError::GgufParseError(format!(
241 "string length {len} exceeds 256 MiB safety limit"
242 )));
243 }
244 let mut buf = vec![0u8; len];
245 r.read_exact(&mut buf)
246 .map_err(|e| MlxError::GgufParseError(format!("read string bytes: {e}")))?;
247 String::from_utf8(buf)
248 .map_err(|e| MlxError::GgufParseError(format!("invalid UTF-8 in string: {e}")))
249}
250
251fn read_metadata_value<R: Read>(r: &mut R, value_type: u32) -> Result<MetadataValue> {
257 match value_type {
258 GGUF_TYPE_UINT8 => Ok(MetadataValue::Uint8(read_u8(r)?)),
259 GGUF_TYPE_INT8 => Ok(MetadataValue::Int8(read_i8(r)?)),
260 GGUF_TYPE_UINT16 => Ok(MetadataValue::Uint16(read_u16(r)?)),
261 GGUF_TYPE_INT16 => Ok(MetadataValue::Int16(read_i16(r)?)),
262 GGUF_TYPE_UINT32 => Ok(MetadataValue::Uint32(read_u32(r)?)),
263 GGUF_TYPE_INT32 => Ok(MetadataValue::Int32(read_i32(r)?)),
264 GGUF_TYPE_FLOAT32 => Ok(MetadataValue::Float32(read_f32(r)?)),
265 GGUF_TYPE_BOOL => {
266 let byte = read_u8(r)?;
267 Ok(MetadataValue::Bool(byte != 0))
268 }
269 GGUF_TYPE_STRING => Ok(MetadataValue::String(read_gguf_string(r)?)),
270 GGUF_TYPE_ARRAY => {
271 let elem_type = read_u32(r)?;
272 let count = read_u64(r)? as usize;
273 if count > 64 * 1024 * 1024 {
274 return Err(MlxError::GgufParseError(format!(
275 "array count {count} exceeds 64M element safety limit"
276 )));
277 }
278 let mut elems = Vec::with_capacity(count);
279 for _ in 0..count {
280 elems.push(read_metadata_value(r, elem_type)?);
281 }
282 Ok(MetadataValue::Array(elems))
283 }
284 GGUF_TYPE_UINT64 => Ok(MetadataValue::Uint64(read_u64(r)?)),
285 GGUF_TYPE_INT64 => Ok(MetadataValue::Int64(read_i64(r)?)),
286 GGUF_TYPE_FLOAT64 => Ok(MetadataValue::Float64(read_f64(r)?)),
287 other => Err(MlxError::GgufParseError(format!(
288 "unknown metadata value type {other}"
289 ))),
290 }
291}
292
293fn ggml_type_from_u32(id: u32) -> Result<GgmlType> {
299 match id {
300 GGML_TYPE_F32 => Ok(GgmlType::F32),
301 GGML_TYPE_F16 => Ok(GgmlType::F16),
302 GGML_TYPE_Q4_0 => Ok(GgmlType::Q4_0),
303 GGML_TYPE_Q8_0 => Ok(GgmlType::Q8_0),
304 GGML_TYPE_Q4_K => Ok(GgmlType::Q4_K),
305 GGML_TYPE_Q6_K => Ok(GgmlType::Q6_K),
306 other => Err(MlxError::GgufParseError(format!(
307 "unsupported GGML type ID {other}"
308 ))),
309 }
310}
311
312fn compute_byte_len(shape: &[usize], ggml_type: GgmlType) -> Result<usize> {
317 let total_elements: usize = shape.iter().product();
318 if total_elements == 0 {
319 return Ok(0);
320 }
321
322 let elems_per_block = ggml_type.block_values() as usize;
323 let bytes_per_block = ggml_type.block_bytes() as usize;
324
325 if total_elements % elems_per_block != 0 {
326 return Err(MlxError::GgufParseError(format!(
327 "total elements {total_elements} not divisible by block size {elems_per_block} \
328 for type {:?}",
329 ggml_type
330 )));
331 }
332
333 Ok((total_elements / elems_per_block) * bytes_per_block)
334}
335
336#[inline]
342fn f16_from_le_bytes(bytes: [u8; 2]) -> f32 {
343 f16::from_le_bytes(bytes).to_f32()
344}
345
346fn dequantize_q4_0(data: &[u8], output: &mut [f32]) -> Result<()> {
352 const BLOCK_BYTES: usize = 18;
353 const BLOCK_ELEMS: usize = 32;
354
355 if data.len() % BLOCK_BYTES != 0 {
356 return Err(MlxError::GgufParseError(format!(
357 "Q4_0 data length {} not divisible by block size {BLOCK_BYTES}",
358 data.len()
359 )));
360 }
361
362 let num_blocks = data.len() / BLOCK_BYTES;
363 if output.len() < num_blocks * BLOCK_ELEMS {
364 return Err(MlxError::GgufParseError(
365 "Q4_0 output buffer too small".into(),
366 ));
367 }
368
369 for i in 0..num_blocks {
370 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
371 let d = f16_from_le_bytes([block[0], block[1]]);
372 let qs = &block[2..18]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
375
376 for j in 0..16 {
377 let x0 = (qs[j] & 0x0F) as i16 - 8;
378 let x1 = (qs[j] >> 4) as i16 - 8;
379 out[j] = x0 as f32 * d;
380 out[j + 16] = x1 as f32 * d;
381 }
382 }
383 Ok(())
384}
385
386fn dequantize_q8_0(data: &[u8], output: &mut [f32]) -> Result<()> {
392 const BLOCK_BYTES: usize = 34;
393 const BLOCK_ELEMS: usize = 32;
394
395 if data.len() % BLOCK_BYTES != 0 {
396 return Err(MlxError::GgufParseError(format!(
397 "Q8_0 data length {} not divisible by block size {BLOCK_BYTES}",
398 data.len()
399 )));
400 }
401
402 let num_blocks = data.len() / BLOCK_BYTES;
403 if output.len() < num_blocks * BLOCK_ELEMS {
404 return Err(MlxError::GgufParseError(
405 "Q8_0 output buffer too small".into(),
406 ));
407 }
408
409 for i in 0..num_blocks {
410 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
411 let d = f16_from_le_bytes([block[0], block[1]]);
412 let qs = &block[2..34]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
415
416 for j in 0..32 {
417 out[j] = (qs[j] as i8) as f32 * d;
418 }
419 }
420 Ok(())
421}
422
423#[inline]
436fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
437 if j < 4 {
438 let sc = scales[j] & 63;
439 let m = scales[j + 4] & 63;
440 (sc, m)
441 } else {
442 let sc = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4);
443 let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
444 (sc, m)
445 }
446}
447
448fn dequantize_q4_k(data: &[u8], output: &mut [f32]) -> Result<()> {
460 const BLOCK_BYTES: usize = 144;
461 const BLOCK_ELEMS: usize = 256;
462
463 if data.len() % BLOCK_BYTES != 0 {
464 return Err(MlxError::GgufParseError(format!(
465 "Q4_K data length {} not divisible by block size {BLOCK_BYTES}",
466 data.len()
467 )));
468 }
469
470 let num_blocks = data.len() / BLOCK_BYTES;
471 if output.len() < num_blocks * BLOCK_ELEMS {
472 return Err(MlxError::GgufParseError(
473 "Q4_K output buffer too small".into(),
474 ));
475 }
476
477 for i in 0..num_blocks {
478 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
479
480 let d = f16_from_le_bytes([block[0], block[1]]);
481 let dmin = f16_from_le_bytes([block[2], block[3]]);
482 let scales = &block[4..16]; let qs = &block[16..144]; let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
486
487 let mut is = 0usize;
491 let mut ys_index = 0usize;
492
493 let mut j = 0usize;
496 while j < 128 {
497 let q = &qs[j..j + 32];
498 let (sc1, m1) = get_scale_min_k4(is, scales);
499 let d1 = d * sc1 as f32;
500 let min1 = dmin * m1 as f32;
501 let (sc2, m2) = get_scale_min_k4(is + 1, scales);
502 let d2 = d * sc2 as f32;
503 let min2 = dmin * m2 as f32;
504
505 for byte in q.iter() {
507 out[ys_index] = d1 * (*byte & 0xF) as f32 - min1;
508 ys_index += 1;
509 }
510 for byte in q.iter() {
512 out[ys_index] = d2 * (*byte >> 4) as f32 - min2;
513 ys_index += 1;
514 }
515
516 is += 2;
517 j += 32;
518 }
519 }
520 Ok(())
521}
522
523fn dequantize_q6_k(data: &[u8], output: &mut [f32]) -> Result<()> {
534 const BLOCK_BYTES: usize = 210;
535 const BLOCK_ELEMS: usize = 256;
536
537 if data.len() % BLOCK_BYTES != 0 {
538 return Err(MlxError::GgufParseError(format!(
539 "Q6_K data length {} not divisible by block size {BLOCK_BYTES}",
540 data.len()
541 )));
542 }
543
544 let num_blocks = data.len() / BLOCK_BYTES;
545 if output.len() < num_blocks * BLOCK_ELEMS {
546 return Err(MlxError::GgufParseError(
547 "Q6_K output buffer too small".into(),
548 ));
549 }
550
551 for i in 0..num_blocks {
552 let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
553
554 let ql = &block[0..128];
555 let qh = &block[128..192];
556 let sc = &block[192..208]; let d = f16_from_le_bytes([block[208], block[209]]);
558
559 let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
560
561 for idx in 0..2 {
563 let ql_base = &ql[64 * idx..];
564 let qh_base = &qh[32 * idx..];
565 let sc_base = &sc[8 * idx..];
566 let out_base = &mut out[128 * idx..];
567
568 for l in 0..32 {
569 let is = l / 16; let q1 = ((ql_base[l] & 0xF) | ((qh_base[l] & 3) << 4)) as i8 - 32_i8;
572 let q2 = ((ql_base[l + 32] & 0xF) | (((qh_base[l] >> 2) & 3) << 4)) as i8
573 - 32_i8;
574 let q3 = ((ql_base[l] >> 4) | (((qh_base[l] >> 4) & 3) << 4)) as i8 - 32_i8;
575 let q4 = ((ql_base[l + 32] >> 4) | (((qh_base[l] >> 6) & 3) << 4)) as i8
576 - 32_i8;
577
578 out_base[l] = d * sc_base[is] as i8 as f32 * q1 as f32;
579 out_base[l + 32] = d * sc_base[is + 2] as i8 as f32 * q2 as f32;
580 out_base[l + 64] = d * sc_base[is + 4] as i8 as f32 * q3 as f32;
581 out_base[l + 96] = d * sc_base[is + 6] as i8 as f32 * q4 as f32;
582 }
583 }
584 }
585 Ok(())
586}
587
588fn dequantize_f16(data: &[u8], output: &mut [f32]) -> Result<()> {
590 if data.len() % 2 != 0 {
591 return Err(MlxError::GgufParseError(
592 "F16 data length not even".into(),
593 ));
594 }
595 let count = data.len() / 2;
596 if output.len() < count {
597 return Err(MlxError::GgufParseError(
598 "F16 output buffer too small".into(),
599 ));
600 }
601 for i in 0..count {
602 output[i] = f16_from_le_bytes([data[2 * i], data[2 * i + 1]]);
603 }
604 Ok(())
605}
606
607fn copy_f32(data: &[u8], output: &mut [f32]) -> Result<()> {
609 if data.len() % 4 != 0 {
610 return Err(MlxError::GgufParseError(
611 "F32 data length not multiple of 4".into(),
612 ));
613 }
614 let count = data.len() / 4;
615 if output.len() < count {
616 return Err(MlxError::GgufParseError(
617 "F32 output buffer too small".into(),
618 ));
619 }
620 for i in 0..count {
621 output[i] = f32::from_le_bytes([
622 data[4 * i],
623 data[4 * i + 1],
624 data[4 * i + 2],
625 data[4 * i + 3],
626 ]);
627 }
628 Ok(())
629}
630
631fn dequantize_to_f32(data: &[u8], ggml_type: GgmlType, output: &mut [f32]) -> Result<()> {
633 match ggml_type {
634 GgmlType::F32 => copy_f32(data, output),
635 GgmlType::F16 => dequantize_f16(data, output),
636 GgmlType::Q4_0 => dequantize_q4_0(data, output),
637 GgmlType::Q8_0 => dequantize_q8_0(data, output),
638 GgmlType::Q4_K => dequantize_q4_k(data, output),
639 GgmlType::Q6_K => dequantize_q6_k(data, output),
640 }
641}
642
643impl GgufFile {
648 pub fn open(path: &Path) -> Result<Self> {
660 let file = std::fs::File::open(path).map_err(|e| {
661 MlxError::IoError(format!("cannot open GGUF file '{}': {e}", path.display()))
662 })?;
663 let mut reader = BufReader::new(file);
664
665 let magic = read_u32(&mut reader)?;
667 if magic != GGUF_MAGIC {
668 return Err(MlxError::GgufParseError(format!(
669 "bad magic: expected 0x{GGUF_MAGIC:08X}, got 0x{magic:08X}"
670 )));
671 }
672
673 let version = read_u32(&mut reader)?;
674 if version != GGUF_VERSION {
675 return Err(MlxError::GgufParseError(format!(
676 "unsupported GGUF version {version} (only v3 is supported)"
677 )));
678 }
679
680 let tensor_count = read_u64(&mut reader)? as usize;
681 let metadata_kv_count = read_u64(&mut reader)? as usize;
682
683 if tensor_count > 100_000 {
685 return Err(MlxError::GgufParseError(format!(
686 "tensor_count {tensor_count} exceeds 100k safety limit"
687 )));
688 }
689 if metadata_kv_count > 1_000_000 {
690 return Err(MlxError::GgufParseError(format!(
691 "metadata_kv_count {metadata_kv_count} exceeds 1M safety limit"
692 )));
693 }
694
695 let mut metadata = HashMap::with_capacity(metadata_kv_count);
697 for _ in 0..metadata_kv_count {
698 let key = read_gguf_string(&mut reader)?;
699 let value_type = read_u32(&mut reader)?;
700 let value = read_metadata_value(&mut reader, value_type)?;
701 metadata.insert(key, value);
702 }
703
704 let alignment = metadata
706 .get(GGUF_ALIGNMENT_KEY)
707 .and_then(|v| v.as_u32())
708 .map(|v| v as u64)
709 .unwrap_or(GGUF_DEFAULT_ALIGNMENT);
710
711 if alignment == 0 || (alignment & (alignment - 1)) != 0 {
712 return Err(MlxError::GgufParseError(format!(
713 "alignment {alignment} is not a power of two"
714 )));
715 }
716
717 let mut tensors = HashMap::with_capacity(tensor_count);
719 for _ in 0..tensor_count {
720 let name = read_gguf_string(&mut reader)?;
721 let n_dims = read_u32(&mut reader)? as usize;
722
723 if n_dims > 8 {
724 return Err(MlxError::GgufParseError(format!(
725 "tensor '{name}' has {n_dims} dimensions (max 8)"
726 )));
727 }
728
729 let mut shape = Vec::with_capacity(n_dims);
730 for _ in 0..n_dims {
731 shape.push(read_u64(&mut reader)? as usize);
732 }
733 shape.reverse();
737
738 let ggml_type_id = read_u32(&mut reader)?;
739 let ggml_type = ggml_type_from_u32(ggml_type_id).map_err(|e| {
740 MlxError::GgufParseError(format!("tensor '{name}': {e}"))
741 })?;
742
743 let offset = read_u64(&mut reader)?;
744 let byte_len = compute_byte_len(&shape, ggml_type).map_err(|e| {
745 MlxError::GgufParseError(format!("tensor '{name}': {e}"))
746 })?;
747
748 tensors.insert(
749 name.clone(),
750 TensorInfo {
751 name,
752 shape,
753 ggml_type,
754 offset,
755 byte_len,
756 },
757 );
758 }
759
760 let pos = reader
764 .stream_position()
765 .map_err(|e| MlxError::GgufParseError(format!("stream_position: {e}")))?;
766 let tensor_data_offset = align_offset(pos, alignment);
767
768 Ok(GgufFile {
769 metadata,
770 tensors,
771 tensor_data_offset,
772 reader: Mutex::new(reader),
773 })
774 }
775
776 pub fn metadata(&self, key: &str) -> Option<&MetadataValue> {
782 self.metadata.get(key)
783 }
784
785 pub fn metadata_string(&self, key: &str) -> Option<&str> {
787 self.metadata.get(key).and_then(|v| v.as_str())
788 }
789
790 pub fn metadata_u32(&self, key: &str) -> Option<u32> {
792 self.metadata.get(key).and_then(|v| v.as_u32())
793 }
794
795 pub fn metadata_f32(&self, key: &str) -> Option<f32> {
797 self.metadata.get(key).and_then(|v| v.as_f32())
798 }
799
800 pub fn tensor_names(&self) -> Vec<&str> {
806 self.tensors.keys().map(|s| s.as_str()).collect()
807 }
808
809 pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
811 self.tensors.get(name)
812 }
813
814 pub fn tensor_count(&self) -> usize {
816 self.tensors.len()
817 }
818
819 pub fn metadata_count(&self) -> usize {
821 self.metadata.len()
822 }
823
824 fn read_tensor_bytes(&self, info: &TensorInfo) -> Result<Vec<u8>> {
833 let abs_offset = self.tensor_data_offset + info.offset;
834 let mut reader = self
835 .reader
836 .lock()
837 .map_err(|_| MlxError::GgufParseError("reader mutex poisoned".into()))?;
838
839 reader
840 .seek(SeekFrom::Start(abs_offset))
841 .map_err(|e| MlxError::IoError(format!("seek to tensor '{}': {e}", info.name)))?;
842
843 let mut buf = vec![0u8; info.byte_len];
844 reader.read_exact(&mut buf).map_err(|e| {
845 MlxError::IoError(format!(
846 "read tensor '{}' ({} bytes at offset {}): {e}",
847 info.name, info.byte_len, abs_offset
848 ))
849 })?;
850
851 Ok(buf)
852 }
853
854 pub fn load_tensor(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
866 let info = self.tensors.get(name).ok_or_else(|| {
867 MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
868 })?;
869
870 let data = self.read_tensor_bytes(info)?;
871
872 match info.ggml_type {
873 GgmlType::F32 => {
874 let mut buf =
875 device.alloc_buffer(info.byte_len, DType::F32, info.shape.clone())?;
876 {
877 let slice: &mut [u8] = buf.as_mut_slice()?;
878 slice.copy_from_slice(&data);
879 }
880 Ok(buf)
881 }
882 GgmlType::F16 => {
883 let mut buf =
884 device.alloc_buffer(info.byte_len, DType::F16, info.shape.clone())?;
885 {
886 let slice: &mut [u8] = buf.as_mut_slice()?;
887 slice.copy_from_slice(&data);
888 }
889 Ok(buf)
890 }
891 GgmlType::Q4_0 | GgmlType::Q8_0 | GgmlType::Q4_K | GgmlType::Q6_K => {
892 let mut buf =
894 device.alloc_buffer(info.byte_len, DType::U8, info.shape.clone())?;
895 {
896 let slice: &mut [u8] = buf.as_mut_slice()?;
897 slice.copy_from_slice(&data);
898 }
899 Ok(buf)
900 }
901 }
902 }
903
904 pub fn load_tensor_f32(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
915 let info = self.tensors.get(name).ok_or_else(|| {
916 MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
917 })?;
918
919 let data = self.read_tensor_bytes(info)?;
920 let total_elements: usize = info.shape.iter().product();
921
922 if total_elements == 0 {
923 return Err(MlxError::GgufParseError(format!(
924 "tensor '{name}' has zero elements"
925 )));
926 }
927
928 let f32_byte_len = total_elements * 4;
929 let mut buf =
930 device.alloc_buffer(f32_byte_len, DType::F32, info.shape.clone())?;
931
932 {
933 let out_slice: &mut [f32] = buf.as_mut_slice()?;
934 dequantize_to_f32(&data, info.ggml_type, out_slice)?;
935 }
936
937 Ok(buf)
938 }
939}
940
941fn align_offset(offset: u64, alignment: u64) -> u64 {
947 let mask = alignment - 1;
948 (offset + mask) & !mask
949}