use std::collections::HashMap;
use std::io::{BufRead, Read};
use super::quantization::GgufQuantType;
use super::tensors::TensorInfo;
use crate::error::{Result, RuvLLMError};
#[derive(Debug, Clone)]
pub struct GgufHeader {
pub magic: u32,
pub version: u32,
pub tensor_count: u64,
pub metadata_kv_count: u64,
}
#[derive(Debug, Clone)]
pub enum GgufValue {
U8(u8),
I8(i8),
U16(u16),
I16(i16),
U32(u32),
I32(i32),
U64(u64),
I64(i64),
F32(f32),
F64(f64),
Bool(bool),
String(String),
Array(Vec<GgufValue>),
}
impl GgufValue {
pub fn as_str(&self) -> Option<&str> {
match self {
GgufValue::String(s) => Some(s),
_ => None,
}
}
pub fn as_u64(&self) -> Option<u64> {
match self {
GgufValue::U8(v) => Some(*v as u64),
GgufValue::U16(v) => Some(*v as u64),
GgufValue::U32(v) => Some(*v as u64),
GgufValue::U64(v) => Some(*v),
GgufValue::I8(v) if *v >= 0 => Some(*v as u64),
GgufValue::I16(v) if *v >= 0 => Some(*v as u64),
GgufValue::I32(v) if *v >= 0 => Some(*v as u64),
GgufValue::I64(v) if *v >= 0 => Some(*v as u64),
_ => None,
}
}
pub fn as_i64(&self) -> Option<i64> {
match self {
GgufValue::I8(v) => Some(*v as i64),
GgufValue::I16(v) => Some(*v as i64),
GgufValue::I32(v) => Some(*v as i64),
GgufValue::I64(v) => Some(*v),
GgufValue::U8(v) => Some(*v as i64),
GgufValue::U16(v) => Some(*v as i64),
GgufValue::U32(v) => Some(*v as i64),
GgufValue::U64(v) if *v <= i64::MAX as u64 => Some(*v as i64),
_ => None,
}
}
pub fn as_f32(&self) -> Option<f32> {
match self {
GgufValue::F32(v) => Some(*v),
GgufValue::F64(v) => Some(*v as f32),
GgufValue::I8(v) => Some(*v as f32),
GgufValue::I16(v) => Some(*v as f32),
GgufValue::I32(v) => Some(*v as f32),
GgufValue::U8(v) => Some(*v as f32),
GgufValue::U16(v) => Some(*v as f32),
GgufValue::U32(v) => Some(*v as f32),
_ => None,
}
}
pub fn as_f64(&self) -> Option<f64> {
match self {
GgufValue::F64(v) => Some(*v),
GgufValue::F32(v) => Some(*v as f64),
GgufValue::I8(v) => Some(*v as f64),
GgufValue::I16(v) => Some(*v as f64),
GgufValue::I32(v) => Some(*v as f64),
GgufValue::I64(v) => Some(*v as f64),
GgufValue::U8(v) => Some(*v as f64),
GgufValue::U16(v) => Some(*v as f64),
GgufValue::U32(v) => Some(*v as f64),
GgufValue::U64(v) => Some(*v as f64),
_ => None,
}
}
pub fn as_bool(&self) -> Option<bool> {
match self {
GgufValue::Bool(v) => Some(*v),
GgufValue::U8(v) => Some(*v != 0),
GgufValue::I8(v) => Some(*v != 0),
_ => None,
}
}
pub fn as_array(&self) -> Option<&[GgufValue]> {
match self {
GgufValue::Array(arr) => Some(arr),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u32)]
pub enum GgufValueType {
U8 = 0,
I8 = 1,
U16 = 2,
I16 = 3,
U32 = 4,
I32 = 5,
F32 = 6,
Bool = 7,
String = 8,
Array = 9,
U64 = 10,
I64 = 11,
F64 = 12,
}
impl TryFrom<u32> for GgufValueType {
type Error = RuvLLMError;
fn try_from(value: u32) -> Result<Self> {
match value {
0 => Ok(Self::U8),
1 => Ok(Self::I8),
2 => Ok(Self::U16),
3 => Ok(Self::I16),
4 => Ok(Self::U32),
5 => Ok(Self::I32),
6 => Ok(Self::F32),
7 => Ok(Self::Bool),
8 => Ok(Self::String),
9 => Ok(Self::Array),
10 => Ok(Self::U64),
11 => Ok(Self::I64),
12 => Ok(Self::F64),
_ => Err(RuvLLMError::Model(format!(
"Unknown GGUF value type: {}",
value
))),
}
}
}
pub fn parse_header<R: Read>(reader: &mut R) -> Result<GgufHeader> {
let magic = read_u32(reader)?;
let version = read_u32(reader)?;
let tensor_count = read_u64(reader)?;
let metadata_kv_count = read_u64(reader)?;
Ok(GgufHeader {
magic,
version,
tensor_count,
metadata_kv_count,
})
}
pub fn parse_metadata<R: Read>(reader: &mut R, count: u64) -> Result<HashMap<String, GgufValue>> {
let mut metadata = HashMap::with_capacity(count as usize);
for _ in 0..count {
let key = read_string(reader)?;
let value = read_value(reader)?;
metadata.insert(key, value);
}
Ok(metadata)
}
pub fn parse_tensor_infos<R: Read>(reader: &mut R, count: u64) -> Result<Vec<TensorInfo>> {
let mut tensors = Vec::with_capacity(count as usize);
for _ in 0..count {
let name = read_string(reader)?;
let n_dims = read_u32(reader)? as usize;
let mut shape = Vec::with_capacity(n_dims);
for _ in 0..n_dims {
shape.push(read_u64(reader)? as usize);
}
let dtype_id = read_u32(reader)?;
let dtype = GgufQuantType::try_from(dtype_id)?;
let offset = read_u64(reader)?;
tensors.push(TensorInfo {
name,
shape,
dtype,
offset,
});
}
Ok(tensors)
}
fn read_value<R: Read>(reader: &mut R) -> Result<GgufValue> {
let type_id = read_u32(reader)?;
let value_type = GgufValueType::try_from(type_id)?;
match value_type {
GgufValueType::U8 => Ok(GgufValue::U8(read_u8(reader)?)),
GgufValueType::I8 => Ok(GgufValue::I8(read_i8(reader)?)),
GgufValueType::U16 => Ok(GgufValue::U16(read_u16(reader)?)),
GgufValueType::I16 => Ok(GgufValue::I16(read_i16(reader)?)),
GgufValueType::U32 => Ok(GgufValue::U32(read_u32(reader)?)),
GgufValueType::I32 => Ok(GgufValue::I32(read_i32(reader)?)),
GgufValueType::U64 => Ok(GgufValue::U64(read_u64(reader)?)),
GgufValueType::I64 => Ok(GgufValue::I64(read_i64(reader)?)),
GgufValueType::F32 => Ok(GgufValue::F32(read_f32(reader)?)),
GgufValueType::F64 => Ok(GgufValue::F64(read_f64(reader)?)),
GgufValueType::Bool => Ok(GgufValue::Bool(read_u8(reader)? != 0)),
GgufValueType::String => Ok(GgufValue::String(read_string(reader)?)),
GgufValueType::Array => read_array(reader),
}
}
const MAX_ARRAY_SIZE: usize = 10_000_000;
fn read_array<R: Read>(reader: &mut R) -> Result<GgufValue> {
let elem_type_id = read_u32(reader)?;
let elem_type = GgufValueType::try_from(elem_type_id)?;
let count = read_u64(reader)?;
if count > MAX_ARRAY_SIZE as u64 {
return Err(RuvLLMError::Model(format!(
"Array size {} exceeds maximum allowed size {}",
count, MAX_ARRAY_SIZE
)));
}
let count = count as usize;
let mut values = Vec::with_capacity(count);
for _ in 0..count {
let value = match elem_type {
GgufValueType::U8 => GgufValue::U8(read_u8(reader)?),
GgufValueType::I8 => GgufValue::I8(read_i8(reader)?),
GgufValueType::U16 => GgufValue::U16(read_u16(reader)?),
GgufValueType::I16 => GgufValue::I16(read_i16(reader)?),
GgufValueType::U32 => GgufValue::U32(read_u32(reader)?),
GgufValueType::I32 => GgufValue::I32(read_i32(reader)?),
GgufValueType::U64 => GgufValue::U64(read_u64(reader)?),
GgufValueType::I64 => GgufValue::I64(read_i64(reader)?),
GgufValueType::F32 => GgufValue::F32(read_f32(reader)?),
GgufValueType::F64 => GgufValue::F64(read_f64(reader)?),
GgufValueType::Bool => GgufValue::Bool(read_u8(reader)? != 0),
GgufValueType::String => GgufValue::String(read_string(reader)?),
GgufValueType::Array => read_array(reader)?,
};
values.push(value);
}
Ok(GgufValue::Array(values))
}
fn read_u8<R: Read>(reader: &mut R) -> Result<u8> {
let mut buf = [0u8; 1];
reader.read_exact(&mut buf).map_err(read_err)?;
Ok(buf[0])
}
fn read_i8<R: Read>(reader: &mut R) -> Result<i8> {
Ok(read_u8(reader)? as i8)
}
fn read_u16<R: Read>(reader: &mut R) -> Result<u16> {
let mut buf = [0u8; 2];
reader.read_exact(&mut buf).map_err(read_err)?;
Ok(u16::from_le_bytes(buf))
}
fn read_i16<R: Read>(reader: &mut R) -> Result<i16> {
let mut buf = [0u8; 2];
reader.read_exact(&mut buf).map_err(read_err)?;
Ok(i16::from_le_bytes(buf))
}
fn read_u32<R: Read>(reader: &mut R) -> Result<u32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf).map_err(read_err)?;
Ok(u32::from_le_bytes(buf))
}
fn read_i32<R: Read>(reader: &mut R) -> Result<i32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf).map_err(read_err)?;
Ok(i32::from_le_bytes(buf))
}
fn read_u64<R: Read>(reader: &mut R) -> Result<u64> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf).map_err(read_err)?;
Ok(u64::from_le_bytes(buf))
}
fn read_i64<R: Read>(reader: &mut R) -> Result<i64> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf).map_err(read_err)?;
Ok(i64::from_le_bytes(buf))
}
fn read_f32<R: Read>(reader: &mut R) -> Result<f32> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf).map_err(read_err)?;
Ok(f32::from_le_bytes(buf))
}
fn read_f64<R: Read>(reader: &mut R) -> Result<f64> {
let mut buf = [0u8; 8];
reader.read_exact(&mut buf).map_err(read_err)?;
Ok(f64::from_le_bytes(buf))
}
const MAX_STRING_SIZE: usize = 65536;
fn read_string<R: Read>(reader: &mut R) -> Result<String> {
let len = read_u64(reader)? as usize;
if len > MAX_STRING_SIZE {
return Err(RuvLLMError::Model(format!(
"String too long: {} bytes (max: {} bytes)",
len, MAX_STRING_SIZE
)));
}
let mut buf = vec![0u8; len];
reader.read_exact(&mut buf).map_err(read_err)?;
String::from_utf8(buf).map_err(|e| RuvLLMError::Model(format!("Invalid UTF-8 string: {}", e)))
}
fn read_err(e: std::io::Error) -> RuvLLMError {
RuvLLMError::Model(format!("Failed to read: {}", e))
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_read_primitives() {
let data = [0x47, 0x47, 0x55, 0x46]; let mut cursor = Cursor::new(data);
assert_eq!(read_u32(&mut cursor).unwrap(), 0x46554747);
let data = [0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let mut cursor = Cursor::new(data);
assert_eq!(read_u64(&mut cursor).unwrap(), 1);
let data = 1.0f32.to_le_bytes();
let mut cursor = Cursor::new(data);
assert_eq!(read_f32(&mut cursor).unwrap(), 1.0);
}
#[test]
fn test_read_string() {
let mut data = vec![];
data.extend_from_slice(&5u64.to_le_bytes()); data.extend_from_slice(b"hello");
let mut cursor = Cursor::new(data);
assert_eq!(read_string(&mut cursor).unwrap(), "hello");
}
#[test]
fn test_parse_header() {
let mut data = vec![];
data.extend_from_slice(&0x46554747u32.to_le_bytes()); data.extend_from_slice(&3u32.to_le_bytes()); data.extend_from_slice(&10u64.to_le_bytes()); data.extend_from_slice(&5u64.to_le_bytes());
let mut cursor = Cursor::new(data);
let header = parse_header(&mut cursor).unwrap();
assert_eq!(header.magic, 0x46554747);
assert_eq!(header.version, 3);
assert_eq!(header.tensor_count, 10);
assert_eq!(header.metadata_kv_count, 5);
}
#[test]
fn test_gguf_value_conversions() {
let val = GgufValue::String("test".to_string());
assert_eq!(val.as_str(), Some("test"));
assert_eq!(val.as_u64(), None);
let val = GgufValue::U32(42);
assert_eq!(val.as_u64(), Some(42));
assert_eq!(val.as_i64(), Some(42));
assert_eq!(val.as_f32(), Some(42.0));
assert_eq!(val.as_str(), None);
let val = GgufValue::I32(-5);
assert_eq!(val.as_i64(), Some(-5));
assert_eq!(val.as_u64(), None);
let val = GgufValue::F32(3.14);
assert!((val.as_f32().unwrap() - 3.14).abs() < 0.001);
assert!((val.as_f64().unwrap() - 3.14).abs() < 0.001);
let val = GgufValue::Bool(true);
assert_eq!(val.as_bool(), Some(true));
let val = GgufValue::Array(vec![GgufValue::U32(1), GgufValue::U32(2)]);
assert_eq!(val.as_array().unwrap().len(), 2);
}
#[test]
fn test_value_type_conversion() {
assert_eq!(GgufValueType::try_from(0).unwrap(), GgufValueType::U8);
assert_eq!(GgufValueType::try_from(6).unwrap(), GgufValueType::F32);
assert_eq!(GgufValueType::try_from(8).unwrap(), GgufValueType::String);
assert!(GgufValueType::try_from(100).is_err());
}
}