1use crate::{DType, Device, Error, Result, Shape, Tensor};
29use byteorder::{LittleEndian, ReadBytesExt};
30use half::{bf16, f16, slice::HalfFloatSliceExt};
31use std::collections::HashMap;
32use std::fs::File;
33use std::io::{BufReader, Read, Write};
34use std::path::Path;
35
36const NPY_MAGIC_STRING: &[u8] = b"\x93NUMPY";
37const NPY_SUFFIX: &str = ".npy";
38
39fn read_header<R: Read>(reader: &mut R) -> Result<String> {
40 let mut magic_string = vec![0u8; NPY_MAGIC_STRING.len()];
41 reader.read_exact(&mut magic_string)?;
42 if magic_string != NPY_MAGIC_STRING {
43 return Err(Error::Npy("magic string mismatch".to_string()));
44 }
45 let mut version = [0u8; 2];
46 reader.read_exact(&mut version)?;
47 let header_len_len = match version[0] {
48 1 => 2,
49 2 => 4,
50 otherwise => return Err(Error::Npy(format!("unsupported version {otherwise}"))),
51 };
52 let mut header_len = vec![0u8; header_len_len];
53 reader.read_exact(&mut header_len)?;
54 let header_len = header_len
55 .iter()
56 .rev()
57 .fold(0_usize, |acc, &v| 256 * acc + v as usize);
58 let mut header = vec![0u8; header_len];
59 reader.read_exact(&mut header)?;
60 Ok(String::from_utf8_lossy(&header).to_string())
61}
62
63#[derive(Debug, PartialEq)]
64struct Header {
65 descr: DType,
66 fortran_order: bool,
67 shape: Vec<usize>,
68}
69
70impl Header {
71 fn shape(&self) -> Shape {
72 Shape::from(self.shape.as_slice())
73 }
74
75 fn to_string(&self) -> Result<String> {
76 let fortran_order = if self.fortran_order { "True" } else { "False" };
77 let mut shape = self
78 .shape
79 .iter()
80 .map(|x| x.to_string())
81 .collect::<Vec<_>>()
82 .join(",");
83 let descr = match self.descr {
84 DType::BF16 => Err(Error::Npy("bf16 is not supported".into()))?,
85 DType::F16 => "f2",
86 DType::F32 => "f4",
87 DType::F64 => "f8",
88 DType::I16 => "i2",
89 DType::I32 => "i4",
90 DType::I64 => "i8",
91 DType::U32 => "u4",
92 DType::U8 => "u1",
93 DType::F8E4M3 => Err(Error::Npy("f8e4m3 is not supported".into()))?,
94 DType::F6E2M3 => Err(Error::Npy("f6e2m3 is not supported".into()))?,
95 DType::F6E3M2 => Err(Error::Npy("f6e3m2 is not supported".into()))?,
96 DType::F4 => Err(Error::Npy("f4 is not supported".into()))?,
97 DType::F8E8M0 => Err(Error::Npy("f8e8m0 is not supported".into()))?,
98 };
99 if !shape.is_empty() {
100 shape.push(',')
101 }
102 Ok(format!(
103 "{{'descr': '<{descr}', 'fortran_order': {fortran_order}, 'shape': ({shape}), }}"
104 ))
105 }
106
107 fn parse(header: &str) -> Result<Header> {
110 let header =
111 header.trim_matches(|c: char| c == '{' || c == '}' || c == ',' || c.is_whitespace());
112
113 let mut parts: Vec<String> = vec![];
114 let mut start_index = 0usize;
115 let mut cnt_parenthesis = 0i64;
116 for (index, c) in header.char_indices() {
117 match c {
118 '(' => cnt_parenthesis += 1,
119 ')' => cnt_parenthesis -= 1,
120 ',' => {
121 if cnt_parenthesis == 0 {
122 parts.push(header[start_index..index].to_owned());
123 start_index = index + 1;
124 }
125 }
126 _ => {}
127 }
128 }
129 parts.push(header[start_index..].to_owned());
130 let mut part_map: HashMap<String, String> = HashMap::new();
131 for part in parts.iter() {
132 let part = part.trim();
133 if !part.is_empty() {
134 match part.split(':').collect::<Vec<_>>().as_slice() {
135 [key, value] => {
136 let key = key.trim_matches(|c: char| c == '\'' || c.is_whitespace());
137 let value = value.trim_matches(|c: char| c == '\'' || c.is_whitespace());
138 let _ = part_map.insert(key.to_owned(), value.to_owned());
139 }
140 _ => return Err(Error::Npy(format!("unable to parse header {header}"))),
141 }
142 }
143 }
144 let fortran_order = match part_map.get("fortran_order") {
145 None => false,
146 Some(fortran_order) => match fortran_order.as_ref() {
147 "False" => false,
148 "True" => true,
149 _ => return Err(Error::Npy(format!("unknown fortran_order {fortran_order}"))),
150 },
151 };
152 let descr = match part_map.get("descr") {
153 None => return Err(Error::Npy("no descr in header".to_string())),
154 Some(descr) => {
155 if descr.is_empty() {
156 return Err(Error::Npy("empty descr".to_string()));
157 }
158 if descr.starts_with('>') {
159 return Err(Error::Npy(format!("little-endian descr {descr}")));
160 }
161 match descr.trim_matches(|c: char| c == '=' || c == '<' || c == '|') {
167 "e" | "f2" => DType::F16,
168 "f" | "f4" => DType::F32,
169 "d" | "f8" => DType::F64,
170 "i" | "i4" => DType::I32,
171 "q" | "i8" => DType::I64,
172 "h" | "i2" => DType::I16,
173 "B" | "u1" => DType::U8,
175 "I" | "u4" => DType::U32,
176 "?" | "b1" => DType::U8,
177 descr => return Err(Error::Npy(format!("unrecognized descr {descr}"))),
180 }
181 }
182 };
183 let shape = match part_map.get("shape") {
184 None => return Err(Error::Npy("no shape in header".to_string())),
185 Some(shape) => {
186 let shape = shape.trim_matches(|c: char| c == '(' || c == ')' || c == ',');
187 if shape.is_empty() {
188 vec![]
189 } else {
190 shape
191 .split(',')
192 .map(|v| v.trim().parse::<usize>())
193 .collect::<std::result::Result<Vec<_>, _>>()?
194 }
195 }
196 };
197 Ok(Header {
198 descr,
199 fortran_order,
200 shape,
201 })
202 }
203}
204
205impl Tensor {
206 pub(crate) fn from_reader<R: std::io::Read>(
208 shape: Shape,
209 dtype: DType,
210 reader: &mut R,
211 ) -> Result<Self> {
212 let elem_count = shape.elem_count();
213 match dtype {
214 DType::BF16 => {
215 let mut data_t = vec![bf16::ZERO; elem_count];
216 reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
217 Tensor::from_vec(data_t, shape, &Device::Cpu)
218 }
219 DType::F16 => {
220 let mut data_t = vec![f16::ZERO; elem_count];
221 reader.read_u16_into::<LittleEndian>(data_t.reinterpret_cast_mut())?;
222 Tensor::from_vec(data_t, shape, &Device::Cpu)
223 }
224 DType::F32 => {
225 let mut data_t = vec![0f32; elem_count];
226 reader.read_f32_into::<LittleEndian>(&mut data_t)?;
227 Tensor::from_vec(data_t, shape, &Device::Cpu)
228 }
229 DType::F64 => {
230 let mut data_t = vec![0f64; elem_count];
231 reader.read_f64_into::<LittleEndian>(&mut data_t)?;
232 Tensor::from_vec(data_t, shape, &Device::Cpu)
233 }
234 DType::U8 => {
235 let mut data_t = vec![0u8; elem_count];
236 reader.read_exact(&mut data_t)?;
237 Tensor::from_vec(data_t, shape, &Device::Cpu)
238 }
239 DType::U32 => {
240 let mut data_t = vec![0u32; elem_count];
241 reader.read_u32_into::<LittleEndian>(&mut data_t)?;
242 Tensor::from_vec(data_t, shape, &Device::Cpu)
243 }
244 DType::I16 => {
245 let mut data_t = vec![0i16; elem_count];
246 reader.read_i16_into::<LittleEndian>(&mut data_t)?;
247 Tensor::from_vec(data_t, shape, &Device::Cpu)
248 }
249 DType::I32 => {
250 let mut data_t = vec![0i32; elem_count];
251 reader.read_i32_into::<LittleEndian>(&mut data_t)?;
252 Tensor::from_vec(data_t, shape, &Device::Cpu)
253 }
254 DType::I64 => {
255 let mut data_t = vec![0i64; elem_count];
256 reader.read_i64_into::<LittleEndian>(&mut data_t)?;
257 Tensor::from_vec(data_t, shape, &Device::Cpu)
258 }
259 DType::F8E4M3 => {
260 let mut data_t = vec![0u8; elem_count];
261 reader.read_exact(&mut data_t)?;
262 let data_f8: Vec<float8::F8E4M3> =
263 data_t.into_iter().map(float8::F8E4M3::from_bits).collect();
264 Tensor::from_vec(data_f8, shape, &Device::Cpu)
265 }
266 DType::F6E2M3 | DType::F6E3M2 | DType::F4 | DType::F8E8M0 => {
267 Err(Error::UnsupportedDTypeForOp(dtype, "from_reader").bt())
268 }
269 }
270 }
271
272 pub fn read_npy<T: AsRef<Path>>(path: T) -> Result<Self> {
274 let mut reader = File::open(path.as_ref())?;
275 let header = read_header(&mut reader)?;
276 let header = Header::parse(&header)?;
277 if header.fortran_order {
278 return Err(Error::Npy("fortran order not supported".to_string()));
279 }
280 Self::from_reader(header.shape(), header.descr, &mut reader)
281 }
282
283 pub fn read_npz<T: AsRef<Path>>(path: T) -> Result<Vec<(String, Self)>> {
285 let zip_reader = BufReader::new(File::open(path.as_ref())?);
286 let mut zip = zip::ZipArchive::new(zip_reader)?;
287 let mut result = vec![];
288 for i in 0..zip.len() {
289 let mut reader = zip.by_index(i)?;
290 let name = {
291 let name = reader.name();
292 name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
293 };
294 let header = read_header(&mut reader)?;
295 let header = Header::parse(&header)?;
296 if header.fortran_order {
297 return Err(Error::Npy("fortran order not supported".to_string()));
298 }
299 let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
300 result.push((name, s))
301 }
302 Ok(result)
303 }
304
305 pub fn read_npz_by_name<T: AsRef<Path>>(path: T, names: &[&str]) -> Result<Vec<Self>> {
307 let zip_reader = BufReader::new(File::open(path.as_ref())?);
308 let mut zip = zip::ZipArchive::new(zip_reader)?;
309 let mut result = vec![];
310 for name in names.iter() {
311 let mut reader = match zip.by_name(&format!("{name}{NPY_SUFFIX}")) {
312 Ok(reader) => reader,
313 Err(_) => Err(Error::Npy(format!(
314 "no array for {name} in {:?}",
315 path.as_ref()
316 )))?,
317 };
318 let header = read_header(&mut reader)?;
319 let header = Header::parse(&header)?;
320 if header.fortran_order {
321 return Err(Error::Npy("fortran order not supported".to_string()));
322 }
323 let s = Self::from_reader(header.shape(), header.descr, &mut reader)?;
324 result.push(s)
325 }
326 Ok(result)
327 }
328
329 fn write<T: Write>(&self, f: &mut T) -> Result<()> {
330 f.write_all(NPY_MAGIC_STRING)?;
331 f.write_all(&[1u8, 0u8])?;
332 let header = Header {
333 descr: self.dtype(),
334 fortran_order: false,
335 shape: self.dims().to_vec(),
336 };
337 let mut header = header.to_string()?;
338 let pad = 16 - (NPY_MAGIC_STRING.len() + 5 + header.len()) % 16;
339 for _ in 0..pad % 16 {
340 header.push(' ')
341 }
342 header.push('\n');
343 f.write_all(&[(header.len() % 256) as u8, (header.len() / 256) as u8])?;
344 f.write_all(header.as_bytes())?;
345 self.write_bytes(f)
346 }
347
348 pub fn write_npy<T: AsRef<Path>>(&self, path: T) -> Result<()> {
350 let mut f = File::create(path.as_ref())?;
351 self.write(&mut f)
352 }
353
354 pub fn write_npz<S: AsRef<str>, T: AsRef<Tensor>, P: AsRef<Path>>(
356 ts: &[(S, T)],
357 path: P,
358 ) -> Result<()> {
359 let mut zip = zip::ZipWriter::new(File::create(path.as_ref())?);
360 let options: zip::write::FileOptions<()> =
361 zip::write::FileOptions::default().compression_method(zip::CompressionMethod::Stored);
362
363 for (name, tensor) in ts.iter() {
364 zip.start_file(format!("{}.npy", name.as_ref()), options)?;
365 tensor.as_ref().write(&mut zip)?
366 }
367 Ok(())
368 }
369}
370
371pub struct NpzTensors {
373 index_per_name: HashMap<String, usize>,
374 path: std::path::PathBuf,
375 }
378
379impl NpzTensors {
380 pub fn new<T: AsRef<Path>>(path: T) -> Result<Self> {
381 let path = path.as_ref().to_owned();
382 let zip_reader = BufReader::new(File::open(&path)?);
383 let mut zip = zip::ZipArchive::new(zip_reader)?;
384 let mut index_per_name = HashMap::new();
385 for i in 0..zip.len() {
386 let file = zip.by_index(i)?;
387 let name = {
388 let name = file.name();
389 name.strip_suffix(NPY_SUFFIX).unwrap_or(name).to_owned()
390 };
391 index_per_name.insert(name, i);
392 }
393 Ok(Self {
394 index_per_name,
395 path,
396 })
397 }
398
399 pub fn names(&self) -> Vec<&String> {
400 self.index_per_name.keys().collect()
401 }
402
403 pub fn get_shape_and_dtype(&self, name: &str) -> Result<(Shape, DType)> {
406 let index = match self.index_per_name.get(name) {
407 None => crate::bail!("cannot find tensor {name}"),
408 Some(index) => *index,
409 };
410 let zip_reader = BufReader::new(File::open(&self.path)?);
411 let mut zip = zip::ZipArchive::new(zip_reader)?;
412 let mut reader = zip.by_index(index)?;
413 let header = read_header(&mut reader)?;
414 let header = Header::parse(&header)?;
415 Ok((header.shape(), header.descr))
416 }
417
418 pub fn get(&self, name: &str) -> Result<Option<Tensor>> {
419 let index = match self.index_per_name.get(name) {
420 None => return Ok(None),
421 Some(index) => *index,
422 };
423 let zip_reader = BufReader::new(File::open(&self.path)?);
425 let mut zip = zip::ZipArchive::new(zip_reader)?;
426 let mut reader = zip.by_index(index)?;
427 let header = read_header(&mut reader)?;
428 let header = Header::parse(&header)?;
429 if header.fortran_order {
430 return Err(Error::Npy("fortran order not supported".to_string()));
431 }
432 let tensor = Tensor::from_reader(header.shape(), header.descr, &mut reader)?;
433 Ok(Some(tensor))
434 }
435}
436
437#[cfg(test)]
438mod tests {
439 use super::Header;
440
441 #[test]
442 fn parse() {
443 let h = "{'descr': '<f8', 'fortran_order': False, 'shape': (128,), }";
444 assert_eq!(
445 Header::parse(h).unwrap(),
446 Header {
447 descr: crate::DType::F64,
448 fortran_order: false,
449 shape: vec![128]
450 }
451 );
452 let h = "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128), }";
453 let h = Header::parse(h).unwrap();
454 assert_eq!(
455 h,
456 Header {
457 descr: crate::DType::F32,
458 fortran_order: true,
459 shape: vec![256, 1, 128]
460 }
461 );
462 assert_eq!(
463 h.to_string().unwrap(),
464 "{'descr': '<f4', 'fortran_order': True, 'shape': (256,1,128,), }"
465 );
466
467 let h = Header {
468 descr: crate::DType::U32,
469 fortran_order: false,
470 shape: vec![],
471 };
472 assert_eq!(
473 h.to_string().unwrap(),
474 "{'descr': '<u4', 'fortran_order': False, 'shape': (), }"
475 );
476 }
477}