use crate::error::{BonsaiError, BonsaiResult};
use crate::gguf::header::GgufHeader;
use crate::gguf::metadata::MetadataStore;
use crate::gguf::tensor_info::TensorStore;
const DEFAULT_ALIGNMENT: usize = 32;
#[derive(Debug)]
pub struct GgufFile<'a> {
pub header: GgufHeader,
pub metadata: MetadataStore,
pub tensors: TensorStore,
pub data_offset: usize,
pub data: &'a [u8],
}
impl<'a> GgufFile<'a> {
pub fn parse(data: &'a [u8]) -> BonsaiResult<Self> {
let (header, offset) = GgufHeader::parse(data, 0)?;
tracing::debug!(
version = header.version,
tensors = header.tensor_count,
metadata = header.metadata_kv_count,
"parsed GGUF header"
);
let (metadata, offset) = MetadataStore::parse(data, offset, header.metadata_kv_count)?;
tracing::debug!(entries = metadata.len(), "parsed metadata");
let (tensors, offset) = TensorStore::parse(data, offset, header.tensor_count)?;
tracing::debug!(count = tensors.len(), "parsed tensor info");
let alignment = metadata
.get("general.alignment")
.and_then(|v| v.as_u32())
.unwrap_or(DEFAULT_ALIGNMENT as u32) as usize;
let data_offset = align_offset(offset, alignment);
Ok(GgufFile {
header,
metadata,
tensors,
data_offset,
data,
})
}
pub fn tensor_data(&self, name: &str) -> BonsaiResult<&'a [u8]> {
let info = self.tensors.require(name)?;
let start = self.data_offset + info.offset as usize;
let size = info.data_size() as usize;
let end = start + size;
if end > self.data.len() {
return Err(BonsaiError::UnexpectedEof { offset: end as u64 });
}
Ok(&self.data[start..end])
}
}
fn align_offset(offset: usize, alignment: usize) -> usize {
(offset + alignment - 1) & !(alignment - 1)
}
#[cfg(feature = "mmap")]
pub fn mmap_gguf_file(path: &std::path::Path) -> BonsaiResult<memmap2::Mmap> {
let file = std::fs::File::open(path)?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
Ok(mmap)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn align_offset_works() {
assert_eq!(align_offset(0, 32), 0);
assert_eq!(align_offset(1, 32), 32);
assert_eq!(align_offset(31, 32), 32);
assert_eq!(align_offset(32, 32), 32);
assert_eq!(align_offset(33, 32), 64);
}
}