1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
use crate::numpy::{self, NpyError, NumpyDtype, NumpyShape, ReadNumbers, WriteNumbers};
use std::{
    fs::File,
    io::{BufReader, BufWriter, Read, Seek, Write},
    path::Path,
};
use zip::{
    result::{ZipError, ZipResult},
    ZipArchive, ZipWriter,
};

/// Something that can be saved to a `.npz` (which is a `.zip`).
///
/// All [Module]s in nn implement SaveToNpz, and the zips are formatted in a `.npz` fashion.
pub trait SaveToNpz {
    /// Save this object into the `.npz` file determined located at `path`.
    ///
    /// Example:
    /// ```ignore
    /// # use dfdx::prelude::*;
    /// let model: (Linear<5, 10>, Linear<10, 5>) = Default::default();
    /// model.save("tst.npz")?;
    /// ```
    fn save<P: AsRef<Path>>(&self, path: P) -> ZipResult<()> {
        let f = File::create(path)?;
        let f = BufWriter::new(f);
        let mut zip = ZipWriter::new(f);
        self.write(&"".into(), &mut zip)?;
        zip.finish()?;
        Ok(())
    }

    /// Write this object into [ZipWriter] `w` with a base filename of `filename_prefix`.
    ///
    /// Example:
    /// ```ignore
    /// # use dfdx::prelude::*;
    /// let model: Linear<5, 10> = Default::default();
    /// let mut zip = ZipWriter::new(...);
    /// model.write("0.", &mut zip)?;
    /// model.write("1.", &mut zip)?;
    /// ```
    /// Will save a zip file with the following files in it:
    /// - `0.weight.npy`
    /// - `0.bias.npy`
    /// - `1.weight.npy`
    /// - `1.bias.npy`
    fn write<W>(&self, _filename_prefix: &String, _w: &mut ZipWriter<W>) -> ZipResult<()>
    where
        W: Write + Seek,
    {
        Ok(())
    }
}

/// Something that can be loaded from a `.npz` file (which is a `zip` file).
///
/// All [Module]s in nn implement LoadFromNpz, and the zips are formatted in a `.npz` fashion.
pub trait LoadFromNpz {
    /// Loads data from a `.npz` zip archive at the specified `path`.
    ///
    /// Example:
    /// ```ignore
    /// # use dfdx::prelude::*;
    /// let mut model: (Linear<5, 10>, Linear<10, 5>) = Default::default();
    /// model.load("tst.npz")?;
    /// ``
    fn load<P: AsRef<Path>>(&mut self, path: P) -> Result<(), NpzError> {
        let f = File::open(path).map_err(|e| NpzError::Npy(NpyError::IoError(e)))?;
        let f = BufReader::new(f);
        let mut zip = ZipArchive::new(f).map_err(NpzError::Zip)?;
        self.read(&"".into(), &mut zip)?;
        Ok(())
    }

    /// Reads this object from a [ZipArchive]. `r` with a base filename of `filename_prefix`.
    ///
    /// Example:
    /// ```ignore
    /// # use dfdx::prelude::*;
    /// let mut model: Linear<5, 10> = Default::default();
    /// let mut zip = ZipArchive::new(...);
    /// model.read("0.", &mut zip)?;
    /// ```
    /// Will try to read data from the following files:
    /// - `0.weight.npy`
    /// - `0.bias.npy`
    fn read<R>(&mut self, _filename_prefix: &String, _r: &mut ZipArchive<R>) -> Result<(), NpzError>
    where
        R: Read + Seek,
    {
        Ok(())
    }
}

/// Error that can happen while loading data from a `.npz` zip archive.
pub enum NpzError {
    /// Something went wrong with reading from the `.zip` archive.
    Zip(ZipError),

    /// Something went wrong with loading data from a `.npy` file
    Npy(NpyError),
}

/// Writes `data` to a new file in a zip archive named `filename`.
///
/// Example:
/// ```ignore
/// let mut zip = ZipWriter::new(...);
/// let linear: Linear<5, 2> = Default::default();
/// npz_fwrite(&mut zip, "weight.npy".into(), linear.data());
/// ```
pub fn npz_fwrite<W: Write + Seek, T: NumpyDtype + NumpyShape + WriteNumbers>(
    w: &mut zip::ZipWriter<W>,
    filename: String,
    data: &T,
) -> ZipResult<()> {
    w.start_file(filename, Default::default())?;
    numpy::write(w, data)?;
    Ok(())
}

/// Reads `data` from a file already in a zip archive named `filename`.
///
/// Example:
/// ```ignore
/// let mut zip = ZipArchive::new(...);
/// let mut linear: Linear<5, 2> = Default::default();
/// npz_fread(&mut zip, "weight.npy".into(), linear.weight.mut_data());
/// ```
pub fn npz_fread<R: Read + Seek, T: NumpyDtype + NumpyShape + ReadNumbers>(
    r: &mut zip::ZipArchive<R>,
    filename: String,
    data: &mut T,
) -> Result<(), NpzError> {
    let mut f = r.by_name(&filename).map_err(NpzError::Zip)?;
    numpy::read(&mut f, data).map_err(NpzError::Npy)?;
    Ok(())
}