Skip to main content

hanzo_ml/
npy.rs

1//! Numpy support for tensors.
2//!
3//! The spec for the npy format can be found in
4//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html).
5//! The functions from this module can be used to read tensors from npy/npz files
6//! or write tensors to these files. A npy file contains a single tensor (unnamed)
7//! whereas a npz file can contain multiple named tensors. npz files are also compressed.
8//!
9//! These two formats are easy to use in Python using the numpy library.
10//!
11//! ```python
12//! import numpy as np
13//! x = np.arange(10)
14//!
15//! # Write a npy file.
16//! np.save("test.npy", x)
17//!
18//! # Read a value from the npy file.
19//! x = np.load("test.npy")
20//!
21//! # Write multiple values to a npz file.
22//! values = { "x": x, "x_plus_one": x + 1 }
23//! np.savez("test.npz", **values)
24//!
25//! # Load multiple values from a npz file.
26//! values = np.loadz("test.npz")
27//! ```
28use crate::{DType, Device, Error, Result, Shape, Tensor};
29use byteorder::{LittleEndian, ReadBytesExt};
30use half::{bf16, f16, slice::HalfFloatSliceExt};
31use std::collections::HashMap;
32use std::fs::File;
33use std::io::{BufReader, Read, Write};
34use std::path::Path;
35
36const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
37const NPY_SUFFIX: &str = ".npy";
38
39fn read_header<R: Read>(reader: &mut R) -> Result<String> {
40    let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
41    reader.read_exact(&mut magic_string)?;
42    if magic_string != NPY_MAGIC_STRING {
43        return Err(Error::Npy("magic string mismatch".to_string()));
44    }
45    let mut version = [0u8; 2];
46    reader.read_exact(&mut version)?;
47    let header_len_len = match version[0] {
48        1 => 2,
49        2 => 4,
50        otherwise => return Err(Error::Npy(format!("unsupported version {otherwise}"))),
51    };
52    let mut header_len = vec![0u8; header_len_len];
53    reader.read_exact(&mut header_len)?;
54    let header_len = header_len
55        .iter()
56        .rev()
57        .fold(0_usize, |acc, &v| 256 * acc + v as usize);
58    let mut header = vec![0u8; header_len];
59    reader.read_exact(&mut header)?;
60    Ok(String::from_utf8_lossy(&header).to_string())
61}
62
63#[derive(Debug, PartialEq)]
64struct Header {
65    descr: DType,
66    fortran_order: bool,
67    shape: Vec<usize>,
68}
69
70impl Header {
71    fn shape(&self) -> Shape {
72        Shape::from(self.shape.as_slice())
73    }
74
75    fn to_string(&self) -> Result<String> {
76        let fortran_order = if self.fortran_order { "True" } else { "False" };
77        let mut shape = self
78            .shape
79            .iter()
80            .map(|x| x.to_string())
81            .collect::<Vec<_>>()
82            .join(",");
83        let descr = match self.descr {
84            DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
85            DType::F16 => "f2",
86            DType::F32 => "f4",
87            DType::F64 => "f8",
88            DType::I16 => "i2",
89            DType::I32 => "i4",
90            DType::I64 => "i8",
91            DType::U32 => "u4",
92            DType::U8 => "u1",
93            DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?,
94            DType::F6E2M3 => Err(Error::Npy("f6e2m3 is not supported".into()))?,
95            DType::F6E3M2 => Err(Error::Npy("f6e3m2 is not supported".into()))?,
96            DType::F4 => Err(Error::Npy("f4 is not supported".into()))?,
97            DType::F8E8M0 => Err(Error::Npy("f8e8m0 is not supported".into()))?,
98        };
99        if !shape.is_empty() {
100            shape.push(',')
101        }
102        Ok(format!(
103            "{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
104        ))
105    }
106
107    // Hacky parser for the npy header, a typical example would be:
108    // {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
109    fn parse(header: &str) -> Result<Header> {
110        let header =
111            header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
112
113        let mut parts: Vec<String> = vec![];
114        let mut start_index = 0usize;
115        let mut cnt_parenthesis = 0i64;
116        for (index, c) in header.char_indices() {
117            match c {
118                '(' => cnt_parenthesis += 1,
119                ')' => cnt_parenthesis -= 1,
120                ',' if cnt_parenthesis == 0 => {
121                    parts.push(header[start_index..index].to_owned());
122                    start_index = index + 1;
123                }
124                _ => {}
125            }
126        }
127        parts.push(header[start_index..].to_owned());
128        let mut part_map: HashMap<String, String> = HashMap::new();
129        for part in parts.iter() {
130            let part = part.trim();
131            if !part.is_empty() {
132                match part.split(':').collect::<Vec<_>>().as_slice() {
133                    [key, value] => {
134                        let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
135                        let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
136                        let _ = part_map.insert(key.to_owned(), value.to_owned());
137                    }
138                    _ => return Err(Error::Npy(format!("unable to parse header {header}"))),
139                }
140            }
141        }
142        let fortran_order = match part_map.get("fortran_order") {
143            None => false,
144            Some(fortran_order) => match fortran_order.as_ref() {
145                "False" => false,
146                "True" => true,
147                _ => return Err(Error::Npy(format!("unknown fortran_order {fortran_order}"))),
148            },
149        };
150        let descr = match part_map.get("descr") {
151            None => return Err(Error::Npy("no descr in header".to_string())),
152            Some(descr) => {
153                if descr.is_empty() {
154                    return Err(Error::Npy("empty descr".to_string()));
155                }
156                if descr.starts_with('>') {
157                    return Err(Error::Npy(format!("little-endian descr {descr}")));
158                }
159                // the only supported types in tensor are:
160                //     float64, float32, float16,
161                //     complex64, complex128,
162                //     int64, int32, int16, int8,
163                //     uint8, and bool.
164                match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
165                    "e" | "f2" => DType::F16,
166                    "f" | "f4" => DType::F32,
167                    "d" | "f8" => DType::F64,
168                    "i" | "i4" => DType::I32,
169                    "q" | "i8" => DType::I64,
170                    "h" | "i2" => DType::I16,
171                    // "b" | "i1" => DType::S8,
172                    "B" | "u1" => DType::U8,
173                    "I" | "u4" => DType::U32,
174                    "?" | "b1" => DType::U8,
175                    // "F" | "F4" => DType::C64,
176                    // "D" | "F8" => DType::C128,
177                    descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))),
178                }
179            }
180        };
181        let shape = match part_map.get("shape") {
182            None => return Err(Error::Npy("no shape in header".to_string())),
183            Some(shape) => {
184                let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
185                if shape.is_empty() {
186                    vec![]
187                } else {
188                    shape
189                        .split(',')
190                        .map(|v| v.trim().parse::<usize>())
191                        .collect::<std::result::Result<Vec<_>, _>>()?
192                }
193            }
194        };
195        Ok(Header {
196            descr,
197            fortran_order,
198            shape,
199        })
200    }
201}
202
203impl Tensor {
204    // TODO: Add the possibility to read directly to a device?
205    pub(crate) fn from_reader<R: std::io::Read>(
206        shape: Shape,
207        dtype: DType,
208        reader: &mut R,
209    ) -> Result<Self> {
210        let elem_count = shape.elem_count();
211        match dtype {
212            DType::BF16 => {
213                let mut data_t = vec![bf16::ZERO; elem_count];
214                reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
215                Tensor::from_vec(data_t, shape, &Device::Cpu)
216            }
217            DType::F16 => {
218                let mut data_t = vec![f16::ZERO; elem_count];
219                reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
220                Tensor::from_vec(data_t, shape, &Device::Cpu)
221            }
222            DType::F32 => {
223                let mut data_t = vec![0f32; elem_count];
224                reader.read_f32_into::<LittleEndian>(&mut data_t)?;
225                Tensor::from_vec(data_t, shape, &Device::Cpu)
226            }
227            DType::F64 => {
228                let mut data_t = vec![0f64; elem_count];
229                reader.read_f64_into::<LittleEndian>(&mut data_t)?;
230                Tensor::from_vec(data_t, shape, &Device::Cpu)
231            }
232            DType::U8 => {
233                let mut data_t = vec![0u8; elem_count];
234                reader.read_exact(&mut data_t)?;
235                Tensor::from_vec(data_t, shape, &Device::Cpu)
236            }
237            DType::U32 => {
238                let mut data_t = vec![0u32; elem_count];
239                reader.read_u32_into::<LittleEndian>(&mut data_t)?;
240                Tensor::from_vec(data_t, shape, &Device::Cpu)
241            }
242            DType::I16 => {
243                let mut data_t = vec![0i16; elem_count];
244                reader.read_i16_into::<LittleEndian>(&mut data_t)?;
245                Tensor::from_vec(data_t, shape, &Device::Cpu)
246            }
247            DType::I32 => {
248                let mut data_t = vec![0i32; elem_count];
249                reader.read_i32_into::<LittleEndian>(&mut data_t)?;
250                Tensor::from_vec(data_t, shape, &Device::Cpu)
251            }
252            DType::I64 => {
253                let mut data_t = vec![0i64; elem_count];
254                reader.read_i64_into::<LittleEndian>(&mut data_t)?;
255                Tensor::from_vec(data_t, shape, &Device::Cpu)
256            }
257            DType::F8E4M3 => {
258                let mut data_t = vec![0u8; elem_count];
259                reader.read_exact(&mut data_t)?;
260                let data_f8: Vec<float8::F8E4M3> =
261                    data_t.into_iter().map(float8::F8E4M3::from_bits).collect();
262                Tensor::from_vec(data_f8, shape, &Device::Cpu)
263            }
264            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
265                Err(Error::UnsupportedDTypeForOp(dtype, "from_reader").bt())
266            }
267        }
268    }
269
270    /// Reads a npy file and return the stored multi-dimensional array as a tensor.
271    pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
272        let mut reader = File::open(path.as_ref())?;
273        let header = read_header(&mut reader)?;
274        let header = Header::parse(&header)?;
275        if header.fortran_order {
276            return Err(Error::Npy("fortran order not supported".to_string()));
277        }
278        Self::from_reader(header.shape(), header.descr, &mut reader)
279    }
280
281    /// Reads a npz file and returns the stored multi-dimensional arrays together with their names.
282    pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>> {
283        let zip_reader = BufReader::new(File::open(path.as_ref())?);
284        let mut zip = zip::ZipArchive::new(zip_reader)?;
285        let mut result = vec![];
286        for i in 0..zip.len() {
287            let mut reader = zip.by_index(i)?;
288            let name = {
289                let name = reader.name();
290                name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
291            };
292            let header = read_header(&mut reader)?;
293            let header = Header::parse(&header)?;
294            if header.fortran_order {
295                return Err(Error::Npy("fortran order not supported".to_string()));
296            }
297            let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
298            result.push((name, s))
299        }
300        Ok(result)
301    }
302
303    /// Reads a npz file and returns the stored multi-dimensional arrays for some specified names.
304    pub fn read_npz_by_name<T: AsRef<Path>>(path: T, names: &[&str]) -> Result<Vec<Self>> {
305        let zip_reader = BufReader::new(File::open(path.as_ref())?);
306        let mut zip = zip::ZipArchive::new(zip_reader)?;
307        let mut result = vec![];
308        for name in names.iter() {
309            let mut reader = match zip.by_name(&format!("{name}{NPY_SUFFIX}")) {
310                Ok(reader) => reader,
311                Err(_) => Err(Error::Npy(format!(
312                    "no array for {name} in {:?}",
313                    path.as_ref()
314                )))?,
315            };
316            let header = read_header(&mut reader)?;
317            let header = Header::parse(&header)?;
318            if header.fortran_order {
319                return Err(Error::Npy("fortran order not supported".to_string()));
320            }
321            let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
322            result.push(s)
323        }
324        Ok(result)
325    }
326
327    fn write<T: Write>(&self, f: &mut T) -> Result<()> {
328        f.write_all(NPY_MAGIC_STRING)?;
329        f.write_all(&[1u8, 0u8])?;
330        let header = Header {
331            descr: self.dtype(),
332            fortran_order: false,
333            shape: self.dims().to_vec(),
334        };
335        let mut header = header.to_string()?;
336        let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
337        for _ in 0..pad % 16 {
338            header.push(' ')
339        }
340        header.push('\n');
341        f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
342        f.write_all(header.as_bytes())?;
343        self.write_bytes(f)
344    }
345
346    /// Writes a multi-dimensional array in the npy format.
347    pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<()> {
348        let mut f = File::create(path.as_ref())?;
349        self.write(&mut f)
350    }
351
352    /// Writes multiple multi-dimensional arrays using the npz format.
353    pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
354        ts: &[(S, T)],
355        path: P,
356    ) -> Result<()> {
357        let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
358        let options: zip::write::FileOptions<()> =
359            zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
360
361        for (name, tensor) in ts.iter() {
362            zip.start_file(format!("{}.npy", name.as_ref()), options)?;
363            tensor.as_ref().write(&mut zip)?
364        }
365        Ok(())
366    }
367}
368
369/// Lazy tensor loader.
370pub struct NpzTensors {
371    index_per_name: HashMap<String, usize>,
372    path: std::path::PathBuf,
373    // We do not store a zip reader as it needs mutable access to extract data. Instead we
374    // re-create a zip reader for each tensor.
375}
376
377impl NpzTensors {
378    pub fn new<T: AsRef<Path>>(path: T) -> Result<Self> {
379        let path = path.as_ref().to_owned();
380        let zip_reader = BufReader::new(File::open(&path)?);
381        let mut zip = zip::ZipArchive::new(zip_reader)?;
382        let mut index_per_name = HashMap::new();
383        for i in 0..zip.len() {
384            let file = zip.by_index(i)?;
385            let name = {
386                let name = file.name();
387                name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
388            };
389            index_per_name.insert(name, i);
390        }
391        Ok(Self {
392            index_per_name,
393            path,
394        })
395    }
396
397    pub fn names(&self) -> Vec<&String> {
398        self.index_per_name.keys().collect()
399    }
400
401    /// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
402    /// reading the whole tensor data.
403    pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
404        let index = match self.index_per_name.get(name) {
405            None => crate::bail!("cannot find tensor {name}"),
406            Some(index) => *index,
407        };
408        let zip_reader = BufReader::new(File::open(&self.path)?);
409        let mut zip = zip::ZipArchive::new(zip_reader)?;
410        let mut reader = zip.by_index(index)?;
411        let header = read_header(&mut reader)?;
412        let header = Header::parse(&header)?;
413        Ok((header.shape(), header.descr))
414    }
415
416    pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
417        let index = match self.index_per_name.get(name) {
418            None => return Ok(None),
419            Some(index) => *index,
420        };
421        // We hope that the file has not changed since first reading it.
422        let zip_reader = BufReader::new(File::open(&self.path)?);
423        let mut zip = zip::ZipArchive::new(zip_reader)?;
424        let mut reader = zip.by_index(index)?;
425        let header = read_header(&mut reader)?;
426        let header = Header::parse(&header)?;
427        if header.fortran_order {
428            return Err(Error::Npy("fortran order not supported".to_string()));
429        }
430        let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?;
431        Ok(Some(tensor))
432    }
433}
434
435#[cfg(test)]
436mod tests {
437    use super::Header;
438
439    #[test]
440    fn parse() {
441        let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
442        assert_eq!(
443            Header::parse(h).unwrap(),
444            Header {
445                descr: crate::DType::F64,
446                fortran_order: false,
447                shape: vec![128]
448            }
449        );
450        let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
451        let h = Header::parse(h).unwrap();
452        assert_eq!(
453            h,
454            Header {
455                descr: crate::DType::F32,
456                fortran_order: true,
457                shape: vec![256, 1, 128]
458            }
459        );
460        assert_eq!(
461            h.to_string().unwrap(),
462            "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
463        );
464
465        let h = Header {
466            descr: crate::DType::U32,
467            fortran_order: false,
468            shape: vec![],
469        };
470        assert_eq!(
471            h.to_string().unwrap(),
472            "{'descr': '<u4', 'fortran_order': False, 'shape': (), }"
473        );
474    }
475}