llama_cpp_bindings/
gguf_context.rs1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::NonNull;
4
5use crate::gguf_context_error::GgufContextError;
6use crate::gguf_type::GgufType;
7
8#[derive(Debug)]
9pub struct GgufContext {
10 context: NonNull<llama_cpp_bindings_sys::gguf_context>,
11}
12
13impl GgufContext {
14 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, GgufContextError> {
20 let path_ref = path.as_ref();
21 let path_str = path_ref
22 .to_str()
23 .ok_or_else(|| GgufContextError::PathToStrError(path_ref.to_path_buf()))?;
24 let c_path = CString::new(path_str)?;
25
26 let init_params = llama_cpp_bindings_sys::gguf_init_params {
27 no_alloc: true,
28 ctx: std::ptr::null_mut(),
29 };
30
31 let raw =
32 unsafe { llama_cpp_bindings_sys::gguf_init_from_file(c_path.as_ptr(), init_params) };
33 let context = NonNull::new(raw)
34 .ok_or_else(|| GgufContextError::InitFailed(path_ref.to_path_buf()))?;
35
36 Ok(Self { context })
37 }
38
39 #[must_use]
40 pub fn n_kv(&self) -> i64 {
41 unsafe { llama_cpp_bindings_sys::gguf_get_n_kv(self.context.as_ptr()) }
42 }
43
44 pub fn find_key(&self, key: &str) -> Result<i64, GgufContextError> {
49 let c_key = CString::new(key)?;
50 let index =
51 unsafe { llama_cpp_bindings_sys::gguf_find_key(self.context.as_ptr(), c_key.as_ptr()) };
52
53 if index < 0 {
54 return Err(GgufContextError::KeyNotFound {
55 key: key.to_string(),
56 });
57 }
58
59 Ok(index)
60 }
61
62 pub fn key_at(&self, key_id: i64) -> Result<&str, GgufContextError> {
70 let c_str = unsafe {
71 CStr::from_ptr(llama_cpp_bindings_sys::gguf_get_key(
72 self.context.as_ptr(),
73 key_id,
74 ))
75 };
76
77 Ok(c_str.to_str()?)
78 }
79
80 #[must_use]
84 pub fn kv_type(&self, key_id: i64) -> Option<GgufType> {
85 let raw =
86 unsafe { llama_cpp_bindings_sys::gguf_get_kv_type(self.context.as_ptr(), key_id) };
87
88 GgufType::from_raw(raw)
89 }
90
91 #[must_use]
95 pub fn val_u32(&self, key_id: i64) -> u32 {
96 unsafe { llama_cpp_bindings_sys::gguf_get_val_u32(self.context.as_ptr(), key_id) }
97 }
98
99 #[must_use]
103 pub fn val_i32(&self, key_id: i64) -> i32 {
104 unsafe { llama_cpp_bindings_sys::gguf_get_val_i32(self.context.as_ptr(), key_id) }
105 }
106
107 #[must_use]
111 pub fn val_u64(&self, key_id: i64) -> u64 {
112 unsafe { llama_cpp_bindings_sys::gguf_get_val_u64(self.context.as_ptr(), key_id) }
113 }
114
115 pub fn val_str(&self, key_id: i64) -> Result<&str, GgufContextError> {
123 let c_str = unsafe {
124 CStr::from_ptr(llama_cpp_bindings_sys::gguf_get_val_str(
125 self.context.as_ptr(),
126 key_id,
127 ))
128 };
129
130 Ok(c_str.to_str()?)
131 }
132
133 #[must_use]
134 pub fn n_tensors(&self) -> i64 {
135 unsafe { llama_cpp_bindings_sys::gguf_get_n_tensors(self.context.as_ptr()) }
136 }
137}
138
139impl Drop for GgufContext {
140 fn drop(&mut self) {
141 unsafe { llama_cpp_bindings_sys::gguf_free(self.context.as_ptr()) }
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use std::ffi::CString;
148 use std::mem::Discriminant;
149 use std::path::PathBuf;
150
151 use super::GgufContext;
152 use crate::gguf_context_error::GgufContextError;
153 use crate::gguf_type::GgufType;
154
155 fn fixture_path() -> PathBuf {
156 PathBuf::from(env!("CARGO_MANIFEST_DIR"))
157 .join("fixtures")
158 .join("ggml-vocab-bert-bge.gguf")
159 }
160
161 fn init_failed_disc() -> Discriminant<GgufContextError> {
162 std::mem::discriminant(&GgufContextError::InitFailed(PathBuf::new()))
163 }
164
165 fn key_not_found_disc() -> Discriminant<GgufContextError> {
166 std::mem::discriminant(&GgufContextError::KeyNotFound { key: String::new() })
167 }
168
169 fn nul_error_disc() -> Discriminant<GgufContextError> {
170 let nul_err = CString::new(b"a\0b".to_vec()).unwrap_err();
171 std::mem::discriminant(&GgufContextError::NulError(nul_err))
172 }
173
174 #[cfg(unix)]
175 fn path_to_str_error_disc() -> Discriminant<GgufContextError> {
176 std::mem::discriminant(&GgufContextError::PathToStrError(PathBuf::new()))
177 }
178
179 #[test]
180 fn from_file_opens_valid_gguf() {
181 let context = GgufContext::from_file(fixture_path());
182
183 assert!(context.is_ok());
184 }
185
186 #[test]
187 fn from_file_nonexistent_returns_init_failed() {
188 let err = GgufContext::from_file("/nonexistent/file.gguf").unwrap_err();
189
190 assert_eq!(std::mem::discriminant(&err), init_failed_disc());
191 }
192
193 #[test]
194 fn n_kv_returns_positive_count() {
195 let context = GgufContext::from_file(fixture_path()).unwrap();
196
197 assert!(context.n_kv() > 0);
198 }
199
200 #[test]
201 fn n_tensors_returns_count() {
202 let context = GgufContext::from_file(fixture_path()).unwrap();
203
204 assert!(context.n_tensors() >= 0);
205 }
206
207 #[test]
208 fn find_key_returns_valid_index_for_known_key() {
209 let context = GgufContext::from_file(fixture_path()).unwrap();
210 let index = context.find_key("general.architecture");
211
212 assert!(index.is_ok());
213 assert!(index.unwrap() >= 0);
214 }
215
216 #[test]
217 fn find_key_returns_error_for_missing_key() {
218 let context = GgufContext::from_file(fixture_path()).unwrap();
219 let err = context.find_key("nonexistent.key").unwrap_err();
220
221 assert_eq!(std::mem::discriminant(&err), key_not_found_disc());
222 }
223
224 #[test]
225 fn key_at_returns_expected_name() {
226 let context = GgufContext::from_file(fixture_path()).unwrap();
227 let index = context.find_key("general.architecture").unwrap();
228 let key_name = context.key_at(index).unwrap();
229
230 assert_eq!(key_name, "general.architecture");
231 }
232
233 #[test]
234 fn kv_type_returns_expected_type_for_string_key() {
235 let context = GgufContext::from_file(fixture_path()).unwrap();
236 let index = context.find_key("general.architecture").unwrap();
237 let value_type = context.kv_type(index);
238
239 assert_eq!(value_type, Some(GgufType::String));
240 }
241
242 #[test]
243 fn val_str_returns_architecture_value() {
244 let context = GgufContext::from_file(fixture_path()).unwrap();
245 let index = context.find_key("general.architecture").unwrap();
246 let value = context.val_str(index).unwrap();
247
248 assert!(!value.is_empty());
249 }
250
251 #[cfg(unix)]
252 #[test]
253 fn from_file_non_utf8_path_returns_error() {
254 use std::ffi::OsStr;
255 use std::os::unix::ffi::OsStrExt;
256
257 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
258 let err = GgufContext::from_file(non_utf8_path).unwrap_err();
259
260 assert_eq!(std::mem::discriminant(&err), path_to_str_error_disc());
261 }
262
263 #[test]
264 fn from_file_with_null_byte_in_path_returns_error() {
265 let err = GgufContext::from_file("/tmp/foo\0bar.gguf").unwrap_err();
266
267 assert_eq!(std::mem::discriminant(&err), nul_error_disc());
268 }
269
270 #[test]
271 fn find_key_with_null_byte_in_key_returns_error() {
272 let context = GgufContext::from_file(fixture_path()).unwrap();
273 let err = context.find_key("foo\0bar").unwrap_err();
274
275 assert_eq!(std::mem::discriminant(&err), nul_error_disc());
276 }
277
278 #[test]
279 fn val_u32_returns_value_for_uint32_key() {
280 let context = GgufContext::from_file(fixture_path()).unwrap();
281
282 let key_id = (0..context.n_kv())
283 .find(|&id| context.kv_type(id) == Some(GgufType::Uint32))
284 .expect("fixture must contain at least one uint32 key");
285
286 let _ = context.val_u32(key_id);
287 }
288
289 struct SyntheticGgufFile {
290 path: PathBuf,
291 }
292
293 impl SyntheticGgufFile {
294 fn new(test_name: &str) -> Self {
295 use std::io::Write as _;
296
297 let path = std::env::temp_dir().join(format!(
298 "llama_cpp_bindings_synthetic_{}_{}.gguf",
299 std::process::id(),
300 test_name,
301 ));
302
303 let mut bytes: Vec<u8> = Vec::new();
304 bytes.extend_from_slice(b"GGUF");
305 bytes.extend_from_slice(&3u32.to_le_bytes());
306 bytes.extend_from_slice(&0u64.to_le_bytes());
307 bytes.extend_from_slice(&3u64.to_le_bytes());
308
309 let arch_key = b"general.architecture";
310 bytes.extend_from_slice(&(arch_key.len() as u64).to_le_bytes());
311 bytes.extend_from_slice(arch_key);
312 bytes.extend_from_slice(&8u32.to_le_bytes());
313 let arch_val = b"synthetic";
314 bytes.extend_from_slice(&(arch_val.len() as u64).to_le_bytes());
315 bytes.extend_from_slice(arch_val);
316
317 let i32_key = b"synthetic.i32_value";
318 bytes.extend_from_slice(&(i32_key.len() as u64).to_le_bytes());
319 bytes.extend_from_slice(i32_key);
320 bytes.extend_from_slice(&5u32.to_le_bytes());
321 bytes.extend_from_slice(&(-12345i32).to_le_bytes());
322
323 let u64_key = b"synthetic.u64_value";
324 bytes.extend_from_slice(&(u64_key.len() as u64).to_le_bytes());
325 bytes.extend_from_slice(u64_key);
326 bytes.extend_from_slice(&10u32.to_le_bytes());
327 bytes.extend_from_slice(&987_654_321u64.to_le_bytes());
328
329 let mut file = std::fs::File::create(&path).unwrap();
330 file.write_all(&bytes).unwrap();
331
332 Self { path }
333 }
334 }
335
336 impl Drop for SyntheticGgufFile {
337 fn drop(&mut self) {
338 std::fs::remove_file(&self.path).ok();
339 }
340 }
341
342 #[test]
343 fn val_i32_and_val_u64_round_trip_through_synthetic_fixture() {
344 let fixture = SyntheticGgufFile::new("val_i32_and_val_u64_round_trip");
345
346 let context = GgufContext::from_file(&fixture.path).unwrap();
347
348 let i32_index = context.find_key("synthetic.i32_value").unwrap();
349 assert_eq!(context.kv_type(i32_index), Some(GgufType::Int32));
350 assert_eq!(context.val_i32(i32_index), -12345);
351
352 let u64_index = context.find_key("synthetic.u64_value").unwrap();
353 assert_eq!(context.kv_type(u64_index), Some(GgufType::Uint64));
354 assert_eq!(context.val_u64(u64_index), 987_654_321);
355 }
356}