#[allow(unused_imports)]
use super::*;
#[allow(unused_imports)]
use crate::assert;
#[allow(unused_imports)]
use complex_native::{c32, c64};
#[cfg(feature = "npy")]
#[cfg_attr(docsrs, doc(cfg(feature = "npy")))]
pub struct Npy<'a> {
aligned_bytes: &'a [u8],
nrows: usize,
ncols: usize,
prefix_len: usize,
dtype: NpyDType,
fortran_order: bool,
}
#[cfg(feature = "npy")]
#[cfg_attr(docsrs, doc(cfg(feature = "npy")))]
#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum NpyDType {
F32,
F64,
C32,
C64,
Other,
}
#[cfg(feature = "npy")]
#[cfg_attr(docsrs, doc(cfg(feature = "npy")))]
pub trait FromNpy: faer_entity::SimpleEntity {
const DTYPE: NpyDType;
}
#[cfg(feature = "npy")]
#[cfg_attr(docsrs, doc(cfg(feature = "npy")))]
impl FromNpy for f32 {
const DTYPE: NpyDType = NpyDType::F32;
}
#[cfg(feature = "npy")]
#[cfg_attr(docsrs, doc(cfg(feature = "npy")))]
impl FromNpy for f64 {
const DTYPE: NpyDType = NpyDType::F64;
}
#[cfg(feature = "npy")]
#[cfg_attr(docsrs, doc(cfg(feature = "npy")))]
impl FromNpy for c32 {
const DTYPE: NpyDType = NpyDType::C32;
}
#[cfg(feature = "npy")]
#[cfg_attr(docsrs, doc(cfg(feature = "npy")))]
impl FromNpy for c64 {
const DTYPE: NpyDType = NpyDType::C64;
}
#[cfg(feature = "npy")]
#[cfg_attr(docsrs, doc(cfg(feature = "npy")))]
impl<'a> Npy<'a> {
fn parse_npyz(
data: &[u8],
npyz: npyz::NpyFile<&[u8]>,
) -> Result<(NpyDType, usize, usize, usize, bool), std::io::Error> {
let ver_major = data[6] - b'\x00';
let length = if ver_major <= 1 {
2usize
} else if ver_major <= 3 {
4usize
} else {
return Err(std::io::Error::new(
std::io::ErrorKind::Other,
"unsupported version",
));
};
let header_len = if length == 2 {
u16::from_le_bytes(data[8..10].try_into().unwrap()) as usize
} else {
u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize
};
let dtype = || -> NpyDType {
match npyz.dtype() {
npyz::DType::Plain(str) => {
let is_complex = match str.type_char() {
npyz::TypeChar::Float => false,
npyz::TypeChar::Complex => true,
_ => return NpyDType::Other,
};
let byte_size = str.size_field();
if byte_size == 8 && is_complex {
NpyDType::C32
} else if byte_size == 16 && is_complex {
NpyDType::C64
} else if byte_size == 4 && !is_complex {
NpyDType::F32
} else if byte_size == 16 && !is_complex {
NpyDType::F64
} else {
NpyDType::Other
}
}
_ => NpyDType::Other,
}
};
let dtype = dtype();
let order = npyz.header().order();
let shape = npyz.shape();
let nrows = shape.get(0).copied().unwrap_or(1) as usize;
let ncols = shape.get(1).copied().unwrap_or(1) as usize;
let prefix_len = 8 + length + header_len;
let fortran_order = order == npyz::Order::Fortran;
Ok((dtype, nrows, ncols, prefix_len, fortran_order))
}
#[inline]
pub fn new(data: &'a [u8]) -> Result<Self, std::io::Error> {
let npyz = npyz::NpyFile::new(data)?;
let (dtype, nrows, ncols, prefix_len, fortran_order) = Self::parse_npyz(data, npyz)?;
Ok(Self {
aligned_bytes: data,
prefix_len,
nrows,
ncols,
dtype,
fortran_order,
})
}
#[inline]
pub fn dtype(&self) -> NpyDType {
self.dtype
}
#[inline]
pub fn is_aligned(&self) -> bool {
self.aligned_bytes.as_ptr().align_offset(64) == 0
}
#[inline]
pub fn as_aligned_ref<E: FromNpy>(&self) -> MatRef<'_, E> {
assert!(self.is_aligned());
assert!(self.dtype == E::DTYPE);
if self.fortran_order {
crate::mat::from_column_major_slice(
bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]),
self.nrows,
self.ncols,
)
} else {
crate::mat::from_row_major_slice(
bytemuck::cast_slice(&self.aligned_bytes[self.prefix_len..]),
self.nrows,
self.ncols,
)
}
}
#[inline]
pub fn to_mat<E: FromNpy>(&self) -> Mat<E> {
assert!(self.dtype == E::DTYPE);
let mut mat = Mat::<E>::with_capacity(self.nrows, self.ncols);
unsafe { mat.set_dims(self.nrows, self.ncols) };
let data = &self.aligned_bytes[self.prefix_len..];
if self.fortran_order {
for j in 0..self.ncols {
bytemuck::cast_slice_mut(mat.col_as_slice_mut(j)).copy_from_slice(
&data[j * self.nrows * core::mem::size_of::<E>()..]
[..self.nrows * core::mem::size_of::<E>()],
)
}
} else {
for j in 0..self.ncols {
for i in 0..self.nrows {
bytemuck::cast_slice_mut(&mut mat.col_as_slice_mut(j)[i..i + 1])
.copy_from_slice(
&data[(i * self.ncols + j) * core::mem::size_of::<E>()..]
[..core::mem::size_of::<E>()],
)
}
}
};
mat
}
}