1use crate::{DType, Device, Error, Result, Shape, Tensor};
29use byteorder::{LittleEndian, ReadBytesExt};
30use half::{bf16, f16, slice::HalfFloatSliceExt};
31use std::collections::HashMap;
32use std::fs::File;
33use std::io::{BufReader, Read, Write};
34use std::path::Path;
35
36const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
37const NPY_SUFFIX: &str = ".npy";
38
39fn read_header<R: Read>(reader: &mut R) -> Result<String> {
40 let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
41 reader.read_exact(&mut magic_string)?;
42 if magic_string != NPY_MAGIC_STRING {
43 return Err(Error::Npy("magic string mismatch".to_string()));
44 }
45 let mut version = [0u8; 2];
46 reader.read_exact(&mut version)?;
47 let header_len_len = match version[0] {
48 1 => 2,
49 2 => 4,
50 otherwise => return Err(Error::Npy(format!("unsupported version {otherwise}"))),
51 };
52 let mut header_len = vec![0u8; header_len_len];
53 reader.read_exact(&mut header_len)?;
54 let header_len = header_len
55 .iter()
56 .rev()
57 .fold(0_usize, |acc, &v| 256 * acc + v as usize);
58 let mut header = vec![0u8; header_len];
59 reader.read_exact(&mut header)?;
60 Ok(String::from_utf8_lossy(&header).to_string())
61}
62
63#[derive(Debug, PartialEq)]
64struct Header {
65 descr: DType,
66 fortran_order: bool,
67 shape: Vec<usize>,
68}
69
70impl Header {
71 fn shape(&self) -> Shape {
72 Shape::from(self.shape.as_slice())
73 }
74
75 fn to_string(&self) -> Result<String> {
76 let fortran_order = if self.fortran_order { "True" } else { "False" };
77 let mut shape = self
78 .shape
79 .iter()
80 .map(|x| x.to_string())
81 .collect::<Vec<_>>()
82 .join(",");
83 let descr = match self.descr {
84 DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
85 DType::F16 => "f2",
86 DType::F32 => "f4",
87 DType::F64 => "f8",
88 DType::I64 => "i8",
89 DType::U32 => "u4",
90 DType::U8 => "u1",
91 };
92 if !shape.is_empty() {
93 shape.push(',')
94 }
95 Ok(format!(
96 "{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
97 ))
98 }
99
100 fn parse(header: &str) -> Result<Header> {
103 let header =
104 header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
105
106 let mut parts: Vec<String> = vec![];
107 let mut start_index = 0usize;
108 let mut cnt_parenthesis = 0i64;
109 for (index, c) in header.chars().enumerate() {
110 match c {
111 '(' => cnt_parenthesis += 1,
112 ')' => cnt_parenthesis -= 1,
113 ',' => {
114 if cnt_parenthesis == 0 {
115 parts.push(header[start_index..index].to_owned());
116 start_index = index + 1;
117 }
118 }
119 _ => {}
120 }
121 }
122 parts.push(header[start_index..].to_owned());
123 let mut part_map: HashMap<String, String> = HashMap::new();
124 for part in parts.iter() {
125 let part = part.trim();
126 if !part.is_empty() {
127 match part.split(':').collect::<Vec<_>>().as_slice() {
128 [key, value] => {
129 let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
130 let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
131 let _ = part_map.insert(key.to_owned(), value.to_owned());
132 }
133 _ => return Err(Error::Npy(format!("unable to parse header {header}"))),
134 }
135 }
136 }
137 let fortran_order = match part_map.get("fortran_order") {
138 None => false,
139 Some(fortran_order) => match fortran_order.as_ref() {
140 "False" => false,
141 "True" => true,
142 _ => return Err(Error::Npy(format!("unknown fortran_order {fortran_order}"))),
143 },
144 };
145 let descr = match part_map.get("descr") {
146 None => return Err(Error::Npy("no descr in header".to_string())),
147 Some(descr) => {
148 if descr.is_empty() {
149 return Err(Error::Npy("empty descr".to_string()));
150 }
151 if descr.starts_with('>') {
152 return Err(Error::Npy(format!("little-endian descr {descr}")));
153 }
154 match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
160 "e" | "f2" => DType::F16,
161 "f" | "f4" => DType::F32,
162 "d" | "f8" => DType::F64,
163 "q" | "i8" => DType::I64,
165 "B" | "u1" => DType::U8,
168 "I" | "u4" => DType::U32,
169 "?" | "b1" => DType::U8,
170 descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))),
173 }
174 }
175 };
176 let shape = match part_map.get("shape") {
177 None => return Err(Error::Npy("no shape in header".to_string())),
178 Some(shape) => {
179 let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
180 if shape.is_empty() {
181 vec![]
182 } else {
183 shape
184 .split(',')
185 .map(|v| v.trim().parse::<usize>())
186 .collect::<std::result::Result<Vec<_>, _>>()?
187 }
188 }
189 };
190 Ok(Header {
191 descr,
192 fortran_order,
193 shape,
194 })
195 }
196}
197
198impl Tensor {
199 pub(crate) fn from_reader<R: std::io::Read>(
201 shape: Shape,
202 dtype: DType,
203 reader: &mut R,
204 ) -> Result<Self> {
205 let elem_count = shape.elem_count();
206 match dtype {
207 DType::BF16 => {
208 let mut data_t = vec![bf16::ZERO; elem_count];
209 reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
210 Tensor::from_vec(data_t, shape, &Device::Cpu)
211 }
212 DType::F16 => {
213 let mut data_t = vec![f16::ZERO; elem_count];
214 reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
215 Tensor::from_vec(data_t, shape, &Device::Cpu)
216 }
217 DType::F32 => {
218 let mut data_t = vec![0f32; elem_count];
219 reader.read_f32_into::<LittleEndian>(&mut data_t)?;
220 Tensor::from_vec(data_t, shape, &Device::Cpu)
221 }
222 DType::F64 => {
223 let mut data_t = vec![0f64; elem_count];
224 reader.read_f64_into::<LittleEndian>(&mut data_t)?;
225 Tensor::from_vec(data_t, shape, &Device::Cpu)
226 }
227 DType::U8 => {
228 let mut data_t = vec![0u8; elem_count];
229 reader.read_exact(&mut data_t)?;
230 Tensor::from_vec(data_t, shape, &Device::Cpu)
231 }
232 DType::U32 => {
233 let mut data_t = vec![0u32; elem_count];
234 reader.read_u32_into::<LittleEndian>(&mut data_t)?;
235 Tensor::from_vec(data_t, shape, &Device::Cpu)
236 }
237 DType::I64 => {
238 let mut data_t = vec![0i64; elem_count];
239 reader.read_i64_into::<LittleEndian>(&mut data_t)?;
240 Tensor::from_vec(data_t, shape, &Device::Cpu)
241 }
242 }
243 }
244
245 pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
247 let mut reader = File::open(path.as_ref())?;
248 let header = read_header(&mut reader)?;
249 let header = Header::parse(&header)?;
250 if header.fortran_order {
251 return Err(Error::Npy("fortran order not supported".to_string()));
252 }
253 Self::from_reader(header.shape(), header.descr, &mut reader)
254 }
255
256 pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>> {
258 let zip_reader = BufReader::new(File::open(path.as_ref())?);
259 let mut zip = zip::ZipArchive::new(zip_reader)?;
260 let mut result = vec![];
261 for i in 0..zip.len() {
262 let mut reader = zip.by_index(i)?;
263 let name = {
264 let name = reader.name();
265 name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
266 };
267 let header = read_header(&mut reader)?;
268 let header = Header::parse(&header)?;
269 if header.fortran_order {
270 return Err(Error::Npy("fortran order not supported".to_string()));
271 }
272 let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
273 result.push((name, s))
274 }
275 Ok(result)
276 }
277
278 pub fn read_npz_by_name<T: AsRef<Path>>(path: T, names: &[&str]) -> Result<Vec<Self>> {
280 let zip_reader = BufReader::new(File::open(path.as_ref())?);
281 let mut zip = zip::ZipArchive::new(zip_reader)?;
282 let mut result = vec![];
283 for name in names.iter() {
284 let mut reader = match zip.by_name(&format!("{name}{NPY_SUFFIX}")) {
285 Ok(reader) => reader,
286 Err(_) => Err(Error::Npy(format!(
287 "no array for {name} in {:?}",
288 path.as_ref()
289 )))?,
290 };
291 let header = read_header(&mut reader)?;
292 let header = Header::parse(&header)?;
293 if header.fortran_order {
294 return Err(Error::Npy("fortran order not supported".to_string()));
295 }
296 let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
297 result.push(s)
298 }
299 Ok(result)
300 }
301
302 fn write<T: Write>(&self, f: &mut T) -> Result<()> {
303 f.write_all(NPY_MAGIC_STRING)?;
304 f.write_all(&[1u8, 0u8])?;
305 let header = Header {
306 descr: self.dtype(),
307 fortran_order: false,
308 shape: self.dims().to_vec(),
309 };
310 let mut header = header.to_string()?;
311 let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
312 for _ in 0..pad % 16 {
313 header.push(' ')
314 }
315 header.push('\n');
316 f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
317 f.write_all(header.as_bytes())?;
318 self.write_bytes(f)
319 }
320
321 pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<()> {
323 let mut f = File::create(path.as_ref())?;
324 self.write(&mut f)
325 }
326
327 pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
329 ts: &[(S, T)],
330 path: P,
331 ) -> Result<()> {
332 let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
333 let options: zip::write::FileOptions<()> =
334 zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
335
336 for (name, tensor) in ts.iter() {
337 zip.start_file(format!("{}.npy", name.as_ref()), options)?;
338 tensor.as_ref().write(&mut zip)?
339 }
340 Ok(())
341 }
342}
343
344pub struct NpzTensors {
346 index_per_name: HashMap<String, usize>,
347 path: std::path::PathBuf,
348 }
351
352impl NpzTensors {
353 pub fn new<T: AsRef<Path>>(path: T) -> Result<Self> {
354 let path = path.as_ref().to_owned();
355 let zip_reader = BufReader::new(File::open(&path)?);
356 let mut zip = zip::ZipArchive::new(zip_reader)?;
357 let mut index_per_name = HashMap::new();
358 for i in 0..zip.len() {
359 let file = zip.by_index(i)?;
360 let name = {
361 let name = file.name();
362 name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
363 };
364 index_per_name.insert(name, i);
365 }
366 Ok(Self {
367 index_per_name,
368 path,
369 })
370 }
371
372 pub fn names(&self) -> Vec<&String> {
373 self.index_per_name.keys().collect()
374 }
375
376 pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
379 let index = match self.index_per_name.get(name) {
380 None => crate::bail!("cannot find tensor {name}"),
381 Some(index) => *index,
382 };
383 let zip_reader = BufReader::new(File::open(&self.path)?);
384 let mut zip = zip::ZipArchive::new(zip_reader)?;
385 let mut reader = zip.by_index(index)?;
386 let header = read_header(&mut reader)?;
387 let header = Header::parse(&header)?;
388 Ok((header.shape(), header.descr))
389 }
390
391 pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
392 let index = match self.index_per_name.get(name) {
393 None => return Ok(None),
394 Some(index) => *index,
395 };
396 let zip_reader = BufReader::new(File::open(&self.path)?);
398 let mut zip = zip::ZipArchive::new(zip_reader)?;
399 let mut reader = zip.by_index(index)?;
400 let header = read_header(&mut reader)?;
401 let header = Header::parse(&header)?;
402 if header.fortran_order {
403 return Err(Error::Npy("fortran order not supported".to_string()));
404 }
405 let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?;
406 Ok(Some(tensor))
407 }
408}
409
410#[cfg(test)]
411mod tests {
412 use super::Header;
413
414 #[test]
415 fn parse() {
416 let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
417 assert_eq!(
418 Header::parse(h).unwrap(),
419 Header {
420 descr: crate::DType::F64,
421 fortran_order: false,
422 shape: vec![128]
423 }
424 );
425 let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
426 let h = Header::parse(h).unwrap();
427 assert_eq!(
428 h,
429 Header {
430 descr: crate::DType::F32,
431 fortran_order: true,
432 shape: vec![256, 1, 128]
433 }
434 );
435 assert_eq!(
436 h.to_string().unwrap(),
437 "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
438 );
439
440 let h = Header {
441 descr: crate::DType::U32,
442 fortran_order: false,
443 shape: vec![],
444 };
445 assert_eq!(
446 h.to_string().unwrap(),
447 "{'descr': '<u4', 'fortran_order': False, 'shape': (), }"
448 );
449 }
450}