candle_core/
npy.rs

1//! Numpy support for tensors.
2//!
3//! The spec for the npy format can be found in
4//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html).
5//! The functions from this module can be used to read tensors from npy/npz files
6//! or write tensors to these files. A npy file contains a single tensor (unnamed)
7//! whereas a npz file can contain multiple named tensors. npz files are also compressed.
8//!
9//! These two formats are easy to use in Python using the numpy library.
10//!
11//! ```python
12//! import numpy as np
13//! x = np.arange(10)
14//!
15//! # Write a npy file.
16//! np.save("test.npy", x)
17//!
18//! # Read a value from the npy file.
19//! x = np.load("test.npy")
20//!
21//! # Write multiple values to a npz file.
22//! values = { "x": x, "x_plus_one": x + 1 }
23//! np.savez("test.npz", **values)
24//!
25//! # Load multiple values from a npz file.
26//! values = np.loadz("test.npz")
27//! ```
28use crate::{DType, Device, Error, Result, Shape, Tensor};
29use byteorder::{LittleEndian, ReadBytesExt};
30use half::{bf16, f16, slice::HalfFloatSliceExt};
31use std::collections::HashMap;
32use std::fs::File;
33use std::io::{BufReader, Read, Write};
34use std::path::Path;
35
36const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
37const NPY_SUFFIX: &str = ".npy";
38
39fn read_header<R: Read>(reader: &mut R) -> Result<String> {
40    let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
41    reader.read_exact(&mut magic_string)?;
42    if magic_string != NPY_MAGIC_STRING {
43        return Err(Error::Npy("magic string mismatch".to_string()));
44    }
45    let mut version = [0u8; 2];
46    reader.read_exact(&mut version)?;
47    let header_len_len = match version[0] {
48        1 => 2,
49        2 => 4,
50        otherwise => return Err(Error::Npy(format!("unsupported version {otherwise}"))),
51    };
52    let mut header_len = vec![0u8; header_len_len];
53    reader.read_exact(&mut header_len)?;
54    let header_len = header_len
55        .iter()
56        .rev()
57        .fold(0_usize, |acc, &v| 256 * acc + v as usize);
58    let mut header = vec![0u8; header_len];
59    reader.read_exact(&mut header)?;
60    Ok(String::from_utf8_lossy(&header).to_string())
61}
62
63#[derive(Debug, PartialEq)]
64struct Header {
65    descr: DType,
66    fortran_order: bool,
67    shape: Vec<usize>,
68}
69
70impl Header {
71    fn shape(&self) -> Shape {
72        Shape::from(self.shape.as_slice())
73    }
74
75    fn to_string(&self) -> Result<String> {
76        let fortran_order = if self.fortran_order { "True" } else { "False" };
77        let mut shape = self
78            .shape
79            .iter()
80            .map(|x| x.to_string())
81            .collect::<Vec<_>>()
82            .join(",");
83        let descr = match self.descr {
84            DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
85            DType::F16 => "f2",
86            DType::F32 => "f4",
87            DType::F64 => "f8",
88            DType::I16 => "i2",
89            DType::I32 => "i4",
90            DType::I64 => "i8",
91            DType::U32 => "u4",
92            DType::U8 => "u1",
93            DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?,
94            DType::F6E2M3 => Err(Error::Npy("f6e2m3 is not supported".into()))?,
95            DType::F6E3M2 => Err(Error::Npy("f6e3m2 is not supported".into()))?,
96            DType::F4 => Err(Error::Npy("f4 is not supported".into()))?,
97            DType::F8E8M0 => Err(Error::Npy("f8e8m0 is not supported".into()))?,
98        };
99        if !shape.is_empty() {
100            shape.push(',')
101        }
102        Ok(format!(
103            "{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
104        ))
105    }
106
107    // Hacky parser for the npy header, a typical example would be:
108    // {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
109    fn parse(header: &str) -> Result<Header> {
110        let header =
111            header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
112
113        let mut parts: Vec<String> = vec![];
114        let mut start_index = 0usize;
115        let mut cnt_parenthesis = 0i64;
116        for (index, c) in header.char_indices() {
117            match c {
118                '(' => cnt_parenthesis += 1,
119                ')' => cnt_parenthesis -= 1,
120                ',' => {
121                    if cnt_parenthesis == 0 {
122                        parts.push(header[start_index..index].to_owned());
123                        start_index = index + 1;
124                    }
125                }
126                _ => {}
127            }
128        }
129        parts.push(header[start_index..].to_owned());
130        let mut part_map: HashMap<String, String> = HashMap::new();
131        for part in parts.iter() {
132            let part = part.trim();
133            if !part.is_empty() {
134                match part.split(':').collect::<Vec<_>>().as_slice() {
135                    [key, value] => {
136                        let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
137                        let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
138                        let _ = part_map.insert(key.to_owned(), value.to_owned());
139                    }
140                    _ => return Err(Error::Npy(format!("unable to parse header {header}"))),
141                }
142            }
143        }
144        let fortran_order = match part_map.get("fortran_order") {
145            None => false,
146            Some(fortran_order) => match fortran_order.as_ref() {
147                "False" => false,
148                "True" => true,
149                _ => return Err(Error::Npy(format!("unknown fortran_order {fortran_order}"))),
150            },
151        };
152        let descr = match part_map.get("descr") {
153            None => return Err(Error::Npy("no descr in header".to_string())),
154            Some(descr) => {
155                if descr.is_empty() {
156                    return Err(Error::Npy("empty descr".to_string()));
157                }
158                if descr.starts_with('>') {
159                    return Err(Error::Npy(format!("little-endian descr {descr}")));
160                }
161                // the only supported types in tensor are:
162                //     float64, float32, float16,
163                //     complex64, complex128,
164                //     int64, int32, int16, int8,
165                //     uint8, and bool.
166                match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
167                    "e" | "f2" => DType::F16,
168                    "f" | "f4" => DType::F32,
169                    "d" | "f8" => DType::F64,
170                    "i" | "i4" => DType::I32,
171                    "q" | "i8" => DType::I64,
172                    "h" | "i2" => DType::I16,
173                    // "b" | "i1" => DType::S8,
174                    "B" | "u1" => DType::U8,
175                    "I" | "u4" => DType::U32,
176                    "?" | "b1" => DType::U8,
177                    // "F" | "F4" => DType::C64,
178                    // "D" | "F8" => DType::C128,
179                    descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))),
180                }
181            }
182        };
183        let shape = match part_map.get("shape") {
184            None => return Err(Error::Npy("no shape in header".to_string())),
185            Some(shape) => {
186                let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
187                if shape.is_empty() {
188                    vec![]
189                } else {
190                    shape
191                        .split(',')
192                        .map(|v| v.trim().parse::<usize>())
193                        .collect::<std::result::Result<Vec<_>, _>>()?
194                }
195            }
196        };
197        Ok(Header {
198            descr,
199            fortran_order,
200            shape,
201        })
202    }
203}
204
205impl Tensor {
206    // TODO: Add the possibility to read directly to a device?
207    pub(crate) fn from_reader<R: std::io::Read>(
208        shape: Shape,
209        dtype: DType,
210        reader: &mut R,
211    ) -> Result<Self> {
212        let elem_count = shape.elem_count();
213        match dtype {
214            DType::BF16 => {
215                let mut data_t = vec![bf16::ZERO; elem_count];
216                reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
217                Tensor::from_vec(data_t, shape, &Device::Cpu)
218            }
219            DType::F16 => {
220                let mut data_t = vec![f16::ZERO; elem_count];
221                reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
222                Tensor::from_vec(data_t, shape, &Device::Cpu)
223            }
224            DType::F32 => {
225                let mut data_t = vec![0f32; elem_count];
226                reader.read_f32_into::<LittleEndian>(&mut data_t)?;
227                Tensor::from_vec(data_t, shape, &Device::Cpu)
228            }
229            DType::F64 => {
230                let mut data_t = vec![0f64; elem_count];
231                reader.read_f64_into::<LittleEndian>(&mut data_t)?;
232                Tensor::from_vec(data_t, shape, &Device::Cpu)
233            }
234            DType::U8 => {
235                let mut data_t = vec![0u8; elem_count];
236                reader.read_exact(&mut data_t)?;
237                Tensor::from_vec(data_t, shape, &Device::Cpu)
238            }
239            DType::U32 => {
240                let mut data_t = vec![0u32; elem_count];
241                reader.read_u32_into::<LittleEndian>(&mut data_t)?;
242                Tensor::from_vec(data_t, shape, &Device::Cpu)
243            }
244            DType::I16 => {
245                let mut data_t = vec![0i16; elem_count];
246                reader.read_i16_into::<LittleEndian>(&mut data_t)?;
247                Tensor::from_vec(data_t, shape, &Device::Cpu)
248            }
249            DType::I32 => {
250                let mut data_t = vec![0i32; elem_count];
251                reader.read_i32_into::<LittleEndian>(&mut data_t)?;
252                Tensor::from_vec(data_t, shape, &Device::Cpu)
253            }
254            DType::I64 => {
255                let mut data_t = vec![0i64; elem_count];
256                reader.read_i64_into::<LittleEndian>(&mut data_t)?;
257                Tensor::from_vec(data_t, shape, &Device::Cpu)
258            }
259            DType::F8E4M3 => {
260                let mut data_t = vec![0u8; elem_count];
261                reader.read_exact(&mut data_t)?;
262                let data_f8: Vec<float8::F8E4M3> =
263                    data_t.into_iter().map(float8::F8E4M3::from_bits).collect();
264                Tensor::from_vec(data_f8, shape, &Device::Cpu)
265            }
266            DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
267                Err(Error::UnsupportedDTypeForOp(dtype, "from_reader").bt())
268            }
269        }
270    }
271
272    /// Reads a npy file and return the stored multi-dimensional array as a tensor.
273    pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
274        let mut reader = File::open(path.as_ref())?;
275        let header = read_header(&mut reader)?;
276        let header = Header::parse(&header)?;
277        if header.fortran_order {
278            return Err(Error::Npy("fortran order not supported".to_string()));
279        }
280        Self::from_reader(header.shape(), header.descr, &mut reader)
281    }
282
283    /// Reads a npz file and returns the stored multi-dimensional arrays together with their names.
284    pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>> {
285        let zip_reader = BufReader::new(File::open(path.as_ref())?);
286        let mut zip = zip::ZipArchive::new(zip_reader)?;
287        let mut result = vec![];
288        for i in 0..zip.len() {
289            let mut reader = zip.by_index(i)?;
290            let name = {
291                let name = reader.name();
292                name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
293            };
294            let header = read_header(&mut reader)?;
295            let header = Header::parse(&header)?;
296            if header.fortran_order {
297                return Err(Error::Npy("fortran order not supported".to_string()));
298            }
299            let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
300            result.push((name, s))
301        }
302        Ok(result)
303    }
304
305    /// Reads a npz file and returns the stored multi-dimensional arrays for some specified names.
306    pub fn read_npz_by_name<T: AsRef<Path>>(path: T, names: &[&str]) -> Result<Vec<Self>> {
307        let zip_reader = BufReader::new(File::open(path.as_ref())?);
308        let mut zip = zip::ZipArchive::new(zip_reader)?;
309        let mut result = vec![];
310        for name in names.iter() {
311            let mut reader = match zip.by_name(&format!("{name}{NPY_SUFFIX}")) {
312                Ok(reader) => reader,
313                Err(_) => Err(Error::Npy(format!(
314                    "no array for {name} in {:?}",
315                    path.as_ref()
316                )))?,
317            };
318            let header = read_header(&mut reader)?;
319            let header = Header::parse(&header)?;
320            if header.fortran_order {
321                return Err(Error::Npy("fortran order not supported".to_string()));
322            }
323            let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
324            result.push(s)
325        }
326        Ok(result)
327    }
328
329    fn write<T: Write>(&self, f: &mut T) -> Result<()> {
330        f.write_all(NPY_MAGIC_STRING)?;
331        f.write_all(&[1u8, 0u8])?;
332        let header = Header {
333            descr: self.dtype(),
334            fortran_order: false,
335            shape: self.dims().to_vec(),
336        };
337        let mut header = header.to_string()?;
338        let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
339        for _ in 0..pad % 16 {
340            header.push(' ')
341        }
342        header.push('\n');
343        f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
344        f.write_all(header.as_bytes())?;
345        self.write_bytes(f)
346    }
347
348    /// Writes a multi-dimensional array in the npy format.
349    pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<()> {
350        let mut f = File::create(path.as_ref())?;
351        self.write(&mut f)
352    }
353
354    /// Writes multiple multi-dimensional arrays using the npz format.
355    pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
356        ts: &[(S, T)],
357        path: P,
358    ) -> Result<()> {
359        let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
360        let options: zip::write::FileOptions<()> =
361            zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
362
363        for (name, tensor) in ts.iter() {
364            zip.start_file(format!("{}.npy", name.as_ref()), options)?;
365            tensor.as_ref().write(&mut zip)?
366        }
367        Ok(())
368    }
369}
370
371/// Lazy tensor loader.
372pub struct NpzTensors {
373    index_per_name: HashMap<String, usize>,
374    path: std::path::PathBuf,
375    // We do not store a zip reader as it needs mutable access to extract data. Instead we
376    // re-create a zip reader for each tensor.
377}
378
379impl NpzTensors {
380    pub fn new<T: AsRef<Path>>(path: T) -> Result<Self> {
381        let path = path.as_ref().to_owned();
382        let zip_reader = BufReader::new(File::open(&path)?);
383        let mut zip = zip::ZipArchive::new(zip_reader)?;
384        let mut index_per_name = HashMap::new();
385        for i in 0..zip.len() {
386            let file = zip.by_index(i)?;
387            let name = {
388                let name = file.name();
389                name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
390            };
391            index_per_name.insert(name, i);
392        }
393        Ok(Self {
394            index_per_name,
395            path,
396        })
397    }
398
399    pub fn names(&self) -> Vec<&String> {
400        self.index_per_name.keys().collect()
401    }
402
403    /// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
404    /// reading the whole tensor data.
405    pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
406        let index = match self.index_per_name.get(name) {
407            None => crate::bail!("cannot find tensor {name}"),
408            Some(index) => *index,
409        };
410        let zip_reader = BufReader::new(File::open(&self.path)?);
411        let mut zip = zip::ZipArchive::new(zip_reader)?;
412        let mut reader = zip.by_index(index)?;
413        let header = read_header(&mut reader)?;
414        let header = Header::parse(&header)?;
415        Ok((header.shape(), header.descr))
416    }
417
418    pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
419        let index = match self.index_per_name.get(name) {
420            None => return Ok(None),
421            Some(index) => *index,
422        };
423        // We hope that the file has not changed since first reading it.
424        let zip_reader = BufReader::new(File::open(&self.path)?);
425        let mut zip = zip::ZipArchive::new(zip_reader)?;
426        let mut reader = zip.by_index(index)?;
427        let header = read_header(&mut reader)?;
428        let header = Header::parse(&header)?;
429        if header.fortran_order {
430            return Err(Error::Npy("fortran order not supported".to_string()));
431        }
432        let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?;
433        Ok(Some(tensor))
434    }
435}
436
437#[cfg(test)]
438mod tests {
439    use super::Header;
440
441    #[test]
442    fn parse() {
443        let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
444        assert_eq!(
445            Header::parse(h).unwrap(),
446            Header {
447                descr: crate::DType::F64,
448                fortran_order: false,
449                shape: vec![128]
450            }
451        );
452        let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
453        let h = Header::parse(h).unwrap();
454        assert_eq!(
455            h,
456            Header {
457                descr: crate::DType::F32,
458                fortran_order: true,
459                shape: vec![256, 1, 128]
460            }
461        );
462        assert_eq!(
463            h.to_string().unwrap(),
464            "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
465        );
466
467        let h = Header {
468            descr: crate::DType::U32,
469            fortran_order: false,
470            shape: vec![],
471        };
472        assert_eq!(
473            h.to_string().unwrap(),
474            "{'descr': '<u4', 'fortran_order': False, 'shape': (), }"
475        );
476    }
477}