llama_cpp_bindings/
gguf_context.rs1use std::ffi::{CStr, CString};
6use std::path::Path;
7use std::ptr::NonNull;
8
9use crate::gguf_context_error::GgufContextError;
10use crate::gguf_type::GgufType;
11
12pub struct GgufContext {
17 context: NonNull<llama_cpp_bindings_sys::gguf_context>,
18}
19
20impl GgufContext {
21 pub fn from_file(path: impl AsRef<Path>) -> Result<Self, GgufContextError> {
29 let path_ref = path.as_ref();
30 let path_str = path_ref
31 .to_str()
32 .ok_or_else(|| GgufContextError::PathToStrError(path_ref.to_path_buf()))?;
33 let c_path = CString::new(path_str)?;
34
35 let init_params = llama_cpp_bindings_sys::gguf_init_params {
36 no_alloc: true,
37 ctx: std::ptr::null_mut(),
38 };
39
40 let raw =
41 unsafe { llama_cpp_bindings_sys::gguf_init_from_file(c_path.as_ptr(), init_params) };
42 let context = NonNull::new(raw)
43 .ok_or_else(|| GgufContextError::InitFailed(path_ref.to_path_buf()))?;
44
45 Ok(Self { context })
46 }
47
48 #[must_use]
50 pub fn n_kv(&self) -> i64 {
51 unsafe { llama_cpp_bindings_sys::gguf_get_n_kv(self.context.as_ptr()) }
52 }
53
54 pub fn find_key(&self, key: &str) -> Result<i64, GgufContextError> {
61 let c_key = CString::new(key)?;
62 let index =
63 unsafe { llama_cpp_bindings_sys::gguf_find_key(self.context.as_ptr(), c_key.as_ptr()) };
64
65 if index < 0 {
66 return Err(GgufContextError::KeyNotFound {
67 key: key.to_string(),
68 });
69 }
70
71 Ok(index)
72 }
73
74 pub fn key_at(&self, key_id: i64) -> Result<&str, GgufContextError> {
84 let c_str = unsafe {
85 CStr::from_ptr(llama_cpp_bindings_sys::gguf_get_key(
86 self.context.as_ptr(),
87 key_id,
88 ))
89 };
90
91 Ok(c_str.to_str()?)
92 }
93
94 #[must_use]
100 pub fn kv_type(&self, key_id: i64) -> Option<GgufType> {
101 let raw =
102 unsafe { llama_cpp_bindings_sys::gguf_get_kv_type(self.context.as_ptr(), key_id) };
103
104 GgufType::from_raw(raw)
105 }
106
107 #[must_use]
113 pub fn val_u32(&self, key_id: i64) -> u32 {
114 unsafe { llama_cpp_bindings_sys::gguf_get_val_u32(self.context.as_ptr(), key_id) }
115 }
116
117 #[must_use]
123 pub fn val_i32(&self, key_id: i64) -> i32 {
124 unsafe { llama_cpp_bindings_sys::gguf_get_val_i32(self.context.as_ptr(), key_id) }
125 }
126
127 #[must_use]
133 pub fn val_u64(&self, key_id: i64) -> u64 {
134 unsafe { llama_cpp_bindings_sys::gguf_get_val_u64(self.context.as_ptr(), key_id) }
135 }
136
137 pub fn val_str(&self, key_id: i64) -> Result<&str, GgufContextError> {
147 let c_str = unsafe {
148 CStr::from_ptr(llama_cpp_bindings_sys::gguf_get_val_str(
149 self.context.as_ptr(),
150 key_id,
151 ))
152 };
153
154 Ok(c_str.to_str()?)
155 }
156
157 #[must_use]
159 pub fn n_tensors(&self) -> i64 {
160 unsafe { llama_cpp_bindings_sys::gguf_get_n_tensors(self.context.as_ptr()) }
161 }
162}
163
164impl Drop for GgufContext {
165 fn drop(&mut self) {
166 unsafe { llama_cpp_bindings_sys::gguf_free(self.context.as_ptr()) }
167 }
168}
169
170#[cfg(test)]
171mod tests {
172 use super::GgufContext;
173 use crate::gguf_context_error::GgufContextError;
174 use crate::gguf_type::GgufType;
175
176 fn fixture_path() -> std::path::PathBuf {
177 std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
178 .join("fixtures")
179 .join("ggml-vocab-bert-bge.gguf")
180 }
181
182 #[test]
183 fn from_file_opens_valid_gguf() {
184 let context = GgufContext::from_file(fixture_path());
185
186 assert!(context.is_ok());
187 }
188
189 #[test]
190 fn from_file_nonexistent_returns_init_failed() {
191 let result = GgufContext::from_file("/nonexistent/file.gguf");
192
193 assert!(matches!(result, Err(GgufContextError::InitFailed(_))));
194 }
195
196 #[test]
197 fn n_kv_returns_positive_count() {
198 let context = GgufContext::from_file(fixture_path()).unwrap();
199
200 assert!(context.n_kv() > 0);
201 }
202
203 #[test]
204 fn n_tensors_returns_count() {
205 let context = GgufContext::from_file(fixture_path()).unwrap();
206
207 assert!(context.n_tensors() >= 0);
208 }
209
210 #[test]
211 fn find_key_returns_valid_index_for_known_key() {
212 let context = GgufContext::from_file(fixture_path()).unwrap();
213 let index = context.find_key("general.architecture");
214
215 assert!(index.is_ok());
216 assert!(index.unwrap() >= 0);
217 }
218
219 #[test]
220 fn find_key_returns_error_for_missing_key() {
221 let context = GgufContext::from_file(fixture_path()).unwrap();
222 let result = context.find_key("nonexistent.key");
223
224 assert!(matches!(result, Err(GgufContextError::KeyNotFound { .. })));
225 }
226
227 #[test]
228 fn key_at_returns_expected_name() {
229 let context = GgufContext::from_file(fixture_path()).unwrap();
230 let index = context.find_key("general.architecture").unwrap();
231 let key_name = context.key_at(index).unwrap();
232
233 assert_eq!(key_name, "general.architecture");
234 }
235
236 #[test]
237 fn kv_type_returns_expected_type_for_string_key() {
238 let context = GgufContext::from_file(fixture_path()).unwrap();
239 let index = context.find_key("general.architecture").unwrap();
240 let value_type = context.kv_type(index);
241
242 assert_eq!(value_type, Some(GgufType::String));
243 }
244
245 #[test]
246 fn val_str_returns_architecture_value() {
247 let context = GgufContext::from_file(fixture_path()).unwrap();
248 let index = context.find_key("general.architecture").unwrap();
249 let value = context.val_str(index).unwrap();
250
251 assert!(!value.is_empty());
252 }
253
254 #[cfg(unix)]
255 #[test]
256 fn from_file_non_utf8_path_returns_error() {
257 use std::ffi::OsStr;
258 use std::os::unix::ffi::OsStrExt;
259
260 let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
261 let result = GgufContext::from_file(non_utf8_path);
262
263 assert!(matches!(result, Err(GgufContextError::PathToStrError(_))));
264 }
265
266 #[test]
267 fn from_file_with_null_byte_in_path_returns_error() {
268 let result = GgufContext::from_file("/tmp/foo\0bar.gguf");
269
270 assert!(matches!(result, Err(GgufContextError::NulError(_))));
271 }
272
273 #[test]
274 fn find_key_with_null_byte_in_key_returns_error() {
275 let context = GgufContext::from_file(fixture_path()).unwrap();
276 let result = context.find_key("foo\0bar");
277
278 assert!(matches!(result, Err(GgufContextError::NulError(_))));
279 }
280
281 #[test]
282 fn val_u32_returns_value_for_uint32_key() {
283 let context = GgufContext::from_file(fixture_path()).unwrap();
284
285 let key_id = (0..context.n_kv())
286 .find(|&id| context.kv_type(id) == Some(GgufType::Uint32))
287 .expect("fixture must contain at least one uint32 key");
288
289 let _ = context.val_u32(key_id);
290 }
291
292 struct SyntheticGgufFile {
293 path: std::path::PathBuf,
294 }
295
296 impl SyntheticGgufFile {
297 fn new(test_name: &str) -> Self {
298 use std::io::Write as _;
299
300 let path = std::env::temp_dir().join(format!(
301 "llama_cpp_bindings_synthetic_{}_{}.gguf",
302 std::process::id(),
303 test_name,
304 ));
305
306 let mut bytes: Vec<u8> = Vec::new();
307 bytes.extend_from_slice(b"GGUF");
308 bytes.extend_from_slice(&3u32.to_le_bytes());
309 bytes.extend_from_slice(&0u64.to_le_bytes());
310 bytes.extend_from_slice(&3u64.to_le_bytes());
311
312 let arch_key = b"general.architecture";
313 bytes.extend_from_slice(&(arch_key.len() as u64).to_le_bytes());
314 bytes.extend_from_slice(arch_key);
315 bytes.extend_from_slice(&8u32.to_le_bytes());
316 let arch_val = b"synthetic";
317 bytes.extend_from_slice(&(arch_val.len() as u64).to_le_bytes());
318 bytes.extend_from_slice(arch_val);
319
320 let i32_key = b"synthetic.i32_value";
321 bytes.extend_from_slice(&(i32_key.len() as u64).to_le_bytes());
322 bytes.extend_from_slice(i32_key);
323 bytes.extend_from_slice(&5u32.to_le_bytes());
324 bytes.extend_from_slice(&(-12345i32).to_le_bytes());
325
326 let u64_key = b"synthetic.u64_value";
327 bytes.extend_from_slice(&(u64_key.len() as u64).to_le_bytes());
328 bytes.extend_from_slice(u64_key);
329 bytes.extend_from_slice(&10u32.to_le_bytes());
330 bytes.extend_from_slice(&987_654_321u64.to_le_bytes());
331
332 let mut file = std::fs::File::create(&path).unwrap();
333 file.write_all(&bytes).unwrap();
334
335 Self { path }
336 }
337 }
338
339 impl Drop for SyntheticGgufFile {
340 fn drop(&mut self) {
341 std::fs::remove_file(&self.path).ok();
342 }
343 }
344
345 #[test]
346 fn val_i32_and_val_u64_round_trip_through_synthetic_fixture() {
347 let fixture = SyntheticGgufFile::new("val_i32_and_val_u64_round_trip");
348
349 let context = GgufContext::from_file(&fixture.path).unwrap();
350
351 let i32_index = context.find_key("synthetic.i32_value").unwrap();
352 assert_eq!(context.kv_type(i32_index), Some(GgufType::Int32));
353 assert_eq!(context.val_i32(i32_index), -12345);
354
355 let u64_index = context.find_key("synthetic.u64_value").unwrap();
356 assert_eq!(context.kv_type(u64_index), Some(GgufType::Uint64));
357 assert_eq!(context.val_u64(u64_index), 987_654_321);
358 }
359}