Skip to main content

faer/
io.rs

1use crate::prelude::*;
2/// npy format conversions
3#[cfg(feature = "npy")]
4pub mod npy {
5	use super::*;
6	/// memory view over a buffer in `npy` format
7	pub struct Npy<'a> {
8		aligned_bytes: &'a [u8],
9		nrows: usize,
10		ncols: usize,
11		prefix_len: usize,
12		dtype: NpyDType,
13		fortran_order: bool,
14	}
15	/// data type of an `npy` buffer
16	#[derive(Debug, Copy, Clone, PartialEq, Eq)]
17	pub enum NpyDType {
18		/// 32-bit floating point
19		F32,
20		/// 64-bit floating point
21		F64,
22		/// 32-bit complex floating point
23		C32,
24		/// 64-bit complex floating point
25		C64,
26		/// unknown type
27		Other,
28	}
29	/// trait implemented for native types that can be read from a `npy` buffer
30	pub trait FromNpy: bytemuck::Pod {
31		/// data type of the buffer data
32		const DTYPE: NpyDType;
33	}
34	impl FromNpy for f32 {
35		const DTYPE: NpyDType = NpyDType::F32;
36	}
37	impl FromNpy for f64 {
38		const DTYPE: NpyDType = NpyDType::F64;
39	}
40	impl FromNpy for c32 {
41		const DTYPE: NpyDType = NpyDType::C32;
42	}
43	impl FromNpy for c64 {
44		const DTYPE: NpyDType = NpyDType::C64;
45	}
46	impl<'a> Npy<'a> {
47		fn parse_npyz(
48			data: &[u8],
49			npyz: npyz::NpyFile<&[u8]>,
50		) -> Result<(NpyDType, usize, usize, usize, bool), std::io::Error> {
51			let ver_major = data[6] - b'\x00';
52			let length = if ver_major <= 1 {
53				2usize
54			} else if ver_major <= 3 {
55				4usize
56			} else {
57				return Err(std::io::Error::new(
58					std::io::ErrorKind::Other,
59					"unsupported version",
60				));
61			};
62			let header_len = if length == 2 {
63				u16::from_le_bytes(data[8..10].try_into().unwrap()) as usize
64			} else {
65				u32::from_le_bytes(data[8..12].try_into().unwrap()) as usize
66			};
67			let dtype = || -> NpyDType {
68				match npyz.dtype() {
69					npyz::DType::Plain(str) => {
70						let is_complex = match str.type_char() {
71							npyz::TypeChar::Float => false,
72							npyz::TypeChar::Complex => true,
73							_ => return NpyDType::Other,
74						};
75						let byte_size = str.size_field();
76						if byte_size == 8 && is_complex {
77							NpyDType::C32
78						} else if byte_size == 16 && is_complex {
79							NpyDType::C64
80						} else if byte_size == 4 && !is_complex {
81							NpyDType::F32
82						} else if byte_size == 8 && !is_complex {
83							NpyDType::F64
84						} else {
85							NpyDType::Other
86						}
87					},
88					_ => NpyDType::Other,
89				}
90			};
91			let dtype = dtype();
92			let order = npyz.header().order();
93			let shape = npyz.shape();
94			let nrows = shape.get(0).copied().unwrap_or(1) as usize;
95			let ncols = shape.get(1).copied().unwrap_or(1) as usize;
96			let prefix_len = 8 + length + header_len;
97			let fortran_order = order == npyz::Order::Fortran;
98			Ok((dtype, nrows, ncols, prefix_len, fortran_order))
99		}
100
101		/// parse a npy file from a memory buffer
102		#[inline]
103		pub fn new(data: &'a [u8]) -> Result<Self, std::io::Error> {
104			let npyz = npyz::NpyFile::new(data)?;
105			let (dtype, nrows, ncols, prefix_len, fortran_order) =
106				Self::parse_npyz(data, npyz)?;
107			Ok(Self {
108				aligned_bytes: data,
109				prefix_len,
110				nrows,
111				ncols,
112				dtype,
113				fortran_order,
114			})
115		}
116
117		/// returns the data type of the memory buffer
118		#[inline]
119		pub fn dtype(&self) -> NpyDType {
120			self.dtype
121		}
122
123		/// checks if the memory buffer is aligned, in which case the data can
124		/// be referenced in-place
125		#[inline]
126		pub fn is_aligned(&self) -> bool {
127			self.aligned_bytes.as_ptr().align_offset(64) == 0
128		}
129
130		/// if the memory buffer is aligned, and the provided type matches the
131		/// one stored in the buffer, returns a matrix view over the data
132		#[inline]
133		pub fn as_aligned_ref<T: FromNpy>(&self) -> MatRef<'_, T> {
134			assert!(self.is_aligned());
135			assert!(self.dtype == T::DTYPE);
136			if self.fortran_order {
137				MatRef::from_column_major_slice(
138					bytemuck::cast_slice(
139						&self.aligned_bytes[self.prefix_len..],
140					),
141					self.nrows,
142					self.ncols,
143				)
144			} else {
145				MatRef::from_row_major_slice(
146					bytemuck::cast_slice(
147						&self.aligned_bytes[self.prefix_len..],
148					),
149					self.nrows,
150					self.ncols,
151				)
152			}
153		}
154
155		/// if the provided type matches the one stored in the buffer, returns a
156		/// matrix containing the data
157		#[inline]
158		pub fn to_mat<T: FromNpy>(&self) -> Mat<T> {
159			assert!(self.dtype == T::DTYPE);
160			let mut mat = Mat::<T>::with_capacity(self.nrows, self.ncols);
161			unsafe { mat.set_dims(self.nrows, self.ncols) };
162			let data = &self.aligned_bytes[self.prefix_len..];
163			if self.fortran_order {
164				for j in 0..self.ncols {
165					bytemuck::cast_slice_mut(mat.col_as_slice_mut(j))
166						.copy_from_slice(
167							&data[j * self.nrows * core::mem::size_of::<T>()..]
168								[..self.nrows * core::mem::size_of::<T>()],
169						)
170				}
171			} else {
172				for j in 0..self.ncols {
173					for i in 0..self.nrows {
174						mat[(i, j)] = bytemuck::cast_slice::<u8, T>(
175							&data[(i * self.ncols + j)
176								* core::mem::size_of::<T>()..][..core::mem::size_of::<T>()],
177						)[0];
178					}
179				}
180			};
181			mat
182		}
183	}
184}