use std::ffi::{CStr, CString};
use std::path::Path;
use std::ptr::NonNull;
use crate::gguf_context_error::GgufContextError;
use crate::gguf_type::GgufType;
pub struct GgufContext {
context: NonNull<llama_cpp_bindings_sys::gguf_context>,
}
impl GgufContext {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self, GgufContextError> {
let path_ref = path.as_ref();
let path_str = path_ref
.to_str()
.ok_or_else(|| GgufContextError::PathToStrError(path_ref.to_path_buf()))?;
let c_path = CString::new(path_str)?;
let init_params = llama_cpp_bindings_sys::gguf_init_params {
no_alloc: true,
ctx: std::ptr::null_mut(),
};
let raw =
unsafe { llama_cpp_bindings_sys::gguf_init_from_file(c_path.as_ptr(), init_params) };
let context = NonNull::new(raw)
.ok_or_else(|| GgufContextError::InitFailed(path_ref.to_path_buf()))?;
Ok(Self { context })
}
#[must_use]
pub fn n_kv(&self) -> i64 {
unsafe { llama_cpp_bindings_sys::gguf_get_n_kv(self.context.as_ptr()) }
}
pub fn find_key(&self, key: &str) -> Result<i64, GgufContextError> {
let c_key = CString::new(key)?;
let index =
unsafe { llama_cpp_bindings_sys::gguf_find_key(self.context.as_ptr(), c_key.as_ptr()) };
if index < 0 {
return Err(GgufContextError::KeyNotFound {
key: key.to_string(),
});
}
Ok(index)
}
pub fn key_at(&self, key_id: i64) -> Result<&str, GgufContextError> {
let c_str = unsafe {
CStr::from_ptr(llama_cpp_bindings_sys::gguf_get_key(
self.context.as_ptr(),
key_id,
))
};
Ok(c_str.to_str()?)
}
#[must_use]
pub fn kv_type(&self, key_id: i64) -> Option<GgufType> {
let raw =
unsafe { llama_cpp_bindings_sys::gguf_get_kv_type(self.context.as_ptr(), key_id) };
GgufType::from_raw(raw)
}
#[must_use]
pub fn val_u32(&self, key_id: i64) -> u32 {
unsafe { llama_cpp_bindings_sys::gguf_get_val_u32(self.context.as_ptr(), key_id) }
}
#[must_use]
pub fn val_i32(&self, key_id: i64) -> i32 {
unsafe { llama_cpp_bindings_sys::gguf_get_val_i32(self.context.as_ptr(), key_id) }
}
#[must_use]
pub fn val_u64(&self, key_id: i64) -> u64 {
unsafe { llama_cpp_bindings_sys::gguf_get_val_u64(self.context.as_ptr(), key_id) }
}
pub fn val_str(&self, key_id: i64) -> Result<&str, GgufContextError> {
let c_str = unsafe {
CStr::from_ptr(llama_cpp_bindings_sys::gguf_get_val_str(
self.context.as_ptr(),
key_id,
))
};
Ok(c_str.to_str()?)
}
#[must_use]
pub fn n_tensors(&self) -> i64 {
unsafe { llama_cpp_bindings_sys::gguf_get_n_tensors(self.context.as_ptr()) }
}
}
impl Drop for GgufContext {
fn drop(&mut self) {
unsafe { llama_cpp_bindings_sys::gguf_free(self.context.as_ptr()) }
}
}
#[cfg(test)]
mod tests {
use super::GgufContext;
use crate::gguf_context_error::GgufContextError;
use crate::gguf_type::GgufType;
fn fixture_path() -> std::path::PathBuf {
std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR"))
.join("fixtures")
.join("ggml-vocab-bert-bge.gguf")
}
#[test]
fn from_file_opens_valid_gguf() {
let context = GgufContext::from_file(fixture_path());
assert!(context.is_ok());
}
#[test]
fn from_file_nonexistent_returns_init_failed() {
let result = GgufContext::from_file("/nonexistent/file.gguf");
assert!(matches!(result, Err(GgufContextError::InitFailed(_))));
}
#[test]
fn n_kv_returns_positive_count() {
let context = GgufContext::from_file(fixture_path()).unwrap();
assert!(context.n_kv() > 0);
}
#[test]
fn n_tensors_returns_count() {
let context = GgufContext::from_file(fixture_path()).unwrap();
assert!(context.n_tensors() >= 0);
}
#[test]
fn find_key_returns_valid_index_for_known_key() {
let context = GgufContext::from_file(fixture_path()).unwrap();
let index = context.find_key("general.architecture");
assert!(index.is_ok());
assert!(index.unwrap() >= 0);
}
#[test]
fn find_key_returns_error_for_missing_key() {
let context = GgufContext::from_file(fixture_path()).unwrap();
let result = context.find_key("nonexistent.key");
assert!(matches!(result, Err(GgufContextError::KeyNotFound { .. })));
}
#[test]
fn key_at_returns_expected_name() {
let context = GgufContext::from_file(fixture_path()).unwrap();
let index = context.find_key("general.architecture").unwrap();
let key_name = context.key_at(index).unwrap();
assert_eq!(key_name, "general.architecture");
}
#[test]
fn kv_type_returns_expected_type_for_string_key() {
let context = GgufContext::from_file(fixture_path()).unwrap();
let index = context.find_key("general.architecture").unwrap();
let value_type = context.kv_type(index);
assert_eq!(value_type, Some(GgufType::String));
}
#[test]
fn val_str_returns_architecture_value() {
let context = GgufContext::from_file(fixture_path()).unwrap();
let index = context.find_key("general.architecture").unwrap();
let value = context.val_str(index).unwrap();
assert!(!value.is_empty());
}
#[cfg(unix)]
#[test]
fn from_file_non_utf8_path_returns_error() {
use std::ffi::OsStr;
use std::os::unix::ffi::OsStrExt;
let non_utf8_path = std::path::Path::new(OsStr::from_bytes(b"/tmp/\xff\xfe.gguf"));
let result = GgufContext::from_file(non_utf8_path);
assert!(matches!(result, Err(GgufContextError::PathToStrError(_))));
}
#[test]
fn from_file_with_null_byte_in_path_returns_error() {
let result = GgufContext::from_file("/tmp/foo\0bar.gguf");
assert!(matches!(result, Err(GgufContextError::NulError(_))));
}
#[test]
fn find_key_with_null_byte_in_key_returns_error() {
let context = GgufContext::from_file(fixture_path()).unwrap();
let result = context.find_key("foo\0bar");
assert!(matches!(result, Err(GgufContextError::NulError(_))));
}
#[test]
fn val_u32_returns_value_for_uint32_key() {
let context = GgufContext::from_file(fixture_path()).unwrap();
let key_id = (0..context.n_kv())
.find(|&id| context.kv_type(id) == Some(GgufType::Uint32))
.expect("fixture must contain at least one uint32 key");
let _ = context.val_u32(key_id);
}
struct SyntheticGgufFile {
path: std::path::PathBuf,
}
impl SyntheticGgufFile {
fn new(test_name: &str) -> Self {
use std::io::Write as _;
let path = std::env::temp_dir().join(format!(
"llama_cpp_bindings_synthetic_{}_{}.gguf",
std::process::id(),
test_name,
));
let mut bytes: Vec<u8> = Vec::new();
bytes.extend_from_slice(b"GGUF");
bytes.extend_from_slice(&3u32.to_le_bytes());
bytes.extend_from_slice(&0u64.to_le_bytes());
bytes.extend_from_slice(&3u64.to_le_bytes());
let arch_key = b"general.architecture";
bytes.extend_from_slice(&(arch_key.len() as u64).to_le_bytes());
bytes.extend_from_slice(arch_key);
bytes.extend_from_slice(&8u32.to_le_bytes());
let arch_val = b"synthetic";
bytes.extend_from_slice(&(arch_val.len() as u64).to_le_bytes());
bytes.extend_from_slice(arch_val);
let i32_key = b"synthetic.i32_value";
bytes.extend_from_slice(&(i32_key.len() as u64).to_le_bytes());
bytes.extend_from_slice(i32_key);
bytes.extend_from_slice(&5u32.to_le_bytes());
bytes.extend_from_slice(&(-12345i32).to_le_bytes());
let u64_key = b"synthetic.u64_value";
bytes.extend_from_slice(&(u64_key.len() as u64).to_le_bytes());
bytes.extend_from_slice(u64_key);
bytes.extend_from_slice(&10u32.to_le_bytes());
bytes.extend_from_slice(&987_654_321u64.to_le_bytes());
let mut file = std::fs::File::create(&path).unwrap();
file.write_all(&bytes).unwrap();
Self { path }
}
}
impl Drop for SyntheticGgufFile {
fn drop(&mut self) {
std::fs::remove_file(&self.path).ok();
}
}
#[test]
fn val_i32_and_val_u64_round_trip_through_synthetic_fixture() {
let fixture = SyntheticGgufFile::new("val_i32_and_val_u64_round_trip");
let context = GgufContext::from_file(&fixture.path).unwrap();
let i32_index = context.find_key("synthetic.i32_value").unwrap();
assert_eq!(context.kv_type(i32_index), Some(GgufType::Int32));
assert_eq!(context.val_i32(i32_index), -12345);
let u64_index = context.find_key("synthetic.u64_value").unwrap();
assert_eq!(context.kv_type(u64_index), Some(GgufType::Uint64));
assert_eq!(context.val_u64(u64_index), 987_654_321);
}
}