1pub mod dtype_parse;
10pub mod header;
11
12use std::fs::File;
13use std::io::{BufReader, BufWriter, Read, Write};
14use std::path::Path;
15
16use ferray_core::Array;
17use ferray_core::dimension::{Dimension, IxDyn};
18use ferray_core::dtype::{DType, Element};
19use ferray_core::dynarray::DynArray;
20use ferray_core::error::{FerrayError, FerrayResult};
21
22use self::dtype_parse::Endianness;
23
24pub fn save<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
31 path: P,
32 array: &Array<T, D>,
33) -> FerrayResult<()> {
34 let file = File::create(path.as_ref()).map_err(|e| {
35 FerrayError::io_error(format!(
36 "failed to create file '{}': {e}",
37 path.as_ref().display()
38 ))
39 })?;
40 let mut writer = BufWriter::new(file);
41 save_to_writer(&mut writer, array)
42}
43
44pub fn save_to_writer<T: Element + NpyElement, D: Dimension, W: Write>(
46 writer: &mut W,
47 array: &Array<T, D>,
48) -> FerrayResult<()> {
49 let fortran_order = false;
50 header::write_header(writer, T::dtype(), array.shape(), fortran_order)?;
51
52 if let Some(slice) = array.as_slice() {
54 T::write_slice(slice, writer)?;
55 } else {
56 return Err(FerrayError::io_error(
57 "cannot save non-contiguous array to .npy (make contiguous first)",
58 ));
59 }
60
61 writer.flush()?;
62 Ok(())
63}
64
65pub fn load<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
72 path: P,
73) -> FerrayResult<Array<T, D>> {
74 let file = File::open(path.as_ref()).map_err(|e| {
75 FerrayError::io_error(format!(
76 "failed to open file '{}': {e}",
77 path.as_ref().display()
78 ))
79 })?;
80 let mut reader = BufReader::new(file);
81 load_from_reader(&mut reader)
82}
83
84pub fn load_from_reader<T: Element + NpyElement, D: Dimension, R: Read>(
86 reader: &mut R,
87) -> FerrayResult<Array<T, D>> {
88 let hdr = header::read_header(reader)?;
89
90 if hdr.dtype != T::dtype() {
92 return Err(FerrayError::invalid_dtype(format!(
93 "expected dtype {:?} for type {}, but file has {:?}",
94 T::dtype(),
95 std::any::type_name::<T>(),
96 hdr.dtype,
97 )));
98 }
99
100 if let Some(ndim) = D::NDIM {
102 if ndim != hdr.shape.len() {
103 return Err(FerrayError::shape_mismatch(format!(
104 "expected {} dimensions, but file has {} (shape {:?})",
105 ndim,
106 hdr.shape.len(),
107 hdr.shape,
108 )));
109 }
110 }
111
112 let total_elements: usize = hdr.shape.iter().product();
113 let data = T::read_vec(reader, total_elements, hdr.endianness)?;
114
115 let dim = build_dimension::<D>(&hdr.shape)?;
116
117 if hdr.fortran_order {
118 Array::from_vec_f(dim, data)
119 } else {
120 Array::from_vec(dim, data)
121 }
122}
123
124pub fn load_dynamic<P: AsRef<Path>>(path: P) -> FerrayResult<DynArray> {
131 let file = File::open(path.as_ref()).map_err(|e| {
132 FerrayError::io_error(format!(
133 "failed to open file '{}': {e}",
134 path.as_ref().display()
135 ))
136 })?;
137 let mut reader = BufReader::new(file);
138 load_dynamic_from_reader(&mut reader)
139}
140
141pub fn load_dynamic_from_reader<R: Read>(reader: &mut R) -> FerrayResult<DynArray> {
143 let hdr = header::read_header(reader)?;
144 let total: usize = hdr.shape.iter().product();
145 let dim = IxDyn::new(&hdr.shape);
146
147 macro_rules! load_typed {
148 ($ty:ty, $variant:ident) => {{
149 let data = <$ty as NpyElement>::read_vec(reader, total, hdr.endianness)?;
150 let arr = if hdr.fortran_order {
151 Array::<$ty, IxDyn>::from_vec_f(dim, data)?
152 } else {
153 Array::<$ty, IxDyn>::from_vec(dim, data)?
154 };
155 Ok(DynArray::$variant(arr))
156 }};
157 }
158
159 match hdr.dtype {
160 DType::Bool => load_typed!(bool, Bool),
161 DType::U8 => load_typed!(u8, U8),
162 DType::U16 => load_typed!(u16, U16),
163 DType::U32 => load_typed!(u32, U32),
164 DType::U64 => load_typed!(u64, U64),
165 DType::U128 => load_typed!(u128, U128),
166 DType::I8 => load_typed!(i8, I8),
167 DType::I16 => load_typed!(i16, I16),
168 DType::I32 => load_typed!(i32, I32),
169 DType::I64 => load_typed!(i64, I64),
170 DType::I128 => load_typed!(i128, I128),
171 DType::F32 => load_typed!(f32, F32),
172 DType::F64 => load_typed!(f64, F64),
173 DType::Complex32 => {
174 load_complex32_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
175 }
176 DType::Complex64 => {
177 load_complex64_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
178 }
179 _ => Err(FerrayError::invalid_dtype(format!(
180 "unsupported dtype {:?} for .npy loading",
181 hdr.dtype
182 ))),
183 }
184}
185
186fn load_complex32_dynamic<R: Read>(
188 reader: &mut R,
189 total: usize,
190 dim: IxDyn,
191 fortran_order: bool,
192 endian: Endianness,
193) -> FerrayResult<DynArray> {
194 let byte_count = total * 8;
196 let mut raw = vec![0u8; byte_count];
197 reader.read_exact(&mut raw)?;
198
199 if endian.needs_swap() {
200 for chunk in raw.chunks_exact_mut(4) {
202 chunk.reverse();
203 }
204 }
205
206 assert_eq!(std::mem::size_of::<[f32; 2]>(), 8);
214
215 let mut data: Vec<u8> = raw;
222
223 if data.len() != total * 8 {
225 return Err(FerrayError::io_error(
226 "unexpected data length for complex32",
227 ));
228 }
229
230 let ptr = data.as_mut_ptr();
233 let cap = data.capacity();
234 std::mem::forget(data);
235
236 if (ptr as usize) % std::mem::align_of::<f32>() != 0 {
240 let data_bytes = unsafe { Vec::from_raw_parts(ptr, total * 8, cap) };
242 return load_complex32_from_bytes_copy(&data_bytes, total, dim, fortran_order);
243 }
244
245 let bytes = unsafe { Vec::from_raw_parts(ptr, total * 8, cap) };
252 load_complex32_from_bytes_copy(&bytes, total, dim, fortran_order)
253}
254
255fn load_complex32_from_bytes_copy(
257 bytes: &[u8],
258 total: usize,
259 dim: IxDyn,
260 fortran_order: bool,
261) -> FerrayResult<DynArray> {
262 let mut arr_dyn = DynArray::zeros(DType::Complex32, dim.as_slice())?;
264 if let DynArray::Complex32(ref mut arr) = arr_dyn {
265 if let Some(slice) = arr.as_slice_mut() {
266 let dst =
268 unsafe { std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, total * 8) };
269 dst.copy_from_slice(bytes);
270 }
271
272 if fortran_order {
275 }
282 }
283 Ok(arr_dyn)
284}
285
286fn load_complex64_dynamic<R: Read>(
288 reader: &mut R,
289 total: usize,
290 dim: IxDyn,
291 fortran_order: bool,
292 endian: Endianness,
293) -> FerrayResult<DynArray> {
294 let byte_count = total * 16;
295 let mut raw = vec![0u8; byte_count];
296 reader.read_exact(&mut raw)?;
297
298 if endian.needs_swap() {
299 for chunk in raw.chunks_exact_mut(8) {
300 chunk.reverse();
301 }
302 }
303
304 load_complex64_from_bytes_copy(&raw, total, dim, fortran_order)
305}
306
307fn load_complex64_from_bytes_copy(
308 bytes: &[u8],
309 total: usize,
310 dim: IxDyn,
311 _fortran_order: bool,
312) -> FerrayResult<DynArray> {
313 let mut arr_dyn = DynArray::zeros(DType::Complex64, dim.as_slice())?;
314 if let DynArray::Complex64(ref mut arr) = arr_dyn {
315 if let Some(slice) = arr.as_slice_mut() {
316 let dst = unsafe {
317 std::slice::from_raw_parts_mut(slice.as_mut_ptr() as *mut u8, total * 16)
318 };
319 dst.copy_from_slice(bytes);
320 }
321 }
322 Ok(arr_dyn)
323}
324
325pub fn save_dynamic<P: AsRef<Path>>(path: P, array: &DynArray) -> FerrayResult<()> {
327 let file = File::create(path.as_ref()).map_err(|e| {
328 FerrayError::io_error(format!(
329 "failed to create file '{}': {e}",
330 path.as_ref().display()
331 ))
332 })?;
333 let mut writer = BufWriter::new(file);
334 save_dynamic_to_writer(&mut writer, array)
335}
336
337pub fn save_dynamic_to_writer<W: Write>(writer: &mut W, array: &DynArray) -> FerrayResult<()> {
339 macro_rules! save_typed {
340 ($arr:expr, $dtype:expr, $ty:ty) => {{
341 header::write_header(writer, $dtype, $arr.shape(), false)?;
342 if let Some(s) = $arr.as_slice() {
343 <$ty as NpyElement>::write_slice(s, writer)?;
344 } else {
345 return Err(FerrayError::io_error(
346 "cannot save non-contiguous DynArray to .npy",
347 ));
348 }
349 }};
350 }
351
352 match array {
353 DynArray::Bool(a) => save_typed!(a, DType::Bool, bool),
354 DynArray::U8(a) => save_typed!(a, DType::U8, u8),
355 DynArray::U16(a) => save_typed!(a, DType::U16, u16),
356 DynArray::U32(a) => save_typed!(a, DType::U32, u32),
357 DynArray::U64(a) => save_typed!(a, DType::U64, u64),
358 DynArray::U128(a) => save_typed!(a, DType::U128, u128),
359 DynArray::I8(a) => save_typed!(a, DType::I8, i8),
360 DynArray::I16(a) => save_typed!(a, DType::I16, i16),
361 DynArray::I32(a) => save_typed!(a, DType::I32, i32),
362 DynArray::I64(a) => save_typed!(a, DType::I64, i64),
363 DynArray::I128(a) => save_typed!(a, DType::I128, i128),
364 DynArray::F32(a) => save_typed!(a, DType::F32, f32),
365 DynArray::F64(a) => save_typed!(a, DType::F64, f64),
366 DynArray::Complex32(a) => {
367 header::write_header(writer, DType::Complex32, a.shape(), false)?;
368 save_complex_raw(a.as_slice(), 8, writer)?;
369 }
370 DynArray::Complex64(a) => {
371 header::write_header(writer, DType::Complex64, a.shape(), false)?;
372 save_complex_raw(a.as_slice(), 16, writer)?;
373 }
374 _ => {
375 return Err(FerrayError::invalid_dtype(
376 "unsupported DynArray variant for .npy saving",
377 ));
378 }
379 }
380
381 writer.flush()?;
382 Ok(())
383}
384
385fn save_complex_raw<T, W: Write>(
388 slice_opt: Option<&[T]>,
389 elem_size: usize,
390 writer: &mut W,
391) -> FerrayResult<()> {
392 let slice = slice_opt
393 .ok_or_else(|| FerrayError::io_error("cannot save non-contiguous complex array"))?;
394 let byte_len = slice.len() * elem_size;
395 let bytes = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, byte_len) };
396 writer.write_all(bytes)?;
397 Ok(())
398}
399
400fn build_dimension<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
402 build_dim_from_shape::<D>(shape)
403}
404
405fn build_dim_from_shape<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
408 use ferray_core::dimension::*;
409 use std::any::Any;
410
411 if let Some(ndim) = D::NDIM {
412 if shape.len() != ndim {
413 return Err(FerrayError::shape_mismatch(format!(
414 "expected {ndim} dimensions, got {}",
415 shape.len()
416 )));
417 }
418 }
419
420 let type_id = std::any::TypeId::of::<D>();
421
422 macro_rules! try_dim {
423 ($dim_ty:ty, $dim_val:expr) => {
424 if type_id == std::any::TypeId::of::<$dim_ty>() {
425 let boxed: Box<dyn Any> = Box::new($dim_val);
426 return Ok(*boxed.downcast::<D>().unwrap());
427 }
428 };
429 }
430
431 try_dim!(IxDyn, IxDyn::new(shape));
432
433 match shape.len() {
434 0 => {
435 try_dim!(Ix0, Ix0);
436 }
437 1 => {
438 try_dim!(Ix1, Ix1::new([shape[0]]));
439 }
440 2 => {
441 try_dim!(Ix2, Ix2::new([shape[0], shape[1]]));
442 }
443 3 => {
444 try_dim!(Ix3, Ix3::new([shape[0], shape[1], shape[2]]));
445 }
446 4 => {
447 try_dim!(Ix4, Ix4::new([shape[0], shape[1], shape[2], shape[3]]));
448 }
449 5 => {
450 try_dim!(
451 Ix5,
452 Ix5::new([shape[0], shape[1], shape[2], shape[3], shape[4]])
453 );
454 }
455 6 => {
456 try_dim!(
457 Ix6,
458 Ix6::new([shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]])
459 );
460 }
461 _ => {}
462 }
463
464 Err(FerrayError::io_error(
465 "unsupported dimension type for .npy loading",
466 ))
467}
468
469pub trait NpyElement: Element + private::NpySealed {
478 fn write_slice<W: Write>(data: &[Self], writer: &mut W) -> FerrayResult<()>;
480
481 fn read_vec<R: Read>(
483 reader: &mut R,
484 count: usize,
485 endian: Endianness,
486 ) -> FerrayResult<Vec<Self>>;
487}
488
489mod private {
490 pub trait NpySealed {}
491}
492
493macro_rules! impl_npy_element {
498 ($ty:ty, $size:expr) => {
499 impl private::NpySealed for $ty {}
500
501 impl NpyElement for $ty {
502 fn write_slice<W: Write>(data: &[$ty], writer: &mut W) -> FerrayResult<()> {
503 for &val in data {
504 writer.write_all(&val.to_ne_bytes())?;
505 }
506 Ok(())
507 }
508
509 fn read_vec<R: Read>(
510 reader: &mut R,
511 count: usize,
512 endian: Endianness,
513 ) -> FerrayResult<Vec<$ty>> {
514 let mut result = Vec::with_capacity(count);
515 let mut buf = [0u8; $size];
516 let needs_swap = endian.needs_swap();
517 for _ in 0..count {
518 reader.read_exact(&mut buf)?;
519 let val = if needs_swap {
520 <$ty>::from_ne_bytes({
521 buf.reverse();
522 buf
523 })
524 } else {
525 <$ty>::from_ne_bytes(buf)
526 };
527 result.push(val);
528 }
529 Ok(result)
530 }
531 }
532 };
533}
534
535impl private::NpySealed for bool {}
537
538impl NpyElement for bool {
539 fn write_slice<W: Write>(data: &[bool], writer: &mut W) -> FerrayResult<()> {
540 for &val in data {
541 writer.write_all(&[val as u8])?;
542 }
543 Ok(())
544 }
545
546 fn read_vec<R: Read>(
547 reader: &mut R,
548 count: usize,
549 _endian: Endianness,
550 ) -> FerrayResult<Vec<bool>> {
551 let mut result = Vec::with_capacity(count);
552 let mut buf = [0u8; 1];
553 for _ in 0..count {
554 reader.read_exact(&mut buf)?;
555 result.push(buf[0] != 0);
556 }
557 Ok(result)
558 }
559}
560
561impl_npy_element!(u8, 1);
562impl_npy_element!(u16, 2);
563impl_npy_element!(u32, 4);
564impl_npy_element!(u64, 8);
565impl_npy_element!(u128, 16);
566impl_npy_element!(i8, 1);
567impl_npy_element!(i16, 2);
568impl_npy_element!(i32, 4);
569impl_npy_element!(i64, 8);
570impl_npy_element!(i128, 16);
571impl_npy_element!(f32, 4);
572impl_npy_element!(f64, 8);
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577 use ferray_core::dimension::{Ix1, Ix2};
578 use std::io::Cursor;
579
580 fn test_dir() -> std::path::PathBuf {
582 let dir = std::env::temp_dir().join(format!("ferray_io_test_{}", std::process::id()));
583 let _ = std::fs::create_dir_all(&dir);
584 dir
585 }
586
587 fn test_file(name: &str) -> std::path::PathBuf {
588 let dir = test_dir();
589 dir.join(name)
590 }
591
592 #[test]
593 fn roundtrip_f64_1d() {
594 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
595 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
596
597 let path = test_file("rt_f64_1d.npy");
598 save(&path, &arr).unwrap();
599 let loaded: Array<f64, Ix1> = load(&path).unwrap();
600
601 assert_eq!(loaded.shape(), &[5]);
602 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
603 let _ = std::fs::remove_file(&path);
604 }
605
606 #[test]
607 fn roundtrip_f32_2d() {
608 let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
609 let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
610
611 let path = test_file("rt_f32_2d.npy");
612 save(&path, &arr).unwrap();
613 let loaded: Array<f32, Ix2> = load(&path).unwrap();
614
615 assert_eq!(loaded.shape(), &[2, 3]);
616 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
617 let _ = std::fs::remove_file(&path);
618 }
619
620 #[test]
621 fn roundtrip_i32() {
622 let data = vec![10i32, 20, 30, 40];
623 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
624
625 let path = test_file("rt_i32.npy");
626 save(&path, &arr).unwrap();
627 let loaded: Array<i32, Ix1> = load(&path).unwrap();
628 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
629 let _ = std::fs::remove_file(&path);
630 }
631
632 #[test]
633 fn roundtrip_i64() {
634 let data = vec![100i64, 200, 300];
635 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
636
637 let path = test_file("rt_i64.npy");
638 save(&path, &arr).unwrap();
639 let loaded: Array<i64, Ix1> = load(&path).unwrap();
640 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
641 let _ = std::fs::remove_file(&path);
642 }
643
644 #[test]
645 fn roundtrip_u8() {
646 let data = vec![0u8, 128, 255];
647 let arr = Array::<u8, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
648
649 let path = test_file("rt_u8.npy");
650 save(&path, &arr).unwrap();
651 let loaded: Array<u8, Ix1> = load(&path).unwrap();
652 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
653 let _ = std::fs::remove_file(&path);
654 }
655
656 #[test]
657 fn roundtrip_bool() {
658 let data = vec![true, false, true, true, false];
659 let arr = Array::<bool, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
660
661 let path = test_file("rt_bool.npy");
662 save(&path, &arr).unwrap();
663 let loaded: Array<bool, Ix1> = load(&path).unwrap();
664 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
665 let _ = std::fs::remove_file(&path);
666 }
667
668 #[test]
669 fn roundtrip_in_memory() {
670 let data = vec![1.0_f64, 2.0, 3.0];
671 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
672
673 let mut buf = Vec::new();
674 save_to_writer(&mut buf, &arr).unwrap();
675
676 let mut cursor = Cursor::new(buf);
677 let loaded: Array<f64, Ix1> = load_from_reader(&mut cursor).unwrap();
678 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
679 }
680
681 #[test]
682 fn load_dynamic_f64() {
683 let data = vec![1.0_f64, 2.0, 3.0];
684 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
685
686 let path = test_file("dyn_f64.npy");
687 save(&path, &arr).unwrap();
688 let dyn_arr = load_dynamic(&path).unwrap();
689
690 assert_eq!(dyn_arr.dtype(), DType::F64);
691 assert_eq!(dyn_arr.shape(), &[3]);
692 let _ = std::fs::remove_file(&path);
693 }
694
695 #[test]
696 fn load_wrong_dtype_error() {
697 let data = vec![1.0_f64, 2.0, 3.0];
698 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
699
700 let path = test_file("wrong_dtype.npy");
701 save(&path, &arr).unwrap();
702
703 let result = load::<f32, Ix1, _>(&path);
704 assert!(result.is_err());
705 let _ = std::fs::remove_file(&path);
706 }
707
708 #[test]
709 fn load_wrong_ndim_error() {
710 let data = vec![1.0_f64, 2.0, 3.0];
711 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
712
713 let path = test_file("wrong_ndim.npy");
714 save(&path, &arr).unwrap();
715
716 let result = load::<f64, Ix2, _>(&path);
717 assert!(result.is_err());
718 let _ = std::fs::remove_file(&path);
719 }
720
721 #[test]
722 fn roundtrip_dynamic() {
723 let data = vec![10i32, 20, 30];
724 let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), data.clone()).unwrap();
725 let dyn_arr = DynArray::I32(arr);
726
727 let path = test_file("rt_dynamic.npy");
728 save_dynamic(&path, &dyn_arr).unwrap();
729
730 let loaded = load_dynamic(&path).unwrap();
731 assert_eq!(loaded.dtype(), DType::I32);
732 assert_eq!(loaded.shape(), &[3]);
733
734 let loaded_arr = loaded.try_into_i32().unwrap();
735 assert_eq!(loaded_arr.as_slice().unwrap(), &data[..]);
736 let _ = std::fs::remove_file(&path);
737 }
738
739 #[test]
740 fn load_dynamic_ixdyn() {
741 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
742 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
743
744 let path = test_file("dyn_ixdyn.npy");
745 save(&path, &arr).unwrap();
746
747 let loaded: Array<f64, IxDyn> = load(&path).unwrap();
749 assert_eq!(loaded.shape(), &[2, 3]);
750 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
751 let _ = std::fs::remove_file(&path);
752 }
753}