use ndarray::{Array, Array1, ArrayBase, Dimension, ShapeError, StrideShape};
use ndarray_npy::{NpzReader, ReadNpzError, ReadableElement};
use std::io::{Cursor, Read, Seek};
use std::{fs, path::Path};
use thiserror::Error;
type WeightResult<T> = Result<T, WeightError>;
#[derive(Error, Debug)]
pub enum WeightError {
#[error("No weights with name {0} found")]
WeightKeyError(String),
#[error("Weight file didn't have the correct format (required: JSON dict of pairs (key, flattened array of weights))")]
WeightFormatError,
#[error("Weight file not found. Filesystem reported error\n {0}.")]
WeightFileNotFoundError(#[from] std::io::Error),
#[error("Weight file not readable. Filesystem reported error\n {0}.")]
WeightFileNpzError(#[from] ReadNpzError),
#[error("Wrong shape for weight:\n {0}.")]
WeightShapeError(#[from] ShapeError),
}
pub trait WeightLoader {
fn get_weight<D, Sh, P: ReadableElement + Copy>(
&mut self,
param_name: &str,
shape: Sh,
) -> WeightResult<Array<P, D>>
where
D: Dimension,
Sh: Into<StrideShape<D>>;
}
pub struct NpzWeightLoader<R>
where
R: Seek + Read,
{
handle: R,
}
impl NpzWeightLoader<std::fs::File> {
pub fn from_path<P: AsRef<Path>>(path: P) -> WeightResult<NpzWeightLoader<std::fs::File>> {
let handle = std::fs::File::open(path)?;
Ok(NpzWeightLoader { handle })
}
}
impl NpzWeightLoader<Cursor<&[u8]>> {
pub fn from_buffer(bytes_array: &[u8]) -> WeightResult<NpzWeightLoader<Cursor<&[u8]>>> {
Ok(NpzWeightLoader {
handle: Cursor::new(bytes_array),
})
}
}
impl<R> WeightLoader for NpzWeightLoader<R>
where
R: Seek + Read,
{
fn get_weight<D, Sh, P: Copy + ReadableElement>(
&mut self,
param_name: &str,
shape: Sh,
) -> WeightResult<Array<P, D>>
where
D: Dimension,
Sh: Into<StrideShape<D>>,
{
let mut reader = NpzReader::new(&mut self.handle)?;
let arr: Result<ArrayBase<_, D>, _> = reader.by_name(param_name);
Ok(match arr {
Ok(a) => {
debug_assert_eq!(&a.raw_dim(), shape.into().raw_dim());
a
}
Err(_) => {
let arr_flat: Array1<P> = reader.by_name(param_name)?;
let arr_reshaped =
Array::from_shape_vec(shape, arr_flat.iter().copied().collect())?;
arr_reshaped
}
})
}
}
#[cfg(test)]
mod tests {
use std::fs::File;
use std::io::Write;
use super::*;
use ndarray::{array, Array2};
use tempfile::tempdir;
#[test]
fn test_npz_weight_loader() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("temp-weights.npz");
let file = File::create(&file_path).unwrap();
let mut npz = ndarray_npy::NpzWriter::new(file);
let a: Array2<f32> = array![[1., 2., 3.], [4., 5., 6.]];
let b: Array1<f32> = array![7., 8., 9.];
npz.add_array("a", &a).unwrap();
npz.add_array("b", &b).unwrap();
npz.finish().unwrap();
let mut loader = NpzWeightLoader::from_path(file_path).unwrap();
assert_eq!(loader.get_weight::<_, _, f32>("a", (2, 3)).unwrap(), a);
assert_eq!(loader.get_weight::<_, _, f32>("b", 3).unwrap(), b);
dir.close().unwrap();
}
}