1use crate::prelude::*;
2#[cfg(feature = "npy")]
4pub mod npy {
5 use super::*;
6 pub struct Npy<'a> {
8 aligned_bytes: &'a [u8],
9 nrows: usize,
10 ncols: usize,
11 prefix_len: usize,
12 dtype: NpyDType,
13 fortran_order: bool,
14 }
15 #[derive(Debug, Copy, Clone, PartialEq, Eq)]
17 pub enum NpyDType {
18 F32,
20 F64,
22 C32,
24 C64,
26 Other,
28 }
29 pub trait FromNpy: bytemuck::Pod {
31 const DTYPE: NpyDType;
33 }
34 impl FromNpy for f32 {
35 const DTYPE: NpyDType = NpyDType::F32;
36 }
37 impl FromNpy for f64 {
38 const DTYPE: NpyDType = NpyDType::F64;
39 }
40 impl FromNpy for c32 {
41 const DTYPE: NpyDType = NpyDType::C32;
42 }
43 impl FromNpy for c64 {
44 const DTYPE: NpyDType = NpyDType::C64;
45 }
46 impl<'a> Npy<'a> {
47 fn parse_npyz(
48 data: &[u8],
49 npyz: npyz::NpyFile<&[u8]>,
50 ) -> Result<(NpyDType, usize, usize, usize, bool), std::io::Error> {
51 let ver_major = data[6] - b'\x00';
52 let length = if ver_major <= 1 {
53 2usize
54 } else if ver_major <= 3 {
55 4usize
56 } else {
57 return Err(std::io::Error::new(
58 std::io::ErrorKind::Other,
59 "unsupported version",
60 ));
61 };
62 let header_len = if length == 2 {
63 u16::from_le_bytes(data[8..10].try_into().unwrap()) as usize
64 } else {
65 u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize
66 };
67 let dtype = || -> NpyDType {
68 match npyz.dtype() {
69 npyz::DType::Plain(str) => {
70 let is_complex = match str.type_char() {
71 npyz::TypeChar::Float => false,
72 npyz::TypeChar::Complex => true,
73 _ => return NpyDType::Other,
74 };
75 let byte_size = str.size_field();
76 if byte_size == 8 && is_complex {
77 NpyDType::C32
78 } else if byte_size == 16 && is_complex {
79 NpyDType::C64
80 } else if byte_size == 4 && !is_complex {
81 NpyDType::F32
82 } else if byte_size == 8 && !is_complex {
83 NpyDType::F64
84 } else {
85 NpyDType::Other
86 }
87 },
88 _ => NpyDType::Other,
89 }
90 };
91 let dtype = dtype();
92 let order = npyz.header().order();
93 let shape = npyz.shape();
94 let nrows = shape.get(0).copied().unwrap_or(1) as usize;
95 let ncols = shape.get(1).copied().unwrap_or(1) as usize;
96 let prefix_len = 8 + length + header_len;
97 let fortran_order = order == npyz::Order::Fortran;
98 Ok((dtype, nrows, ncols, prefix_len, fortran_order))
99 }
100
101 #[inline]
103 pub fn new(data: &'a [u8]) -> Result<Self, std::io::Error> {
104 let npyz = npyz::NpyFile::new(data)?;
105 let (dtype, nrows, ncols, prefix_len, fortran_order) =
106 Self::parse_npyz(data, npyz)?;
107 Ok(Self {
108 aligned_bytes: data,
109 prefix_len,
110 nrows,
111 ncols,
112 dtype,
113 fortran_order,
114 })
115 }
116
117 #[inline]
119 pub fn dtype(&self) -> NpyDType {
120 self.dtype
121 }
122
123 #[inline]
126 pub fn is_aligned(&self) -> bool {
127 self.aligned_bytes.as_ptr().align_offset(64) == 0
128 }
129
130 #[inline]
133 pub fn as_aligned_ref<T: FromNpy>(&self) -> MatRef<'_, T> {
134 assert!(self.is_aligned());
135 assert!(self.dtype == T::DTYPE);
136 if self.fortran_order {
137 MatRef::from_column_major_slice(
138 bytemuck::cast_slice(
139 &self.aligned_bytes[self.prefix_len..],
140 ),
141 self.nrows,
142 self.ncols,
143 )
144 } else {
145 MatRef::from_row_major_slice(
146 bytemuck::cast_slice(
147 &self.aligned_bytes[self.prefix_len..],
148 ),
149 self.nrows,
150 self.ncols,
151 )
152 }
153 }
154
155 #[inline]
158 pub fn to_mat<T: FromNpy>(&self) -> Mat<T> {
159 assert!(self.dtype == T::DTYPE);
160 let mut mat = Mat::<T>::with_capacity(self.nrows, self.ncols);
161 unsafe { mat.set_dims(self.nrows, self.ncols) };
162 let data = &self.aligned_bytes[self.prefix_len..];
163 if self.fortran_order {
164 for j in 0..self.ncols {
165 bytemuck::cast_slice_mut(mat.col_as_slice_mut(j))
166 .copy_from_slice(
167 &data[j * self.nrows * core::mem::size_of::<T>()..]
168 [..self.nrows * core::mem::size_of::<T>()],
169 )
170 }
171 } else {
172 for j in 0..self.ncols {
173 for i in 0..self.nrows {
174 mat[(i, j)] = bytemuck::cast_slice::<u8, T>(
175 &data[(i * self.ncols + j)
176 * core::mem::size_of::<T>()..][..core::mem::size_of::<T>()],
177 )[0];
178 }
179 }
180 };
181 mat
182 }
183 }
184}