use serde::{Deserialize, Serialize};
use serde_json::Value as JsonValue;
use std::collections::BTreeMap;
use std::fs;
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AprTensorDescriptor {
pub name: String,
pub dtype: String,
pub shape: Vec<usize>,
pub offset: usize,
pub size: usize,
}
pub type AprMetadata = BTreeMap<String, JsonValue>;
#[derive(Debug)]
pub struct AprReader {
pub metadata: AprMetadata,
pub tensors: Vec<AprTensorDescriptor>,
data: Vec<u8>,
data_offset: usize,
}
impl AprReader {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Self, String> {
let data = fs::read(path).map_err(|e| format!("Failed to read file: {e}"))?;
Self::from_bytes(data)
}
pub fn from_bytes(data: Vec<u8>) -> Result<Self, String> {
use crate::format::v2::AprV2ReaderRef;
let reader =
AprV2ReaderRef::from_bytes(&data).map_err(|e| format!("Invalid APR file: {e}"))?;
let meta = reader.metadata();
let mut metadata = AprMetadata::new();
if !meta.model_type.is_empty() {
metadata.insert(
"model_type".to_string(),
JsonValue::String(meta.model_type.clone()),
);
}
if let Some(ref name) = meta.name {
metadata.insert("model_name".to_string(), JsonValue::String(name.clone()));
}
if let Some(ref desc) = meta.description {
metadata.insert("description".to_string(), JsonValue::String(desc.clone()));
}
if let Some(ref author) = meta.author {
metadata.insert("author".to_string(), JsonValue::String(author.clone()));
}
if let Some(ref license) = meta.license {
metadata.insert("license".to_string(), JsonValue::String(license.clone()));
}
if let Some(ref version) = meta.version {
metadata.insert("version".to_string(), JsonValue::String(version.clone()));
}
if let Some(ref arch) = meta.architecture {
metadata.insert("architecture".to_string(), JsonValue::String(arch.clone()));
}
if let Some(v) = meta.hidden_size {
metadata.insert(
"hidden_size".into(),
JsonValue::Number(serde_json::Number::from(v)),
);
}
if let Some(v) = meta.num_layers {
metadata.insert(
"num_layers".into(),
JsonValue::Number(serde_json::Number::from(v)),
);
}
if let Some(v) = meta.num_heads {
metadata.insert(
"num_heads".into(),
JsonValue::Number(serde_json::Number::from(v)),
);
}
if let Some(v) = meta.num_kv_heads {
metadata.insert(
"num_kv_heads".into(),
JsonValue::Number(serde_json::Number::from(v)),
);
}
if let Some(v) = meta.vocab_size {
metadata.insert(
"vocab_size".into(),
JsonValue::Number(serde_json::Number::from(v)),
);
}
if let Some(v) = meta.intermediate_size {
metadata.insert(
"intermediate_size".into(),
JsonValue::Number(serde_json::Number::from(v)),
);
}
if let Some(v) = meta.max_position_embeddings {
metadata.insert(
"max_position_embeddings".into(),
JsonValue::Number(serde_json::Number::from(v)),
);
}
if let Some(v) = meta.rope_theta {
if let Some(n) = serde_json::Number::from_f64(v as f64) {
metadata.insert("rope_theta".into(), JsonValue::Number(n));
}
}
if let Some(v) = meta.rms_norm_eps {
if let Some(n) = serde_json::Number::from_f64(v as f64) {
metadata.insert("rms_norm_eps".into(), JsonValue::Number(n));
}
}
if let Some(v) = meta.head_dim {
metadata.insert(
"head_dim".into(),
JsonValue::Number(serde_json::Number::from(v)),
);
}
for (k, v) in &meta.custom {
metadata.insert(k.clone(), v.clone());
}
let tensor_names = reader.tensor_names();
let mut tensors = Vec::new();
for name in tensor_names {
if let Some(entry) = reader.get_tensor(name) {
tensors.push(AprTensorDescriptor {
name: entry.name.clone(),
dtype: format!("{:?}", entry.dtype),
shape: entry.shape.clone(),
offset: entry.offset as usize,
size: entry.size as usize,
});
}
}
let data_offset = reader.header().data_offset as usize;
Ok(Self {
metadata,
tensors,
data,
data_offset,
})
}
pub fn open_filtered<P, F>(path: P, filter: F) -> Result<Self, String>
where
P: AsRef<Path>,
F: Fn(&str) -> bool,
{
let data = fs::read(path).map_err(|e| format!("Failed to read file: {e}"))?;
Self::from_bytes_filtered(data, filter)
}
pub fn from_bytes_filtered<F>(data: Vec<u8>, filter: F) -> Result<Self, String>
where
F: Fn(&str) -> bool,
{
let mut reader = Self::from_bytes(data)?;
reader.tensors.retain(|t| filter(&t.name));
Ok(reader)
}
#[must_use]
pub fn get_metadata(&self, key: &str) -> Option<&JsonValue> {
self.metadata.get(key)
}
pub fn all_metadata(&self) -> impl Iterator<Item = (&String, &JsonValue)> {
self.metadata.iter()
}
fn get_tensor_bytes(&self, name: &str) -> Result<(&AprTensorDescriptor, &[u8]), String> {
let desc = self
.tensors
.iter()
.find(|t| t.name == name)
.ok_or_else(|| format!("Tensor not found: {name}"))?;
let start = self.data_offset + desc.offset;
let end = start + desc.size;
if end > self.data.len() {
return Err(format!(
"Tensor '{name}' data out of bounds: {end} > {}",
self.data.len()
));
}
Ok((desc, &self.data[start..end]))
}
pub fn read_tensor_f32(&self, name: &str) -> Result<Vec<f32>, String> {
let (desc, bytes) = self.get_tensor_bytes(name)?;
if desc.dtype != "F32" {
return Err(format!(
"Tensor not found or not F32: {name} (dtype={})",
desc.dtype
));
}
Ok(bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect())
}
pub fn read_tensor_as_f32(&self, name: &str) -> Result<Vec<f32>, String> {
use crate::format::gguf::dequant::{dequantize_q4_k, dequantize_q6_k};
let (desc, bytes) = self.get_tensor_bytes(name)?;
let element_count: usize = desc.shape.iter().product();
match desc.dtype.as_str() {
"F32" => Ok(bytes
.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()),
"F16" => Ok(bytes
.chunks_exact(2)
.map(|c| trueno::f16_to_f32(u16::from_le_bytes([c[0], c[1]])))
.collect()),
"BF16" => Ok(bytes
.chunks_exact(2)
.map(|c| {
let bits = u16::from_le_bytes([c[0], c[1]]);
f32::from_bits(u32::from(bits) << 16)
})
.collect()),
"Q8" => {
if bytes.len() < 4 {
return Err(format!("Tensor '{name}' Q8 data too short"));
}
let scale = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
Ok(bytes[4..]
.iter()
.map(|&b| f32::from(b as i8) * scale)
.collect())
}
"Q4" => Ok(Self::dequantize_q4_inline(bytes, element_count)),
"Q4K" => dequantize_q4_k(bytes, 0, element_count)
.map_err(|e| format!("Tensor '{name}' Q4K dequant failed: {e}")),
"Q6K" => dequantize_q6_k(bytes, 0, element_count)
.map_err(|e| format!("Tensor '{name}' Q6K dequant failed: {e}")),
other => Err(format!(
"Tensor '{name}' not found or unsupported dtype: {other}"
)),
}
}
fn dequantize_q4_inline(data: &[u8], element_count: usize) -> Vec<f32> {
use crate::format::f16_safety::F16_MIN_NORMAL;
const BLOCK_SIZE: usize = 32;
let mut result = Vec::with_capacity(element_count);
let mut pos = 0;
let mut remaining = element_count;
while remaining > 0 && pos + 2 <= data.len() {
let scale_bits = u16::from_le_bytes([data[pos], data[pos + 1]]);
let scale_raw = trueno::f16_to_f32(scale_bits);
let scale = if scale_raw.is_nan()
|| scale_raw.is_infinite()
|| scale_raw.abs() < F16_MIN_NORMAL
{
0.0
} else {
scale_raw
};
pos += 2;
let values_in_block = remaining.min(BLOCK_SIZE);
for i in 0..values_in_block {
let byte_idx = pos + i / 2;
if byte_idx >= data.len() {
break;
}
let byte = data[byte_idx];
let nibble = if i % 2 == 0 { byte & 0x0F } else { byte >> 4 };
let q = (nibble as i8) - 8;
result.push(f32::from(q) * scale);
}
pos += 16;
remaining = remaining.saturating_sub(BLOCK_SIZE);
}
result.resize(element_count, 0.0);
result
}
pub fn read_tensor_f32_checked(&self, name: &str) -> Result<Vec<f32>, String> {
let data = self.read_tensor_f32(name)?;
for (i, &v) in data.iter().enumerate() {
if !v.is_finite() {
return Err(format!(
"F-CKPT-013: tensor '{name}' contains non-finite value at index {i}: {v}"
));
}
}
Ok(data)
}
pub fn validate_tensor_shape(
&self,
name: &str,
expected_elements: usize,
) -> Result<(), String> {
let desc = self
.tensors
.iter()
.find(|t| t.name == name)
.ok_or_else(|| format!("F-CKPT-014: tensor '{name}' not found"))?;
let actual_elements: usize = desc.shape.iter().product();
if actual_elements != expected_elements {
return Err(format!(
"F-CKPT-014: tensor '{name}' shape mismatch: \
expected {expected_elements} elements, got {actual_elements} (shape {:?})",
desc.shape,
));
}
Ok(())
}
}
#[derive(Debug, Default)]
pub struct AprWriter {
metadata: AprMetadata,
tensors: Vec<(String, Vec<usize>, Vec<f32>)>,
}
impl AprWriter {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn set_metadata(&mut self, key: impl Into<String>, value: JsonValue) {
self.metadata.insert(key.into(), value);
}
pub fn add_tensor_f32(&mut self, name: impl Into<String>, shape: Vec<usize>, data: &[f32]) {
self.tensors.push((name.into(), shape, data.to_vec()));
}
pub fn add_tensor_f32_owned(
&mut self,
name: impl Into<String>,
shape: Vec<usize>,
data: Vec<f32>,
) {
self.tensors.push((name.into(), shape, data));
}
fn build_v2_metadata(&self) -> crate::format::v2::AprV2Metadata {
use crate::format::v2::AprV2Metadata;
let mut v2_meta = AprV2Metadata::default();
for (key, value) in &self.metadata {
match key.as_str() {
"model_type" => {
if let Some(s) = value.as_str() {
v2_meta.model_type = s.to_string();
}
}
"model_name" => v2_meta.name = value.as_str().map(String::from),
"description" => v2_meta.description = value.as_str().map(String::from),
"author" => v2_meta.author = value.as_str().map(String::from),
"license" => v2_meta.license = value.as_str().map(String::from),
"version" => v2_meta.version = value.as_str().map(String::from),
"architecture" => v2_meta.architecture = value.as_str().map(String::from),
"hidden_size" => v2_meta.hidden_size = value.as_u64().map(|v| v as usize),
"num_hidden_layers" | "num_layers" => {
v2_meta.num_layers = value.as_u64().map(|v| v as usize);
}
"num_attention_heads" | "num_heads" => {
v2_meta.num_heads = value.as_u64().map(|v| v as usize);
}
"num_kv_heads" => v2_meta.num_kv_heads = value.as_u64().map(|v| v as usize),
"vocab_size" => v2_meta.vocab_size = value.as_u64().map(|v| v as usize),
"intermediate_size" => {
v2_meta.intermediate_size = value.as_u64().map(|v| v as usize);
}
"max_position_embeddings" => {
v2_meta.max_position_embeddings = value.as_u64().map(|v| v as usize);
}
"rope_theta" => v2_meta.rope_theta = value.as_f64().map(|v| v as f32),
"rms_norm_eps" => v2_meta.rms_norm_eps = value.as_f64().map(|v| v as f32),
"head_dim" => v2_meta.head_dim = value.as_u64().map(|v| v as usize),
"num_experts" => v2_meta.num_experts = value.as_u64().map(|v| v as usize),
"num_experts_per_tok" => {
v2_meta.num_experts_per_tok = value.as_u64().map(|v| v as usize);
}
"moe_intermediate_size" => {
v2_meta.moe_intermediate_size = value.as_u64().map(|v| v as usize);
}
_ => {
v2_meta.custom.insert(key.clone(), value.clone());
}
}
}
v2_meta
}
pub fn to_bytes(&self) -> Result<Vec<u8>, String> {
use crate::format::v2::AprV2Writer as V2Writer;
let mut writer = V2Writer::new(self.build_v2_metadata());
for (name, shape, data) in &self.tensors {
writer.add_f32_tensor(name, shape.clone(), data);
}
writer
.write()
.map_err(|e| format!("APR serialization failed: {e}"))
}
pub fn into_bytes(self) -> Result<Vec<u8>, String> {
use crate::format::v2::AprV2Writer as V2Writer;
let mut writer = V2Writer::new(self.build_v2_metadata());
for (name, shape, data) in self.tensors {
writer.add_tensor_f32_owned(name, shape, data);
}
writer
.write()
.map_err(|e| format!("APR serialization failed: {e}"))
}
pub fn write<P: AsRef<Path>>(&self, path: P) -> Result<(), String> {
use std::io::Write;
let path = path.as_ref();
let bytes = self.to_bytes()?;
let tmp_path = path.with_extension("apr.tmp");
let mut file =
fs::File::create(&tmp_path).map_err(|e| format!("Failed to create temp file: {e}"))?;
file.write_all(&bytes)
.map_err(|e| format!("Failed to write temp file: {e}"))?;
file.sync_all()
.map_err(|e| format!("Failed to fsync temp file: {e}"))?;
drop(file);
fs::rename(&tmp_path, path).map_err(|e| format!("Failed to rename temp file: {e}"))?;
Ok(())
}
pub fn write_into<P: AsRef<Path>>(self, path: P) -> Result<(), String> {
use crate::format::v2::AprV2StreamingWriter;
let path = path.as_ref();
let tmp_path = path.with_extension("apr.tmp");
let mut writer = AprV2StreamingWriter::new(self.build_v2_metadata())
.map_err(|e| format!("Failed to create streaming writer: {e}"))?;
for (name, shape, data) in self.tensors {
writer
.add_f32_tensor(name, shape, &data)
.map_err(|e| format!("Failed to add tensor: {e}"))?;
}
writer
.finalize(&tmp_path)
.map_err(|e| format!("Failed to finalize streaming write: {e}"))?;
let file = fs::File::open(&tmp_path)
.map_err(|e| format!("Failed to open tmp file for fsync: {e}"))?;
file.sync_all()
.map_err(|e| format!("Failed to fsync tmp file: {e}"))?;
drop(file);
fs::rename(&tmp_path, path).map_err(|e| format!("Failed to rename temp file: {e}"))?;
Ok(())
}
}
mod crc32;
#[cfg(test)]
mod tests;