1pub mod dtype_parse;
10pub mod header;
11
12use std::fs::File;
13use std::io::{BufReader, BufWriter, Read, Write};
14use std::path::Path;
15
16use ferray_core::Array;
17use ferray_core::dimension::{Dimension, IxDyn};
18use ferray_core::dtype::{DType, Element};
19use ferray_core::dynarray::DynArray;
20use ferray_core::error::{FerrayError, FerrayResult};
21
22use self::dtype_parse::Endianness;
23
24pub(crate) fn checked_total_elements(shape: &[usize]) -> FerrayResult<usize> {
27 shape.iter().try_fold(1usize, |acc, &dim| {
28 acc.checked_mul(dim).ok_or_else(|| {
29 FerrayError::io_error("shape overflow: total elements exceed usize::MAX")
30 })
31 })
32}
33
34pub fn save<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
41 path: P,
42 array: &Array<T, D>,
43) -> FerrayResult<()> {
44 let file = File::create(path.as_ref()).map_err(|e| {
45 FerrayError::io_error(format!(
46 "failed to create file '{}': {e}",
47 path.as_ref().display()
48 ))
49 })?;
50 let mut writer = BufWriter::new(file);
51 save_to_writer(&mut writer, array)
52}
53
54pub fn save_to_writer<T: Element + NpyElement, D: Dimension, W: Write>(
61 writer: &mut W,
62 array: &Array<T, D>,
63) -> FerrayResult<()> {
64 let fortran_order = false;
65 header::write_header(writer, T::dtype(), array.shape(), fortran_order)?;
66
67 if let Some(slice) = array.as_slice() {
69 T::write_slice(slice, writer)?;
70 } else {
71 let data: Vec<T> = array.iter().cloned().collect();
73 T::write_slice(&data, writer)?;
74 }
75
76 writer.flush()?;
77 Ok(())
78}
79
80pub fn load<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
87 path: P,
88) -> FerrayResult<Array<T, D>> {
89 let file = File::open(path.as_ref()).map_err(|e| {
90 FerrayError::io_error(format!(
91 "failed to open file '{}': {e}",
92 path.as_ref().display()
93 ))
94 })?;
95 let mut reader = BufReader::new(file);
96 load_from_reader(&mut reader)
97}
98
99pub fn load_from_reader<T: Element + NpyElement, D: Dimension, R: Read>(
101 reader: &mut R,
102) -> FerrayResult<Array<T, D>> {
103 let hdr = header::read_header(reader)?;
104
105 if hdr.dtype != T::dtype() {
107 return Err(FerrayError::invalid_dtype(format!(
108 "expected dtype {:?} for type {}, but file has {:?}",
109 T::dtype(),
110 std::any::type_name::<T>(),
111 hdr.dtype,
112 )));
113 }
114
115 if let Some(ndim) = D::NDIM {
117 if ndim != hdr.shape.len() {
118 return Err(FerrayError::shape_mismatch(format!(
119 "expected {} dimensions, but file has {} (shape {:?})",
120 ndim,
121 hdr.shape.len(),
122 hdr.shape,
123 )));
124 }
125 }
126
127 let total_elements = checked_total_elements(&hdr.shape)?;
128 let data = T::read_vec(reader, total_elements, hdr.endianness)?;
129
130 let dim = build_dimension::<D>(&hdr.shape)?;
131
132 if hdr.fortran_order {
133 Array::from_vec_f(dim, data)
134 } else {
135 Array::from_vec(dim, data)
136 }
137}
138
139pub fn load_dynamic<P: AsRef<Path>>(path: P) -> FerrayResult<DynArray> {
146 let file = File::open(path.as_ref()).map_err(|e| {
147 FerrayError::io_error(format!(
148 "failed to open file '{}': {e}",
149 path.as_ref().display()
150 ))
151 })?;
152 let mut reader = BufReader::new(file);
153 load_dynamic_from_reader(&mut reader)
154}
155
156pub fn load_dynamic_from_reader<R: Read>(reader: &mut R) -> FerrayResult<DynArray> {
158 let hdr = header::read_header(reader)?;
159 let total = checked_total_elements(&hdr.shape)?;
160 let dim = IxDyn::new(&hdr.shape);
161
162 macro_rules! load_typed {
163 ($ty:ty, $variant:ident) => {{
164 let data = <$ty as NpyElement>::read_vec(reader, total, hdr.endianness)?;
165 let arr = if hdr.fortran_order {
166 Array::<$ty, IxDyn>::from_vec_f(dim, data)?
167 } else {
168 Array::<$ty, IxDyn>::from_vec(dim, data)?
169 };
170 Ok(DynArray::$variant(arr))
171 }};
172 }
173
174 match hdr.dtype {
175 DType::Bool => load_typed!(bool, Bool),
176 DType::U8 => load_typed!(u8, U8),
177 DType::U16 => load_typed!(u16, U16),
178 DType::U32 => load_typed!(u32, U32),
179 DType::U64 => load_typed!(u64, U64),
180 DType::U128 => load_typed!(u128, U128),
181 DType::I8 => load_typed!(i8, I8),
182 DType::I16 => load_typed!(i16, I16),
183 DType::I32 => load_typed!(i32, I32),
184 DType::I64 => load_typed!(i64, I64),
185 DType::I128 => load_typed!(i128, I128),
186 #[cfg(feature = "f16")]
187 DType::F16 => load_typed!(half::f16, F16),
188 DType::F32 => load_typed!(f32, F32),
189 DType::F64 => load_typed!(f64, F64),
190 #[cfg(feature = "bf16")]
191 DType::BF16 => load_typed!(half::bf16, BF16),
192 DType::Complex32 => {
193 load_complex32_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
194 }
195 DType::Complex64 => {
196 load_complex64_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
197 }
198 _ => Err(FerrayError::invalid_dtype(format!(
199 "unsupported dtype {:?} for .npy loading",
200 hdr.dtype
201 ))),
202 }
203}
204
205fn load_complex32_dynamic<R: Read>(
207 reader: &mut R,
208 total: usize,
209 dim: IxDyn,
210 fortran_order: bool,
211 endian: Endianness,
212) -> FerrayResult<DynArray> {
213 let byte_count = total * 8;
214 let mut raw = vec![0u8; byte_count];
215 reader.read_exact(&mut raw)?;
216
217 if endian.needs_swap() {
218 for chunk in raw.chunks_exact_mut(4) {
219 chunk.reverse();
220 }
221 }
222
223 load_complex32_from_bytes_copy(&raw, total, dim, fortran_order)
224}
225
226fn load_complex32_from_bytes_copy(
228 bytes: &[u8],
229 total: usize,
230 dim: IxDyn,
231 fortran_order: bool,
232) -> FerrayResult<DynArray> {
233 use num_complex::Complex;
234
235 let mut data = Vec::with_capacity(total);
237 for chunk in bytes.chunks_exact(8) {
238 let re = f32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
239 let im = f32::from_ne_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
240 data.push(Complex::new(re, im));
241 }
242
243 let arr = if fortran_order {
244 Array::<Complex<f32>, IxDyn>::from_vec_f(dim, data)?
245 } else {
246 Array::<Complex<f32>, IxDyn>::from_vec(dim, data)?
247 };
248 Ok(DynArray::Complex32(arr))
249}
250
251fn load_complex64_dynamic<R: Read>(
253 reader: &mut R,
254 total: usize,
255 dim: IxDyn,
256 fortran_order: bool,
257 endian: Endianness,
258) -> FerrayResult<DynArray> {
259 let byte_count = total * 16;
260 let mut raw = vec![0u8; byte_count];
261 reader.read_exact(&mut raw)?;
262
263 if endian.needs_swap() {
264 for chunk in raw.chunks_exact_mut(8) {
265 chunk.reverse();
266 }
267 }
268
269 load_complex64_from_bytes_copy(&raw, total, dim, fortran_order)
270}
271
272fn load_complex64_from_bytes_copy(
273 bytes: &[u8],
274 total: usize,
275 dim: IxDyn,
276 fortran_order: bool,
277) -> FerrayResult<DynArray> {
278 use num_complex::Complex;
279
280 let mut data = Vec::with_capacity(total);
282 for chunk in bytes.chunks_exact(16) {
283 let re = f64::from_ne_bytes([
284 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
285 ]);
286 let im = f64::from_ne_bytes([
287 chunk[8], chunk[9], chunk[10], chunk[11], chunk[12], chunk[13], chunk[14], chunk[15],
288 ]);
289 data.push(Complex::new(re, im));
290 }
291
292 let arr = if fortran_order {
293 Array::<Complex<f64>, IxDyn>::from_vec_f(dim, data)?
294 } else {
295 Array::<Complex<f64>, IxDyn>::from_vec(dim, data)?
296 };
297 Ok(DynArray::Complex64(arr))
298}
299
300pub fn save_dynamic<P: AsRef<Path>>(path: P, array: &DynArray) -> FerrayResult<()> {
302 let file = File::create(path.as_ref()).map_err(|e| {
303 FerrayError::io_error(format!(
304 "failed to create file '{}': {e}",
305 path.as_ref().display()
306 ))
307 })?;
308 let mut writer = BufWriter::new(file);
309 save_dynamic_to_writer(&mut writer, array)
310}
311
312pub fn save_dynamic_to_writer<W: Write>(writer: &mut W, array: &DynArray) -> FerrayResult<()> {
314 macro_rules! save_typed {
315 ($arr:expr, $dtype:expr, $ty:ty) => {{
316 header::write_header(writer, $dtype, $arr.shape(), false)?;
317 if let Some(s) = $arr.as_slice() {
318 <$ty as NpyElement>::write_slice(s, writer)?;
319 } else {
320 let data: Vec<$ty> = $arr.iter().cloned().collect();
321 <$ty as NpyElement>::write_slice(&data, writer)?;
322 }
323 }};
324 }
325
326 match array {
327 DynArray::Bool(a) => save_typed!(a, DType::Bool, bool),
328 DynArray::U8(a) => save_typed!(a, DType::U8, u8),
329 DynArray::U16(a) => save_typed!(a, DType::U16, u16),
330 DynArray::U32(a) => save_typed!(a, DType::U32, u32),
331 DynArray::U64(a) => save_typed!(a, DType::U64, u64),
332 DynArray::U128(a) => save_typed!(a, DType::U128, u128),
333 DynArray::I8(a) => save_typed!(a, DType::I8, i8),
334 DynArray::I16(a) => save_typed!(a, DType::I16, i16),
335 DynArray::I32(a) => save_typed!(a, DType::I32, i32),
336 DynArray::I64(a) => save_typed!(a, DType::I64, i64),
337 DynArray::I128(a) => save_typed!(a, DType::I128, i128),
338 #[cfg(feature = "f16")]
339 DynArray::F16(a) => save_typed!(a, DType::F16, half::f16),
340 DynArray::F32(a) => save_typed!(a, DType::F32, f32),
341 DynArray::F64(a) => save_typed!(a, DType::F64, f64),
342 #[cfg(feature = "bf16")]
343 DynArray::BF16(a) => save_typed!(a, DType::BF16, half::bf16),
344 DynArray::Complex32(a) => {
345 header::write_header(writer, DType::Complex32, a.shape(), false)?;
346 save_complex_raw(a.as_slice(), 8, writer)?;
347 }
348 DynArray::Complex64(a) => {
349 header::write_header(writer, DType::Complex64, a.shape(), false)?;
350 save_complex_raw(a.as_slice(), 16, writer)?;
351 }
352 _ => {
353 return Err(FerrayError::invalid_dtype(
354 "unsupported DynArray variant for .npy saving",
355 ));
356 }
357 }
358
359 writer.flush()?;
360 Ok(())
361}
362
363fn save_complex_raw<T, W: Write>(
366 slice_opt: Option<&[T]>,
367 elem_size: usize,
368 writer: &mut W,
369) -> FerrayResult<()> {
370 let slice = slice_opt
371 .ok_or_else(|| FerrayError::io_error("cannot save non-contiguous complex array"))?;
372 let byte_len = slice.len() * elem_size;
373 let bytes = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, byte_len) };
374 writer.write_all(bytes)?;
375 Ok(())
376}
377
378fn build_dimension<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
380 build_dim_from_shape::<D>(shape)
381}
382
383fn build_dim_from_shape<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
386 use ferray_core::dimension::*;
387 use std::any::Any;
388
389 if let Some(ndim) = D::NDIM {
390 if shape.len() != ndim {
391 return Err(FerrayError::shape_mismatch(format!(
392 "expected {ndim} dimensions, got {}",
393 shape.len()
394 )));
395 }
396 }
397
398 let type_id = std::any::TypeId::of::<D>();
399
400 macro_rules! try_dim {
401 ($dim_ty:ty, $dim_val:expr) => {
402 if type_id == std::any::TypeId::of::<$dim_ty>() {
403 let boxed: Box<dyn Any> = Box::new($dim_val);
404 return Ok(*boxed.downcast::<D>().unwrap());
405 }
406 };
407 }
408
409 try_dim!(IxDyn, IxDyn::new(shape));
410
411 match shape.len() {
412 0 => {
413 try_dim!(Ix0, Ix0);
414 }
415 1 => {
416 try_dim!(Ix1, Ix1::new([shape[0]]));
417 }
418 2 => {
419 try_dim!(Ix2, Ix2::new([shape[0], shape[1]]));
420 }
421 3 => {
422 try_dim!(Ix3, Ix3::new([shape[0], shape[1], shape[2]]));
423 }
424 4 => {
425 try_dim!(Ix4, Ix4::new([shape[0], shape[1], shape[2], shape[3]]));
426 }
427 5 => {
428 try_dim!(
429 Ix5,
430 Ix5::new([shape[0], shape[1], shape[2], shape[3], shape[4]])
431 );
432 }
433 6 => {
434 try_dim!(
435 Ix6,
436 Ix6::new([shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]])
437 );
438 }
439 _ => {}
440 }
441
442 Err(FerrayError::io_error(
443 "unsupported dimension type for .npy loading",
444 ))
445}
446
447pub trait NpyElement: Element + private::NpySealed {
456 fn write_slice<W: Write>(data: &[Self], writer: &mut W) -> FerrayResult<()>;
458
459 fn read_vec<R: Read>(
461 reader: &mut R,
462 count: usize,
463 endian: Endianness,
464 ) -> FerrayResult<Vec<Self>>;
465}
466
467mod private {
468 pub trait NpySealed {}
469}
470
471macro_rules! impl_npy_element {
476 ($ty:ty, $size:expr) => {
477 impl private::NpySealed for $ty {}
478
479 impl NpyElement for $ty {
480 fn write_slice<W: Write>(data: &[$ty], writer: &mut W) -> FerrayResult<()> {
481 let byte_len = data.len() * $size;
485 let bytes =
488 unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, byte_len) };
489 writer.write_all(bytes)?;
490 Ok(())
491 }
492
493 fn read_vec<R: Read>(
494 reader: &mut R,
495 count: usize,
496 endian: Endianness,
497 ) -> FerrayResult<Vec<$ty>> {
498 if !endian.needs_swap() {
499 let byte_len = count * $size;
501 let mut raw = vec![0u8; byte_len];
502 reader.read_exact(&mut raw)?;
503 let mut result = Vec::with_capacity(count);
504 for chunk in raw.chunks_exact($size) {
505 let arr: [u8; $size] = chunk.try_into().unwrap();
506 result.push(<$ty>::from_ne_bytes(arr));
507 }
508 Ok(result)
509 } else {
510 let byte_len = count * $size;
512 let mut raw = vec![0u8; byte_len];
513 reader.read_exact(&mut raw)?;
514 let mut result = Vec::with_capacity(count);
515 for chunk in raw.chunks_exact_mut($size) {
516 chunk.reverse();
517 let arr: [u8; $size] = chunk.try_into().unwrap();
518 result.push(<$ty>::from_ne_bytes(arr));
519 }
520 Ok(result)
521 }
522 }
523 }
524 };
525}
526
527impl private::NpySealed for bool {}
529
530impl NpyElement for bool {
531 fn write_slice<W: Write>(data: &[bool], writer: &mut W) -> FerrayResult<()> {
532 for &val in data {
533 writer.write_all(&[val as u8])?;
534 }
535 Ok(())
536 }
537
538 fn read_vec<R: Read>(
539 reader: &mut R,
540 count: usize,
541 _endian: Endianness,
542 ) -> FerrayResult<Vec<bool>> {
543 let mut result = Vec::with_capacity(count);
544 let mut buf = [0u8; 1];
545 for _ in 0..count {
546 reader.read_exact(&mut buf)?;
547 result.push(buf[0] != 0);
548 }
549 Ok(result)
550 }
551}
552
553impl_npy_element!(u8, 1);
554impl_npy_element!(u16, 2);
555impl_npy_element!(u32, 4);
556impl_npy_element!(u64, 8);
557impl_npy_element!(u128, 16);
558impl_npy_element!(i8, 1);
559impl_npy_element!(i16, 2);
560impl_npy_element!(i32, 4);
561impl_npy_element!(i64, 8);
562impl_npy_element!(i128, 16);
563impl_npy_element!(f32, 4);
564impl_npy_element!(f64, 8);
565
566#[cfg(feature = "f16")]
567impl_npy_element!(half::f16, 2);
568#[cfg(feature = "bf16")]
569impl_npy_element!(half::bf16, 2);
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574 use ferray_core::dimension::{Ix1, Ix2};
575 use std::io::Cursor;
576
577 fn test_dir() -> std::path::PathBuf {
579 let dir = std::env::temp_dir().join(format!("ferray_io_test_{}", std::process::id()));
580 let _ = std::fs::create_dir_all(&dir);
581 dir
582 }
583
584 fn test_file(name: &str) -> std::path::PathBuf {
585 let dir = test_dir();
586 dir.join(name)
587 }
588
589 #[test]
590 fn roundtrip_f64_1d() {
591 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
592 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
593
594 let path = test_file("rt_f64_1d.npy");
595 save(&path, &arr).unwrap();
596 let loaded: Array<f64, Ix1> = load(&path).unwrap();
597
598 assert_eq!(loaded.shape(), &[5]);
599 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
600 let _ = std::fs::remove_file(&path);
601 }
602
603 #[test]
604 fn roundtrip_f32_2d() {
605 let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
606 let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
607
608 let path = test_file("rt_f32_2d.npy");
609 save(&path, &arr).unwrap();
610 let loaded: Array<f32, Ix2> = load(&path).unwrap();
611
612 assert_eq!(loaded.shape(), &[2, 3]);
613 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
614 let _ = std::fs::remove_file(&path);
615 }
616
617 #[test]
618 fn roundtrip_i32() {
619 let data = vec![10i32, 20, 30, 40];
620 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
621
622 let path = test_file("rt_i32.npy");
623 save(&path, &arr).unwrap();
624 let loaded: Array<i32, Ix1> = load(&path).unwrap();
625 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
626 let _ = std::fs::remove_file(&path);
627 }
628
629 #[test]
630 fn roundtrip_i64() {
631 let data = vec![100i64, 200, 300];
632 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
633
634 let path = test_file("rt_i64.npy");
635 save(&path, &arr).unwrap();
636 let loaded: Array<i64, Ix1> = load(&path).unwrap();
637 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
638 let _ = std::fs::remove_file(&path);
639 }
640
641 #[test]
642 fn roundtrip_u8() {
643 let data = vec![0u8, 128, 255];
644 let arr = Array::<u8, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
645
646 let path = test_file("rt_u8.npy");
647 save(&path, &arr).unwrap();
648 let loaded: Array<u8, Ix1> = load(&path).unwrap();
649 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
650 let _ = std::fs::remove_file(&path);
651 }
652
653 #[test]
654 fn roundtrip_bool() {
655 let data = vec![true, false, true, true, false];
656 let arr = Array::<bool, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
657
658 let path = test_file("rt_bool.npy");
659 save(&path, &arr).unwrap();
660 let loaded: Array<bool, Ix1> = load(&path).unwrap();
661 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
662 let _ = std::fs::remove_file(&path);
663 }
664
665 #[test]
666 fn roundtrip_in_memory() {
667 let data = vec![1.0_f64, 2.0, 3.0];
668 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
669
670 let mut buf = Vec::new();
671 save_to_writer(&mut buf, &arr).unwrap();
672
673 let mut cursor = Cursor::new(buf);
674 let loaded: Array<f64, Ix1> = load_from_reader(&mut cursor).unwrap();
675 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
676 }
677
678 #[test]
679 fn load_dynamic_f64() {
680 let data = vec![1.0_f64, 2.0, 3.0];
681 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
682
683 let path = test_file("dyn_f64.npy");
684 save(&path, &arr).unwrap();
685 let dyn_arr = load_dynamic(&path).unwrap();
686
687 assert_eq!(dyn_arr.dtype(), DType::F64);
688 assert_eq!(dyn_arr.shape(), &[3]);
689 let _ = std::fs::remove_file(&path);
690 }
691
692 #[test]
693 fn load_wrong_dtype_error() {
694 let data = vec![1.0_f64, 2.0, 3.0];
695 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
696
697 let path = test_file("wrong_dtype.npy");
698 save(&path, &arr).unwrap();
699
700 let result = load::<f32, Ix1, _>(&path);
701 assert!(result.is_err());
702 let _ = std::fs::remove_file(&path);
703 }
704
705 #[test]
706 fn load_wrong_ndim_error() {
707 let data = vec![1.0_f64, 2.0, 3.0];
708 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
709
710 let path = test_file("wrong_ndim.npy");
711 save(&path, &arr).unwrap();
712
713 let result = load::<f64, Ix2, _>(&path);
714 assert!(result.is_err());
715 let _ = std::fs::remove_file(&path);
716 }
717
718 #[test]
719 fn roundtrip_dynamic() {
720 let data = vec![10i32, 20, 30];
721 let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), data.clone()).unwrap();
722 let dyn_arr = DynArray::I32(arr);
723
724 let path = test_file("rt_dynamic.npy");
725 save_dynamic(&path, &dyn_arr).unwrap();
726
727 let loaded = load_dynamic(&path).unwrap();
728 assert_eq!(loaded.dtype(), DType::I32);
729 assert_eq!(loaded.shape(), &[3]);
730
731 let loaded_arr = loaded.try_into_i32().unwrap();
732 assert_eq!(loaded_arr.as_slice().unwrap(), &data[..]);
733 let _ = std::fs::remove_file(&path);
734 }
735
736 #[test]
737 fn load_dynamic_ixdyn() {
738 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
739 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
740
741 let path = test_file("dyn_ixdyn.npy");
742 save(&path, &arr).unwrap();
743
744 let loaded: Array<f64, IxDyn> = load(&path).unwrap();
746 assert_eq!(loaded.shape(), &[2, 3]);
747 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
748 let _ = std::fs::remove_file(&path);
749 }
750
751 #[test]
752 fn load_fortran_order_npy() {
753 let mut buf = Vec::new();
758 let header_str = "{'descr': '<f8', 'fortran_order': True, 'shape': (2, 3), }";
760 let header_len = header_str.len();
761 let total_before_pad = 6 + 2 + 2 + header_len;
763 let padding = 64 - (total_before_pad % 64);
764 let padded_header_len = header_len + padding;
765
766 buf.extend_from_slice(b"\x93NUMPY");
768 buf.push(1);
770 buf.push(0);
771 buf.extend_from_slice(&(padded_header_len as u16).to_le_bytes());
773 buf.extend_from_slice(header_str.as_bytes());
775 buf.extend(std::iter::repeat_n(b' ', padding - 1));
777 buf.push(b'\n');
778
779 for &v in &[1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0] {
783 buf.extend_from_slice(&v.to_le_bytes());
784 }
785
786 let mut cursor = Cursor::new(buf);
787 let loaded: Array<f64, Ix2> = load_from_reader(&mut cursor).unwrap();
788 assert_eq!(loaded.shape(), &[2, 3]);
789 let flat: Vec<f64> = loaded.iter().copied().collect();
791 assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
792 }
793
794 #[test]
795 fn roundtrip_from_vec_f() {
796 let data = vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0];
798 let arr = Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), data).unwrap();
799 assert_eq!(arr.shape(), &[2, 3]);
800
801 let mut buf = Vec::new();
803 save_to_writer(&mut buf, &arr).unwrap();
804
805 let mut cursor = Cursor::new(buf);
806 let loaded: Array<f64, Ix2> = load_from_reader(&mut cursor).unwrap();
807 assert_eq!(loaded.shape(), &[2, 3]);
808 let orig: Vec<f64> = arr.iter().copied().collect();
810 let back: Vec<f64> = loaded.iter().copied().collect();
811 assert_eq!(orig, back);
812 }
813
814 #[test]
817 fn malformed_bad_magic() {
818 let data = b"NOT_NPY_FILE_DATA_HERE";
819 let mut cursor = Cursor::new(data.to_vec());
820 let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
821 assert!(result.is_err());
822 let msg = result.unwrap_err().to_string();
823 assert!(
824 msg.contains("magic") || msg.contains("not a valid"),
825 "got: {msg}"
826 );
827 }
828
829 #[test]
830 fn malformed_truncated_header() {
831 let mut data = Vec::new();
833 data.extend_from_slice(b"\x93NUMPY");
834 data.push(1); data.push(0);
836 let mut cursor = Cursor::new(data);
838 let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
839 assert!(result.is_err());
840 }
841
842 #[test]
843 fn malformed_truncated_data() {
844 let mut buf = Vec::new();
846 let header_str = "{'descr': '<f8', 'fortran_order': False, 'shape': (100,), }";
847 let header_len = header_str.len();
848 let total = 6 + 2 + 2 + header_len;
849 let padding = 64 - (total % 64);
850 let padded_len = header_len + padding;
851
852 buf.extend_from_slice(b"\x93NUMPY");
853 buf.push(1);
854 buf.push(0);
855 buf.extend_from_slice(&(padded_len as u16).to_le_bytes());
856 buf.extend_from_slice(header_str.as_bytes());
857 buf.extend(std::iter::repeat_n(b' ', padding - 1));
858 buf.push(b'\n');
859 for &v in &[1.0_f64, 2.0, 3.0] {
861 buf.extend_from_slice(&v.to_le_bytes());
862 }
863
864 let mut cursor = Cursor::new(buf);
865 let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
866 assert!(result.is_err(), "should fail with truncated data");
867 }
868
869 #[test]
870 fn malformed_unsupported_version() {
871 let mut data = Vec::new();
872 data.extend_from_slice(b"\x93NUMPY");
873 data.push(9); data.push(0);
875 data.extend_from_slice(&[10, 0]); data.extend_from_slice(b"0123456789"); let mut cursor = Cursor::new(data);
878 let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
879 assert!(result.is_err());
880 let msg = result.unwrap_err().to_string();
881 assert!(msg.contains("version"), "got: {msg}");
882 }
883
884 #[test]
885 fn malformed_empty_file() {
886 let cursor = Cursor::new(Vec::<u8>::new());
887 let result = load_from_reader::<f64, Ix1, _>(&mut cursor.clone());
888 assert!(result.is_err());
889 }
890
891 #[test]
892 fn load_big_endian_f64() {
893 let mut buf = Vec::new();
895 let header_str = "{'descr': '>f8', 'fortran_order': False, 'shape': (3,), }";
896 let header_len = header_str.len();
897 let total = 6 + 2 + 2 + header_len;
898 let padding = 64 - (total % 64);
899 let padded_len = header_len + padding;
900
901 buf.extend_from_slice(b"\x93NUMPY");
902 buf.push(1);
903 buf.push(0);
904 buf.extend_from_slice(&(padded_len as u16).to_le_bytes());
905 buf.extend_from_slice(header_str.as_bytes());
906 buf.extend(std::iter::repeat_n(b' ', padding - 1));
907 buf.push(b'\n');
908
909 for &v in &[1.0_f64, 2.5, -4.75] {
911 buf.extend_from_slice(&v.to_be_bytes());
912 }
913
914 let mut cursor = Cursor::new(buf);
915 let loaded: Array<f64, Ix1> = load_from_reader(&mut cursor).unwrap();
916 assert_eq!(loaded.shape(), &[3]);
917 let data = loaded.as_slice().unwrap();
918 assert!((data[0] - 1.0).abs() < 1e-15);
919 assert!((data[1] - 2.5).abs() < 1e-15);
920 assert!((data[2] - (-4.75)).abs() < 1e-15);
921 }
922
923 #[test]
924 fn load_big_endian_i32() {
925 let mut buf = Vec::new();
927 let header_str = "{'descr': '>i4', 'fortran_order': False, 'shape': (4,), }";
928 let header_len = header_str.len();
929 let total = 6 + 2 + 2 + header_len;
930 let padding = 64 - (total % 64);
931 let padded_len = header_len + padding;
932
933 buf.extend_from_slice(b"\x93NUMPY");
934 buf.push(1);
935 buf.push(0);
936 buf.extend_from_slice(&(padded_len as u16).to_le_bytes());
937 buf.extend_from_slice(header_str.as_bytes());
938 buf.extend(std::iter::repeat_n(b' ', padding - 1));
939 buf.push(b'\n');
940
941 for &v in &[1_i32, -2, 1000, i32::MAX] {
942 buf.extend_from_slice(&v.to_be_bytes());
943 }
944
945 let mut cursor = Cursor::new(buf);
946 let loaded: Array<i32, Ix1> = load_from_reader(&mut cursor).unwrap();
947 assert_eq!(loaded.shape(), &[4]);
948 let data = loaded.as_slice().unwrap();
949 assert_eq!(data, &[1, -2, 1000, i32::MAX]);
950 }
951
952 #[cfg(feature = "f16")]
957 #[test]
958 fn roundtrip_f16_1d() {
959 use half::f16;
960 let data: Vec<f16> = [0.0, 1.0, -1.5, 2.25, 3.5, -0.125]
961 .iter()
962 .map(|&v: &f32| f16::from_f32(v))
963 .collect();
964 let arr = Array::<f16, Ix1>::from_vec(Ix1::new([6]), data.clone()).unwrap();
965
966 let path = test_file("rt_f16_1d.npy");
967 save(&path, &arr).unwrap();
968 let loaded: Array<f16, Ix1> = load(&path).unwrap();
969 assert_eq!(loaded.shape(), &[6]);
970 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
971 let _ = std::fs::remove_file(&path);
972 }
973
974 #[cfg(feature = "f16")]
975 #[test]
976 fn roundtrip_f16_2d() {
977 use half::f16;
978 let data: Vec<f16> = (0..12)
979 .map(|i| f16::from_f32(i as f32 * 0.25 - 1.0))
980 .collect();
981 let arr = Array::<f16, Ix2>::from_vec(Ix2::new([3, 4]), data.clone()).unwrap();
982
983 let path = test_file("rt_f16_2d.npy");
984 save(&path, &arr).unwrap();
985 let loaded: Array<f16, Ix2> = load(&path).unwrap();
986 assert_eq!(loaded.shape(), &[3, 4]);
987 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
988 let _ = std::fs::remove_file(&path);
989 }
990
991 #[cfg(feature = "f16")]
992 #[test]
993 fn roundtrip_f16_dynamic() {
994 use half::f16;
995 let data: Vec<f16> = (0..8).map(|i| f16::from_f32(i as f32)).collect();
996 let arr = Array::<f16, IxDyn>::from_vec(IxDyn::new(&[2, 4]), data.clone()).unwrap();
997 let dyn_in = DynArray::F16(arr);
998
999 let path = test_file("rt_f16_dyn.npy");
1000 save_dynamic(&path, &dyn_in).unwrap();
1001 let loaded = load_dynamic(&path).unwrap();
1002 assert_eq!(loaded.dtype(), DType::F16);
1003 assert_eq!(loaded.shape(), &[2, 4]);
1004 match loaded {
1005 DynArray::F16(a) => assert_eq!(a.as_slice().unwrap(), &data[..]),
1006 _ => panic!("expected F16 variant"),
1007 }
1008 let _ = std::fs::remove_file(&path);
1009 }
1010
1011 #[cfg(feature = "f16")]
1012 #[test]
1013 fn f16_descriptor_is_f2() {
1014 use half::f16;
1015 let arr = Array::<f16, Ix1>::from_vec(Ix1::new([2]), vec![f16::ZERO, f16::ONE]).unwrap();
1016 let mut buf = Vec::new();
1017 save_to_writer(&mut buf, &arr).unwrap();
1018 let header_len = buf.len().saturating_sub(4); let header = String::from_utf8_lossy(&buf[..header_len]);
1023 assert!(
1024 header.contains("f2"),
1025 "expected 'f2' in header, got: {header}"
1026 );
1027 }
1028
1029 #[cfg(feature = "bf16")]
1030 #[test]
1031 fn roundtrip_bf16_1d() {
1032 use half::bf16;
1033 let data: Vec<bf16> = [0.0, 1.0, -1.5, 2.25, 3.5, -0.125]
1034 .iter()
1035 .map(|&v: &f32| bf16::from_f32(v))
1036 .collect();
1037 let arr = Array::<bf16, Ix1>::from_vec(Ix1::new([6]), data.clone()).unwrap();
1038
1039 let path = test_file("rt_bf16_1d.npy");
1040 save(&path, &arr).unwrap();
1041 let loaded: Array<bf16, Ix1> = load(&path).unwrap();
1042 assert_eq!(loaded.shape(), &[6]);
1043 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
1044 let _ = std::fs::remove_file(&path);
1045 }
1046
1047 #[cfg(feature = "bf16")]
1048 #[test]
1049 fn roundtrip_bf16_dynamic() {
1050 use half::bf16;
1051 let data: Vec<bf16> = (0..6).map(|i| bf16::from_f32(i as f32 * 0.5)).collect();
1052 let arr = Array::<bf16, IxDyn>::from_vec(IxDyn::new(&[2, 3]), data.clone()).unwrap();
1053 let dyn_in = DynArray::BF16(arr);
1054
1055 let path = test_file("rt_bf16_dyn.npy");
1056 save_dynamic(&path, &dyn_in).unwrap();
1057 let loaded = load_dynamic(&path).unwrap();
1058 assert_eq!(loaded.dtype(), DType::BF16);
1059 assert_eq!(loaded.shape(), &[2, 3]);
1060 match loaded {
1061 DynArray::BF16(a) => assert_eq!(a.as_slice().unwrap(), &data[..]),
1062 _ => panic!("expected BF16 variant"),
1063 }
1064 let _ = std::fs::remove_file(&path);
1065 }
1066
1067 #[cfg(feature = "bf16")]
1068 #[test]
1069 fn bf16_descriptor_is_bf16_tag() {
1070 use half::bf16;
1071 let arr = Array::<bf16, Ix1>::from_vec(Ix1::new([2]), vec![bf16::ZERO, bf16::ONE]).unwrap();
1072 let mut buf = Vec::new();
1073 save_to_writer(&mut buf, &arr).unwrap();
1074 let header_len = buf.len().saturating_sub(4); let header = String::from_utf8_lossy(&buf[..header_len]);
1077 assert!(
1078 header.contains("bf16"),
1079 "expected 'bf16' in header, got: {header}"
1080 );
1081 }
1082}