Skip to main content

ferray_io/npy/
mod.rs

1// ferray-io: .npy file I/O
2//
3// REQ-1: save(path, &array) writes .npy format
4// REQ-2: load::<T, D>(path) reads .npy and returns Result<Array<T, D>, FerrayError>
5// REQ-3: load_dynamic(path) reads .npy and returns Result<DynArray, FerrayError>
6// REQ-6: Support format versions 1.0, 2.0, 3.0
7// REQ-12: Support reading/writing both little-endian and big-endian
8
9pub mod dtype_parse;
10pub mod header;
11
12use std::fs::File;
13use std::io::{BufReader, BufWriter, Read, Write};
14use std::path::Path;
15
16use ferray_core::Array;
17use ferray_core::dimension::{Dimension, IxDyn};
18use ferray_core::dtype::{DType, Element};
19use ferray_core::dynarray::DynArray;
20use ferray_core::error::{FerrayError, FerrayResult};
21
22use self::dtype_parse::Endianness;
23
24/// Compute the total number of elements from a shape, using checked
25/// multiplication to guard against overflow from untrusted `.npy` files.
26pub(crate) fn checked_total_elements(shape: &[usize]) -> FerrayResult<usize> {
27    shape.iter().try_fold(1usize, |acc, &dim| {
28        acc.checked_mul(dim).ok_or_else(|| {
29            FerrayError::io_error("shape overflow: total elements exceed usize::MAX")
30        })
31    })
32}
33
34/// Save an array to a `.npy` file.
35///
36/// The file is written in native byte order with C (row-major) layout.
37///
38/// # Errors
39/// Returns `FerrayError::IoError` if the file cannot be created or written.
40pub fn save<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
41    path: P,
42    array: &Array<T, D>,
43) -> FerrayResult<()> {
44    let file = File::create(path.as_ref()).map_err(|e| {
45        FerrayError::io_error(format!(
46            "failed to create file '{}': {e}",
47            path.as_ref().display()
48        ))
49    })?;
50    let mut writer = BufWriter::new(file);
51    save_to_writer(&mut writer, array)
52}
53
54/// Save an array to a writer in `.npy` format.
55///
56/// If the array is C-contiguous, its data buffer is written directly.
57/// Otherwise, elements are iterated in logical (row-major) order and
58/// written individually — this handles transposed, sliced, or
59/// otherwise non-contiguous arrays transparently.
60pub fn save_to_writer<T: Element + NpyElement, D: Dimension, W: Write>(
61    writer: &mut W,
62    array: &Array<T, D>,
63) -> FerrayResult<()> {
64    let fortran_order = false;
65    header::write_header(writer, T::dtype(), array.shape(), fortran_order)?;
66
67    // Write data — fast path for contiguous, fallback for strided
68    if let Some(slice) = array.as_slice() {
69        T::write_slice(slice, writer)?;
70    } else {
71        // Non-contiguous: collect into logical order and write
72        let data: Vec<T> = array.iter().cloned().collect();
73        T::write_slice(&data, writer)?;
74    }
75
76    writer.flush()?;
77    Ok(())
78}
79
80/// Load an array from a `.npy` file with compile-time type and dimension.
81///
82/// # Errors
83/// - Returns `FerrayError::InvalidDtype` if the file's dtype doesn't match `T`.
84/// - Returns `FerrayError::ShapeMismatch` if the file's shape doesn't match `D`.
85/// - Returns `FerrayError::IoError` on file read failures.
86pub fn load<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
87    path: P,
88) -> FerrayResult<Array<T, D>> {
89    let file = File::open(path.as_ref()).map_err(|e| {
90        FerrayError::io_error(format!(
91            "failed to open file '{}': {e}",
92            path.as_ref().display()
93        ))
94    })?;
95    let mut reader = BufReader::new(file);
96    load_from_reader(&mut reader)
97}
98
99/// Load an array from a reader in `.npy` format with compile-time type.
100pub fn load_from_reader<T: Element + NpyElement, D: Dimension, R: Read>(
101    reader: &mut R,
102) -> FerrayResult<Array<T, D>> {
103    let hdr = header::read_header(reader)?;
104
105    // Check dtype matches T
106    if hdr.dtype != T::dtype() {
107        return Err(FerrayError::invalid_dtype(format!(
108            "expected dtype {:?} for type {}, but file has {:?}",
109            T::dtype(),
110            std::any::type_name::<T>(),
111            hdr.dtype,
112        )));
113    }
114
115    // Check dimension compatibility
116    if let Some(ndim) = D::NDIM {
117        if ndim != hdr.shape.len() {
118            return Err(FerrayError::shape_mismatch(format!(
119                "expected {} dimensions, but file has {} (shape {:?})",
120                ndim,
121                hdr.shape.len(),
122                hdr.shape,
123            )));
124        }
125    }
126
127    let total_elements = checked_total_elements(&hdr.shape)?;
128    let data = T::read_vec(reader, total_elements, hdr.endianness)?;
129
130    let dim = build_dimension::<D>(&hdr.shape)?;
131
132    if hdr.fortran_order {
133        Array::from_vec_f(dim, data)
134    } else {
135        Array::from_vec(dim, data)
136    }
137}
138
139/// Load a `.npy` file with runtime type dispatch.
140///
141/// Returns a `DynArray` whose variant corresponds to the file's dtype.
142///
143/// # Errors
144/// Returns errors on I/O failures or unsupported dtypes.
145pub fn load_dynamic<P: AsRef<Path>>(path: P) -> FerrayResult<DynArray> {
146    let file = File::open(path.as_ref()).map_err(|e| {
147        FerrayError::io_error(format!(
148            "failed to open file '{}': {e}",
149            path.as_ref().display()
150        ))
151    })?;
152    let mut reader = BufReader::new(file);
153    load_dynamic_from_reader(&mut reader)
154}
155
156/// Load a `.npy` from a reader with runtime type dispatch.
157pub fn load_dynamic_from_reader<R: Read>(reader: &mut R) -> FerrayResult<DynArray> {
158    let hdr = header::read_header(reader)?;
159    let total = checked_total_elements(&hdr.shape)?;
160    let dim = IxDyn::new(&hdr.shape);
161
162    macro_rules! load_typed {
163        ($ty:ty, $variant:ident) => {{
164            let data = <$ty as NpyElement>::read_vec(reader, total, hdr.endianness)?;
165            let arr = if hdr.fortran_order {
166                Array::<$ty, IxDyn>::from_vec_f(dim, data)?
167            } else {
168                Array::<$ty, IxDyn>::from_vec(dim, data)?
169            };
170            Ok(DynArray::$variant(arr))
171        }};
172    }
173
174    match hdr.dtype {
175        DType::Bool => load_typed!(bool, Bool),
176        DType::U8 => load_typed!(u8, U8),
177        DType::U16 => load_typed!(u16, U16),
178        DType::U32 => load_typed!(u32, U32),
179        DType::U64 => load_typed!(u64, U64),
180        DType::U128 => load_typed!(u128, U128),
181        DType::I8 => load_typed!(i8, I8),
182        DType::I16 => load_typed!(i16, I16),
183        DType::I32 => load_typed!(i32, I32),
184        DType::I64 => load_typed!(i64, I64),
185        DType::I128 => load_typed!(i128, I128),
186        #[cfg(feature = "f16")]
187        DType::F16 => load_typed!(half::f16, F16),
188        DType::F32 => load_typed!(f32, F32),
189        DType::F64 => load_typed!(f64, F64),
190        #[cfg(feature = "bf16")]
191        DType::BF16 => load_typed!(half::bf16, BF16),
192        DType::Complex32 => {
193            load_complex32_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
194        }
195        DType::Complex64 => {
196            load_complex64_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
197        }
198        _ => Err(FerrayError::invalid_dtype(format!(
199            "unsupported dtype {:?} for .npy loading",
200            hdr.dtype
201        ))),
202    }
203}
204
205/// Read complex64 (Complex<f32>) data via raw bytes.
206fn load_complex32_dynamic<R: Read>(
207    reader: &mut R,
208    total: usize,
209    dim: IxDyn,
210    fortran_order: bool,
211    endian: Endianness,
212) -> FerrayResult<DynArray> {
213    let byte_count = total * 8;
214    let mut raw = vec![0u8; byte_count];
215    reader.read_exact(&mut raw)?;
216
217    if endian.needs_swap() {
218        for chunk in raw.chunks_exact_mut(4) {
219            chunk.reverse();
220        }
221    }
222
223    load_complex32_from_bytes_copy(&raw, total, dim, fortran_order)
224}
225
226/// Build a Complex32 DynArray from raw bytes, respecting Fortran order.
227fn load_complex32_from_bytes_copy(
228    bytes: &[u8],
229    total: usize,
230    dim: IxDyn,
231    fortran_order: bool,
232) -> FerrayResult<DynArray> {
233    use num_complex::Complex;
234
235    // Parse raw bytes into Vec<Complex<f32>> by reading f32 pairs
236    let mut data = Vec::with_capacity(total);
237    for chunk in bytes.chunks_exact(8) {
238        let re = f32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
239        let im = f32::from_ne_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
240        data.push(Complex::new(re, im));
241    }
242
243    let arr = if fortran_order {
244        Array::<Complex<f32>, IxDyn>::from_vec_f(dim, data)?
245    } else {
246        Array::<Complex<f32>, IxDyn>::from_vec(dim, data)?
247    };
248    Ok(DynArray::Complex32(arr))
249}
250
251/// Read complex128 (Complex<f64>) data via raw bytes.
252fn load_complex64_dynamic<R: Read>(
253    reader: &mut R,
254    total: usize,
255    dim: IxDyn,
256    fortran_order: bool,
257    endian: Endianness,
258) -> FerrayResult<DynArray> {
259    let byte_count = total * 16;
260    let mut raw = vec![0u8; byte_count];
261    reader.read_exact(&mut raw)?;
262
263    if endian.needs_swap() {
264        for chunk in raw.chunks_exact_mut(8) {
265            chunk.reverse();
266        }
267    }
268
269    load_complex64_from_bytes_copy(&raw, total, dim, fortran_order)
270}
271
272fn load_complex64_from_bytes_copy(
273    bytes: &[u8],
274    total: usize,
275    dim: IxDyn,
276    fortran_order: bool,
277) -> FerrayResult<DynArray> {
278    use num_complex::Complex;
279
280    // Parse raw bytes into Vec<Complex<f64>> by reading f64 pairs
281    let mut data = Vec::with_capacity(total);
282    for chunk in bytes.chunks_exact(16) {
283        let re = f64::from_ne_bytes([
284            chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
285        ]);
286        let im = f64::from_ne_bytes([
287            chunk[8], chunk[9], chunk[10], chunk[11], chunk[12], chunk[13], chunk[14], chunk[15],
288        ]);
289        data.push(Complex::new(re, im));
290    }
291
292    let arr = if fortran_order {
293        Array::<Complex<f64>, IxDyn>::from_vec_f(dim, data)?
294    } else {
295        Array::<Complex<f64>, IxDyn>::from_vec(dim, data)?
296    };
297    Ok(DynArray::Complex64(arr))
298}
299
300/// Save a `DynArray` to a `.npy` file.
301pub fn save_dynamic<P: AsRef<Path>>(path: P, array: &DynArray) -> FerrayResult<()> {
302    let file = File::create(path.as_ref()).map_err(|e| {
303        FerrayError::io_error(format!(
304            "failed to create file '{}': {e}",
305            path.as_ref().display()
306        ))
307    })?;
308    let mut writer = BufWriter::new(file);
309    save_dynamic_to_writer(&mut writer, array)
310}
311
312/// Save a `DynArray` to a writer in `.npy` format.
313pub fn save_dynamic_to_writer<W: Write>(writer: &mut W, array: &DynArray) -> FerrayResult<()> {
314    macro_rules! save_typed {
315        ($arr:expr, $dtype:expr, $ty:ty) => {{
316            header::write_header(writer, $dtype, $arr.shape(), false)?;
317            if let Some(s) = $arr.as_slice() {
318                <$ty as NpyElement>::write_slice(s, writer)?;
319            } else {
320                let data: Vec<$ty> = $arr.iter().cloned().collect();
321                <$ty as NpyElement>::write_slice(&data, writer)?;
322            }
323        }};
324    }
325
326    match array {
327        DynArray::Bool(a) => save_typed!(a, DType::Bool, bool),
328        DynArray::U8(a) => save_typed!(a, DType::U8, u8),
329        DynArray::U16(a) => save_typed!(a, DType::U16, u16),
330        DynArray::U32(a) => save_typed!(a, DType::U32, u32),
331        DynArray::U64(a) => save_typed!(a, DType::U64, u64),
332        DynArray::U128(a) => save_typed!(a, DType::U128, u128),
333        DynArray::I8(a) => save_typed!(a, DType::I8, i8),
334        DynArray::I16(a) => save_typed!(a, DType::I16, i16),
335        DynArray::I32(a) => save_typed!(a, DType::I32, i32),
336        DynArray::I64(a) => save_typed!(a, DType::I64, i64),
337        DynArray::I128(a) => save_typed!(a, DType::I128, i128),
338        #[cfg(feature = "f16")]
339        DynArray::F16(a) => save_typed!(a, DType::F16, half::f16),
340        DynArray::F32(a) => save_typed!(a, DType::F32, f32),
341        DynArray::F64(a) => save_typed!(a, DType::F64, f64),
342        #[cfg(feature = "bf16")]
343        DynArray::BF16(a) => save_typed!(a, DType::BF16, half::bf16),
344        DynArray::Complex32(a) => {
345            header::write_header(writer, DType::Complex32, a.shape(), false)?;
346            save_complex_raw(a.as_slice(), 8, writer)?;
347        }
348        DynArray::Complex64(a) => {
349            header::write_header(writer, DType::Complex64, a.shape(), false)?;
350            save_complex_raw(a.as_slice(), 16, writer)?;
351        }
352        _ => {
353            return Err(FerrayError::invalid_dtype(
354                "unsupported DynArray variant for .npy saving",
355            ));
356        }
357    }
358
359    writer.flush()?;
360    Ok(())
361}
362
363/// Write complex array data as raw bytes without naming the Complex type.
364/// `elem_size` is the total size per element (8 for Complex<f32>, 16 for Complex<f64>).
365fn save_complex_raw<T, W: Write>(
366    slice_opt: Option<&[T]>,
367    elem_size: usize,
368    writer: &mut W,
369) -> FerrayResult<()> {
370    let slice = slice_opt
371        .ok_or_else(|| FerrayError::io_error("cannot save non-contiguous complex array"))?;
372    let byte_len = slice.len() * elem_size;
373    let bytes = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, byte_len) };
374    writer.write_all(bytes)?;
375    Ok(())
376}
377
378/// Build a dimension value of type `D` from a shape slice.
379fn build_dimension<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
380    build_dim_from_shape::<D>(shape)
381}
382
383/// Helper to build a dimension from a shape slice.
384/// This works for all fixed dimensions (Ix0-Ix6) and IxDyn.
385fn build_dim_from_shape<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
386    use ferray_core::dimension::*;
387    use std::any::Any;
388
389    if let Some(ndim) = D::NDIM {
390        if shape.len() != ndim {
391            return Err(FerrayError::shape_mismatch(format!(
392                "expected {ndim} dimensions, got {}",
393                shape.len()
394            )));
395        }
396    }
397
398    let type_id = std::any::TypeId::of::<D>();
399
400    macro_rules! try_dim {
401        ($dim_ty:ty, $dim_val:expr) => {
402            if type_id == std::any::TypeId::of::<$dim_ty>() {
403                let boxed: Box<dyn Any> = Box::new($dim_val);
404                return Ok(*boxed.downcast::<D>().unwrap());
405            }
406        };
407    }
408
409    try_dim!(IxDyn, IxDyn::new(shape));
410
411    match shape.len() {
412        0 => {
413            try_dim!(Ix0, Ix0);
414        }
415        1 => {
416            try_dim!(Ix1, Ix1::new([shape[0]]));
417        }
418        2 => {
419            try_dim!(Ix2, Ix2::new([shape[0], shape[1]]));
420        }
421        3 => {
422            try_dim!(Ix3, Ix3::new([shape[0], shape[1], shape[2]]));
423        }
424        4 => {
425            try_dim!(Ix4, Ix4::new([shape[0], shape[1], shape[2], shape[3]]));
426        }
427        5 => {
428            try_dim!(
429                Ix5,
430                Ix5::new([shape[0], shape[1], shape[2], shape[3], shape[4]])
431            );
432        }
433        6 => {
434            try_dim!(
435                Ix6,
436                Ix6::new([shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]])
437            );
438        }
439        _ => {}
440    }
441
442    Err(FerrayError::io_error(
443        "unsupported dimension type for .npy loading",
444    ))
445}
446
447// ---------------------------------------------------------------------------
448// NpyElement trait -- sealed, provides binary read/write for each element type
449// ---------------------------------------------------------------------------
450
451/// Trait for element types that support .npy binary serialization.
452///
453/// This is sealed and implemented for all primitive `Element` types
454/// (excluding Complex, which is handled via raw byte I/O in the dynamic path).
455pub trait NpyElement: Element + private::NpySealed {
456    /// Write a contiguous slice of elements to a writer in native byte order.
457    fn write_slice<W: Write>(data: &[Self], writer: &mut W) -> FerrayResult<()>;
458
459    /// Read `count` elements from a reader, applying byte-swapping if needed.
460    fn read_vec<R: Read>(
461        reader: &mut R,
462        count: usize,
463        endian: Endianness,
464    ) -> FerrayResult<Vec<Self>>;
465}
466
467mod private {
468    pub trait NpySealed {}
469}
470
471// ---------------------------------------------------------------------------
472// Macro for implementing NpyElement for primitive numeric types
473// ---------------------------------------------------------------------------
474
475macro_rules! impl_npy_element {
476    ($ty:ty, $size:expr) => {
477        impl private::NpySealed for $ty {}
478
479        impl NpyElement for $ty {
480            fn write_slice<W: Write>(data: &[$ty], writer: &mut W) -> FerrayResult<()> {
481                // Bulk write: reinterpret the typed slice as raw bytes and write
482                // in a single call. This is safe because the data is contiguous
483                // and we're writing in native byte order.
484                let byte_len = data.len() * $size;
485                // SAFETY: &[T] is contiguous and properly aligned; reinterpreting
486                // as &[u8] of length len*size_of::<T> is valid for any Copy type.
487                let bytes =
488                    unsafe { std::slice::from_raw_parts(data.as_ptr() as *const u8, byte_len) };
489                writer.write_all(bytes)?;
490                Ok(())
491            }
492
493            fn read_vec<R: Read>(
494                reader: &mut R,
495                count: usize,
496                endian: Endianness,
497            ) -> FerrayResult<Vec<$ty>> {
498                if !endian.needs_swap() {
499                    // Fast path: native endianness — bulk read and reinterpret.
500                    let byte_len = count * $size;
501                    let mut raw = vec![0u8; byte_len];
502                    reader.read_exact(&mut raw)?;
503                    let mut result = Vec::with_capacity(count);
504                    for chunk in raw.chunks_exact($size) {
505                        let arr: [u8; $size] = chunk.try_into().unwrap();
506                        result.push(<$ty>::from_ne_bytes(arr));
507                    }
508                    Ok(result)
509                } else {
510                    // Byte-swap path: read all bytes, then swap+convert per element.
511                    let byte_len = count * $size;
512                    let mut raw = vec![0u8; byte_len];
513                    reader.read_exact(&mut raw)?;
514                    let mut result = Vec::with_capacity(count);
515                    for chunk in raw.chunks_exact_mut($size) {
516                        chunk.reverse();
517                        let arr: [u8; $size] = chunk.try_into().unwrap();
518                        result.push(<$ty>::from_ne_bytes(arr));
519                    }
520                    Ok(result)
521                }
522            }
523        }
524    };
525}
526
527// Bool -- special case
528impl private::NpySealed for bool {}
529
530impl NpyElement for bool {
531    fn write_slice<W: Write>(data: &[bool], writer: &mut W) -> FerrayResult<()> {
532        for &val in data {
533            writer.write_all(&[val as u8])?;
534        }
535        Ok(())
536    }
537
538    fn read_vec<R: Read>(
539        reader: &mut R,
540        count: usize,
541        _endian: Endianness,
542    ) -> FerrayResult<Vec<bool>> {
543        let mut result = Vec::with_capacity(count);
544        let mut buf = [0u8; 1];
545        for _ in 0..count {
546            reader.read_exact(&mut buf)?;
547            result.push(buf[0] != 0);
548        }
549        Ok(result)
550    }
551}
552
553impl_npy_element!(u8, 1);
554impl_npy_element!(u16, 2);
555impl_npy_element!(u32, 4);
556impl_npy_element!(u64, 8);
557impl_npy_element!(u128, 16);
558impl_npy_element!(i8, 1);
559impl_npy_element!(i16, 2);
560impl_npy_element!(i32, 4);
561impl_npy_element!(i64, 8);
562impl_npy_element!(i128, 16);
563impl_npy_element!(f32, 4);
564impl_npy_element!(f64, 8);
565
566#[cfg(feature = "f16")]
567impl_npy_element!(half::f16, 2);
568#[cfg(feature = "bf16")]
569impl_npy_element!(half::bf16, 2);
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574    use ferray_core::dimension::{Ix1, Ix2};
575    use std::io::Cursor;
576
577    /// Create a temporary directory for tests that auto-cleans on drop.
578    fn test_dir() -> std::path::PathBuf {
579        let dir = std::env::temp_dir().join(format!("ferray_io_test_{}", std::process::id()));
580        let _ = std::fs::create_dir_all(&dir);
581        dir
582    }
583
584    fn test_file(name: &str) -> std::path::PathBuf {
585        let dir = test_dir();
586        dir.join(name)
587    }
588
589    #[test]
590    fn roundtrip_f64_1d() {
591        let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
592        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
593
594        let path = test_file("rt_f64_1d.npy");
595        save(&path, &arr).unwrap();
596        let loaded: Array<f64, Ix1> = load(&path).unwrap();
597
598        assert_eq!(loaded.shape(), &[5]);
599        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
600        let _ = std::fs::remove_file(&path);
601    }
602
603    #[test]
604    fn roundtrip_f32_2d() {
605        let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
606        let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
607
608        let path = test_file("rt_f32_2d.npy");
609        save(&path, &arr).unwrap();
610        let loaded: Array<f32, Ix2> = load(&path).unwrap();
611
612        assert_eq!(loaded.shape(), &[2, 3]);
613        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
614        let _ = std::fs::remove_file(&path);
615    }
616
617    #[test]
618    fn roundtrip_i32() {
619        let data = vec![10i32, 20, 30, 40];
620        let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
621
622        let path = test_file("rt_i32.npy");
623        save(&path, &arr).unwrap();
624        let loaded: Array<i32, Ix1> = load(&path).unwrap();
625        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
626        let _ = std::fs::remove_file(&path);
627    }
628
629    #[test]
630    fn roundtrip_i64() {
631        let data = vec![100i64, 200, 300];
632        let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
633
634        let path = test_file("rt_i64.npy");
635        save(&path, &arr).unwrap();
636        let loaded: Array<i64, Ix1> = load(&path).unwrap();
637        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
638        let _ = std::fs::remove_file(&path);
639    }
640
641    #[test]
642    fn roundtrip_u8() {
643        let data = vec![0u8, 128, 255];
644        let arr = Array::<u8, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
645
646        let path = test_file("rt_u8.npy");
647        save(&path, &arr).unwrap();
648        let loaded: Array<u8, Ix1> = load(&path).unwrap();
649        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
650        let _ = std::fs::remove_file(&path);
651    }
652
653    #[test]
654    fn roundtrip_bool() {
655        let data = vec![true, false, true, true, false];
656        let arr = Array::<bool, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
657
658        let path = test_file("rt_bool.npy");
659        save(&path, &arr).unwrap();
660        let loaded: Array<bool, Ix1> = load(&path).unwrap();
661        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
662        let _ = std::fs::remove_file(&path);
663    }
664
665    #[test]
666    fn roundtrip_in_memory() {
667        let data = vec![1.0_f64, 2.0, 3.0];
668        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
669
670        let mut buf = Vec::new();
671        save_to_writer(&mut buf, &arr).unwrap();
672
673        let mut cursor = Cursor::new(buf);
674        let loaded: Array<f64, Ix1> = load_from_reader(&mut cursor).unwrap();
675        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
676    }
677
678    #[test]
679    fn load_dynamic_f64() {
680        let data = vec![1.0_f64, 2.0, 3.0];
681        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
682
683        let path = test_file("dyn_f64.npy");
684        save(&path, &arr).unwrap();
685        let dyn_arr = load_dynamic(&path).unwrap();
686
687        assert_eq!(dyn_arr.dtype(), DType::F64);
688        assert_eq!(dyn_arr.shape(), &[3]);
689        let _ = std::fs::remove_file(&path);
690    }
691
692    #[test]
693    fn load_wrong_dtype_error() {
694        let data = vec![1.0_f64, 2.0, 3.0];
695        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
696
697        let path = test_file("wrong_dtype.npy");
698        save(&path, &arr).unwrap();
699
700        let result = load::<f32, Ix1, _>(&path);
701        assert!(result.is_err());
702        let _ = std::fs::remove_file(&path);
703    }
704
705    #[test]
706    fn load_wrong_ndim_error() {
707        let data = vec![1.0_f64, 2.0, 3.0];
708        let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
709
710        let path = test_file("wrong_ndim.npy");
711        save(&path, &arr).unwrap();
712
713        let result = load::<f64, Ix2, _>(&path);
714        assert!(result.is_err());
715        let _ = std::fs::remove_file(&path);
716    }
717
718    #[test]
719    fn roundtrip_dynamic() {
720        let data = vec![10i32, 20, 30];
721        let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), data.clone()).unwrap();
722        let dyn_arr = DynArray::I32(arr);
723
724        let path = test_file("rt_dynamic.npy");
725        save_dynamic(&path, &dyn_arr).unwrap();
726
727        let loaded = load_dynamic(&path).unwrap();
728        assert_eq!(loaded.dtype(), DType::I32);
729        assert_eq!(loaded.shape(), &[3]);
730
731        let loaded_arr = loaded.try_into_i32().unwrap();
732        assert_eq!(loaded_arr.as_slice().unwrap(), &data[..]);
733        let _ = std::fs::remove_file(&path);
734    }
735
736    #[test]
737    fn load_dynamic_ixdyn() {
738        let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
739        let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
740
741        let path = test_file("dyn_ixdyn.npy");
742        save(&path, &arr).unwrap();
743
744        // Load as IxDyn
745        let loaded: Array<f64, IxDyn> = load(&path).unwrap();
746        assert_eq!(loaded.shape(), &[2, 3]);
747        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
748        let _ = std::fs::remove_file(&path);
749    }
750
751    #[test]
752    fn load_fortran_order_npy() {
753        // Manually construct a .npy file with fortran_order=True.
754        // Data for a 2x3 f64 array in column-major order:
755        //   logical layout: [[1, 2, 3], [4, 5, 6]]
756        //   Fortran storage: [1, 4, 2, 5, 3, 6] (columns first)
757        let mut buf = Vec::new();
758        // Write header manually
759        let header_str = "{'descr': '<f8', 'fortran_order': True, 'shape': (2, 3), }";
760        let header_len = header_str.len();
761        // Pad to 64-byte alignment (magic=6 + version=2 + hdr_len=2 + header)
762        let total_before_pad = 6 + 2 + 2 + header_len;
763        let padding = 64 - (total_before_pad % 64);
764        let padded_header_len = header_len + padding;
765
766        // Magic
767        buf.extend_from_slice(b"\x93NUMPY");
768        // Version 1.0
769        buf.push(1);
770        buf.push(0);
771        // Header length (little-endian u16)
772        buf.extend_from_slice(&(padded_header_len as u16).to_le_bytes());
773        // Header string
774        buf.extend_from_slice(header_str.as_bytes());
775        // Padding (spaces + newline)
776        buf.extend(std::iter::repeat_n(b' ', padding - 1));
777        buf.push(b'\n');
778
779        // Data: 6 f64 values in Fortran (column-major) order
780        // Logical: [[1, 2, 3], [4, 5, 6]]
781        // Fortran storage: col0=[1,4], col1=[2,5], col2=[3,6]
782        for &v in &[1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0] {
783            buf.extend_from_slice(&v.to_le_bytes());
784        }
785
786        let mut cursor = Cursor::new(buf);
787        let loaded: Array<f64, Ix2> = load_from_reader(&mut cursor).unwrap();
788        assert_eq!(loaded.shape(), &[2, 3]);
789        // The logical data should be [[1,2,3],[4,5,6]] regardless of storage order
790        let flat: Vec<f64> = loaded.iter().copied().collect();
791        assert_eq!(flat, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
792    }
793
794    #[test]
795    fn roundtrip_from_vec_f() {
796        // Create a Fortran-order array, save, and reload
797        let data = vec![1.0_f64, 4.0, 2.0, 5.0, 3.0, 6.0];
798        let arr = Array::<f64, Ix2>::from_vec_f(Ix2::new([2, 3]), data).unwrap();
799        assert_eq!(arr.shape(), &[2, 3]);
800
801        // Save (will iterate in logical order for non-contiguous)
802        let mut buf = Vec::new();
803        save_to_writer(&mut buf, &arr).unwrap();
804
805        let mut cursor = Cursor::new(buf);
806        let loaded: Array<f64, Ix2> = load_from_reader(&mut cursor).unwrap();
807        assert_eq!(loaded.shape(), &[2, 3]);
808        // Compare logical element order
809        let orig: Vec<f64> = arr.iter().copied().collect();
810        let back: Vec<f64> = loaded.iter().copied().collect();
811        assert_eq!(orig, back);
812    }
813
814    // --- Malformed .npy file tests ---
815
816    #[test]
817    fn malformed_bad_magic() {
818        let data = b"NOT_NPY_FILE_DATA_HERE";
819        let mut cursor = Cursor::new(data.to_vec());
820        let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
821        assert!(result.is_err());
822        let msg = result.unwrap_err().to_string();
823        assert!(
824            msg.contains("magic") || msg.contains("not a valid"),
825            "got: {msg}"
826        );
827    }
828
829    #[test]
830    fn malformed_truncated_header() {
831        // Valid magic + version but truncated before header length
832        let mut data = Vec::new();
833        data.extend_from_slice(b"\x93NUMPY");
834        data.push(1); // version 1
835        data.push(0);
836        // Missing header length bytes
837        let mut cursor = Cursor::new(data);
838        let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
839        assert!(result.is_err());
840    }
841
842    #[test]
843    fn malformed_truncated_data() {
844        // Valid header but not enough data bytes
845        let mut buf = Vec::new();
846        let header_str = "{'descr': '<f8', 'fortran_order': False, 'shape': (100,), }";
847        let header_len = header_str.len();
848        let total = 6 + 2 + 2 + header_len;
849        let padding = 64 - (total % 64);
850        let padded_len = header_len + padding;
851
852        buf.extend_from_slice(b"\x93NUMPY");
853        buf.push(1);
854        buf.push(0);
855        buf.extend_from_slice(&(padded_len as u16).to_le_bytes());
856        buf.extend_from_slice(header_str.as_bytes());
857        buf.extend(std::iter::repeat_n(b' ', padding - 1));
858        buf.push(b'\n');
859        // Only write 3 f64 values instead of 100
860        for &v in &[1.0_f64, 2.0, 3.0] {
861            buf.extend_from_slice(&v.to_le_bytes());
862        }
863
864        let mut cursor = Cursor::new(buf);
865        let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
866        assert!(result.is_err(), "should fail with truncated data");
867    }
868
869    #[test]
870    fn malformed_unsupported_version() {
871        let mut data = Vec::new();
872        data.extend_from_slice(b"\x93NUMPY");
873        data.push(9); // version 9.0 — unsupported
874        data.push(0);
875        data.extend_from_slice(&[10, 0]); // header length
876        data.extend_from_slice(b"0123456789"); // dummy header
877        let mut cursor = Cursor::new(data);
878        let result = load_from_reader::<f64, Ix1, _>(&mut cursor);
879        assert!(result.is_err());
880        let msg = result.unwrap_err().to_string();
881        assert!(msg.contains("version"), "got: {msg}");
882    }
883
884    #[test]
885    fn malformed_empty_file() {
886        let cursor = Cursor::new(Vec::<u8>::new());
887        let result = load_from_reader::<f64, Ix1, _>(&mut cursor.clone());
888        assert!(result.is_err());
889    }
890
891    #[test]
892    fn load_big_endian_f64() {
893        // Construct a .npy file with big-endian f64 data ('>f8')
894        let mut buf = Vec::new();
895        let header_str = "{'descr': '>f8', 'fortran_order': False, 'shape': (3,), }";
896        let header_len = header_str.len();
897        let total = 6 + 2 + 2 + header_len;
898        let padding = 64 - (total % 64);
899        let padded_len = header_len + padding;
900
901        buf.extend_from_slice(b"\x93NUMPY");
902        buf.push(1);
903        buf.push(0);
904        buf.extend_from_slice(&(padded_len as u16).to_le_bytes());
905        buf.extend_from_slice(header_str.as_bytes());
906        buf.extend(std::iter::repeat_n(b' ', padding - 1));
907        buf.push(b'\n');
908
909        // Write 3 f64 values in BIG-endian byte order
910        for &v in &[1.0_f64, 2.5, -4.75] {
911            buf.extend_from_slice(&v.to_be_bytes());
912        }
913
914        let mut cursor = Cursor::new(buf);
915        let loaded: Array<f64, Ix1> = load_from_reader(&mut cursor).unwrap();
916        assert_eq!(loaded.shape(), &[3]);
917        let data = loaded.as_slice().unwrap();
918        assert!((data[0] - 1.0).abs() < 1e-15);
919        assert!((data[1] - 2.5).abs() < 1e-15);
920        assert!((data[2] - (-4.75)).abs() < 1e-15);
921    }
922
923    #[test]
924    fn load_big_endian_i32() {
925        // Big-endian i32 ('>i4')
926        let mut buf = Vec::new();
927        let header_str = "{'descr': '>i4', 'fortran_order': False, 'shape': (4,), }";
928        let header_len = header_str.len();
929        let total = 6 + 2 + 2 + header_len;
930        let padding = 64 - (total % 64);
931        let padded_len = header_len + padding;
932
933        buf.extend_from_slice(b"\x93NUMPY");
934        buf.push(1);
935        buf.push(0);
936        buf.extend_from_slice(&(padded_len as u16).to_le_bytes());
937        buf.extend_from_slice(header_str.as_bytes());
938        buf.extend(std::iter::repeat_n(b' ', padding - 1));
939        buf.push(b'\n');
940
941        for &v in &[1_i32, -2, 1000, i32::MAX] {
942            buf.extend_from_slice(&v.to_be_bytes());
943        }
944
945        let mut cursor = Cursor::new(buf);
946        let loaded: Array<i32, Ix1> = load_from_reader(&mut cursor).unwrap();
947        assert_eq!(loaded.shape(), &[4]);
948        let data = loaded.as_slice().unwrap();
949        assert_eq!(data, &[1, -2, 1000, i32::MAX]);
950    }
951
952    // -----------------------------------------------------------------------
953    // f16 / bf16 round-trip tests (issue #118)
954    // -----------------------------------------------------------------------
955
956    #[cfg(feature = "f16")]
957    #[test]
958    fn roundtrip_f16_1d() {
959        use half::f16;
960        let data: Vec<f16> = [0.0, 1.0, -1.5, 2.25, 3.5, -0.125]
961            .iter()
962            .map(|&v: &f32| f16::from_f32(v))
963            .collect();
964        let arr = Array::<f16, Ix1>::from_vec(Ix1::new([6]), data.clone()).unwrap();
965
966        let path = test_file("rt_f16_1d.npy");
967        save(&path, &arr).unwrap();
968        let loaded: Array<f16, Ix1> = load(&path).unwrap();
969        assert_eq!(loaded.shape(), &[6]);
970        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
971        let _ = std::fs::remove_file(&path);
972    }
973
974    #[cfg(feature = "f16")]
975    #[test]
976    fn roundtrip_f16_2d() {
977        use half::f16;
978        let data: Vec<f16> = (0..12)
979            .map(|i| f16::from_f32(i as f32 * 0.25 - 1.0))
980            .collect();
981        let arr = Array::<f16, Ix2>::from_vec(Ix2::new([3, 4]), data.clone()).unwrap();
982
983        let path = test_file("rt_f16_2d.npy");
984        save(&path, &arr).unwrap();
985        let loaded: Array<f16, Ix2> = load(&path).unwrap();
986        assert_eq!(loaded.shape(), &[3, 4]);
987        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
988        let _ = std::fs::remove_file(&path);
989    }
990
991    #[cfg(feature = "f16")]
992    #[test]
993    fn roundtrip_f16_dynamic() {
994        use half::f16;
995        let data: Vec<f16> = (0..8).map(|i| f16::from_f32(i as f32)).collect();
996        let arr = Array::<f16, IxDyn>::from_vec(IxDyn::new(&[2, 4]), data.clone()).unwrap();
997        let dyn_in = DynArray::F16(arr);
998
999        let path = test_file("rt_f16_dyn.npy");
1000        save_dynamic(&path, &dyn_in).unwrap();
1001        let loaded = load_dynamic(&path).unwrap();
1002        assert_eq!(loaded.dtype(), DType::F16);
1003        assert_eq!(loaded.shape(), &[2, 4]);
1004        match loaded {
1005            DynArray::F16(a) => assert_eq!(a.as_slice().unwrap(), &data[..]),
1006            _ => panic!("expected F16 variant"),
1007        }
1008        let _ = std::fs::remove_file(&path);
1009    }
1010
1011    #[cfg(feature = "f16")]
1012    #[test]
1013    fn f16_descriptor_is_f2() {
1014        use half::f16;
1015        let arr = Array::<f16, Ix1>::from_vec(Ix1::new([2]), vec![f16::ZERO, f16::ONE]).unwrap();
1016        let mut buf = Vec::new();
1017        save_to_writer(&mut buf, &arr).unwrap();
1018        // The header must contain the standard NumPy `f2` dtype descriptor
1019        // so files are interoperable with vanilla NumPy. Inspect only the
1020        // header region — trailing bytes are raw data, not text.
1021        let header_len = buf.len().saturating_sub(4); // strip the 4 data bytes
1022        let header = String::from_utf8_lossy(&buf[..header_len]);
1023        assert!(
1024            header.contains("f2"),
1025            "expected 'f2' in header, got: {header}"
1026        );
1027    }
1028
1029    #[cfg(feature = "bf16")]
1030    #[test]
1031    fn roundtrip_bf16_1d() {
1032        use half::bf16;
1033        let data: Vec<bf16> = [0.0, 1.0, -1.5, 2.25, 3.5, -0.125]
1034            .iter()
1035            .map(|&v: &f32| bf16::from_f32(v))
1036            .collect();
1037        let arr = Array::<bf16, Ix1>::from_vec(Ix1::new([6]), data.clone()).unwrap();
1038
1039        let path = test_file("rt_bf16_1d.npy");
1040        save(&path, &arr).unwrap();
1041        let loaded: Array<bf16, Ix1> = load(&path).unwrap();
1042        assert_eq!(loaded.shape(), &[6]);
1043        assert_eq!(loaded.as_slice().unwrap(), &data[..]);
1044        let _ = std::fs::remove_file(&path);
1045    }
1046
1047    #[cfg(feature = "bf16")]
1048    #[test]
1049    fn roundtrip_bf16_dynamic() {
1050        use half::bf16;
1051        let data: Vec<bf16> = (0..6).map(|i| bf16::from_f32(i as f32 * 0.5)).collect();
1052        let arr = Array::<bf16, IxDyn>::from_vec(IxDyn::new(&[2, 3]), data.clone()).unwrap();
1053        let dyn_in = DynArray::BF16(arr);
1054
1055        let path = test_file("rt_bf16_dyn.npy");
1056        save_dynamic(&path, &dyn_in).unwrap();
1057        let loaded = load_dynamic(&path).unwrap();
1058        assert_eq!(loaded.dtype(), DType::BF16);
1059        assert_eq!(loaded.shape(), &[2, 3]);
1060        match loaded {
1061            DynArray::BF16(a) => assert_eq!(a.as_slice().unwrap(), &data[..]),
1062            _ => panic!("expected BF16 variant"),
1063        }
1064        let _ = std::fs::remove_file(&path);
1065    }
1066
1067    #[cfg(feature = "bf16")]
1068    #[test]
1069    fn bf16_descriptor_is_bf16_tag() {
1070        use half::bf16;
1071        let arr = Array::<bf16, Ix1>::from_vec(Ix1::new([2]), vec![bf16::ZERO, bf16::ONE]).unwrap();
1072        let mut buf = Vec::new();
1073        save_to_writer(&mut buf, &arr).unwrap();
1074        // ferray-specific non-NumPy tag for bfloat16.
1075        let header_len = buf.len().saturating_sub(4); // strip the 4 data bytes
1076        let header = String::from_utf8_lossy(&buf[..header_len]);
1077        assert!(
1078            header.contains("bf16"),
1079            "expected 'bf16' in header, got: {header}"
1080        );
1081    }
1082}