use std::collections::HashMap;
use std::sync::Arc;
use super::dtype::GgmlDtype;
use super::fetcher::{InMemoryFetcher, TensorFetcher};
use super::value::{GgufValue, GgufValueType};
use crate::error::{Result, RullamaError};
const GGUF_MAGIC: u32 = 0x4655_4747; const SUPPORTED_VERSION: u32 = 3;
const DEFAULT_ALIGNMENT: u64 = 32;
const STREAMING_HEADER_HINT: u64 = 32 * 1024 * 1024;
#[derive(Debug, Clone)]
pub struct TensorDesc {
pub name: String,
pub dims: Vec<u64>,
pub dtype: GgmlDtype,
pub offset: u64,
}
impl TensorDesc {
pub fn elem_count(&self) -> u64 {
self.dims.iter().product()
}
pub fn byte_len(&self) -> u64 {
let elems = self.elem_count() as usize;
let block_elems = self.dtype.block_elems();
let blocks = elems / block_elems;
debug_assert!(
elems.is_multiple_of(block_elems),
"tensor {} has {} elems, not a multiple of block_elems {}",
self.name,
elems,
block_elems
);
(blocks * self.dtype.block_bytes()) as u64
}
}
pub struct GgufReader {
fetcher: Arc<dyn TensorFetcher>,
in_memory: Option<Arc<[u8]>>,
metadata: HashMap<String, GgufValue>,
tensors: Vec<TensorDesc>,
data_offset: usize,
alignment: u64,
version: u32,
}
impl GgufReader {
pub fn new(data: Vec<u8>) -> Result<Self> {
let arc: Arc<[u8]> = data.into();
let header = parse_header(&arc)?;
Ok(Self {
fetcher: Arc::new(InMemoryFetcher::from_arc(arc.clone())),
in_memory: Some(arc),
metadata: header.metadata,
tensors: header.tensors,
data_offset: header.data_offset,
alignment: header.alignment,
version: header.version,
})
}
pub async fn new_streaming(fetcher: Arc<dyn TensorFetcher>) -> Result<Self> {
let total = fetcher.total_len();
let header_len = STREAMING_HEADER_HINT.min(total);
let mut header_bytes = fetcher.fetch(0, header_len).await?;
loop {
match parse_header(&header_bytes) {
Ok(h) => {
return Ok(Self {
fetcher,
in_memory: None,
metadata: h.metadata,
tensors: h.tensors,
data_offset: h.data_offset,
alignment: h.alignment,
version: h.version,
});
}
Err(RullamaError::Gguf(msg)) if msg.starts_with("unexpected EOF") => {
let new_len = ((header_bytes.len() as u64) * 2).min(total);
if new_len == header_bytes.len() as u64 {
return Err(RullamaError::Gguf(format!(
"header parse failed even after reading the whole file: {msg}"
)));
}
header_bytes = fetcher.fetch(0, new_len).await?;
}
Err(e) => return Err(e),
}
}
}
pub fn version(&self) -> u32 {
self.version
}
pub fn alignment(&self) -> u64 {
self.alignment
}
pub fn metadata(&self) -> &HashMap<String, GgufValue> {
&self.metadata
}
pub fn tensors(&self) -> &[TensorDesc] {
&self.tensors
}
pub fn data_section_offset(&self) -> u64 {
self.data_offset as u64
}
pub fn is_in_memory(&self) -> bool {
self.in_memory.is_some()
}
pub fn fetcher(&self) -> Arc<dyn TensorFetcher> {
self.fetcher.clone()
}
pub fn get(&self, key: &str) -> Result<&GgufValue> {
self.metadata
.get(key)
.ok_or_else(|| RullamaError::Gguf(format!("missing metadata key: {key}")))
}
pub fn get_opt(&self, key: &str) -> Option<&GgufValue> {
self.metadata.get(key)
}
pub fn tensor(&self, name: &str) -> Result<&TensorDesc> {
self.tensors
.iter()
.find(|t| t.name == name)
.ok_or_else(|| RullamaError::Gguf(format!("missing tensor: {name}")))
}
fn tensor_range(&self, name: &str) -> Result<(u64, u64)> {
let t = self.tensor(name)?;
let start = self.data_offset as u64 + t.offset;
let len = t.byte_len();
Ok((start, len))
}
pub fn tensor_bytes(&self, name: &str) -> Result<&[u8]> {
let bytes = self.in_memory.as_ref().ok_or_else(|| {
RullamaError::Gguf(format!(
"tensor_bytes({name}): reader is streaming; use fetch_tensor_bytes().await"
))
})?;
let (start, len) = self.tensor_range(name)?;
let s = start as usize;
let e = s + len as usize;
if e > bytes.len() {
return Err(RullamaError::Gguf(format!(
"tensor {name} extends past buffer end ({e} > {})",
bytes.len()
)));
}
Ok(&bytes[s..e])
}
pub async fn fetch_tensor_bytes(&self, name: &str) -> Result<Vec<u8>> {
let (start, len) = self.tensor_range(name)?;
self.fetcher.fetch(start, len).await
}
pub async fn fetch_tensor_range(
&self,
name: &str,
byte_offset: u64,
byte_len: u64,
) -> Result<Vec<u8>> {
let (start, total) = self.tensor_range(name)?;
let end = byte_offset.checked_add(byte_len).ok_or_else(|| {
RullamaError::Gguf(format!(
"fetch_tensor_range({name}): range overflow {byte_offset}+{byte_len}"
))
})?;
if end > total {
return Err(RullamaError::Gguf(format!(
"fetch_tensor_range({name}): range {byte_offset}..{end} extends past tensor end ({total})"
)));
}
self.fetcher.fetch(start + byte_offset, byte_len).await
}
}
fn align_up(x: u64, a: u64) -> u64 {
if a <= 1 { x } else { x.div_ceil(a) * a }
}
struct ParsedHeader {
metadata: HashMap<String, GgufValue>,
tensors: Vec<TensorDesc>,
data_offset: usize,
alignment: u64,
version: u32,
}
fn parse_header(data: &[u8]) -> Result<ParsedHeader> {
let mut c = Cursor { buf: data, pos: 0 };
let magic = c.read_u32()?;
if magic != GGUF_MAGIC {
return Err(RullamaError::Gguf(format!(
"bad magic 0x{magic:08x}, expected 0x{GGUF_MAGIC:08x} (GGUF)"
)));
}
let version = c.read_u32()?;
if version != SUPPORTED_VERSION {
return Err(RullamaError::Gguf(format!(
"unsupported GGUF version {version}, expected {SUPPORTED_VERSION}"
)));
}
let tensor_count = c.read_u64()? as usize;
let kv_count = c.read_u64()? as usize;
let mut metadata: HashMap<String, GgufValue> = HashMap::with_capacity(kv_count);
for _ in 0..kv_count {
let key = c.read_string()?;
let vt = GgufValueType::from_u32(c.read_u32()?)?;
let val = read_value(&mut c, vt)?;
metadata.insert(key, val);
}
let mut tensors: Vec<TensorDesc> = Vec::with_capacity(tensor_count);
for _ in 0..tensor_count {
let name = c.read_string()?;
let n_dims = c.read_u32()? as usize;
if n_dims > 8 {
return Err(RullamaError::Gguf(format!(
"tensor {name} has {n_dims} dims (>8)"
)));
}
let mut dims = Vec::with_capacity(n_dims);
for _ in 0..n_dims {
dims.push(c.read_u64()?);
}
let dtype = GgmlDtype::from_u32(c.read_u32()?)?;
let offset = c.read_u64()?;
tensors.push(TensorDesc {
name,
dims,
dtype,
offset,
});
}
let alignment = metadata
.get("general.alignment")
.and_then(|v| v.as_u64().ok())
.unwrap_or(DEFAULT_ALIGNMENT);
let unaligned = c.pos as u64;
let data_offset = align_up(unaligned, alignment) as usize;
Ok(ParsedHeader {
metadata,
tensors,
data_offset,
alignment,
version,
})
}
fn read_value(c: &mut Cursor<'_>, vt: GgufValueType) -> Result<GgufValue> {
Ok(match vt {
GgufValueType::U8 => GgufValue::U8(c.read_u8()?),
GgufValueType::I8 => GgufValue::I8(c.read_u8()? as i8),
GgufValueType::U16 => GgufValue::U16(c.read_u16()?),
GgufValueType::I16 => GgufValue::I16(c.read_u16()? as i16),
GgufValueType::U32 => GgufValue::U32(c.read_u32()?),
GgufValueType::I32 => GgufValue::I32(c.read_u32()? as i32),
GgufValueType::U64 => GgufValue::U64(c.read_u64()?),
GgufValueType::I64 => GgufValue::I64(c.read_u64()? as i64),
GgufValueType::F32 => GgufValue::F32(f32::from_bits(c.read_u32()?)),
GgufValueType::F64 => GgufValue::F64(f64::from_bits(c.read_u64()?)),
GgufValueType::Bool => GgufValue::Bool(c.read_u8()? != 0),
GgufValueType::String => GgufValue::String(c.read_string()?),
GgufValueType::Array => {
let elem = GgufValueType::from_u32(c.read_u32()?)?;
let n = c.read_u64()? as usize;
read_array(c, elem, n)?
}
})
}
fn read_array(c: &mut Cursor<'_>, elem: GgufValueType, n: usize) -> Result<GgufValue> {
Ok(match elem {
GgufValueType::U8 => {
let bytes = c.read_bytes(n)?.to_vec();
GgufValue::ArrayU8(bytes)
}
GgufValueType::I8 => {
let raw = c.read_bytes(n)?;
GgufValue::ArrayI8(raw.iter().map(|&b| b as i8).collect())
}
GgufValueType::U16 => GgufValue::ArrayU16(c.read_u16_vec(n)?),
GgufValueType::I16 => {
GgufValue::ArrayI16(c.read_u16_vec(n)?.into_iter().map(|x| x as i16).collect())
}
GgufValueType::U32 => GgufValue::ArrayU32(c.read_u32_vec(n)?),
GgufValueType::I32 => {
GgufValue::ArrayI32(c.read_u32_vec(n)?.into_iter().map(|x| x as i32).collect())
}
GgufValueType::U64 => GgufValue::ArrayU64(c.read_u64_vec(n)?),
GgufValueType::I64 => {
GgufValue::ArrayI64(c.read_u64_vec(n)?.into_iter().map(|x| x as i64).collect())
}
GgufValueType::F32 => {
GgufValue::ArrayF32(c.read_u32_vec(n)?.into_iter().map(f32::from_bits).collect())
}
GgufValueType::F64 => {
GgufValue::ArrayF64(c.read_u64_vec(n)?.into_iter().map(f64::from_bits).collect())
}
GgufValueType::Bool => {
let raw = c.read_bytes(n)?;
GgufValue::ArrayBool(raw.iter().map(|&b| b != 0).collect())
}
GgufValueType::String => {
let mut out = Vec::with_capacity(n);
for _ in 0..n {
out.push(c.read_string()?);
}
GgufValue::ArrayString(out)
}
GgufValueType::Array => {
return Err(RullamaError::Gguf(
"nested arrays are not supported by GGUF v3".into(),
));
}
})
}
struct Cursor<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> Cursor<'a> {
fn need(&self, n: usize) -> Result<()> {
if self.pos + n > self.buf.len() {
Err(RullamaError::Gguf(format!(
"unexpected EOF: needed {n} bytes at {}, buffer len {}",
self.pos,
self.buf.len()
)))
} else {
Ok(())
}
}
fn read_bytes(&mut self, n: usize) -> Result<&'a [u8]> {
self.need(n)?;
let s = &self.buf[self.pos..self.pos + n];
self.pos += n;
Ok(s)
}
fn read_u8(&mut self) -> Result<u8> {
self.need(1)?;
let v = self.buf[self.pos];
self.pos += 1;
Ok(v)
}
fn read_u16(&mut self) -> Result<u16> {
self.need(2)?;
let v = u16::from_le_bytes(self.buf[self.pos..self.pos + 2].try_into().unwrap());
self.pos += 2;
Ok(v)
}
fn read_u32(&mut self) -> Result<u32> {
self.need(4)?;
let v = u32::from_le_bytes(self.buf[self.pos..self.pos + 4].try_into().unwrap());
self.pos += 4;
Ok(v)
}
fn read_u64(&mut self) -> Result<u64> {
self.need(8)?;
let v = u64::from_le_bytes(self.buf[self.pos..self.pos + 8].try_into().unwrap());
self.pos += 8;
Ok(v)
}
fn read_string(&mut self) -> Result<String> {
let n = self.read_u64()? as usize;
let bytes = self.read_bytes(n)?;
Ok(String::from_utf8_lossy(bytes).into_owned())
}
fn read_u16_vec(&mut self, n: usize) -> Result<Vec<u16>> {
let mut out = Vec::with_capacity(n);
for _ in 0..n {
out.push(self.read_u16()?);
}
Ok(out)
}
fn read_u32_vec(&mut self, n: usize) -> Result<Vec<u32>> {
let mut out = Vec::with_capacity(n);
for _ in 0..n {
out.push(self.read_u32()?);
}
Ok(out)
}
fn read_u64_vec(&mut self, n: usize) -> Result<Vec<u64>> {
let mut out = Vec::with_capacity(n);
for _ in 0..n {
out.push(self.read_u64()?);
}
Ok(out)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn synth() -> Vec<u8> {
let mut b = Vec::new();
b.extend_from_slice(b"GGUF");
b.extend_from_slice(&3u32.to_le_bytes()); b.extend_from_slice(&0u64.to_le_bytes()); b.extend_from_slice(&1u64.to_le_bytes()); let key = b"x";
b.extend_from_slice(&(key.len() as u64).to_le_bytes());
b.extend_from_slice(key);
b.extend_from_slice(&(GgufValueType::U32 as u32).to_le_bytes());
b.extend_from_slice(&42u32.to_le_bytes());
while b.len() % 32 != 0 {
b.push(0);
}
b
}
#[test]
fn parses_minimal_gguf() {
let bytes = synth();
let r = GgufReader::new(bytes).expect("parse");
assert_eq!(r.version(), 3);
assert_eq!(r.tensors().len(), 0);
assert_eq!(r.get("x").unwrap().as_u32().unwrap(), 42);
assert!(r.is_in_memory());
}
#[test]
fn streaming_reader_matches_in_memory() {
let bytes = synth();
let in_mem = GgufReader::new(bytes.clone()).expect("parse");
let fetcher: Arc<dyn TensorFetcher> = Arc::new(InMemoryFetcher::new(bytes));
let streamed = pollster::block_on(GgufReader::new_streaming(fetcher)).expect("stream");
assert_eq!(streamed.version(), in_mem.version());
assert_eq!(streamed.tensors().len(), in_mem.tensors().len());
assert_eq!(
streamed.get("x").unwrap().as_u32().unwrap(),
in_mem.get("x").unwrap().as_u32().unwrap()
);
assert!(!streamed.is_in_memory());
assert!(streamed.tensor_bytes("anything").is_err()); }
}