Skip to main content

llama_cpp_bindings/
gguf_context.rs

1use std::ffi::{CStr, CString};
2use std::path::Path;
3use std::ptr::NonNull;
4
5use crate::gguf_context_error::GgufContextError;
6use crate::gguf_type::GgufType;
7
8#[derive(Debug)]
9pub struct GgufContext {
10    context: NonNull<llama_cpp_bindings_sys::gguf_context>,
11}
12
13impl GgufContext {
14    /// # Errors
15    ///
16    /// Returns [`GgufContextError::InitFailed`] if the file cannot be opened or parsed.
17    /// Returns [`GgufContextError::PathToStrError`] if the path is not valid UTF-8.
18    /// Returns [`GgufContextError::NulError`] if the path contains a null byte.
19    pub fn from_file(path: impl AsRef<Path>) -> Result<Self, GgufContextError> {
20        let path_ref = path.as_ref();
21        let path_str = path_ref
22            .to_str()
23            .ok_or_else(|| GgufContextError::PathToStrError(path_ref.to_path_buf()))?;
24        let c_path = CString::new(path_str)?;
25
26        let init_params = llama_cpp_bindings_sys::gguf_init_params {
27            no_alloc: true,
28            ctx: std::ptr::null_mut(),
29        };
30
31        let raw =
32            unsafe { llama_cpp_bindings_sys::gguf_init_from_file(c_path.as_ptr(), init_params) };
33        let context = NonNull::new(raw)
34            .ok_or_else(|| GgufContextError::InitFailed(path_ref.to_path_buf()))?;
35
36        Ok(Self { context })
37    }
38
39    #[must_use]
40    pub fn n_kv(&self) -> i64 {
41        unsafe { llama_cpp_bindings_sys::gguf_get_n_kv(self.context.as_ptr()) }
42    }
43
44    /// # Errors
45    ///
46    /// Returns [`GgufContextError::KeyNotFound`] if the key does not exist.
47    /// Returns [`GgufContextError::NulError`] if the key contains a null byte.
48    pub fn find_key(&self, key: &str) -> Result<i64, GgufContextError> {
49        let c_key = CString::new(key)?;
50        let index =
51            unsafe { llama_cpp_bindings_sys::gguf_find_key(self.context.as_ptr(), c_key.as_ptr()) };
52
53        if index < 0 {
54            return Err(GgufContextError::KeyNotFound {
55                key: key.to_string(),
56            });
57        }
58
59        Ok(index)
60    }
61
62    /// # Safety considerations
63    ///
64    /// The caller must ensure `key_id` is in range `[0, n_kv())`.
65    ///
66    /// # Errors
67    ///
68    /// Returns [`GgufContextError::Utf8Error`] if the key name is not valid UTF-8.
69    pub fn key_at(&self, key_id: i64) -> Result<&str, GgufContextError> {
70        let c_str = unsafe {
71            CStr::from_ptr(llama_cpp_bindings_sys::gguf_get_key(
72                self.context.as_ptr(),
73                key_id,
74            ))
75        };
76
77        Ok(c_str.to_str()?)
78    }
79
80    /// # Safety considerations
81    ///
82    /// The caller must ensure `key_id` is in range `[0, n_kv())`.
83    #[must_use]
84    pub fn kv_type(&self, key_id: i64) -> Option<GgufType> {
85        let raw =
86            unsafe { llama_cpp_bindings_sys::gguf_get_kv_type(self.context.as_ptr(), key_id) };
87
88        GgufType::from_raw(raw)
89    }
90
91    /// # Safety considerations
92    ///
93    /// The caller must ensure the key at `key_id` has type [`GgufType::Uint32`].
94    #[must_use]
95    pub fn val_u32(&self, key_id: i64) -> u32 {
96        unsafe { llama_cpp_bindings_sys::gguf_get_val_u32(self.context.as_ptr(), key_id) }
97    }
98
99    /// # Safety considerations
100    ///
101    /// The caller must ensure the key at `key_id` has type [`GgufType::Int32`].
102    #[must_use]
103    pub fn val_i32(&self, key_id: i64) -> i32 {
104        unsafe { llama_cpp_bindings_sys::gguf_get_val_i32(self.context.as_ptr(), key_id) }
105    }
106
107    /// # Safety considerations
108    ///
109    /// The caller must ensure the key at `key_id` has type [`GgufType::Uint64`].
110    #[must_use]
111    pub fn val_u64(&self, key_id: i64) -> u64 {
112        unsafe { llama_cpp_bindings_sys::gguf_get_val_u64(self.context.as_ptr(), key_id) }
113    }
114
115    /// # Safety considerations
116    ///
117    /// The caller must ensure the key at `key_id` has type [`GgufType::String`].
118    ///
119    /// # Errors
120    ///
121    /// Returns [`GgufContextError::Utf8Error`] if the string value is not valid UTF-8.
122    pub fn val_str(&self, key_id: i64) -> Result<&str, GgufContextError> {
123        let c_str = unsafe {
124            CStr::from_ptr(llama_cpp_bindings_sys::gguf_get_val_str(
125                self.context.as_ptr(),
126                key_id,
127            ))
128        };
129
130        Ok(c_str.to_str()?)
131    }
132
133    #[must_use]
134    pub fn n_tensors(&self) -> i64 {
135        unsafe { llama_cpp_bindings_sys::gguf_get_n_tensors(self.context.as_ptr()) }
136    }
137}
138
139impl Drop for GgufContext {
140    fn drop(&mut self) {
141        unsafe { llama_cpp_bindings_sys::gguf_free(self.context.as_ptr()) }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use std::ffi::CString;
148    use std::mem::Discriminant;
149    use std::path::PathBuf;
150
151    use super::GgufContext;
152    use crate::gguf_context_error::GgufContextError;
153    use crate::gguf_type::GgufType;
154
155    fn fixture_path() -> PathBuf {
156        PathBuf::from(env!("CARGO_MANIFEST_DIR"))
157            .join("fixtures")
158            .join("ggml-vocab-bert-bge.gguf")
159    }
160
161    fn init_failed_disc() -> Discriminant<GgufContextError> {
162        std::mem::discriminant(&GgufContextError::InitFailed(PathBuf::new()))
163    }
164
165    fn key_not_found_disc() -> Discriminant<GgufContextError> {
166        std::mem::discriminant(&GgufContextError::KeyNotFound { key: String::new() })
167    }
168
169    fn nul_error_disc() -> Discriminant<GgufContextError> {
170        let nul_err = CString::new(b"a\0b".to_vec()).unwrap_err();
171        std::mem::discriminant(&GgufContextError::NulError(nul_err))
172    }
173
174    #[cfg(unix)]
175    fn path_to_str_error_disc() -> Discriminant<GgufContextError> {
176        std::mem::discriminant(&GgufContextError::PathToStrError(PathBuf::new()))
177    }
178
179    #[test]
180    fn from_file_opens_valid_gguf() {
181        let context = GgufContext::from_file(fixture_path());
182
183        assert!(context.is_ok());
184    }
185
186    #[test]
187    fn from_file_nonexistent_returns_init_failed() {
188        let err = GgufContext::from_file("/nonexistent/file.gguf").unwrap_err();
189
190        assert_eq!(std::mem::discriminant(&err), init_failed_disc());
191    }
192
193    #[test]
194    fn n_kv_returns_positive_count() {
195        let context = GgufContext::from_file(fixture_path()).unwrap();
196
197        assert!(context.n_kv() > 0);
198    }
199
200    #[test]
201    fn n_tensors_returns_count() {
202        let context = GgufContext::from_file(fixture_path()).unwrap();
203
204        assert!(context.n_tensors() >= 0);
205    }
206
207    #[test]
208    fn find_key_returns_valid_index_for_known_key() {
209        let context = GgufContext::from_file(fixture_path()).unwrap();
210        let index = context.find_key("general.architecture");
211
212        assert!(index.is_ok());
213        assert!(index.unwrap() >= 0);
214    }
215
216    #[test]
217    fn find_key_returns_error_for_missing_key() {
218        let context = GgufContext::from_file(fixture_path()).unwrap();
219        let err = context.find_key("nonexistent.key").unwrap_err();
220
221        assert_eq!(std::mem::discriminant(&err), key_not_found_disc());
222    }
223
224    #[test]
225    fn key_at_returns_expected_name() {
226        let context = GgufContext::from_file(fixture_path()).unwrap();
227        let index = context.find_key("general.architecture").unwrap();
228        let key_name = context.key_at(index).unwrap();
229
230        assert_eq!(key_name, "general.architecture");
231    }
232
233    #[test]
234    fn kv_type_returns_expected_type_for_string_key() {
235        let context = GgufContext::from_file(fixture_path()).unwrap();
236        let index = context.find_key("general.architecture").unwrap();
237        let value_type = context.kv_type(index);
238
239        assert_eq!(value_type, Some(GgufType::String));
240    }
241
242    #[test]
243    fn val_str_returns_architecture_value() {
244        let context = GgufContext::from_file(fixture_path()).unwrap();
245        let index = context.find_key("general.architecture").unwrap();
246        let value = context.val_str(index).unwrap();
247
248        assert!(!value.is_empty());
249    }
250
251    #[cfg(unix)]
252    #[test]
253    fn from_file_non_utf8_path_returns_error() {
254        use std::ffi::OsStr;
255        use std::os::unix::ffi::OsStrExt;
256
257        let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
258        let err = GgufContext::from_file(non_utf8_path).unwrap_err();
259
260        assert_eq!(std::mem::discriminant(&err), path_to_str_error_disc());
261    }
262
263    #[test]
264    fn from_file_with_null_byte_in_path_returns_error() {
265        let err = GgufContext::from_file("/tmp/foo\0bar.gguf").unwrap_err();
266
267        assert_eq!(std::mem::discriminant(&err), nul_error_disc());
268    }
269
270    #[test]
271    fn find_key_with_null_byte_in_key_returns_error() {
272        let context = GgufContext::from_file(fixture_path()).unwrap();
273        let err = context.find_key("foo\0bar").unwrap_err();
274
275        assert_eq!(std::mem::discriminant(&err), nul_error_disc());
276    }
277
278    #[test]
279    fn val_u32_returns_value_for_uint32_key() {
280        let context = GgufContext::from_file(fixture_path()).unwrap();
281
282        let key_id = (0..context.n_kv())
283            .find(|&id| context.kv_type(id) == Some(GgufType::Uint32))
284            .expect("fixture must contain at least one uint32 key");
285
286        let _ = context.val_u32(key_id);
287    }
288
289    struct SyntheticGgufFile {
290        path: PathBuf,
291    }
292
293    impl SyntheticGgufFile {
294        fn new(test_name: &str) -> Self {
295            use std::io::Write as _;
296
297            let path = std::env::temp_dir().join(format!(
298                "llama_cpp_bindings_synthetic_{}_{}.gguf",
299                std::process::id(),
300                test_name,
301            ));
302
303            let mut bytes: Vec<u8> = Vec::new();
304            bytes.extend_from_slice(b"GGUF");
305            bytes.extend_from_slice(&3u32.to_le_bytes());
306            bytes.extend_from_slice(&0u64.to_le_bytes());
307            bytes.extend_from_slice(&3u64.to_le_bytes());
308
309            let arch_key = b"general.architecture";
310            bytes.extend_from_slice(&(arch_key.len() as u64).to_le_bytes());
311            bytes.extend_from_slice(arch_key);
312            bytes.extend_from_slice(&8u32.to_le_bytes());
313            let arch_val = b"synthetic";
314            bytes.extend_from_slice(&(arch_val.len() as u64).to_le_bytes());
315            bytes.extend_from_slice(arch_val);
316
317            let i32_key = b"synthetic.i32_value";
318            bytes.extend_from_slice(&(i32_key.len() as u64).to_le_bytes());
319            bytes.extend_from_slice(i32_key);
320            bytes.extend_from_slice(&5u32.to_le_bytes());
321            bytes.extend_from_slice(&(-12345i32).to_le_bytes());
322
323            let u64_key = b"synthetic.u64_value";
324            bytes.extend_from_slice(&(u64_key.len() as u64).to_le_bytes());
325            bytes.extend_from_slice(u64_key);
326            bytes.extend_from_slice(&10u32.to_le_bytes());
327            bytes.extend_from_slice(&987_654_321u64.to_le_bytes());
328
329            let mut file = std::fs::File::create(&path).unwrap();
330            file.write_all(&bytes).unwrap();
331
332            Self { path }
333        }
334    }
335
336    impl Drop for SyntheticGgufFile {
337        fn drop(&mut self) {
338            std::fs::remove_file(&self.path).ok();
339        }
340    }
341
342    #[test]
343    fn val_i32_and_val_u64_round_trip_through_synthetic_fixture() {
344        let fixture = SyntheticGgufFile::new("val_i32_and_val_u64_round_trip");
345
346        let context = GgufContext::from_file(&fixture.path).unwrap();
347
348        let i32_index = context.find_key("synthetic.i32_value").unwrap();
349        assert_eq!(context.kv_type(i32_index), Some(GgufType::Int32));
350        assert_eq!(context.val_i32(i32_index), -12345);
351
352        let u64_index = context.find_key("synthetic.u64_value").unwrap();
353        assert_eq!(context.kv_type(u64_index), Some(GgufType::Uint64));
354        assert_eq!(context.val_u64(u64_index), 987_654_321);
355    }
356}