use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use crate::tensors::Tensor;
use std::error::Error;
use briny::prelude::*;
const BPAT_MAGIC: &[u8; 4] = b"bpat";
struct PackedTensor {
shape: Vec<u64>,
data: Vec<f64>,
}
impl Validate for PackedTensor {
fn validate(&self) -> Result<(), ValidationError> {
let expected = self.shape.iter().product::<u64>() as usize;
if self.data.len() != expected {
return Err(ValidationError);
}
Ok(())
}
}
pub fn save_model(path: &str, tensors: &[Tensor<f64>]) -> Result<(), Box<dyn Error>> {
let mut file = BufWriter::new(File::create(path)?);
file.write_all(b"bpat")?;
file.write_all(&[tensors.len() as u8])?;
for tensor in tensors {
assert_eq!(
tensor.data.len(),
tensor.shape.iter().product(),
"tensor shape/data mismatch"
);
let dims = tensor.shape.len() as u64;
file.write_all(&dims.to_le_bytes())?;
for &dim in &tensor.shape {
file.write_all(&(dim as u64).to_le_bytes())?;
}
for &val in &tensor.data {
file.write_all(&val.to_le_bytes())?;
}
}
Ok(())
}
pub fn load_model(path: &str) -> Result<Vec<Tensor<f64>>, Box<dyn Error>> {
let mut file = BufReader::new(File::open(path)?);
let mut buf8 = [0u8; 8];
let mut magic = [0u8; 4];
file.read_exact(&mut magic)?;
if &magic != BPAT_MAGIC {
return Err("invalid magic header".into());
}
let mut count = [0u8; 1];
file.read_exact(&mut count)?;
let count = count[0] as usize;
let mut tensors = Vec::with_capacity(count);
for _ in 0..count {
file.read_exact(&mut buf8)?;
let ndim = u64::from_le_bytes(buf8) as usize;
let mut shape = Vec::with_capacity(ndim);
for _ in 0..ndim {
file.read_exact(&mut buf8)?;
shape.push(u64::from_le_bytes(buf8));
}
let size: usize = shape.iter().product::<u64>() as usize;
let mut data = Vec::with_capacity(size);
for _ in 0..size {
file.read_exact(&mut buf8)?;
data.push(f64::from_le_bytes(buf8));
}
let raw_tensor = PackedTensor { shape, data };
let trusted = TrustedData::new(raw_tensor)?;
let inner = trusted.into_inner();
let shape_usize: Vec<usize> = inner.shape.iter().map(|&x| x as usize).collect();
tensors.push(Tensor::new(shape_usize, inner.data));
}
Ok(tensors)
}