1use num_complex::Complex;
4
5use crate::array::owned::Array;
6use crate::dimension::IxDyn;
7use crate::dtype::DType;
8use crate::error::{FerrayError, FerrayResult};
9
10#[derive(Debug, Clone)]
19#[non_exhaustive]
20pub enum DynArray {
21 Bool(Array<bool, IxDyn>),
23 U8(Array<u8, IxDyn>),
25 U16(Array<u16, IxDyn>),
27 U32(Array<u32, IxDyn>),
29 U64(Array<u64, IxDyn>),
31 U128(Array<u128, IxDyn>),
33 I8(Array<i8, IxDyn>),
35 I16(Array<i16, IxDyn>),
37 I32(Array<i32, IxDyn>),
39 I64(Array<i64, IxDyn>),
41 I128(Array<i128, IxDyn>),
43 F32(Array<f32, IxDyn>),
45 F64(Array<f64, IxDyn>),
47 Complex32(Array<Complex<f32>, IxDyn>),
49 Complex64(Array<Complex<f64>, IxDyn>),
51 #[cfg(feature = "f16")]
53 F16(Array<half::f16, IxDyn>),
54 #[cfg(feature = "bf16")]
56 BF16(Array<half::bf16, IxDyn>),
57}
58
59impl DynArray {
60 pub fn dtype(&self) -> DType {
62 match self {
63 Self::Bool(_) => DType::Bool,
64 Self::U8(_) => DType::U8,
65 Self::U16(_) => DType::U16,
66 Self::U32(_) => DType::U32,
67 Self::U64(_) => DType::U64,
68 Self::U128(_) => DType::U128,
69 Self::I8(_) => DType::I8,
70 Self::I16(_) => DType::I16,
71 Self::I32(_) => DType::I32,
72 Self::I64(_) => DType::I64,
73 Self::I128(_) => DType::I128,
74 Self::F32(_) => DType::F32,
75 Self::F64(_) => DType::F64,
76 Self::Complex32(_) => DType::Complex32,
77 Self::Complex64(_) => DType::Complex64,
78 #[cfg(feature = "f16")]
79 Self::F16(_) => DType::F16,
80 #[cfg(feature = "bf16")]
81 Self::BF16(_) => DType::BF16,
82 }
83 }
84
85 pub fn shape(&self) -> &[usize] {
87 match self {
88 Self::Bool(a) => a.shape(),
89 Self::U8(a) => a.shape(),
90 Self::U16(a) => a.shape(),
91 Self::U32(a) => a.shape(),
92 Self::U64(a) => a.shape(),
93 Self::U128(a) => a.shape(),
94 Self::I8(a) => a.shape(),
95 Self::I16(a) => a.shape(),
96 Self::I32(a) => a.shape(),
97 Self::I64(a) => a.shape(),
98 Self::I128(a) => a.shape(),
99 Self::F32(a) => a.shape(),
100 Self::F64(a) => a.shape(),
101 Self::Complex32(a) => a.shape(),
102 Self::Complex64(a) => a.shape(),
103 #[cfg(feature = "f16")]
104 Self::F16(a) => a.shape(),
105 #[cfg(feature = "bf16")]
106 Self::BF16(a) => a.shape(),
107 }
108 }
109
110 pub fn ndim(&self) -> usize {
112 self.shape().len()
113 }
114
115 pub fn size(&self) -> usize {
117 self.shape().iter().product()
118 }
119
120 pub fn is_empty(&self) -> bool {
122 self.size() == 0
123 }
124
125 pub fn itemsize(&self) -> usize {
127 self.dtype().size_of()
128 }
129
130 pub fn nbytes(&self) -> usize {
132 self.size() * self.itemsize()
133 }
134
135 pub fn try_into_f64(self) -> FerrayResult<Array<f64, IxDyn>> {
140 match self {
141 Self::F64(a) => Ok(a),
142 other => Err(FerrayError::invalid_dtype(format!(
143 "expected float64, got {}",
144 other.dtype()
145 ))),
146 }
147 }
148
149 pub fn try_into_f32(self) -> FerrayResult<Array<f32, IxDyn>> {
151 match self {
152 Self::F32(a) => Ok(a),
153 other => Err(FerrayError::invalid_dtype(format!(
154 "expected float32, got {}",
155 other.dtype()
156 ))),
157 }
158 }
159
160 pub fn try_into_i64(self) -> FerrayResult<Array<i64, IxDyn>> {
162 match self {
163 Self::I64(a) => Ok(a),
164 other => Err(FerrayError::invalid_dtype(format!(
165 "expected int64, got {}",
166 other.dtype()
167 ))),
168 }
169 }
170
171 pub fn try_into_i32(self) -> FerrayResult<Array<i32, IxDyn>> {
173 match self {
174 Self::I32(a) => Ok(a),
175 other => Err(FerrayError::invalid_dtype(format!(
176 "expected int32, got {}",
177 other.dtype()
178 ))),
179 }
180 }
181
182 pub fn try_into_bool(self) -> FerrayResult<Array<bool, IxDyn>> {
184 match self {
185 Self::Bool(a) => Ok(a),
186 other => Err(FerrayError::invalid_dtype(format!(
187 "expected bool, got {}",
188 other.dtype()
189 ))),
190 }
191 }
192
193 pub fn zeros(dtype: DType, shape: &[usize]) -> FerrayResult<Self> {
195 let dim = IxDyn::new(shape);
196 Ok(match dtype {
197 DType::Bool => Self::Bool(Array::zeros(dim)?),
198 DType::U8 => Self::U8(Array::zeros(dim)?),
199 DType::U16 => Self::U16(Array::zeros(dim)?),
200 DType::U32 => Self::U32(Array::zeros(dim)?),
201 DType::U64 => Self::U64(Array::zeros(dim)?),
202 DType::U128 => Self::U128(Array::zeros(dim)?),
203 DType::I8 => Self::I8(Array::zeros(dim)?),
204 DType::I16 => Self::I16(Array::zeros(dim)?),
205 DType::I32 => Self::I32(Array::zeros(dim)?),
206 DType::I64 => Self::I64(Array::zeros(dim)?),
207 DType::I128 => Self::I128(Array::zeros(dim)?),
208 DType::F32 => Self::F32(Array::zeros(dim)?),
209 DType::F64 => Self::F64(Array::zeros(dim)?),
210 DType::Complex32 => Self::Complex32(Array::zeros(dim)?),
211 DType::Complex64 => Self::Complex64(Array::zeros(dim)?),
212 #[cfg(feature = "f16")]
213 DType::F16 => Self::F16(Array::zeros(dim)?),
214 #[cfg(feature = "bf16")]
215 DType::BF16 => Self::BF16(Array::zeros(dim)?),
216 })
217 }
218}
219
220impl std::fmt::Display for DynArray {
221 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222 match self {
223 Self::Bool(a) => write!(f, "{a}"),
224 Self::U8(a) => write!(f, "{a}"),
225 Self::U16(a) => write!(f, "{a}"),
226 Self::U32(a) => write!(f, "{a}"),
227 Self::U64(a) => write!(f, "{a}"),
228 Self::U128(a) => write!(f, "{a}"),
229 Self::I8(a) => write!(f, "{a}"),
230 Self::I16(a) => write!(f, "{a}"),
231 Self::I32(a) => write!(f, "{a}"),
232 Self::I64(a) => write!(f, "{a}"),
233 Self::I128(a) => write!(f, "{a}"),
234 Self::F32(a) => write!(f, "{a}"),
235 Self::F64(a) => write!(f, "{a}"),
236 Self::Complex32(a) => write!(f, "{a}"),
237 Self::Complex64(a) => write!(f, "{a}"),
238 #[cfg(feature = "f16")]
239 Self::F16(a) => write!(f, "{a}"),
240 #[cfg(feature = "bf16")]
241 Self::BF16(a) => write!(f, "{a}"),
242 }
243 }
244}
245
246macro_rules! impl_from_array_dyn {
248 ($ty:ty, $variant:ident) => {
249 impl From<Array<$ty, IxDyn>> for DynArray {
250 fn from(a: Array<$ty, IxDyn>) -> Self {
251 Self::$variant(a)
252 }
253 }
254 };
255}
256
257impl_from_array_dyn!(bool, Bool);
258impl_from_array_dyn!(u8, U8);
259impl_from_array_dyn!(u16, U16);
260impl_from_array_dyn!(u32, U32);
261impl_from_array_dyn!(u64, U64);
262impl_from_array_dyn!(u128, U128);
263impl_from_array_dyn!(i8, I8);
264impl_from_array_dyn!(i16, I16);
265impl_from_array_dyn!(i32, I32);
266impl_from_array_dyn!(i64, I64);
267impl_from_array_dyn!(i128, I128);
268impl_from_array_dyn!(f32, F32);
269impl_from_array_dyn!(f64, F64);
270impl_from_array_dyn!(Complex<f32>, Complex32);
271impl_from_array_dyn!(Complex<f64>, Complex64);
272#[cfg(feature = "f16")]
273impl_from_array_dyn!(half::f16, F16);
274#[cfg(feature = "bf16")]
275impl_from_array_dyn!(half::bf16, BF16);
276
277#[cfg(test)]
278mod tests {
279 use super::*;
280
281 #[test]
282 fn dynarray_zeros_f64() {
283 let da = DynArray::zeros(DType::F64, &[2, 3]).unwrap();
284 assert_eq!(da.dtype(), DType::F64);
285 assert_eq!(da.shape(), &[2, 3]);
286 assert_eq!(da.ndim(), 2);
287 assert_eq!(da.size(), 6);
288 assert_eq!(da.itemsize(), 8);
289 assert_eq!(da.nbytes(), 48);
290 }
291
292 #[test]
293 fn dynarray_zeros_i32() {
294 let da = DynArray::zeros(DType::I32, &[4]).unwrap();
295 assert_eq!(da.dtype(), DType::I32);
296 assert_eq!(da.shape(), &[4]);
297 }
298
299 #[test]
300 fn dynarray_try_into_f64() {
301 let da = DynArray::zeros(DType::F64, &[3]).unwrap();
302 let arr = da.try_into_f64().unwrap();
303 assert_eq!(arr.shape(), &[3]);
304 }
305
306 #[test]
307 fn dynarray_try_into_wrong_type() {
308 let da = DynArray::zeros(DType::I32, &[3]).unwrap();
309 assert!(da.try_into_f64().is_err());
310 }
311
312 #[test]
313 fn dynarray_from_typed() {
314 let arr = Array::<f64, IxDyn>::zeros(IxDyn::new(&[2, 2])).unwrap();
315 let da: DynArray = arr.into();
316 assert_eq!(da.dtype(), DType::F64);
317 }
318
319 #[test]
320 fn dynarray_display() {
321 let da = DynArray::zeros(DType::I32, &[3]).unwrap();
322 let s = format!("{da}");
323 assert!(s.contains("[0, 0, 0]"));
324 }
325
326 #[test]
327 fn dynarray_is_empty() {
328 let da = DynArray::zeros(DType::F32, &[0]).unwrap();
329 assert!(da.is_empty());
330 }
331}