pub mod dtype_parse;
pub mod header;
use std::fs::File;
use std::io::{BufReader, BufWriter, Read, Write};
use std::path::Path;
use ferray_core::Array;
use ferray_core::dimension::{Dimension, IxDyn};
use ferray_core::dtype::{DType, Element};
use ferray_core::dynarray::DynArray;
use ferray_core::error::{FerrayError, FerrayResult};
use self::dtype_parse::Endianness;
pub(crate) fn checked_total_elements(shape: &[usize]) -> FerrayResult<usize> {
shape.iter().try_fold(1usize, |acc, &dim| {
acc.checked_mul(dim)
.ok_or_else(|| FerrayError::io_error("shape overflow: total elements exceed usize::MAX"))
})
}
pub fn save<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
path: P,
array: &Array<T, D>,
) -> FerrayResult<()> {
let file = File::create(path.as_ref()).map_err(|e| {
FerrayError::io_error(format!(
"failed to create file '{}': {e}",
path.as_ref().display()
))
})?;
let mut writer = BufWriter::new(file);
save_to_writer(&mut writer, array)
}
pub fn save_to_writer<T: Element + NpyElement, D: Dimension, W: Write>(
writer: &mut W,
array: &Array<T, D>,
) -> FerrayResult<()> {
let fortran_order = false;
header::write_header(writer, T::dtype(), array.shape(), fortran_order)?;
if let Some(slice) = array.as_slice() {
T::write_slice(slice, writer)?;
} else {
return Err(FerrayError::io_error(
"cannot save non-contiguous array to .npy (make contiguous first)",
));
}
writer.flush()?;
Ok(())
}
pub fn load<T: Element + NpyElement, D: Dimension, P: AsRef<Path>>(
path: P,
) -> FerrayResult<Array<T, D>> {
let file = File::open(path.as_ref()).map_err(|e| {
FerrayError::io_error(format!(
"failed to open file '{}': {e}",
path.as_ref().display()
))
})?;
let mut reader = BufReader::new(file);
load_from_reader(&mut reader)
}
pub fn load_from_reader<T: Element + NpyElement, D: Dimension, R: Read>(
reader: &mut R,
) -> FerrayResult<Array<T, D>> {
let hdr = header::read_header(reader)?;
if hdr.dtype != T::dtype() {
return Err(FerrayError::invalid_dtype(format!(
"expected dtype {:?} for type {}, but file has {:?}",
T::dtype(),
std::any::type_name::<T>(),
hdr.dtype,
)));
}
if let Some(ndim) = D::NDIM {
if ndim != hdr.shape.len() {
return Err(FerrayError::shape_mismatch(format!(
"expected {} dimensions, but file has {} (shape {:?})",
ndim,
hdr.shape.len(),
hdr.shape,
)));
}
}
let total_elements = checked_total_elements(&hdr.shape)?;
let data = T::read_vec(reader, total_elements, hdr.endianness)?;
let dim = build_dimension::<D>(&hdr.shape)?;
if hdr.fortran_order {
Array::from_vec_f(dim, data)
} else {
Array::from_vec(dim, data)
}
}
pub fn load_dynamic<P: AsRef<Path>>(path: P) -> FerrayResult<DynArray> {
let file = File::open(path.as_ref()).map_err(|e| {
FerrayError::io_error(format!(
"failed to open file '{}': {e}",
path.as_ref().display()
))
})?;
let mut reader = BufReader::new(file);
load_dynamic_from_reader(&mut reader)
}
pub fn load_dynamic_from_reader<R: Read>(reader: &mut R) -> FerrayResult<DynArray> {
let hdr = header::read_header(reader)?;
let total = checked_total_elements(&hdr.shape)?;
let dim = IxDyn::new(&hdr.shape);
macro_rules! load_typed {
($ty:ty, $variant:ident) => {{
let data = <$ty as NpyElement>::read_vec(reader, total, hdr.endianness)?;
let arr = if hdr.fortran_order {
Array::<$ty, IxDyn>::from_vec_f(dim, data)?
} else {
Array::<$ty, IxDyn>::from_vec(dim, data)?
};
Ok(DynArray::$variant(arr))
}};
}
match hdr.dtype {
DType::Bool => load_typed!(bool, Bool),
DType::U8 => load_typed!(u8, U8),
DType::U16 => load_typed!(u16, U16),
DType::U32 => load_typed!(u32, U32),
DType::U64 => load_typed!(u64, U64),
DType::U128 => load_typed!(u128, U128),
DType::I8 => load_typed!(i8, I8),
DType::I16 => load_typed!(i16, I16),
DType::I32 => load_typed!(i32, I32),
DType::I64 => load_typed!(i64, I64),
DType::I128 => load_typed!(i128, I128),
DType::F32 => load_typed!(f32, F32),
DType::F64 => load_typed!(f64, F64),
DType::Complex32 => {
load_complex32_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
}
DType::Complex64 => {
load_complex64_dynamic(reader, total, dim, hdr.fortran_order, hdr.endianness)
}
_ => Err(FerrayError::invalid_dtype(format!(
"unsupported dtype {:?} for .npy loading",
hdr.dtype
))),
}
}
fn load_complex32_dynamic<R: Read>(
reader: &mut R,
total: usize,
dim: IxDyn,
fortran_order: bool,
endian: Endianness,
) -> FerrayResult<DynArray> {
let byte_count = total * 8;
let mut raw = vec![0u8; byte_count];
reader.read_exact(&mut raw)?;
if endian.needs_swap() {
for chunk in raw.chunks_exact_mut(4) {
chunk.reverse();
}
}
load_complex32_from_bytes_copy(&raw, total, dim, fortran_order)
}
fn load_complex32_from_bytes_copy(
bytes: &[u8],
total: usize,
dim: IxDyn,
fortran_order: bool,
) -> FerrayResult<DynArray> {
use num_complex::Complex;
let mut data = Vec::with_capacity(total);
for chunk in bytes.chunks_exact(8) {
let re = f32::from_ne_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
let im = f32::from_ne_bytes([chunk[4], chunk[5], chunk[6], chunk[7]]);
data.push(Complex::new(re, im));
}
let arr = if fortran_order {
Array::<Complex<f32>, IxDyn>::from_vec_f(dim, data)?
} else {
Array::<Complex<f32>, IxDyn>::from_vec(dim, data)?
};
Ok(DynArray::Complex32(arr))
}
fn load_complex64_dynamic<R: Read>(
reader: &mut R,
total: usize,
dim: IxDyn,
fortran_order: bool,
endian: Endianness,
) -> FerrayResult<DynArray> {
let byte_count = total * 16;
let mut raw = vec![0u8; byte_count];
reader.read_exact(&mut raw)?;
if endian.needs_swap() {
for chunk in raw.chunks_exact_mut(8) {
chunk.reverse();
}
}
load_complex64_from_bytes_copy(&raw, total, dim, fortran_order)
}
fn load_complex64_from_bytes_copy(
bytes: &[u8],
total: usize,
dim: IxDyn,
fortran_order: bool,
) -> FerrayResult<DynArray> {
use num_complex::Complex;
let mut data = Vec::with_capacity(total);
for chunk in bytes.chunks_exact(16) {
let re = f64::from_ne_bytes([
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6], chunk[7],
]);
let im = f64::from_ne_bytes([
chunk[8], chunk[9], chunk[10], chunk[11], chunk[12], chunk[13], chunk[14], chunk[15],
]);
data.push(Complex::new(re, im));
}
let arr = if fortran_order {
Array::<Complex<f64>, IxDyn>::from_vec_f(dim, data)?
} else {
Array::<Complex<f64>, IxDyn>::from_vec(dim, data)?
};
Ok(DynArray::Complex64(arr))
}
pub fn save_dynamic<P: AsRef<Path>>(path: P, array: &DynArray) -> FerrayResult<()> {
let file = File::create(path.as_ref()).map_err(|e| {
FerrayError::io_error(format!(
"failed to create file '{}': {e}",
path.as_ref().display()
))
})?;
let mut writer = BufWriter::new(file);
save_dynamic_to_writer(&mut writer, array)
}
pub fn save_dynamic_to_writer<W: Write>(writer: &mut W, array: &DynArray) -> FerrayResult<()> {
macro_rules! save_typed {
($arr:expr, $dtype:expr, $ty:ty) => {{
header::write_header(writer, $dtype, $arr.shape(), false)?;
if let Some(s) = $arr.as_slice() {
<$ty as NpyElement>::write_slice(s, writer)?;
} else {
return Err(FerrayError::io_error(
"cannot save non-contiguous DynArray to .npy",
));
}
}};
}
match array {
DynArray::Bool(a) => save_typed!(a, DType::Bool, bool),
DynArray::U8(a) => save_typed!(a, DType::U8, u8),
DynArray::U16(a) => save_typed!(a, DType::U16, u16),
DynArray::U32(a) => save_typed!(a, DType::U32, u32),
DynArray::U64(a) => save_typed!(a, DType::U64, u64),
DynArray::U128(a) => save_typed!(a, DType::U128, u128),
DynArray::I8(a) => save_typed!(a, DType::I8, i8),
DynArray::I16(a) => save_typed!(a, DType::I16, i16),
DynArray::I32(a) => save_typed!(a, DType::I32, i32),
DynArray::I64(a) => save_typed!(a, DType::I64, i64),
DynArray::I128(a) => save_typed!(a, DType::I128, i128),
DynArray::F32(a) => save_typed!(a, DType::F32, f32),
DynArray::F64(a) => save_typed!(a, DType::F64, f64),
DynArray::Complex32(a) => {
header::write_header(writer, DType::Complex32, a.shape(), false)?;
save_complex_raw(a.as_slice(), 8, writer)?;
}
DynArray::Complex64(a) => {
header::write_header(writer, DType::Complex64, a.shape(), false)?;
save_complex_raw(a.as_slice(), 16, writer)?;
}
_ => {
return Err(FerrayError::invalid_dtype(
"unsupported DynArray variant for .npy saving",
));
}
}
writer.flush()?;
Ok(())
}
fn save_complex_raw<T, W: Write>(
slice_opt: Option<&[T]>,
elem_size: usize,
writer: &mut W,
) -> FerrayResult<()> {
let slice = slice_opt
.ok_or_else(|| FerrayError::io_error("cannot save non-contiguous complex array"))?;
let byte_len = slice.len() * elem_size;
let bytes = unsafe { std::slice::from_raw_parts(slice.as_ptr() as *const u8, byte_len) };
writer.write_all(bytes)?;
Ok(())
}
fn build_dimension<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
build_dim_from_shape::<D>(shape)
}
fn build_dim_from_shape<D: Dimension>(shape: &[usize]) -> FerrayResult<D> {
use ferray_core::dimension::*;
use std::any::Any;
if let Some(ndim) = D::NDIM {
if shape.len() != ndim {
return Err(FerrayError::shape_mismatch(format!(
"expected {ndim} dimensions, got {}",
shape.len()
)));
}
}
let type_id = std::any::TypeId::of::<D>();
macro_rules! try_dim {
($dim_ty:ty, $dim_val:expr) => {
if type_id == std::any::TypeId::of::<$dim_ty>() {
let boxed: Box<dyn Any> = Box::new($dim_val);
return Ok(*boxed.downcast::<D>().unwrap());
}
};
}
try_dim!(IxDyn, IxDyn::new(shape));
match shape.len() {
0 => {
try_dim!(Ix0, Ix0);
}
1 => {
try_dim!(Ix1, Ix1::new([shape[0]]));
}
2 => {
try_dim!(Ix2, Ix2::new([shape[0], shape[1]]));
}
3 => {
try_dim!(Ix3, Ix3::new([shape[0], shape[1], shape[2]]));
}
4 => {
try_dim!(Ix4, Ix4::new([shape[0], shape[1], shape[2], shape[3]]));
}
5 => {
try_dim!(
Ix5,
Ix5::new([shape[0], shape[1], shape[2], shape[3], shape[4]])
);
}
6 => {
try_dim!(
Ix6,
Ix6::new([shape[0], shape[1], shape[2], shape[3], shape[4], shape[5]])
);
}
_ => {}
}
Err(FerrayError::io_error(
"unsupported dimension type for .npy loading",
))
}
pub trait NpyElement: Element + private::NpySealed {
fn write_slice<W: Write>(data: &[Self], writer: &mut W) -> FerrayResult<()>;
fn read_vec<R: Read>(
reader: &mut R,
count: usize,
endian: Endianness,
) -> FerrayResult<Vec<Self>>;
}
mod private {
pub trait NpySealed {}
}
macro_rules! impl_npy_element {
($ty:ty, $size:expr) => {
impl private::NpySealed for $ty {}
impl NpyElement for $ty {
fn write_slice<W: Write>(data: &[$ty], writer: &mut W) -> FerrayResult<()> {
for &val in data {
writer.write_all(&val.to_ne_bytes())?;
}
Ok(())
}
fn read_vec<R: Read>(
reader: &mut R,
count: usize,
endian: Endianness,
) -> FerrayResult<Vec<$ty>> {
let mut result = Vec::with_capacity(count);
let mut buf = [0u8; $size];
let needs_swap = endian.needs_swap();
for _ in 0..count {
reader.read_exact(&mut buf)?;
let val = if needs_swap {
<$ty>::from_ne_bytes({
buf.reverse();
buf
})
} else {
<$ty>::from_ne_bytes(buf)
};
result.push(val);
}
Ok(result)
}
}
};
}
impl private::NpySealed for bool {}
impl NpyElement for bool {
fn write_slice<W: Write>(data: &[bool], writer: &mut W) -> FerrayResult<()> {
for &val in data {
writer.write_all(&[val as u8])?;
}
Ok(())
}
fn read_vec<R: Read>(
reader: &mut R,
count: usize,
_endian: Endianness,
) -> FerrayResult<Vec<bool>> {
let mut result = Vec::with_capacity(count);
let mut buf = [0u8; 1];
for _ in 0..count {
reader.read_exact(&mut buf)?;
result.push(buf[0] != 0);
}
Ok(result)
}
}
impl_npy_element!(u8, 1);
impl_npy_element!(u16, 2);
impl_npy_element!(u32, 4);
impl_npy_element!(u64, 8);
impl_npy_element!(u128, 16);
impl_npy_element!(i8, 1);
impl_npy_element!(i16, 2);
impl_npy_element!(i32, 4);
impl_npy_element!(i64, 8);
impl_npy_element!(i128, 16);
impl_npy_element!(f32, 4);
impl_npy_element!(f64, 8);
#[cfg(test)]
mod tests {
use super::*;
use ferray_core::dimension::{Ix1, Ix2};
use std::io::Cursor;
fn test_dir() -> std::path::PathBuf {
let dir = std::env::temp_dir().join(format!("ferray_io_test_{}", std::process::id()));
let _ = std::fs::create_dir_all(&dir);
dir
}
fn test_file(name: &str) -> std::path::PathBuf {
let dir = test_dir();
dir.join(name)
}
#[test]
fn roundtrip_f64_1d() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
let path = test_file("rt_f64_1d.npy");
save(&path, &arr).unwrap();
let loaded: Array<f64, Ix1> = load(&path).unwrap();
assert_eq!(loaded.shape(), &[5]);
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn roundtrip_f32_2d() {
let data = vec![1.0_f32, 2.0, 3.0, 4.0, 5.0, 6.0];
let arr = Array::<f32, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
let path = test_file("rt_f32_2d.npy");
save(&path, &arr).unwrap();
let loaded: Array<f32, Ix2> = load(&path).unwrap();
assert_eq!(loaded.shape(), &[2, 3]);
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn roundtrip_i32() {
let data = vec![10i32, 20, 30, 40];
let arr = Array::<i32, Ix1>::from_vec(Ix1::new([4]), data.clone()).unwrap();
let path = test_file("rt_i32.npy");
save(&path, &arr).unwrap();
let loaded: Array<i32, Ix1> = load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn roundtrip_i64() {
let data = vec![100i64, 200, 300];
let arr = Array::<i64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
let path = test_file("rt_i64.npy");
save(&path, &arr).unwrap();
let loaded: Array<i64, Ix1> = load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn roundtrip_u8() {
let data = vec![0u8, 128, 255];
let arr = Array::<u8, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
let path = test_file("rt_u8.npy");
save(&path, &arr).unwrap();
let loaded: Array<u8, Ix1> = load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn roundtrip_bool() {
let data = vec![true, false, true, true, false];
let arr = Array::<bool, Ix1>::from_vec(Ix1::new([5]), data.clone()).unwrap();
let path = test_file("rt_bool.npy");
save(&path, &arr).unwrap();
let loaded: Array<bool, Ix1> = load(&path).unwrap();
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn roundtrip_in_memory() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data.clone()).unwrap();
let mut buf = Vec::new();
save_to_writer(&mut buf, &arr).unwrap();
let mut cursor = Cursor::new(buf);
let loaded: Array<f64, Ix1> = load_from_reader(&mut cursor).unwrap();
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
}
#[test]
fn load_dynamic_f64() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
let path = test_file("dyn_f64.npy");
save(&path, &arr).unwrap();
let dyn_arr = load_dynamic(&path).unwrap();
assert_eq!(dyn_arr.dtype(), DType::F64);
assert_eq!(dyn_arr.shape(), &[3]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn load_wrong_dtype_error() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
let path = test_file("wrong_dtype.npy");
save(&path, &arr).unwrap();
let result = load::<f32, Ix1, _>(&path);
assert!(result.is_err());
let _ = std::fs::remove_file(&path);
}
#[test]
fn load_wrong_ndim_error() {
let data = vec![1.0_f64, 2.0, 3.0];
let arr = Array::<f64, Ix1>::from_vec(Ix1::new([3]), data).unwrap();
let path = test_file("wrong_ndim.npy");
save(&path, &arr).unwrap();
let result = load::<f64, Ix2, _>(&path);
assert!(result.is_err());
let _ = std::fs::remove_file(&path);
}
#[test]
fn roundtrip_dynamic() {
let data = vec![10i32, 20, 30];
let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), data.clone()).unwrap();
let dyn_arr = DynArray::I32(arr);
let path = test_file("rt_dynamic.npy");
save_dynamic(&path, &dyn_arr).unwrap();
let loaded = load_dynamic(&path).unwrap();
assert_eq!(loaded.dtype(), DType::I32);
assert_eq!(loaded.shape(), &[3]);
let loaded_arr = loaded.try_into_i32().unwrap();
assert_eq!(loaded_arr.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
#[test]
fn load_dynamic_ixdyn() {
let data = vec![1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0];
let arr = Array::<f64, Ix2>::from_vec(Ix2::new([2, 3]), data.clone()).unwrap();
let path = test_file("dyn_ixdyn.npy");
save(&path, &arr).unwrap();
let loaded: Array<f64, IxDyn> = load(&path).unwrap();
assert_eq!(loaded.shape(), &[2, 3]);
assert_eq!(loaded.as_slice().unwrap(), &data[..]);
let _ = std::fs::remove_file(&path);
}
}