1pub mod dtype_parse;
10pub mod header;
11
12use std::fs::File;
13use std::io::{BufReader, BufWriter, Read, Write};
14use std::path::Path;
15
16use ferray_core::Array;
17use ferray_core::dimension::{Dimension, IxDyn};
18use ferray_core::dtype::{DType, Element};
19use ferray_core::dynarray::DynArray;
20use ferray_core::error::{FerrayError, FerrayResult};
21
22use self::dtype_parse::Endianness;
23
24pub(crate) fn checked_total_elements(shape: &[usize]) -> FerrayResult<usize> {
27 shape.iter().try_fold(1usize, |acc, &dim| {
28 acc.checked_mul(dim)
29 .ok_or_else(|| FerrayError::io_error("shape overflow: total elements exceed usize::MAX"))
30 })
31}
32
33pub fn save<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
40 path: P,
41 array: &Array<T, D>,
42) -> FerrayResult<()> {
43 let file = File::create(path.as_ref()).map_err(|e| {
44 FerrayError::io_error(format!(
45 "failed to create file '{}': {e}",
46 path.as_ref().display()
47 ))
48 })?;
49 let mut writer = BufWriter::new(file);
50 save_to_writer(&mut writer, array)
51}
52
53pub fn save_to_writer<T: Element + NpyElement, D: Dimension, W: Write>(
55 writer: &mut W,
56 array: &Array<T, D>,
57) -> FerrayResult<()> {
58 let fortran_order = false;
59 header::write_header(writer, T::dtype(), array.shape(), fortran_order)?;
60
61 if let Some(slice) = array.as_slice() {
63 T::write_slice(slice, writer)?;
64 } else {
65 return Err(FerrayError::io_error(
66 "cannot save non-contiguous array to .npy (make contiguous first)",
67 ));
68 }
69
70 writer.flush()?;
71 Ok(())
72}
73
74pub fn load<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
81 path: P,
82) -> FerrayResult<Array<T, D>> {
83 let file = File::open(path.as_ref()).map_err(|e| {
84 FerrayError::io_error(format!(
85 "failed to open file '{}': {e}",
86 path.as_ref().display()
87 ))
88 })?;
89 let mut reader = BufReader::new(file);
90 load_from_reader(&mut reader)
91}
92
93pub fn load_from_reader<T: Element + NpyElement, D: Dimension, R: Read>(
95 reader: &mut R,
96) -> FerrayResult<Array<T, D>> {
97 let hdr = header::read_header(reader)?;
98
99 if hdr.dtype != T::dtype() {
101 return Err(FerrayError::invalid_dtype(format!(
102 "expected dtype {:?} for type {}, but file has {:?}",
103 T::dtype(),
104 std::any::type_name::<T>(),
105 hdr.dtype,
106 )));
107 }
108
109 if let Some(ndim) = D::NDIM {
111 if ndim != hdr.shape.len() {
112 return Err(FerrayError::shape_mismatch(format!(
113 "expected {} dimensions, but file has {} (shape {:?})",
114 ndim,
115 hdr.shape.len(),
116 hdr.shape,
117 )));
118 }
119 }
120
121 let total_elements = checked_total_elements(&hdr.shape)?;
122 let data = T::read_vec(reader, total_elements, hdr.endianness)?;
123
124 let dim = build_dimension::<D>(&hdr.shape)?;
125
126 if hdr.fortran_order {
127 Array::from_vec_f(dim, data)
128 } else {
129 Array::from_vec(dim, data)
130 }
131}
132
133pub fn load_dynamic<P: AsRef<Path>>(path: P) -> FerrayResult<DynArray> {
140 let file = File::open(path.as_ref()).map_err(|e| {
141 FerrayError::io_error(format!(
142 "failed to open file '{}': {e}",
143 path.as_ref().display()
144 ))
145 })?;
146 let mut reader = BufReader::new(file);
147 load_dynamic_from_reader(&mut reader)
148}
149
150pub fn load_dynamic_from_reader<R: Read>(reader: &mut R) -> FerrayResult<DynArray> {
152 let hdr = header::read_header(reader)?;
153 let total = checked_total_elements(&hdr.shape)?;
154 let dim = IxDyn::new(&hdr.shape);
155
156 macro_rules! load_typed {
157 ($ty:ty, $variant:ident) => {{
158 let data = <$ty as NpyElement>::read_vec(reader, total, hdr.endianness)?;
159 let arr = if hdr.fortran_order {
160 Array::<$ty, IxDyn>::from_vec_f(dim, data)?
161 } else {
162 Array::<$ty, IxDyn>::from_vec(dim, data)?
163 };
164 Ok(DynArray::$variant(arr))
165 }};
166 }
167
168 match hdr.dtype {
169 DType::Bool => load_typed!(bool, Bool),
170 DType::U8 => load_typed!(u8, U8),
171 DType::U16 => load_typed!(u16, U16),
172 DType::U32 => load_typed!(u32, U32),
173 DType::U64 => load_typed!(u64, U64),
174 DType::U128 => load_typed!(u128, U128),
175 DType::I8 => load_typed!(i8, I8),
176 DType::I16 => load_typed!(i16, I16),
177 DType::I32 => load_typed!(i32, I32),
178 DType::I64 => load_typed!(i64, I64),
179 DType::I128 => load_typed!(i128, I128),
180 DType::F32 => load_typed!(f32, F32),
181 DType::F64 => load_typed!(f64, F64),
182 DType::Complex32 => {
183 load_complex32_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
184 }
185 DType::Complex64 => {
186 load_complex64_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
187 }
188 _ => Err(FerrayError::invalid_dtype(format!(
189 "unsupported dtype {:?} for .npy loading",
190 hdr.dtype
191 ))),
192 }
193}
194
195fn load_complex32_dynamic<R: Read>(
197 reader: &mut R,
198 total: usize,
199 dim: IxDyn,
200 fortran_order: bool,
201 endian: Endianness,
202) -> FerrayResult<DynArray> {
203 let byte_count = total * 8;
204 let mut raw = vec![0u8; byte_count];
205 reader.read_exact(&mut raw)?;
206
207 if endian.needs_swap() {
208 for chunk in raw.chunks_exact_mut(4) {
209 chunk.reverse();
210 }
211 }
212
213 load_complex32_from_bytes_copy(&raw, total, dim, fortran_order)
214}
215
216fn load_complex32_from_bytes_copy(
218 bytes: &[u8],
219 total: usize,
220 dim: IxDyn,
221 fortran_order: bool,
222) -> FerrayResult<DynArray> {
223 use num_complex::Complex;
224
225 let mut data = Vec::with_capacity(total);
227 for chunk in bytes.chunks_exact(8) {
228 let re = f32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
229 let im = f32::from_ne_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
230 data.push(Complex::new(re, im));
231 }
232
233 let arr = if fortran_order {
234 Array::<Complex<f32>, IxDyn>::from_vec_f(dim, data)?
235 } else {
236 Array::<Complex<f32>, IxDyn>::from_vec(dim, data)?
237 };
238 Ok(DynArray::Complex32(arr))
239}
240
241fn load_complex64_dynamic<R: Read>(
243 reader: &mut R,
244 total: usize,
245 dim: IxDyn,
246 fortran_order: bool,
247 endian: Endianness,
248) -> FerrayResult<DynArray> {
249 let byte_count = total * 16;
250 let mut raw = vec![0u8; byte_count];
251 reader.read_exact(&mut raw)?;
252
253 if endian.needs_swap() {
254 for chunk in raw.chunks_exact_mut(8) {
255 chunk.reverse();
256 }
257 }
258
259 load_complex64_from_bytes_copy(&raw, total, dim, fortran_order)
260}
261
262fn load_complex64_from_bytes_copy(
263 bytes: &[u8],
264 total: usize,
265 dim: IxDyn,
266 fortran_order: bool,
267) -> FerrayResult<DynArray> {
268 use num_complex::Complex;
269
270 let mut data = Vec::with_capacity(total);
272 for chunk in bytes.chunks_exact(16) {
273 let re = f64::from_ne_bytes([
274 chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
275 ]);
276 let im = f64::from_ne_bytes([
277 chunk[8], chunk[9], chunk[10], chunk[11], chunk[12], chunk[13], chunk[14], chunk[15],
278 ]);
279 data.push(Complex::new(re, im));
280 }
281
282 let arr = if fortran_order {
283 Array::<Complex<f64>, IxDyn>::from_vec_f(dim, data)?
284 } else {
285 Array::<Complex<f64>, IxDyn>::from_vec(dim, data)?
286 };
287 Ok(DynArray::Complex64(arr))
288}
289
290pub fn save_dynamic<P: AsRef<Path>>(path: P, array: &DynArray) -> FerrayResult<()> {
292 let file = File::create(path.as_ref()).map_err(|e| {
293 FerrayError::io_error(format!(
294 "failed to create file '{}': {e}",
295 path.as_ref().display()
296 ))
297 })?;
298 let mut writer = BufWriter::new(file);
299 save_dynamic_to_writer(&mut writer, array)
300}
301
302pub fn save_dynamic_to_writer<W: Write>(writer: &mut W, array: &DynArray) -> FerrayResult<()> {
304 macro_rules! save_typed {
305 ($arr:expr, $dtype:expr, $ty:ty) => {{
306 header::write_header(writer, $dtype, $arr.shape(), false)?;
307 if let Some(s) = $arr.as_slice() {
308 <$ty as NpyElement>::write_slice(s, writer)?;
309 } else {
310 return Err(FerrayError::io_error(
311 "cannot save non-contiguous DynArray to .npy",
312 ));
313 }
314 }};
315 }
316
317 match array {
318 DynArray::Bool(a) => save_typed!(a, DType::Bool, bool),
319 DynArray::U8(a) => save_typed!(a, DType::U8, u8),
320 DynArray::U16(a) => save_typed!(a, DType::U16, u16),
321 DynArray::U32(a) => save_typed!(a, DType::U32, u32),
322 DynArray::U64(a) => save_typed!(a, DType::U64, u64),
323 DynArray::U128(a) => save_typed!(a, DType::U128, u128),
324 DynArray::I8(a) => save_typed!(a, DType::I8, i8),
325 DynArray::I16(a) => save_typed!(a, DType::I16, i16),
326 DynArray::I32(a) => save_typed!(a, DType::I32, i32),
327 DynArray::I64(a) => save_typed!(a, DType::I64, i64),
328 DynArray::I128(a) => save_typed!(a, DType::I128, i128),
329 DynArray::F32(a) => save_typed!(a, DType::F32, f32),
330 DynArray::F64(a) => save_typed!(a, DType::F64, f64),
331 DynArray::Complex32(a) => {
332 header::write_header(writer, DType::Complex32, a.shape(), false)?;
333 save_complex_raw(a.as_slice(), 8, writer)?;
334 }
335 DynArray::Complex64(a) => {
336 header::write_header(writer, DType::Complex64, a.shape(), false)?;
337 save_complex_raw(a.as_slice(), 16, writer)?;
338 }
339 _ => {
340 return Err(FerrayError::invalid_dtype(
341 "unsupported DynArray variant for .npy saving",
342 ));
343 }
344 }
345
346 writer.flush()?;
347 Ok(())
348}
349
350fn save_complex_raw<T, W: Write>(
353 slice_opt: Option<&[T]>,
354 elem_size: usize,
355 writer: &mut W,
356) -> FerrayResult<()> {
357 let slice = slice_opt
358 .ok_or_else(|| FerrayError::io_error("cannot save non-contiguous complex array"))?;
359 let byte_len = slice.len() * elem_size;
360 let bytes = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, byte_len) };
361 writer.write_all(bytes)?;
362 Ok(())
363}
364
365fn build_dimension<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
367 build_dim_from_shape::<D>(shape)
368}
369
370fn build_dim_from_shape<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
373 use ferray_core::dimension::*;
374 use std::any::Any;
375
376 if let Some(ndim) = D::NDIM {
377 if shape.len() != ndim {
378 return Err(FerrayError::shape_mismatch(format!(
379 "expected {ndim} dimensions, got {}",
380 shape.len()
381 )));
382 }
383 }
384
385 let type_id = std::any::TypeId::of::<D>();
386
387 macro_rules! try_dim {
388 ($dim_ty:ty, $dim_val:expr) => {
389 if type_id == std::any::TypeId::of::<$dim_ty>() {
390 let boxed: Box<dyn Any> = Box::new($dim_val);
391 return Ok(*boxed.downcast::<D>().unwrap());
392 }
393 };
394 }
395
396 try_dim!(IxDyn, IxDyn::new(shape));
397
398 match shape.len() {
399 0 => {
400 try_dim!(Ix0, Ix0);
401 }
402 1 => {
403 try_dim!(Ix1, Ix1::new([shape[0]]));
404 }
405 2 => {
406 try_dim!(Ix2, Ix2::new([shape[0], shape[1]]));
407 }
408 3 => {
409 try_dim!(Ix3, Ix3::new([shape[0], shape[1], shape[2]]));
410 }
411 4 => {
412 try_dim!(Ix4, Ix4::new([shape[0], shape[1], shape[2], shape[3]]));
413 }
414 5 => {
415 try_dim!(
416 Ix5,
417 Ix5::new([shape[0], shape[1], shape[2], shape[3], shape[4]])
418 );
419 }
420 6 => {
421 try_dim!(
422 Ix6,
423 Ix6::new([shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]])
424 );
425 }
426 _ => {}
427 }
428
429 Err(FerrayError::io_error(
430 "unsupported dimension type for .npy loading",
431 ))
432}
433
434pub trait NpyElement: Element + private::NpySealed {
443 fn write_slice<W: Write>(data: &[Self], writer: &mut W) -> FerrayResult<()>;
445
446 fn read_vec<R: Read>(
448 reader: &mut R,
449 count: usize,
450 endian: Endianness,
451 ) -> FerrayResult<Vec<Self>>;
452}
453
454mod private {
455 pub trait NpySealed {}
456}
457
458macro_rules! impl_npy_element {
463 ($ty:ty, $size:expr) => {
464 impl private::NpySealed for $ty {}
465
466 impl NpyElement for $ty {
467 fn write_slice<W: Write>(data: &[$ty], writer: &mut W) -> FerrayResult<()> {
468 for &val in data {
469 writer.write_all(&val.to_ne_bytes())?;
470 }
471 Ok(())
472 }
473
474 fn read_vec<R: Read>(
475 reader: &mut R,
476 count: usize,
477 endian: Endianness,
478 ) -> FerrayResult<Vec<$ty>> {
479 let mut result = Vec::with_capacity(count);
480 let mut buf = [0u8; $size];
481 let needs_swap = endian.needs_swap();
482 for _ in 0..count {
483 reader.read_exact(&mut buf)?;
484 let val = if needs_swap {
485 <$ty>::from_ne_bytes({
486 buf.reverse();
487 buf
488 })
489 } else {
490 <$ty>::from_ne_bytes(buf)
491 };
492 result.push(val);
493 }
494 Ok(result)
495 }
496 }
497 };
498}
499
500impl private::NpySealed for bool {}
502
503impl NpyElement for bool {
504 fn write_slice<W: Write>(data: &[bool], writer: &mut W) -> FerrayResult<()> {
505 for &val in data {
506 writer.write_all(&[val as u8])?;
507 }
508 Ok(())
509 }
510
511 fn read_vec<R: Read>(
512 reader: &mut R,
513 count: usize,
514 _endian: Endianness,
515 ) -> FerrayResult<Vec<bool>> {
516 let mut result = Vec::with_capacity(count);
517 let mut buf = [0u8; 1];
518 for _ in 0..count {
519 reader.read_exact(&mut buf)?;
520 result.push(buf[0] != 0);
521 }
522 Ok(result)
523 }
524}
525
526impl_npy_element!(u8, 1);
527impl_npy_element!(u16, 2);
528impl_npy_element!(u32, 4);
529impl_npy_element!(u64, 8);
530impl_npy_element!(u128, 16);
531impl_npy_element!(i8, 1);
532impl_npy_element!(i16, 2);
533impl_npy_element!(i32, 4);
534impl_npy_element!(i64, 8);
535impl_npy_element!(i128, 16);
536impl_npy_element!(f32, 4);
537impl_npy_element!(f64, 8);
538
539#[cfg(test)]
540mod tests {
541 use super::*;
542 use ferray_core::dimension::{Ix1, Ix2};
543 use std::io::Cursor;
544
545 fn test_dir() -> std::path::PathBuf {
547 let dir = std::env::temp_dir().join(format!("ferray_io_test_{}", std::process::id()));
548 let _ = std::fs::create_dir_all(&dir);
549 dir
550 }
551
552 fn test_file(name: &str) -> std::path::PathBuf {
553 let dir = test_dir();
554 dir.join(name)
555 }
556
557 #[test]
558 fn roundtrip_f64_1d() {
559 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
560 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
561
562 let path = test_file("rt_f64_1d.npy");
563 save(&path, &arr).unwrap();
564 let loaded: Array<f64, Ix1> = load(&path).unwrap();
565
566 assert_eq!(loaded.shape(), &[5]);
567 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
568 let _ = std::fs::remove_file(&path);
569 }
570
571 #[test]
572 fn roundtrip_f32_2d() {
573 let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
574 let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
575
576 let path = test_file("rt_f32_2d.npy");
577 save(&path, &arr).unwrap();
578 let loaded: Array<f32, Ix2> = load(&path).unwrap();
579
580 assert_eq!(loaded.shape(), &[2, 3]);
581 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
582 let _ = std::fs::remove_file(&path);
583 }
584
585 #[test]
586 fn roundtrip_i32() {
587 let data = vec![10i32, 20, 30, 40];
588 let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
589
590 let path = test_file("rt_i32.npy");
591 save(&path, &arr).unwrap();
592 let loaded: Array<i32, Ix1> = load(&path).unwrap();
593 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
594 let _ = std::fs::remove_file(&path);
595 }
596
597 #[test]
598 fn roundtrip_i64() {
599 let data = vec![100i64, 200, 300];
600 let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
601
602 let path = test_file("rt_i64.npy");
603 save(&path, &arr).unwrap();
604 let loaded: Array<i64, Ix1> = load(&path).unwrap();
605 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
606 let _ = std::fs::remove_file(&path);
607 }
608
609 #[test]
610 fn roundtrip_u8() {
611 let data = vec![0u8, 128, 255];
612 let arr = Array::<u8, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
613
614 let path = test_file("rt_u8.npy");
615 save(&path, &arr).unwrap();
616 let loaded: Array<u8, Ix1> = load(&path).unwrap();
617 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
618 let _ = std::fs::remove_file(&path);
619 }
620
621 #[test]
622 fn roundtrip_bool() {
623 let data = vec![true, false, true, true, false];
624 let arr = Array::<bool, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
625
626 let path = test_file("rt_bool.npy");
627 save(&path, &arr).unwrap();
628 let loaded: Array<bool, Ix1> = load(&path).unwrap();
629 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
630 let _ = std::fs::remove_file(&path);
631 }
632
633 #[test]
634 fn roundtrip_in_memory() {
635 let data = vec![1.0_f64, 2.0, 3.0];
636 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
637
638 let mut buf = Vec::new();
639 save_to_writer(&mut buf, &arr).unwrap();
640
641 let mut cursor = Cursor::new(buf);
642 let loaded: Array<f64, Ix1> = load_from_reader(&mut cursor).unwrap();
643 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
644 }
645
646 #[test]
647 fn load_dynamic_f64() {
648 let data = vec![1.0_f64, 2.0, 3.0];
649 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
650
651 let path = test_file("dyn_f64.npy");
652 save(&path, &arr).unwrap();
653 let dyn_arr = load_dynamic(&path).unwrap();
654
655 assert_eq!(dyn_arr.dtype(), DType::F64);
656 assert_eq!(dyn_arr.shape(), &[3]);
657 let _ = std::fs::remove_file(&path);
658 }
659
660 #[test]
661 fn load_wrong_dtype_error() {
662 let data = vec![1.0_f64, 2.0, 3.0];
663 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
664
665 let path = test_file("wrong_dtype.npy");
666 save(&path, &arr).unwrap();
667
668 let result = load::<f32, Ix1, _>(&path);
669 assert!(result.is_err());
670 let _ = std::fs::remove_file(&path);
671 }
672
673 #[test]
674 fn load_wrong_ndim_error() {
675 let data = vec![1.0_f64, 2.0, 3.0];
676 let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
677
678 let path = test_file("wrong_ndim.npy");
679 save(&path, &arr).unwrap();
680
681 let result = load::<f64, Ix2, _>(&path);
682 assert!(result.is_err());
683 let _ = std::fs::remove_file(&path);
684 }
685
686 #[test]
687 fn roundtrip_dynamic() {
688 let data = vec![10i32, 20, 30];
689 let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), data.clone()).unwrap();
690 let dyn_arr = DynArray::I32(arr);
691
692 let path = test_file("rt_dynamic.npy");
693 save_dynamic(&path, &dyn_arr).unwrap();
694
695 let loaded = load_dynamic(&path).unwrap();
696 assert_eq!(loaded.dtype(), DType::I32);
697 assert_eq!(loaded.shape(), &[3]);
698
699 let loaded_arr = loaded.try_into_i32().unwrap();
700 assert_eq!(loaded_arr.as_slice().unwrap(), &data[..]);
701 let _ = std::fs::remove_file(&path);
702 }
703
704 #[test]
705 fn load_dynamic_ixdyn() {
706 let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
707 let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
708
709 let path = test_file("dyn_ixdyn.npy");
710 save(&path, &arr).unwrap();
711
712 let loaded: Array<f64, IxDyn> = load(&path).unwrap();
714 assert_eq!(loaded.shape(), &[2, 3]);
715 assert_eq!(loaded.as_slice().unwrap(), &data[..]);
716 let _ = std::fs::remove_file(&path);
717 }
718}