1use std::collections::HashMap;
4use std::fmt;
5use std::path::Path;
6
7#[derive(Debug, Clone)]
11pub struct GgufHeader {
12 pub magic: [u8; 4],
14 pub version: u32,
16 pub tensor_count: u64,
18 pub metadata_kv_count: u64,
20}
21
22#[derive(Debug, Clone)]
24pub enum GgufValue {
25 U8(u8),
26 I8(i8),
27 U16(u16),
28 I16(i16),
29 U32(u32),
30 I32(i32),
31 U64(u64),
32 I64(i64),
33 F32(f32),
34 F64(f64),
35 Bool(bool),
36 String(String),
37 Array(Vec<GgufValue>),
38}
39
40#[derive(Debug, Clone)]
42pub struct GgufTensorInfo {
43 pub name: String,
45 pub n_dims: u32,
47 pub dims: Vec<u64>,
49 pub dtype: u32,
51 pub offset: u64,
53}
54
55pub type GgufResult<T> = Result<T, GgufError>;
57
58#[derive(Debug, Clone)]
60pub enum GgufError {
61 InvalidMagic([u8; 4]),
63 UnsupportedVersion(u32),
65 Io(String),
67 InvalidData(String),
69 TensorNotFound(String),
71}
72
73impl fmt::Display for GgufError {
74 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 match self {
76 GgufError::InvalidMagic(magic) => {
77 write!(f, "Invalid GGUF magic: {:?}", magic)
78 }
79 GgufError::UnsupportedVersion(v) => {
80 write!(f, "Unsupported GGUF version: {}", v)
81 }
82 GgufError::Io(msg) => write!(f, "IO error: {}", msg),
83 GgufError::InvalidData(msg) => write!(f, "Invalid data: {}", msg),
84 GgufError::TensorNotFound(name) => write!(f, "Tensor not found: {}", name),
85 }
86 }
87}
88
89impl std::error::Error for GgufError {}
90
91#[derive(Debug)]
93pub struct GgufLoader {
94 path: String,
96 header: Option<GgufHeader>,
98 tensors: Vec<GgufTensorInfo>,
100 metadata: HashMap<String, GgufValue>,
102}
103
104impl GgufLoader {
105 pub fn new(path: impl AsRef<Path>) -> Self {
107 Self {
108 path: path.as_ref().to_string_lossy().to_string(),
109 header: None,
110 tensors: Vec::new(),
111 metadata: HashMap::new(),
112 }
113 }
114
115 pub fn validate_path(&self) -> GgufResult<()> {
117 let path = Path::new(&self.path);
118 if !path.exists() {
119 return Err(GgufError::Io(format!("File not found: {}", self.path)));
120 }
121 if path.extension().map_or(true, |ext| ext != "gguf") {
122 return Err(GgufError::InvalidData(
123 "File does not have .gguf extension".to_string(),
124 ));
125 }
126 Ok(())
127 }
128
129 pub fn parse_header(&mut self, data: &[u8]) -> GgufResult<()> {
131 if data.len() < 24 {
132 return Err(GgufError::InvalidData(
133 "File too small for header".to_string(),
134 ));
135 }
136
137 let magic: [u8; 4] = data[0..4].try_into().expect("invariant: slice is 4 bytes");
139 if &magic != b"GGUF" {
140 return Err(GgufError::InvalidMagic(magic));
141 }
142
143 let version =
144 u32::from_le_bytes(data[4..8].try_into().expect("invariant: slice is 4 bytes"));
145 if !(2..=3).contains(&version) {
146 return Err(GgufError::UnsupportedVersion(version));
147 }
148
149 let tensor_count =
150 u64::from_le_bytes(data[8..16].try_into().expect("invariant: slice is 8 bytes"));
151 let metadata_kv_count = u64::from_le_bytes(
152 data[16..24]
153 .try_into()
154 .expect("invariant: slice is 8 bytes"),
155 );
156
157 self.header = Some(GgufHeader {
158 magic,
159 version,
160 tensor_count,
161 metadata_kv_count,
162 });
163
164 Ok(())
165 }
166
167 pub fn header(&self) -> Option<&GgufHeader> {
169 self.header.as_ref()
170 }
171
172 pub fn tensor_count(&self) -> u64 {
174 self.header.as_ref().map_or(0, |h| h.tensor_count)
175 }
176
177 pub fn path(&self) -> &str {
179 &self.path
180 }
181}