1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
use crate::numpy::{self, NpyError, NumpyDtype, NumpyShape, ReadNumbers, WriteNumbers};
use std::{
fs::File,
io::{BufReader, BufWriter, Read, Seek, Write},
path::Path,
};
use zip::{
result::{ZipError, ZipResult},
ZipArchive, ZipWriter,
};
/// Something that can be saved to a `.npz` (which is a `.zip`).
///
/// All [Module]s in nn implement SaveToNpz, and the zips are formatted in a `.npz` fashion.
pub trait SaveToNpz {
/// Save this object into the `.npz` file determined located at `path`.
///
/// Example:
/// ```ignore
/// # use dfdx::prelude::*;
/// let model: (Linear<5, 10>, Linear<10, 5>) = Default::default();
/// model.save("tst.npz")?;
/// ```
fn save<P: AsRef<Path>>(&self, path: P) -> ZipResult<()> {
let f = File::create(path)?;
let f = BufWriter::new(f);
let mut zip = ZipWriter::new(f);
self.write(&"".into(), &mut zip)?;
zip.finish()?;
Ok(())
}
/// Write this object into [ZipWriter] `w` with a base filename of `filename_prefix`.
///
/// Example:
/// ```ignore
/// # use dfdx::prelude::*;
/// let model: Linear<5, 10> = Default::default();
/// let mut zip = ZipWriter::new(...);
/// model.write("0.", &mut zip)?;
/// model.write("1.", &mut zip)?;
/// ```
/// Will save a zip file with the following files in it:
/// - `0.weight.npy`
/// - `0.bias.npy`
/// - `1.weight.npy`
/// - `1.bias.npy`
fn write<W>(&self, _filename_prefix: &String, _w: &mut ZipWriter<W>) -> ZipResult<()>
where
W: Write + Seek,
{
Ok(())
}
}
/// Something that can be loaded from a `.npz` file (which is a `zip` file).
///
/// All [Module]s in nn implement LoadFromNpz, and the zips are formatted in a `.npz` fashion.
pub trait LoadFromNpz {
/// Loads data from a `.npz` zip archive at the specified `path`.
///
/// Example:
/// ```ignore
/// # use dfdx::prelude::*;
/// let mut model: (Linear<5, 10>, Linear<10, 5>) = Default::default();
/// model.load("tst.npz")?;
/// ``
fn load<P: AsRef<Path>>(&mut self, path: P) -> Result<(), NpzError> {
let f = File::open(path).map_err(|e| NpzError::Npy(NpyError::IoError(e)))?;
let f = BufReader::new(f);
let mut zip = ZipArchive::new(f).map_err(NpzError::Zip)?;
self.read(&"".into(), &mut zip)?;
Ok(())
}
/// Reads this object from a [ZipArchive]. `r` with a base filename of `filename_prefix`.
///
/// Example:
/// ```ignore
/// # use dfdx::prelude::*;
/// let mut model: Linear<5, 10> = Default::default();
/// let mut zip = ZipArchive::new(...);
/// model.read("0.", &mut zip)?;
/// ```
/// Will try to read data from the following files:
/// - `0.weight.npy`
/// - `0.bias.npy`
fn read<R>(&mut self, _filename_prefix: &String, _r: &mut ZipArchive<R>) -> Result<(), NpzError>
where
R: Read + Seek,
{
Ok(())
}
}
/// Error that can happen while loading data from a `.npz` zip archive.
pub enum NpzError {
/// Something went wrong with reading from the `.zip` archive.
Zip(ZipError),
/// Something went wrong with loading data from a `.npy` file
Npy(NpyError),
}
/// Writes `data` to a new file in a zip archive named `filename`.
///
/// Example:
/// ```ignore
/// let mut zip = ZipWriter::new(...);
/// let linear: Linear<5, 2> = Default::default();
/// npz_fwrite(&mut zip, "weight.npy".into(), linear.data());
/// ```
pub fn npz_fwrite<W: Write + Seek, T: NumpyDtype + NumpyShape + WriteNumbers>(
w: &mut zip::ZipWriter<W>,
filename: String,
data: &T,
) -> ZipResult<()> {
w.start_file(filename, Default::default())?;
numpy::write(w, data)?;
Ok(())
}
/// Reads `data` from a file already in a zip archive named `filename`.
///
/// Example:
/// ```ignore
/// let mut zip = ZipArchive::new(...);
/// let mut linear: Linear<5, 2> = Default::default();
/// npz_fread(&mut zip, "weight.npy".into(), linear.weight.mut_data());
/// ```
pub fn npz_fread<R: Read + Seek, T: NumpyDtype + NumpyShape + ReadNumbers>(
r: &mut zip::ZipArchive<R>,
filename: String,
data: &mut T,
) -> Result<(), NpzError> {
let mut f = r.by_name(&filename).map_err(NpzError::Zip)?;
numpy::read(&mut f, data).map_err(NpzError::Npy)?;
Ok(())
}