candle_core/
npy.rs

1//! Numpy support for tensors.
2//!
3//! The spec for the npy format can be found in
4//! [npy-format](https://docs.scipy.org/doc/numpy-1.14.2/neps/npy-format.html).
5//! The functions from this module can be used to read tensors from npy/npz files
6//! or write tensors to these files. A npy file contains a single tensor (unnamed)
7//! whereas a npz file can contain multiple named tensors. npz files are also compressed.
8//!
9//! These two formats are easy to use in Python using the numpy library.
10//!
11//! ```python
12//! import numpy as np
13//! x = np.arange(10)
14//!
15//! # Write a npy file.
16//! np.save("test.npy", x)
17//!
18//! # Read a value from the npy file.
19//! x = np.load("test.npy")
20//!
21//! # Write multiple values to a npz file.
22//! values = { "x": x, "x_plus_one": x + 1 }
23//! np.savez("test.npz", **values)
24//!
25//! # Load multiple values from a npz file.
26//! values = np.loadz("test.npz")
27//! ```
28use crate::{DType, Device, Error, Result, Shape, Tensor};
29use byteorder::{LittleEndian, ReadBytesExt};
30use half::{bf16, f16, slice::HalfFloatSliceExt};
31use std::collections::HashMap;
32use std::fs::File;
33use std::io::{BufReader, Read, Write};
34use std::path::Path;
35
36const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
37const NPY_SUFFIX: &str = ".npy";
38
39fn read_header<R: Read>(reader: &mut R) -> Result<String> {
40    let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
41    reader.read_exact(&mut magic_string)?;
42    if magic_string != NPY_MAGIC_STRING {
43        return Err(Error::Npy("magic string mismatch".to_string()));
44    }
45    let mut version = [0u8; 2];
46    reader.read_exact(&mut version)?;
47    let header_len_len = match version[0] {
48        1 => 2,
49        2 => 4,
50        otherwise => return Err(Error::Npy(format!("unsupported version {otherwise}"))),
51    };
52    let mut header_len = vec![0u8; header_len_len];
53    reader.read_exact(&mut header_len)?;
54    let header_len = header_len
55        .iter()
56        .rev()
57        .fold(0_usize, |acc, &v| 256 * acc + v as usize);
58    let mut header = vec![0u8; header_len];
59    reader.read_exact(&mut header)?;
60    Ok(String::from_utf8_lossy(&header).to_string())
61}
62
63#[derive(Debug, PartialEq)]
64struct Header {
65    descr: DType,
66    fortran_order: bool,
67    shape: Vec<usize>,
68}
69
70impl Header {
71    fn shape(&self) -> Shape {
72        Shape::from(self.shape.as_slice())
73    }
74
75    fn to_string(&self) -> Result<String> {
76        let fortran_order = if self.fortran_order { "True" } else { "False" };
77        let mut shape = self
78            .shape
79            .iter()
80            .map(|x| x.to_string())
81            .collect::<Vec<_>>()
82            .join(",");
83        let descr = match self.descr {
84            DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
85            DType::F16 => "f2",
86            DType::F32 => "f4",
87            DType::F64 => "f8",
88            DType::I64 => "i8",
89            DType::U32 => "u4",
90            DType::U8 => "u1",
91        };
92        if !shape.is_empty() {
93            shape.push(',')
94        }
95        Ok(format!(
96            "{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
97        ))
98    }
99
100    // Hacky parser for the npy header, a typical example would be:
101    // {'descr': '<f8', 'fortran_order': False, 'shape': (128,), }
102    fn parse(header: &str) -> Result<Header> {
103        let header =
104            header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
105
106        let mut parts: Vec<String> = vec![];
107        let mut start_index = 0usize;
108        let mut cnt_parenthesis = 0i64;
109        for (index, c) in header.chars().enumerate() {
110            match c {
111                '(' => cnt_parenthesis += 1,
112                ')' => cnt_parenthesis -= 1,
113                ',' => {
114                    if cnt_parenthesis == 0 {
115                        parts.push(header[start_index..index].to_owned());
116                        start_index = index + 1;
117                    }
118                }
119                _ => {}
120            }
121        }
122        parts.push(header[start_index..].to_owned());
123        let mut part_map: HashMap<String, String> = HashMap::new();
124        for part in parts.iter() {
125            let part = part.trim();
126            if !part.is_empty() {
127                match part.split(':').collect::<Vec<_>>().as_slice() {
128                    [key, value] => {
129                        let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
130                        let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
131                        let _ = part_map.insert(key.to_owned(), value.to_owned());
132                    }
133                    _ => return Err(Error::Npy(format!("unable to parse header {header}"))),
134                }
135            }
136        }
137        let fortran_order = match part_map.get("fortran_order") {
138            None => false,
139            Some(fortran_order) => match fortran_order.as_ref() {
140                "False" => false,
141                "True" => true,
142                _ => return Err(Error::Npy(format!("unknown fortran_order {fortran_order}"))),
143            },
144        };
145        let descr = match part_map.get("descr") {
146            None => return Err(Error::Npy("no descr in header".to_string())),
147            Some(descr) => {
148                if descr.is_empty() {
149                    return Err(Error::Npy("empty descr".to_string()));
150                }
151                if descr.starts_with('>') {
152                    return Err(Error::Npy(format!("little-endian descr {descr}")));
153                }
154                // the only supported types in tensor are:
155                //     float64, float32, float16,
156                //     complex64, complex128,
157                //     int64, int32, int16, int8,
158                //     uint8, and bool.
159                match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
160                    "e" | "f2" => DType::F16,
161                    "f" | "f4" => DType::F32,
162                    "d" | "f8" => DType::F64,
163                    // "i" | "i4" => DType::S32,
164                    "q" | "i8" => DType::I64,
165                    // "h" | "i2" => DType::S16,
166                    // "b" | "i1" => DType::S8,
167                    "B" | "u1" => DType::U8,
168                    "I" | "u4" => DType::U32,
169                    "?" | "b1" => DType::U8,
170                    // "F" | "F4" => DType::C64,
171                    // "D" | "F8" => DType::C128,
172                    descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))),
173                }
174            }
175        };
176        let shape = match part_map.get("shape") {
177            None => return Err(Error::Npy("no shape in header".to_string())),
178            Some(shape) => {
179                let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
180                if shape.is_empty() {
181                    vec![]
182                } else {
183                    shape
184                        .split(',')
185                        .map(|v| v.trim().parse::<usize>())
186                        .collect::<std::result::Result<Vec<_>, _>>()?
187                }
188            }
189        };
190        Ok(Header {
191            descr,
192            fortran_order,
193            shape,
194        })
195    }
196}
197
198impl Tensor {
199    // TODO: Add the possibility to read directly to a device?
200    pub(crate) fn from_reader<R: std::io::Read>(
201        shape: Shape,
202        dtype: DType,
203        reader: &mut R,
204    ) -> Result<Self> {
205        let elem_count = shape.elem_count();
206        match dtype {
207            DType::BF16 => {
208                let mut data_t = vec![bf16::ZERO; elem_count];
209                reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
210                Tensor::from_vec(data_t, shape, &Device::Cpu)
211            }
212            DType::F16 => {
213                let mut data_t = vec![f16::ZERO; elem_count];
214                reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
215                Tensor::from_vec(data_t, shape, &Device::Cpu)
216            }
217            DType::F32 => {
218                let mut data_t = vec![0f32; elem_count];
219                reader.read_f32_into::<LittleEndian>(&mut data_t)?;
220                Tensor::from_vec(data_t, shape, &Device::Cpu)
221            }
222            DType::F64 => {
223                let mut data_t = vec![0f64; elem_count];
224                reader.read_f64_into::<LittleEndian>(&mut data_t)?;
225                Tensor::from_vec(data_t, shape, &Device::Cpu)
226            }
227            DType::U8 => {
228                let mut data_t = vec![0u8; elem_count];
229                reader.read_exact(&mut data_t)?;
230                Tensor::from_vec(data_t, shape, &Device::Cpu)
231            }
232            DType::U32 => {
233                let mut data_t = vec![0u32; elem_count];
234                reader.read_u32_into::<LittleEndian>(&mut data_t)?;
235                Tensor::from_vec(data_t, shape, &Device::Cpu)
236            }
237            DType::I64 => {
238                let mut data_t = vec![0i64; elem_count];
239                reader.read_i64_into::<LittleEndian>(&mut data_t)?;
240                Tensor::from_vec(data_t, shape, &Device::Cpu)
241            }
242        }
243    }
244
245    /// Reads a npy file and return the stored multi-dimensional array as a tensor.
246    pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
247        let mut reader = File::open(path.as_ref())?;
248        let header = read_header(&mut reader)?;
249        let header = Header::parse(&header)?;
250        if header.fortran_order {
251            return Err(Error::Npy("fortran order not supported".to_string()));
252        }
253        Self::from_reader(header.shape(), header.descr, &mut reader)
254    }
255
256    /// Reads a npz file and returns the stored multi-dimensional arrays together with their names.
257    pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>> {
258        let zip_reader = BufReader::new(File::open(path.as_ref())?);
259        let mut zip = zip::ZipArchive::new(zip_reader)?;
260        let mut result = vec![];
261        for i in 0..zip.len() {
262            let mut reader = zip.by_index(i)?;
263            let name = {
264                let name = reader.name();
265                name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
266            };
267            let header = read_header(&mut reader)?;
268            let header = Header::parse(&header)?;
269            if header.fortran_order {
270                return Err(Error::Npy("fortran order not supported".to_string()));
271            }
272            let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
273            result.push((name, s))
274        }
275        Ok(result)
276    }
277
278    /// Reads a npz file and returns the stored multi-dimensional arrays for some specified names.
279    pub fn read_npz_by_name<T: AsRef<Path>>(path: T, names: &[&str]) -> Result<Vec<Self>> {
280        let zip_reader = BufReader::new(File::open(path.as_ref())?);
281        let mut zip = zip::ZipArchive::new(zip_reader)?;
282        let mut result = vec![];
283        for name in names.iter() {
284            let mut reader = match zip.by_name(&format!("{name}{NPY_SUFFIX}")) {
285                Ok(reader) => reader,
286                Err(_) => Err(Error::Npy(format!(
287                    "no array for {name} in {:?}",
288                    path.as_ref()
289                )))?,
290            };
291            let header = read_header(&mut reader)?;
292            let header = Header::parse(&header)?;
293            if header.fortran_order {
294                return Err(Error::Npy("fortran order not supported".to_string()));
295            }
296            let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
297            result.push(s)
298        }
299        Ok(result)
300    }
301
302    fn write<T: Write>(&self, f: &mut T) -> Result<()> {
303        f.write_all(NPY_MAGIC_STRING)?;
304        f.write_all(&[1u8, 0u8])?;
305        let header = Header {
306            descr: self.dtype(),
307            fortran_order: false,
308            shape: self.dims().to_vec(),
309        };
310        let mut header = header.to_string()?;
311        let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
312        for _ in 0..pad % 16 {
313            header.push(' ')
314        }
315        header.push('\n');
316        f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
317        f.write_all(header.as_bytes())?;
318        self.write_bytes(f)
319    }
320
321    /// Writes a multi-dimensional array in the npy format.
322    pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<()> {
323        let mut f = File::create(path.as_ref())?;
324        self.write(&mut f)
325    }
326
327    /// Writes multiple multi-dimensional arrays using the npz format.
328    pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
329        ts: &[(S, T)],
330        path: P,
331    ) -> Result<()> {
332        let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
333        let options: zip::write::FileOptions<()> =
334            zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
335
336        for (name, tensor) in ts.iter() {
337            zip.start_file(format!("{}.npy", name.as_ref()), options)?;
338            tensor.as_ref().write(&mut zip)?
339        }
340        Ok(())
341    }
342}
343
344/// Lazy tensor loader.
345pub struct NpzTensors {
346    index_per_name: HashMap<String, usize>,
347    path: std::path::PathBuf,
348    // We do not store a zip reader as it needs mutable access to extract data. Instead we
349    // re-create a zip reader for each tensor.
350}
351
352impl NpzTensors {
353    pub fn new<T: AsRef<Path>>(path: T) -> Result<Self> {
354        let path = path.as_ref().to_owned();
355        let zip_reader = BufReader::new(File::open(&path)?);
356        let mut zip = zip::ZipArchive::new(zip_reader)?;
357        let mut index_per_name = HashMap::new();
358        for i in 0..zip.len() {
359            let file = zip.by_index(i)?;
360            let name = {
361                let name = file.name();
362                name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
363            };
364            index_per_name.insert(name, i);
365        }
366        Ok(Self {
367            index_per_name,
368            path,
369        })
370    }
371
372    pub fn names(&self) -> Vec<&String> {
373        self.index_per_name.keys().collect()
374    }
375
376    /// This only returns the shape and dtype for a named tensor. Compared to `get`, this avoids
377    /// reading the whole tensor data.
378    pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
379        let index = match self.index_per_name.get(name) {
380            None => crate::bail!("cannot find tensor {name}"),
381            Some(index) => *index,
382        };
383        let zip_reader = BufReader::new(File::open(&self.path)?);
384        let mut zip = zip::ZipArchive::new(zip_reader)?;
385        let mut reader = zip.by_index(index)?;
386        let header = read_header(&mut reader)?;
387        let header = Header::parse(&header)?;
388        Ok((header.shape(), header.descr))
389    }
390
391    pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
392        let index = match self.index_per_name.get(name) {
393            None => return Ok(None),
394            Some(index) => *index,
395        };
396        // We hope that the file has not changed since first reading it.
397        let zip_reader = BufReader::new(File::open(&self.path)?);
398        let mut zip = zip::ZipArchive::new(zip_reader)?;
399        let mut reader = zip.by_index(index)?;
400        let header = read_header(&mut reader)?;
401        let header = Header::parse(&header)?;
402        if header.fortran_order {
403            return Err(Error::Npy("fortran order not supported".to_string()));
404        }
405        let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?;
406        Ok(Some(tensor))
407    }
408}
409
410#[cfg(test)]
411mod tests {
412    use super::Header;
413
414    #[test]
415    fn parse() {
416        let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
417        assert_eq!(
418            Header::parse(h).unwrap(),
419            Header {
420                descr: crate::DType::F64,
421                fortran_order: false,
422                shape: vec![128]
423            }
424        );
425        let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
426        let h = Header::parse(h).unwrap();
427        assert_eq!(
428            h,
429            Header {
430                descr: crate::DType::F32,
431                fortran_order: true,
432                shape: vec![256, 1, 128]
433            }
434        );
435        assert_eq!(
436            h.to_string().unwrap(),
437            "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
438        );
439
440        let h = Header {
441            descr: crate::DType::U32,
442            fortran_order: false,
443            shape: vec![],
444        };
445        assert_eq!(
446            h.to_string().unwrap(),
447            "{'descr': '<u4', 'fortran_order': False, 'shape': (), }"
448        );
449    }
450}