use super::constants::GgmlType;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub enum MetadataValue {
Uint8(u8),
Int8(i8),
Uint16(u16),
Int16(i16),
Uint32(u32),
Int32(i32),
Float32(f32),
Bool(bool),
String(String),
Array(MetadataArray),
Uint64(u64),
Int64(i64),
Float64(f64),
}
#[derive(Debug, Clone)]
pub struct MetadataArray {
pub values: Vec<MetadataValue>,
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: String,
pub n_dims: u32,
pub dims: Vec<u64>,
pub dtype: GgmlType,
pub offset: u64,
}
impl TensorInfo {
pub fn n_elements(&self) -> u64 {
self.dims.iter().product()
}
pub fn data_size(&self) -> usize {
let n_elements = self.n_elements() as usize;
let block_size = self.dtype.block_size();
let type_size = self.dtype.type_size();
(n_elements / block_size) * type_size
}
}
#[derive(Debug)]
pub struct GgufHeader {
pub version: u32,
pub tensor_count: u64,
pub metadata_kv_count: u64,
}
#[derive(Debug)]
pub struct GgufData {
pub header: GgufHeader,
pub metadata: HashMap<String, MetadataValue>,
pub tensors: Vec<TensorInfo>,
pub data_offset: u64,
}
impl GgufData {
pub fn get_string(&self, key: &str) -> Option<&str> {
match self.metadata.get(key)? {
MetadataValue::String(s) => Some(s.as_str()),
_ => None,
}
}
pub fn get_u32(&self, key: &str) -> Option<u32> {
match self.metadata.get(key)? {
MetadataValue::Uint32(v) => Some(*v),
_ => None,
}
}
pub fn get_u64(&self, key: &str) -> Option<u64> {
match self.metadata.get(key)? {
MetadataValue::Uint64(v) => Some(*v),
_ => None,
}
}
pub fn get_f32(&self, key: &str) -> Option<f32> {
match self.metadata.get(key)? {
MetadataValue::Float32(v) => Some(*v),
_ => None,
}
}
pub fn get_bool(&self, key: &str) -> Option<bool> {
match self.metadata.get(key)? {
MetadataValue::Bool(v) => Some(*v),
_ => None,
}
}
pub fn get_tensor(&self, name: &str) -> Option<&TensorInfo> {
self.tensors.iter().find(|t| t.name == name)
}
pub fn get_u32_array(&self, key: &str) -> Option<Vec<u32>> {
match self.metadata.get(key)? {
MetadataValue::Uint32(v) => Some(vec![*v]),
MetadataValue::Array(arr) => {
let mut out = Vec::with_capacity(arr.values.len());
for v in &arr.values {
match v {
MetadataValue::Uint32(x) => out.push(*x),
MetadataValue::Uint64(x) => out.push(*x as u32),
MetadataValue::Int32(x) => out.push(*x as u32),
_ => return None,
}
}
Some(out)
}
_ => None,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn data_with(key: &str, value: MetadataValue) -> GgufData {
let mut metadata = HashMap::new();
metadata.insert(key.to_string(), value);
GgufData {
header: GgufHeader {
version: 3,
tensor_count: 0,
metadata_kv_count: 1,
},
metadata,
tensors: vec![],
data_offset: 0,
}
}
#[test]
fn test_get_u32_array_reads_per_layer_array() {
let arr = MetadataValue::Array(MetadataArray {
values: vec![
MetadataValue::Uint32(8),
MetadataValue::Uint32(8),
MetadataValue::Uint32(1),
],
});
let d = data_with("gemma4.attention.head_count_kv", arr);
assert_eq!(
d.get_u32_array("gemma4.attention.head_count_kv"),
Some(vec![8, 8, 1])
);
assert_eq!(d.get_u32("gemma4.attention.head_count_kv"), None);
assert_eq!(d.get_u32_array("missing"), None);
}
#[test]
fn test_get_u32_array_wraps_scalar() {
let d = data_with("k", MetadataValue::Uint32(7));
assert_eq!(d.get_u32_array("k"), Some(vec![7]));
}
}