1pub mod dtype_parse;
10pub mod header;
11
12use std::fs::File;
13use std::io::{BufReader, BufWriter, Read, Write};
14use std::path::Path;
15
16use ferray_core::Array;
17use ferray_core::dimension::{Dimension, IxDyn};
18use ferray_core::dtype::{DType, Element};
19use ferray_core::dynarray::DynArray;
20use ferray_core::error::{FerrayError, FerrayResult};
21
22use self::dtype_parse::Endianness;
23
24pub(crate) fn checked_total_elements(shape: &[usize]) -> FerrayResult<usize> {
27 shape.iter().try_fold(1usize, |acc, &dim| {
28 acc.checked_mul(dim).ok_or_else(|| {
29 FerrayError::io_error("shape overflow: total elements exceed usize::MAX")
30 })
31 })
32}
33
34pub fn save<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
41 path: P,
42 array: &Array<T, D>,
43) -> FerrayResult<()> {
44 let file = File::create(path.as_ref()).map_err(|e| {
45 FerrayError::io_error(format!(
46 "failed to create file '{}': {e}",
47 path.as_ref().display()
48 ))
49 })?;
50 let mut writer = BufWriter::new(file);
51 save_to_writer(&mut writer, array)
52}
53
54pub fn save_to_writer<T: Element + NpyElement, D: Dimension, W: Write>(
61 writer: &mut W,
62 array: &Array<T, D>,
63) -> FerrayResult<()> {
64 let fortran_order = false;
65 header::write_header(writer, T::dtype(), array.shape(), fortran_order)?;
66
67 if let Some(slice) = array.as_slice() {
69 T::write_slice(slice, writer)?;
70 } else {
71 let data: Vec<T> = array.iter().cloned().collect();
73 T::write_slice(&data, writer)?;
74 }
75
76 writer.flush()?;
77 Ok(())
78}
79
80pub fn load<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
87 path: P,
88) -> FerrayResult<Array<T, D>> {
89 let file = File::open(path.as_ref()).map_err(|e| {
90 FerrayError::io_error(format!(
91 "failed to open file '{}': {e}",
92 path.as_ref().display()
93 ))
94 })?;
95 let mut reader = BufReader::new(file);
96 load_from_reader(&mut reader)
97}
98
99pub fn load_from_reader<T: Element + NpyElement, D: Dimension, R: Read>(
101 reader: &mut R,
102) -> FerrayResult<Array<T, D>> {
103 let hdr = header::read_header(reader)?;
104
105 if hdr.dtype != T::dtype() {
107 return Err(FerrayError::invalid_dtype(format!(
108 "expected dtype {:?} for type {}, but file has {:?}",
109 T::dtype(),
110 std::any::type_name::<T>(),
111 hdr.dtype,
112 )));
113 }
114
115 if let Some(ndim) = D::NDIM {
117 if ndim != hdr.shape.len() {
118 return Err(FerrayError::shape_mismatch(format!(
119 "expected {} dimensions, but file has {} (shape {:?})",
120 ndim,
121 hdr.shape.len(),
122 hdr.shape,
123 )));
124 }
125 }
126
127 let total_elements = checked_total_elements(&hdr.shape)?;
128 let data = T::read_vec(reader, total_elements, hdr.endianness)?;
129
130 let dim = build_dimension::<D>(&hdr.shape)?;
131
132 if hdr.fortran_order {
133 Array::from_vec_f(dim, data)
134 } else {
135 Array::from_vec(dim, data)
136 }
137}
138
139pub fn load_dynamic<P: AsRef<Path>>(path: P) -> FerrayResult<DynArray> {
146 let file = File::open(path.as_ref()).map_err(|e| {
147 FerrayError::io_error(format!(
148 "failed to open file '{}': {e}",
149 path.as_ref().display()
150 ))
151 })?;
152 let mut reader = BufReader::new(file);
153 load_dynamic_from_reader(&mut reader)
154}
155
156pub fn load_dynamic_from_reader<R: Read>(reader: &mut R) -> FerrayResult<DynArray> {
158 let hdr = header::read_header(reader)?;
159 let total = checked_total_elements(&hdr.shape)?;
160 let dim = IxDyn::new(&hdr.shape);
161
162 macro_rules! load_typed {
163 ($ty:ty, $variant:ident) => {{
164 let data = <$ty as NpyElement>::read_vec(reader, total, hdr.endianness)?;
165 let arr = if hdr.fortran_order {
166 Array::<$ty, IxDyn>::from_vec_f(dim, data)?
167 } else {
168 Array::<$ty, IxDyn>::from_vec(dim, data)?
169 };
170 Ok(DynArray::$variant(arr))
171 }};
172 }
173
174 match hdr.dtype {
175 DType::Bool => load_typed!(bool, Bool),
176 DType::U8 => load_typed!(u8, U8),
177 DType::U16 => load_typed!(u16, U16),
178 DType::U32 => load_typed!(u32, U32),
179 DType::U64 => load_typed!(u64, U64),
180 DType::U128 => load_typed!(u128, U128),
181 DType::I8 => load_typed!(i8, I8),
182 DType::I16 => load_typed!(i16, I16),
183 DType::I32 => load_typed!(i32, I32),
184 DType::I64 => load_typed!(i64, I64),
185 DType::I128 => load_typed!(i128, I128),
186 #[cfg(feature = "f16")]
187 DType::F16 => load_typed!(half::f16, F16),
188 DType::F32 => load_typed!(f32, F32),
189 DType::F64 => load_typed!(f64, F64),
190 #[cfg(feature = "bf16")]
191 DType::BF16 => load_typed!(half::bf16, BF16),
192 DType::Complex32 => {
193 load_complex32_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
194 }
195 DType::Complex64 => {
196 load_complex64_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
197 }
198 DType::DateTime64(unit) => {
199 let data = <ferray_core::dtype::DateTime64 as NpyElement>::read_vec(
200 reader,
201 total,
202 hdr.endianness,
203 )?;
204 let arr = if hdr.fortran_order {
205 Array::<ferray_core::dtype::DateTime64, IxDyn>::from_vec_f(dim, data)?
206 } else {
207 Array::<ferray_core::dtype::DateTime64, IxDyn>::from_vec(dim, data)?
208 };
209 Ok(DynArray::DateTime64(arr, unit))
210 }
211 DType::Timedelta64(unit) => {
212 let data = <ferray_core::dtype::Timedelta64 as NpyElement>::read_vec(
213 reader,
214 total,
215 hdr.endianness,
216 )?;
217 let arr = if hdr.fortran_order {
218 Array::<ferray_core::dtype::Timedelta64, IxDyn>::from_vec_f(dim, data)?
219 } else {
220 Array::<ferray_core::dtype::Timedelta64, IxDyn>::from_vec(dim, data)?
221 };
222 Ok(DynArray::Timedelta64(arr, unit))
223 }
224 _ => Err(FerrayError::invalid_dtype(format!(
225 "unsupported dtype {:?} for .npy loading",
226 hdr.dtype
227 ))),
228 }
229}
230
231fn load_complex32_dynamic<R: Read>(
233 reader: &mut R,
234 total: usize,
235 dim: IxDyn,
236 fortran_order: bool,
237 endian: Endianness,
238) -> FerrayResult<DynArray> {
239 let byte_count = total * 8;
240 let mut raw = vec![0u8; byte_count];
241 reader.read_exact(&mut raw)?;
242
243 if endian.needs_swap() {
244 for chunk in raw.chunks_exact_mut(4) {
245 chunk.reverse();
246 }
247 }
248
249 load_complex32_from_bytes_copy(&raw, total, dim, fortran_order)
250}
251
252fn load_complex32_from_bytes_copy(
254 bytes: &[u8],
255 total: usize,
256 dim: IxDyn,
257 fortran_order: bool,
258) -> FerrayResult<DynArray> {
259 use num_complex::Complex;
260
261 let mut data = Vec::with_capacity(total);
263 for chunk in bytes.chunks_exact(8) {
264 let re = f32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
265 let im = f32::from_ne_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
266 data.push(Complex::new(re, im));
267 }
268
269 let arr = if fortran_order {
270 Array::<Complex<f32>, IxDyn>::from_vec_f(dim, data)?
271 } else {
272 Array::<Complex<f32>, IxDyn>::from_vec(dim, data)?
273 };
274 Ok(DynArray::Complex32(arr))
275}
276
277fn load_complex64_dynamic<R: Read>(
279 reader: &mut R,
280 total: usize,
281 dim: IxDyn,
282 fortran_order: bool,
283 endian: Endianness,
284) -> FerrayResult<DynArray> {
285 let byte_count = total * 16;
286 let mut raw = vec![0u8; byte_count];
287 reader.read_exact(&mut raw)?;
288
289 if endian.needs_swap() {
290 for chunk in raw.chunks_exact_mut(8) {
291 chunk.reverse();
292 }
293 }
294
295 load_complex64_from_bytes_copy(&raw, total, dim, fortran_order)
296}
297
298fn load_complex64_from_bytes_copy(
299 bytes: &[u8],
300 total: usize,
301 dim: IxDyn,
302 fortran_order: bool,
303) -> FerrayResult<DynArray> {
304 use num_complex::Complex;
305
306 let mut data = Vec::with_capacity(total);
308 for chunk in bytes.chunks_exact(16) {
309 let re = f64::from_ne_bytes([
310 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
311 ]);
312 let im = f64::from_ne_bytes([
313 chunk[8], chunk[9], chunk[10], chunk[11], chunk[12], chunk[13], chunk[14], chunk[15],
314 ]);
315 data.push(Complex::new(re, im));
316 }
317
318 let arr = if fortran_order {
319 Array::<Complex<f64>, IxDyn>::from_vec_f(dim, data)?
320 } else {
321 Array::<Complex<f64>, IxDyn>::from_vec(dim, data)?
322 };
323 Ok(DynArray::Complex64(arr))
324}
325
326pub fn save_dynamic<P: AsRef<Path>>(path: P, array: &DynArray) -> FerrayResult<()> {
328 let file = File::create(path.as_ref()).map_err(|e| {
329 FerrayError::io_error(format!(
330 "failed to create file '{}': {e}",
331 path.as_ref().display()
332 ))
333 })?;
334 let mut writer = BufWriter::new(file);
335 save_dynamic_to_writer(&mut writer, array)
336}
337
338pub fn save_dynamic_to_writer<W: Write>(writer: &mut W, array: &DynArray) -> FerrayResult<()> {
340 macro_rules! save_typed {
341 ($arr:expr, $dtype:expr, $ty:ty) => {{
342 header::write_header(writer, $dtype, $arr.shape(), false)?;
343 if let Some(s) = $arr.as_slice() {
344 <$ty as NpyElement>::write_slice(s, writer)?;
345 } else {
346 let data: Vec<$ty> = $arr.iter().cloned().collect();
347 <$ty as NpyElement>::write_slice(&data, writer)?;
348 }
349 }};
350 }
351
352 match array {
353 DynArray::Bool(a) => save_typed!(a, DType::Bool, bool),
354 DynArray::U8(a) => save_typed!(a, DType::U8, u8),
355 DynArray::U16(a) => save_typed!(a, DType::U16, u16),
356 DynArray::U32(a) => save_typed!(a, DType::U32, u32),
357 DynArray::U64(a) => save_typed!(a, DType::U64, u64),
358 DynArray::U128(a) => save_typed!(a, DType::U128, u128),
359 DynArray::I8(a) => save_typed!(a, DType::I8, i8),
360 DynArray::I16(a) => save_typed!(a, DType::I16, i16),
361 DynArray::I32(a) => save_typed!(a, DType::I32, i32),
362 DynArray::I64(a) => save_typed!(a, DType::I64, i64),
363 DynArray::I128(a) => save_typed!(a, DType::I128, i128),
364 #[cfg(feature = "f16")]
365 DynArray::F16(a) => save_typed!(a, DType::F16, half::f16),
366 DynArray::F32(a) => save_typed!(a, DType::F32, f32),
367 DynArray::F64(a) => save_typed!(a, DType::F64, f64),
368 #[cfg(feature = "bf16")]
369 DynArray::BF16(a) => save_typed!(a, DType::BF16, half::bf16),
370 DynArray::Complex32(a) => {
371 header::write_header(writer, DType::Complex32, a.shape(), false)?;
372 save_complex_raw(a.as_slice(), 8, writer)?;
373 }
374 DynArray::Complex64(a) => {
375 header::write_header(writer, DType::Complex64, a.shape(), false)?;
376 save_complex_raw(a.as_slice(), 16, writer)?;
377 }
378 _ => {
379 return Err(FerrayError::invalid_dtype(
380 "unsupported DynArray variant for .npy saving",
381 ));
382 }
383 }
384
385 writer.flush()?;
386 Ok(())
387}
388
389fn save_complex_raw<T, W: Write>(
392 slice_opt: Option<&[T]>,
393 elem_size: usize,
394 writer: &mut W,
395) -> FerrayResult<()> {
396 let slice = slice_opt
397 .ok_or_else(|| FerrayError::io_error("cannot save non-contiguous complex array"))?;
398 let byte_len = slice.len() * elem_size;
399 let bytes = unsafe { std::slice::from_raw_parts(slice.as_ptr().cast::<u8>(), byte_len) };
400 writer.write_all(bytes)?;
401 Ok(())
402}
403
404fn build_dimension<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
406 build_dim_from_shape::<D>(shape)
407}
408
409fn build_dim_from_shape<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
412 use ferray_core::dimension::{Ix0, Ix1, Ix2, Ix3, Ix4, Ix5, Ix6, IxDyn};
413 use std::any::Any;
414
415 if let Some(ndim) = D::NDIM {
416 if shape.len() != ndim {
417 return Err(FerrayError::shape_mismatch(format!(
418 "expected {ndim} dimensions, got {}",
419 shape.len()
420 )));
421 }
422 }
423
424 let type_id = std::any::TypeId::of::<D>();
425
426 macro_rules! try_dim {
427 ($dim_ty:ty, $dim_val:expr) => {
428 if type_id == std::any::TypeId::of::<$dim_ty>() {
429 let boxed: Box<dyn Any> = Box::new($dim_val);
430 return Ok(*boxed.downcast::<D>().unwrap());
431 }
432 };
433 }
434
435 try_dim!(IxDyn, IxDyn::new(shape));
436
437 match shape.len() {
438 0 => {
439 try_dim!(Ix0, Ix0);
440 }
441 1 => {
442 try_dim!(Ix1, Ix1::new([shape[0]]));
443 }
444 2 => {
445 try_dim!(Ix2, Ix2::new([shape[0], shape[1]]));
446 }
447 3 => {
448 try_dim!(Ix3, Ix3::new([shape[0], shape[1], shape[2]]));
449 }
450 4 => {
451 try_dim!(Ix4, Ix4::new([shape[0], shape[1], shape[2], shape[3]]));
452 }
453 5 => {
454 try_dim!(
455 Ix5,
456 Ix5::new([shape[0], shape[1], shape[2], shape[3], shape[4]])
457 );
458 }
459 6 => {
460 try_dim!(
461 Ix6,
462 Ix6::new([shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]])
463 );
464 }
465 _ => {}
466 }
467
468 Err(FerrayError::io_error(
469 "unsupported dimension type for .npy loading",
470 ))
471}
472
473pub trait NpyElement: Element + private::NpySealed {
482 fn write_slice<W: Write>(data: &[Self], writer: &mut W) -> FerrayResult<()>;
484
485 fn read_vec<R: Read>(
487 reader: &mut R,
488 count: usize,
489 endian: Endianness,
490 ) -> FerrayResult<Vec<Self>>;
491}
492
493mod private {
494 pub trait NpySealed {}
495}
496
497macro_rules! impl_npy_element {
502 ($ty:ty, $size:expr) => {
503 impl private::NpySealed for $ty {}
504
505 impl NpyElement for $ty {
506 fn write_slice<W: Write>(data: &[$ty], writer: &mut W) -> FerrayResult<()> {
507 let byte_len = data.len() * $size;
511 let bytes =
514 unsafe { std::slice::from_raw_parts(data.as_ptr().cast::<u8>(), byte_len) };
515 writer.write_all(bytes)?;
516 Ok(())
517 }
518
519 fn read_vec<R: Read>(
520 reader: &mut R,
521 count: usize,
522 endian: Endianness,
523 ) -> FerrayResult<Vec<$ty>> {
524 if !endian.needs_swap() {
525 let byte_len = count * $size;
527 let mut raw = vec![0u8; byte_len];
528 reader.read_exact(&mut raw)?;
529 let mut result = Vec::with_capacity(count);
530 for chunk in raw.chunks_exact($size) {
531 let arr: [u8; $size] = chunk.try_into().unwrap();
532 result.push(<$ty>::from_ne_bytes(arr));
533 }
534 Ok(result)
535 } else {
536 let byte_len = count * $size;
538 let mut raw = vec![0u8; byte_len];
539 reader.read_exact(&mut raw)?;
540 let mut result = Vec::with_capacity(count);
541 for chunk in raw.chunks_exact_mut($size) {
542 chunk.reverse();
543 let arr: [u8; $size] = chunk.try_into().unwrap();
544 result.push(<$ty>::from_ne_bytes(arr));
545 }
546 Ok(result)
547 }
548 }
549 }
550 };
551}
552
553impl private::NpySealed for bool {}
555
556impl NpyElement for bool {
557 fn write_slice<W: Write>(data: &[Self], writer: &mut W) -> FerrayResult<()> {
558 for &val in data {
559 writer.write_all(&[u8::from(val)])?;
560 }
561 Ok(())
562 }
563
564 fn read_vec<R: Read>(
565 reader: &mut R,
566 count: usize,
567 _endian: Endianness,
568 ) -> FerrayResult<Vec<Self>> {
569 let mut result = Vec::with_capacity(count);
570 let mut buf = [0u8; 1];
571 for _ in 0..count {
572 reader.read_exact(&mut buf)?;
573 result.push(buf[0] != 0);
574 }
575 Ok(result)
576 }
577}
578
579impl_npy_element!(u8, 1);
580impl_npy_element!(u16, 2);
581impl_npy_element!(u32, 4);
582impl_npy_element!(u64, 8);
583impl_npy_element!(u128, 16);
584impl_npy_element!(i8, 1);
585impl_npy_element!(i16, 2);
586impl_npy_element!(i32, 4);
587impl_npy_element!(i64, 8);
588impl_npy_element!(i128, 16);
589impl_npy_element!(f32, 4);
590impl_npy_element!(f64, 8);
591
592#[cfg(feature = "f16")]
593impl_npy_element!(half::f16, 2);
594#[cfg(feature = "bf16")]
595impl_npy_element!(half::bf16, 2);
596
597impl private::NpySealed for ferray_core::dtype::DateTime64 {}
603impl private::NpySealed for ferray_core::dtype::Timedelta64 {}
604
605macro_rules! impl_npy_time_element {
606 ($ty:path) => {
607 impl NpyElement for $ty {
608 fn write_slice<W: Write>(data: &[Self], writer: &mut W) -> FerrayResult<()> {
609 for v in data {
610 writer.write_all(&v.0.to_ne_bytes())?;
611 }
612 Ok(())
613 }
614
615 fn read_vec<R: Read>(
616 reader: &mut R,
617 count: usize,
618 endian: Endianness,
619 ) -> FerrayResult<Vec<Self>> {
620 let mut out = Vec::with_capacity(count);
621 let mut buf = [0u8; 8];
622 for _ in 0..count {
623 reader.read_exact(&mut buf)?;
624 if endian.needs_swap() {
625 buf.reverse();
626 }
627 out.push(Self(i64::from_ne_bytes(buf)));
628 }
629 Ok(out)
630 }
631 }
632 };
633}
634
635impl_npy_time_element!(ferray_core::dtype::DateTime64);
636impl_npy_time_element!(ferray_core::dtype::Timedelta64);
637
638#[cfg(test)]
639mod tests {
640 use super::*;
641 use ferray_core::dimension::{Ix1, Ix2};
642 use std::io::Cursor;
643
644 fn test_dir() -> std::path::PathBuf {
646 let dir = std::env::temp_dir().join(format!("ferray_io_test_{}", std::process::id()));
647 let _ = std::fs::create_dir_all(&dir);
648 dir
649 }
650
651 fn test_file(name: &str) -> std::path::PathBuf {
652 let dir = test_dir();
653 dir.join(name)
654 }
655
656 #[test]
657 fn roundtrip_f64_1d() {
658 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
659 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
660
661 let path = test_file("rt_f64_1d.npy");
662 save(&path, &arr).unwrap();
663 let loaded: Array<f64, Ix1> = load(&path).unwrap();
664
665 assert_eq!(loaded.shape(), &[5]);
666 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
667 let _ = std::fs::remove_file(&path);
668 }
669
670 #[test]
671 fn roundtrip_f32_2d() {
672 let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
673 let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
674
675 let path = test_file("rt_f32_2d.npy");
676 save(&path, &arr).unwrap();
677 let loaded: Array<f32, Ix2> = load(&path).unwrap();
678
679 assert_eq!(loaded.shape(), &[2, 3]);
680 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
681 let _ = std::fs::remove_file(&path);
682 }
683
684 #[test]
685 fn roundtrip_i32() {
686 let data = vec![10i32, 20, 30, 40];
687 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
688
689 let path = test_file("rt_i32.npy");
690 save(&path, &arr).unwrap();
691 let loaded: Array<i32, Ix1> = load(&path).unwrap();
692 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
693 let _ = std::fs::remove_file(&path);
694 }
695
696 #[test]
697 fn roundtrip_i64() {
698 let data = vec![100i64, 200, 300];
699 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
700
701 let path = test_file("rt_i64.npy");
702 save(&path, &arr).unwrap();
703 let loaded: Array<i64, Ix1> = load(&path).unwrap();
704 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
705 let _ = std::fs::remove_file(&path);
706 }
707
708 #[test]
709 fn roundtrip_u8() {
710 let data = vec![0u8, 128, 255];
711 let arr = Array::<u8, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
712
713 let path = test_file("rt_u8.npy");
714 save(&path, &arr).unwrap();
715 let loaded: Array<u8, Ix1> = load(&path).unwrap();
716 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
717 let _ = std::fs::remove_file(&path);
718 }
719
720 #[test]
721 fn roundtrip_bool() {
722 let data = vec![true, false, true, true, false];
723 let arr = Array::<bool, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
724
725 let path = test_file("rt_bool.npy");
726 save(&path, &arr).unwrap();
727 let loaded: Array<bool, Ix1> = load(&path).unwrap();
728 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
729 let _ = std::fs::remove_file(&path);
730 }
731
732 #[test]
733 fn roundtrip_in_memory() {
734 let data = vec![1.0_f64, 2.0, 3.0];
735 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
736
737 let mut buf = Vec::new();
738 save_to_writer(&mut buf, &arr).unwrap();
739
740 let mut cursor = Cursor::new(buf);
741 let loaded: Array<f64, Ix1> = load_from_reader(&mut cursor).unwrap();
742 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
743 }
744
745 #[test]
746 fn load_dynamic_f64() {
747 let data = vec![1.0_f64, 2.0, 3.0];
748 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
749
750 let path = test_file("dyn_f64.npy");
751 save(&path, &arr).unwrap();
752 let dyn_arr = load_dynamic(&path).unwrap();
753
754 assert_eq!(dyn_arr.dtype(), DType::F64);
755 assert_eq!(dyn_arr.shape(), &[3]);
756 let _ = std::fs::remove_file(&path);
757 }
758
759 #[test]
760 fn load_wrong_dtype_error() {
761 let data = vec![1.0_f64, 2.0, 3.0];
762 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
763
764 let path = test_file("wrong_dtype.npy");
765 save(&path, &arr).unwrap();
766
767 let result = load::<f32, Ix1, _>(&path);
768 assert!(result.is_err());
769 let _ = std::fs::remove_file(&path);
770 }
771
772 #[test]
773 fn load_wrong_ndim_error() {
774 let data = vec![1.0_f64, 2.0, 3.0];
775 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
776
777 let path = test_file("wrong_ndim.npy");
778 save(&path, &arr).unwrap();
779
780 let result = load::<f64, Ix2, _>(&path);
781 assert!(result.is_err());
782 let _ = std::fs::remove_file(&path);
783 }
784
785 #[test]
786 fn roundtrip_dynamic() {
787 let data = vec![10i32, 20, 30];
788 let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), data.clone()).unwrap();
789 let dyn_arr = DynArray::I32(arr);
790
791 let path = test_file("rt_dynamic.npy");
792 save_dynamic(&path, &dyn_arr).unwrap();
793
794 let loaded = load_dynamic(&path).unwrap();
795 assert_eq!(loaded.dtype(), DType::I32);
796 assert_eq!(loaded.shape(), &[3]);
797
798 let loaded_arr = loaded.try_into_i32().unwrap();
799 assert_eq!(loaded_arr.as_slice().unwrap(), &data[..]);
800 let _ = std::fs::remove_file(&path);
801 }
802
803 #[test]
804 fn load_dynamic_ixdyn() {
805 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
806 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
807
808 let path = test_file("dyn_ixdyn.npy");
809 save(&path, &arr).unwrap();
810
811 let loaded: Array<f64, IxDyn> = load(&path).unwrap();
813 assert_eq!(loaded.shape(), &[2, 3]);
814 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
815 let _ = std::fs::remove_file(&path);
816 }
817
818 #[test]
819 fn load_fortran_order_npy() {
820 let mut buf = Vec::new();
825 let header_str = "{'descr': '<f8', 'fortran_order': True, 'shape': (2, 3), }";
827 let header_len = header_str.len();
828 let total_before_pad = 6 + 2 + 2 + header_len;
830 let padding = 64 - (total_before_pad % 64);
831 let padded_header_len = header_len + padding;
832
833 buf.extend_from_slice(b"\x93NUMPY");
835 buf.push(1);
837 buf.push(0);
838 buf.extend_from_slice(&(padded_header_len as u16).to_le_bytes());
840 buf.extend_from_slice(header_str.as_bytes());
842 buf.extend(std::iter::repeat_n(b' ', padding - 1));
844 buf.push(b'\n');
845
846 for &v in &[1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0] {
850 buf.extend_from_slice(&v.to_le_bytes());
851 }
852
853 let mut cursor = Cursor::new(buf);
854 let loaded: Array<f64, Ix2> = load_from_reader(&mut cursor).unwrap();
855 assert_eq!(loaded.shape(), &[2, 3]);
856 let flat: Vec<f64> = loaded.iter().copied().collect();
858 assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
859 }
860
861 #[test]
862 fn roundtrip_from_vec_f() {
863 let data = vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0];
865 let arr = Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), data).unwrap();
866 assert_eq!(arr.shape(), &[2, 3]);
867
868 let mut buf = Vec::new();
870 save_to_writer(&mut buf, &arr).unwrap();
871
872 let mut cursor = Cursor::new(buf);
873 let loaded: Array<f64, Ix2> = load_from_reader(&mut cursor).unwrap();
874 assert_eq!(loaded.shape(), &[2, 3]);
875 let orig: Vec<f64> = arr.iter().copied().collect();
877 let back: Vec<f64> = loaded.iter().copied().collect();
878 assert_eq!(orig, back);
879 }
880
881 #[test]
884 fn malformed_bad_magic() {
885 let data = b"NOT_NPY_FILE_DATA_HERE";
886 let mut cursor = Cursor::new(data.to_vec());
887 let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
888 assert!(result.is_err());
889 let msg = result.unwrap_err().to_string();
890 assert!(
891 msg.contains("magic") || msg.contains("not a valid"),
892 "got: {msg}"
893 );
894 }
895
896 #[test]
897 fn malformed_truncated_header() {
898 let mut data = Vec::new();
900 data.extend_from_slice(b"\x93NUMPY");
901 data.push(1); data.push(0);
903 let mut cursor = Cursor::new(data);
905 let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
906 assert!(result.is_err());
907 }
908
909 #[test]
910 fn malformed_truncated_data() {
911 let mut buf = Vec::new();
913 let header_str = "{'descr': '<f8', 'fortran_order': False, 'shape': (100,), }";
914 let header_len = header_str.len();
915 let total = 6 + 2 + 2 + header_len;
916 let padding = 64 - (total % 64);
917 let padded_len = header_len + padding;
918
919 buf.extend_from_slice(b"\x93NUMPY");
920 buf.push(1);
921 buf.push(0);
922 buf.extend_from_slice(&(padded_len as u16).to_le_bytes());
923 buf.extend_from_slice(header_str.as_bytes());
924 buf.extend(std::iter::repeat_n(b' ', padding - 1));
925 buf.push(b'\n');
926 for &v in &[1.0_f64, 2.0, 3.0] {
928 buf.extend_from_slice(&v.to_le_bytes());
929 }
930
931 let mut cursor = Cursor::new(buf);
932 let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
933 assert!(result.is_err(), "should fail with truncated data");
934 }
935
936 #[test]
937 fn malformed_unsupported_version() {
938 let mut data = Vec::new();
939 data.extend_from_slice(b"\x93NUMPY");
940 data.push(9); data.push(0);
942 data.extend_from_slice(&[10, 0]); data.extend_from_slice(b"0123456789"); let mut cursor = Cursor::new(data);
945 let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
946 assert!(result.is_err());
947 let msg = result.unwrap_err().to_string();
948 assert!(msg.contains("version"), "got: {msg}");
949 }
950
951 #[test]
952 fn malformed_empty_file() {
953 let mut cursor = Cursor::new(Vec::<u8>::new());
954 let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
955 assert!(result.is_err());
956 }
957
958 #[test]
959 fn load_big_endian_f64() {
960 let mut buf = Vec::new();
962 let header_str = "{'descr': '>f8', 'fortran_order': False, 'shape': (3,), }";
963 let header_len = header_str.len();
964 let total = 6 + 2 + 2 + header_len;
965 let padding = 64 - (total % 64);
966 let padded_len = header_len + padding;
967
968 buf.extend_from_slice(b"\x93NUMPY");
969 buf.push(1);
970 buf.push(0);
971 buf.extend_from_slice(&(padded_len as u16).to_le_bytes());
972 buf.extend_from_slice(header_str.as_bytes());
973 buf.extend(std::iter::repeat_n(b' ', padding - 1));
974 buf.push(b'\n');
975
976 for &v in &[1.0_f64, 2.5, -4.75] {
978 buf.extend_from_slice(&v.to_be_bytes());
979 }
980
981 let mut cursor = Cursor::new(buf);
982 let loaded: Array<f64, Ix1> = load_from_reader(&mut cursor).unwrap();
983 assert_eq!(loaded.shape(), &[3]);
984 let data = loaded.as_slice().unwrap();
985 assert!((data[0] - 1.0).abs() < 1e-15);
986 assert!((data[1] - 2.5).abs() < 1e-15);
987 assert!((data[2] - (-4.75)).abs() < 1e-15);
988 }
989
990 #[test]
991 fn load_big_endian_i32() {
992 let mut buf = Vec::new();
994 let header_str = "{'descr': '>i4', 'fortran_order': False, 'shape': (4,), }";
995 let header_len = header_str.len();
996 let total = 6 + 2 + 2 + header_len;
997 let padding = 64 - (total % 64);
998 let padded_len = header_len + padding;
999
1000 buf.extend_from_slice(b"\x93NUMPY");
1001 buf.push(1);
1002 buf.push(0);
1003 buf.extend_from_slice(&(padded_len as u16).to_le_bytes());
1004 buf.extend_from_slice(header_str.as_bytes());
1005 buf.extend(std::iter::repeat_n(b' ', padding - 1));
1006 buf.push(b'\n');
1007
1008 for &v in &[1_i32, -2, 1000, i32::MAX] {
1009 buf.extend_from_slice(&v.to_be_bytes());
1010 }
1011
1012 let mut cursor = Cursor::new(buf);
1013 let loaded: Array<i32, Ix1> = load_from_reader(&mut cursor).unwrap();
1014 assert_eq!(loaded.shape(), &[4]);
1015 let data = loaded.as_slice().unwrap();
1016 assert_eq!(data, &[1, -2, 1000, i32::MAX]);
1017 }
1018
1019 #[cfg(feature = "f16")]
1024 #[test]
1025 fn roundtrip_f16_1d() {
1026 use half::f16;
1027 let data: Vec<f16> = [0.0, 1.0, -1.5, 2.25, 3.5, -0.125]
1028 .iter()
1029 .map(|&v: &f32| f16::from_f32(v))
1030 .collect();
1031 let arr = Array::<f16, Ix1>::from_vec(Ix1::new([6]), data.clone()).unwrap();
1032
1033 let path = test_file("rt_f16_1d.npy");
1034 save(&path, &arr).unwrap();
1035 let loaded: Array<f16, Ix1> = load(&path).unwrap();
1036 assert_eq!(loaded.shape(), &[6]);
1037 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
1038 let _ = std::fs::remove_file(&path);
1039 }
1040
1041 #[cfg(feature = "f16")]
1042 #[test]
1043 fn roundtrip_f16_2d() {
1044 use half::f16;
1045 let data: Vec<f16> = (0..12)
1046 .map(|i| f16::from_f32((i as f32).mul_add(0.25, -1.0)))
1047 .collect();
1048 let arr = Array::<f16, Ix2>::from_vec(Ix2::new([3, 4]), data.clone()).unwrap();
1049
1050 let path = test_file("rt_f16_2d.npy");
1051 save(&path, &arr).unwrap();
1052 let loaded: Array<f16, Ix2> = load(&path).unwrap();
1053 assert_eq!(loaded.shape(), &[3, 4]);
1054 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
1055 let _ = std::fs::remove_file(&path);
1056 }
1057
1058 #[cfg(feature = "f16")]
1059 #[test]
1060 fn roundtrip_f16_dynamic() {
1061 use half::f16;
1062 let data: Vec<f16> = (0..8).map(|i| f16::from_f32(i as f32)).collect();
1063 let arr = Array::<f16, IxDyn>::from_vec(IxDyn::new(&[2, 4]), data.clone()).unwrap();
1064 let dyn_in = DynArray::F16(arr);
1065
1066 let path = test_file("rt_f16_dyn.npy");
1067 save_dynamic(&path, &dyn_in).unwrap();
1068 let loaded = load_dynamic(&path).unwrap();
1069 assert_eq!(loaded.dtype(), DType::F16);
1070 assert_eq!(loaded.shape(), &[2, 4]);
1071 match loaded {
1072 DynArray::F16(a) => assert_eq!(a.as_slice().unwrap(), &data[..]),
1073 _ => panic!("expected F16 variant"),
1074 }
1075 let _ = std::fs::remove_file(&path);
1076 }
1077
1078 #[cfg(feature = "f16")]
1079 #[test]
1080 fn f16_descriptor_is_f2() {
1081 use half::f16;
1082 let arr = Array::<f16, Ix1>::from_vec(Ix1::new([2]), vec![f16::ZERO, f16::ONE]).unwrap();
1083 let mut buf = Vec::new();
1084 save_to_writer(&mut buf, &arr).unwrap();
1085 let header_len = buf.len().saturating_sub(4); let header = String::from_utf8_lossy(&buf[..header_len]);
1090 assert!(
1091 header.contains("f2"),
1092 "expected 'f2' in header, got: {header}"
1093 );
1094 }
1095
1096 #[cfg(feature = "bf16")]
1097 #[test]
1098 fn roundtrip_bf16_1d() {
1099 use half::bf16;
1100 let data: Vec<bf16> = [0.0, 1.0, -1.5, 2.25, 3.5, -0.125]
1101 .iter()
1102 .map(|&v: &f32| bf16::from_f32(v))
1103 .collect();
1104 let arr = Array::<bf16, Ix1>::from_vec(Ix1::new([6]), data.clone()).unwrap();
1105
1106 let path = test_file("rt_bf16_1d.npy");
1107 save(&path, &arr).unwrap();
1108 let loaded: Array<bf16, Ix1> = load(&path).unwrap();
1109 assert_eq!(loaded.shape(), &[6]);
1110 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
1111 let _ = std::fs::remove_file(&path);
1112 }
1113
1114 #[cfg(feature = "bf16")]
1115 #[test]
1116 fn roundtrip_bf16_dynamic() {
1117 use half::bf16;
1118 let data: Vec<bf16> = (0..6).map(|i| bf16::from_f32(i as f32 * 0.5)).collect();
1119 let arr = Array::<bf16, IxDyn>::from_vec(IxDyn::new(&[2, 3]), data.clone()).unwrap();
1120 let dyn_in = DynArray::BF16(arr);
1121
1122 let path = test_file("rt_bf16_dyn.npy");
1123 save_dynamic(&path, &dyn_in).unwrap();
1124 let loaded = load_dynamic(&path).unwrap();
1125 assert_eq!(loaded.dtype(), DType::BF16);
1126 assert_eq!(loaded.shape(), &[2, 3]);
1127 match loaded {
1128 DynArray::BF16(a) => assert_eq!(a.as_slice().unwrap(), &data[..]),
1129 _ => panic!("expected BF16 variant"),
1130 }
1131 let _ = std::fs::remove_file(&path);
1132 }
1133
1134 #[cfg(feature = "bf16")]
1135 #[test]
1136 fn bf16_descriptor_is_bf16_tag() {
1137 use half::bf16;
1138 let arr = Array::<bf16, Ix1>::from_vec(Ix1::new([2]), vec![bf16::ZERO, bf16::ONE]).unwrap();
1139 let mut buf = Vec::new();
1140 save_to_writer(&mut buf, &arr).unwrap();
1141 let header_len = buf.len().saturating_sub(4); let header = String::from_utf8_lossy(&buf[..header_len]);
1144 assert!(
1145 header.contains("bf16"),
1146 "expected 'bf16' in header, got: {header}"
1147 );
1148 }
1149
1150 #[test]
1151 fn datetime64_npy_roundtrip() {
1152 use ferray_core::dtype::DateTime64;
1153 let original = vec![
1154 DateTime64(0),
1155 DateTime64::nat(),
1156 DateTime64(1_700_000_000_000_000_000),
1157 DateTime64(-1),
1158 ];
1159 let arr = Array::<DateTime64, Ix1>::from_vec(Ix1::new([4]), original.clone()).unwrap();
1160 let mut buf = Vec::new();
1161 save_to_writer(&mut buf, &arr).unwrap();
1162 let header = String::from_utf8_lossy(&buf);
1164 assert!(
1165 header.contains("M8[ns]"),
1166 "expected 'M8[ns]' in header, got: {header}"
1167 );
1168
1169 let mut cursor = std::io::Cursor::new(buf);
1170 let back: Array<DateTime64, Ix1> = load_from_reader(&mut cursor).unwrap();
1171 let v: Vec<DateTime64> = back.iter().copied().collect();
1172 for (a, b) in v.iter().zip(original.iter()) {
1174 assert_eq!(a.0, b.0, "DateTime64 i64 round-trip mismatch");
1175 }
1176 }
1177
1178 #[test]
1179 fn timedelta64_npy_roundtrip() {
1180 use ferray_core::dtype::Timedelta64;
1181 let original = vec![Timedelta64(1000), Timedelta64::nat(), Timedelta64(0)];
1182 let arr = Array::<Timedelta64, Ix1>::from_vec(Ix1::new([3]), original.clone()).unwrap();
1183 let mut buf = Vec::new();
1184 save_to_writer(&mut buf, &arr).unwrap();
1185 let header = String::from_utf8_lossy(&buf);
1186 assert!(
1187 header.contains("m8[ns]"),
1188 "expected 'm8[ns]' in header, got: {header}"
1189 );
1190
1191 let mut cursor = std::io::Cursor::new(buf);
1192 let back: Array<Timedelta64, Ix1> = load_from_reader(&mut cursor).unwrap();
1193 let v: Vec<Timedelta64> = back.iter().copied().collect();
1194 for (a, b) in v.iter().zip(original.iter()) {
1195 assert_eq!(a.0, b.0);
1196 }
1197 }
1198
1199 #[test]
1200 fn datetime64_dynarray_load_preserves_unit() {
1201 use ferray_core::dtype::{DateTime64, TimeUnit};
1202 let arr = Array::<DateTime64, Ix1>::from_vec(
1203 Ix1::new([3]),
1204 vec![DateTime64(0), DateTime64(1), DateTime64::nat()],
1205 )
1206 .unwrap();
1207 let mut buf = Vec::new();
1208 save_to_writer(&mut buf, &arr).unwrap();
1209
1210 let mut cursor = std::io::Cursor::new(buf);
1211 let dyn_arr = load_dynamic_from_reader(&mut cursor).unwrap();
1212 assert_eq!(
1214 dyn_arr.dtype(),
1215 ferray_core::DType::DateTime64(TimeUnit::Ns)
1216 );
1217 let (typed, unit) = dyn_arr.try_into_datetime64().unwrap();
1218 assert_eq!(unit, TimeUnit::Ns);
1219 let vals: Vec<i64> = typed.iter().map(|v| v.0).collect();
1220 assert_eq!(vals[0], 0);
1221 assert_eq!(vals[1], 1);
1222 assert_eq!(vals[2], i64::MIN); }
1224}