1use num_complex::Complex;
4
5use crate::array::owned::Array;
6use crate::dimension::IxDyn;
7use crate::dtype::casting::CastKind;
8use crate::dtype::{DType, I256};
9use crate::error::{FerrayError, FerrayResult};
10
11#[derive(Debug, Clone)]
20#[non_exhaustive]
21pub enum DynArray {
22 Bool(Array<bool, IxDyn>),
24 U8(Array<u8, IxDyn>),
26 U16(Array<u16, IxDyn>),
28 U32(Array<u32, IxDyn>),
30 U64(Array<u64, IxDyn>),
32 U128(Array<u128, IxDyn>),
34 I8(Array<i8, IxDyn>),
36 I16(Array<i16, IxDyn>),
38 I32(Array<i32, IxDyn>),
40 I64(Array<i64, IxDyn>),
42 I128(Array<i128, IxDyn>),
44 I256(Array<I256, IxDyn>),
48 F32(Array<f32, IxDyn>),
50 F64(Array<f64, IxDyn>),
52 Complex32(Array<Complex<f32>, IxDyn>),
54 Complex64(Array<Complex<f64>, IxDyn>),
56 #[cfg(feature = "f16")]
58 F16(Array<half::f16, IxDyn>),
59 #[cfg(feature = "bf16")]
61 BF16(Array<half::bf16, IxDyn>),
62}
63
64macro_rules! dispatch {
69 ($value:expr, $binding:ident => $expr:expr) => {
70 match $value {
71 Self::Bool($binding) => $expr,
72 Self::U8($binding) => $expr,
73 Self::U16($binding) => $expr,
74 Self::U32($binding) => $expr,
75 Self::U64($binding) => $expr,
76 Self::U128($binding) => $expr,
77 Self::I8($binding) => $expr,
78 Self::I16($binding) => $expr,
79 Self::I32($binding) => $expr,
80 Self::I64($binding) => $expr,
81 Self::I128($binding) => $expr,
82 Self::I256($binding) => $expr,
83 Self::F32($binding) => $expr,
84 Self::F64($binding) => $expr,
85 Self::Complex32($binding) => $expr,
86 Self::Complex64($binding) => $expr,
87 #[cfg(feature = "f16")]
88 Self::F16($binding) => $expr,
89 #[cfg(feature = "bf16")]
90 Self::BF16($binding) => $expr,
91 }
92 };
93}
94
95impl DynArray {
96 #[must_use]
98 pub const fn dtype(&self) -> DType {
99 match self {
100 Self::Bool(_) => DType::Bool,
101 Self::U8(_) => DType::U8,
102 Self::U16(_) => DType::U16,
103 Self::U32(_) => DType::U32,
104 Self::U64(_) => DType::U64,
105 Self::U128(_) => DType::U128,
106 Self::I8(_) => DType::I8,
107 Self::I16(_) => DType::I16,
108 Self::I32(_) => DType::I32,
109 Self::I64(_) => DType::I64,
110 Self::I128(_) => DType::I128,
111 Self::I256(_) => DType::I256,
112 Self::F32(_) => DType::F32,
113 Self::F64(_) => DType::F64,
114 Self::Complex32(_) => DType::Complex32,
115 Self::Complex64(_) => DType::Complex64,
116 #[cfg(feature = "f16")]
117 Self::F16(_) => DType::F16,
118 #[cfg(feature = "bf16")]
119 Self::BF16(_) => DType::BF16,
120 }
121 }
122
123 #[must_use]
125 pub fn shape(&self) -> &[usize] {
126 dispatch!(self, a => a.shape())
127 }
128
129 #[must_use]
131 pub fn ndim(&self) -> usize {
132 self.shape().len()
133 }
134
135 #[must_use]
137 pub fn size(&self) -> usize {
138 self.shape().iter().product()
139 }
140
141 #[must_use]
143 pub fn is_empty(&self) -> bool {
144 self.size() == 0
145 }
146
147 #[must_use]
149 pub const fn itemsize(&self) -> usize {
150 self.dtype().size_of()
151 }
152
153 #[must_use]
155 pub fn nbytes(&self) -> usize {
156 self.size() * self.itemsize()
157 }
158
159 pub fn try_into_f64(self) -> FerrayResult<Array<f64, IxDyn>> {
164 match self {
165 Self::F64(a) => Ok(a),
166 other => Err(FerrayError::invalid_dtype(format!(
167 "expected float64, got {}",
168 other.dtype()
169 ))),
170 }
171 }
172
173 pub fn try_into_f32(self) -> FerrayResult<Array<f32, IxDyn>> {
175 match self {
176 Self::F32(a) => Ok(a),
177 other => Err(FerrayError::invalid_dtype(format!(
178 "expected float32, got {}",
179 other.dtype()
180 ))),
181 }
182 }
183
184 pub fn try_into_i64(self) -> FerrayResult<Array<i64, IxDyn>> {
186 match self {
187 Self::I64(a) => Ok(a),
188 other => Err(FerrayError::invalid_dtype(format!(
189 "expected int64, got {}",
190 other.dtype()
191 ))),
192 }
193 }
194
195 pub fn try_into_i32(self) -> FerrayResult<Array<i32, IxDyn>> {
197 match self {
198 Self::I32(a) => Ok(a),
199 other => Err(FerrayError::invalid_dtype(format!(
200 "expected int32, got {}",
201 other.dtype()
202 ))),
203 }
204 }
205
206 pub fn try_into_bool(self) -> FerrayResult<Array<bool, IxDyn>> {
208 match self {
209 Self::Bool(a) => Ok(a),
210 other => Err(FerrayError::invalid_dtype(format!(
211 "expected bool, got {}",
212 other.dtype()
213 ))),
214 }
215 }
216
217 pub fn astype(&self, target: DType, casting: CastKind) -> FerrayResult<Self> {
232 #[cfg(feature = "f16")]
234 if matches!(self, Self::F16(_)) || target == DType::F16 {
235 return Err(FerrayError::invalid_dtype(
236 "DynArray::astype does not yet support f16",
237 ));
238 }
239 #[cfg(feature = "bf16")]
240 if matches!(self, Self::BF16(_)) || target == DType::BF16 {
241 return Err(FerrayError::invalid_dtype(
242 "DynArray::astype does not yet support bf16",
243 ));
244 }
245 if matches!(self, Self::I256(_)) || target == DType::I256 {
251 return Err(FerrayError::invalid_dtype(
252 "DynArray::astype does not yet support I256 — construct I256 arrays directly",
253 ));
254 }
255
256 macro_rules! cast_into {
259 ($U:ty) => {
260 match self {
261 Self::Bool(a) => a.cast::<$U>(casting),
262 Self::U8(a) => a.cast::<$U>(casting),
263 Self::U16(a) => a.cast::<$U>(casting),
264 Self::U32(a) => a.cast::<$U>(casting),
265 Self::U64(a) => a.cast::<$U>(casting),
266 Self::U128(a) => a.cast::<$U>(casting),
267 Self::I8(a) => a.cast::<$U>(casting),
268 Self::I16(a) => a.cast::<$U>(casting),
269 Self::I32(a) => a.cast::<$U>(casting),
270 Self::I64(a) => a.cast::<$U>(casting),
271 Self::I128(a) => a.cast::<$U>(casting),
272 Self::F32(a) => a.cast::<$U>(casting),
273 Self::F64(a) => a.cast::<$U>(casting),
274 Self::Complex32(a) => a.cast::<$U>(casting),
275 Self::Complex64(a) => a.cast::<$U>(casting),
276 Self::I256(_) => unreachable!("I256 source rejected above"),
277 #[cfg(feature = "f16")]
278 Self::F16(_) => unreachable!("f16 source rejected above"),
279 #[cfg(feature = "bf16")]
280 Self::BF16(_) => unreachable!("bf16 source rejected above"),
281 }
282 };
283 }
284
285 Ok(match target {
286 DType::Bool => Self::Bool(cast_into!(bool)?),
287 DType::U8 => Self::U8(cast_into!(u8)?),
288 DType::U16 => Self::U16(cast_into!(u16)?),
289 DType::U32 => Self::U32(cast_into!(u32)?),
290 DType::U64 => Self::U64(cast_into!(u64)?),
291 DType::U128 => Self::U128(cast_into!(u128)?),
292 DType::I8 => Self::I8(cast_into!(i8)?),
293 DType::I16 => Self::I16(cast_into!(i16)?),
294 DType::I32 => Self::I32(cast_into!(i32)?),
295 DType::I64 => Self::I64(cast_into!(i64)?),
296 DType::I128 => Self::I128(cast_into!(i128)?),
297 DType::F32 => Self::F32(cast_into!(f32)?),
298 DType::F64 => Self::F64(cast_into!(f64)?),
299 DType::Complex32 => Self::Complex32(cast_into!(Complex<f32>)?),
300 DType::Complex64 => Self::Complex64(cast_into!(Complex<f64>)?),
301 DType::I256 => unreachable!("I256 target rejected above"),
302 #[cfg(feature = "f16")]
303 DType::F16 => unreachable!("f16 target rejected above"),
304 #[cfg(feature = "bf16")]
305 DType::BF16 => unreachable!("bf16 target rejected above"),
306 })
307 }
308
309 pub fn zeros(dtype: DType, shape: &[usize]) -> FerrayResult<Self> {
311 let dim = IxDyn::new(shape);
312 Ok(match dtype {
313 DType::Bool => Self::Bool(Array::zeros(dim)?),
314 DType::U8 => Self::U8(Array::zeros(dim)?),
315 DType::U16 => Self::U16(Array::zeros(dim)?),
316 DType::U32 => Self::U32(Array::zeros(dim)?),
317 DType::U64 => Self::U64(Array::zeros(dim)?),
318 DType::U128 => Self::U128(Array::zeros(dim)?),
319 DType::I8 => Self::I8(Array::zeros(dim)?),
320 DType::I16 => Self::I16(Array::zeros(dim)?),
321 DType::I32 => Self::I32(Array::zeros(dim)?),
322 DType::I64 => Self::I64(Array::zeros(dim)?),
323 DType::I128 => Self::I128(Array::zeros(dim)?),
324 DType::I256 => Self::I256(Array::zeros(dim)?),
325 DType::F32 => Self::F32(Array::zeros(dim)?),
326 DType::F64 => Self::F64(Array::zeros(dim)?),
327 DType::Complex32 => Self::Complex32(Array::zeros(dim)?),
328 DType::Complex64 => Self::Complex64(Array::zeros(dim)?),
329 #[cfg(feature = "f16")]
330 DType::F16 => Self::F16(Array::zeros(dim)?),
331 #[cfg(feature = "bf16")]
332 DType::BF16 => Self::BF16(Array::zeros(dim)?),
333 })
334 }
335}
336
337impl std::fmt::Display for DynArray {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 dispatch!(self, a => write!(f, "{a}"))
340 }
341}
342
343macro_rules! impl_from_array_dyn {
345 ($ty:ty, $variant:ident) => {
346 impl From<Array<$ty, IxDyn>> for DynArray {
347 fn from(a: Array<$ty, IxDyn>) -> Self {
348 Self::$variant(a)
349 }
350 }
351 };
352}
353
354impl_from_array_dyn!(bool, Bool);
355impl_from_array_dyn!(u8, U8);
356impl_from_array_dyn!(u16, U16);
357impl_from_array_dyn!(u32, U32);
358impl_from_array_dyn!(u64, U64);
359impl_from_array_dyn!(u128, U128);
360impl_from_array_dyn!(i8, I8);
361impl_from_array_dyn!(i16, I16);
362impl_from_array_dyn!(i32, I32);
363impl_from_array_dyn!(i64, I64);
364impl_from_array_dyn!(i128, I128);
365impl_from_array_dyn!(I256, I256);
366impl_from_array_dyn!(f32, F32);
367impl_from_array_dyn!(f64, F64);
368impl_from_array_dyn!(Complex<f32>, Complex32);
369impl_from_array_dyn!(Complex<f64>, Complex64);
370#[cfg(feature = "f16")]
371impl_from_array_dyn!(half::f16, F16);
372#[cfg(feature = "bf16")]
373impl_from_array_dyn!(half::bf16, BF16);
374
375#[cfg(test)]
376mod tests {
377 use super::*;
378
379 #[test]
380 fn dynarray_zeros_f64() {
381 let da = DynArray::zeros(DType::F64, &[2, 3]).unwrap();
382 assert_eq!(da.dtype(), DType::F64);
383 assert_eq!(da.shape(), &[2, 3]);
384 assert_eq!(da.ndim(), 2);
385 assert_eq!(da.size(), 6);
386 assert_eq!(da.itemsize(), 8);
387 assert_eq!(da.nbytes(), 48);
388 }
389
390 #[test]
391 fn dynarray_zeros_i32() {
392 let da = DynArray::zeros(DType::I32, &[4]).unwrap();
393 assert_eq!(da.dtype(), DType::I32);
394 assert_eq!(da.shape(), &[4]);
395 }
396
397 #[test]
398 fn dynarray_try_into_f64() {
399 let da = DynArray::zeros(DType::F64, &[3]).unwrap();
400 let arr = da.try_into_f64().unwrap();
401 assert_eq!(arr.shape(), &[3]);
402 }
403
404 #[test]
405 fn dynarray_try_into_wrong_type() {
406 let da = DynArray::zeros(DType::I32, &[3]).unwrap();
407 assert!(da.try_into_f64().is_err());
408 }
409
410 #[test]
413 fn dynarray_astype_f64_to_i32_unsafe() {
414 let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![1.5, 2.7, -3.9]).unwrap();
415 let dy = DynArray::F64(arr);
416 let casted = dy.astype(DType::I32, CastKind::Unsafe).unwrap();
417 assert_eq!(casted.dtype(), DType::I32);
418 match casted {
419 DynArray::I32(a) => assert_eq!(a.as_slice().unwrap(), &[1, 2, -3]),
420 _ => panic!("expected I32"),
421 }
422 }
423
424 #[test]
425 fn dynarray_astype_safe_widening() {
426 let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), vec![10, 20, 30]).unwrap();
427 let dy = DynArray::I32(arr);
428 let casted = dy.astype(DType::I64, CastKind::Safe).unwrap();
429 assert_eq!(casted.dtype(), DType::I64);
430 match casted {
431 DynArray::I64(a) => assert_eq!(a.as_slice().unwrap(), &[10i64, 20, 30]),
432 _ => panic!("expected I64"),
433 }
434 }
435
436 #[test]
437 fn dynarray_astype_safe_narrowing_errors() {
438 let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
439 let dy = DynArray::F64(arr);
440 assert!(dy.astype(DType::F32, CastKind::Safe).is_err());
441 }
442
443 #[test]
444 fn dynarray_astype_complex_to_real_unsafe() {
445 let arr = Array::<Complex<f64>, IxDyn>::from_vec(
446 IxDyn::new(&[2]),
447 vec![Complex::new(1.5, 9.0), Complex::new(2.5, -1.0)],
448 )
449 .unwrap();
450 let dy = DynArray::Complex64(arr);
451 let casted = dy.astype(DType::F64, CastKind::Unsafe).unwrap();
452 match casted {
453 DynArray::F64(a) => assert_eq!(a.as_slice().unwrap(), &[1.5, 2.5]),
454 _ => panic!("expected F64"),
455 }
456 }
457
458 #[test]
459 fn dynarray_astype_bool_to_u8_safe() {
460 let arr =
461 Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3]), vec![true, false, true]).unwrap();
462 let dy = DynArray::Bool(arr);
463 let casted = dy.astype(DType::U8, CastKind::Safe).unwrap();
464 match casted {
465 DynArray::U8(a) => assert_eq!(a.as_slice().unwrap(), &[1u8, 0, 1]),
466 _ => panic!("expected U8"),
467 }
468 }
469
470 #[test]
471 fn dynarray_astype_no_kind_requires_identity() {
472 let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
473 let dy = DynArray::F64(arr);
474 assert!(dy.astype(DType::F64, CastKind::No).is_ok());
475 assert!(dy.astype(DType::F32, CastKind::No).is_err());
476 }
477
478 #[test]
479 fn dynarray_from_typed() {
480 let arr = Array::<f64, IxDyn>::zeros(IxDyn::new(&[2, 2])).unwrap();
481 let da: DynArray = arr.into();
482 assert_eq!(da.dtype(), DType::F64);
483 }
484
485 #[test]
486 fn dynarray_display() {
487 let da = DynArray::zeros(DType::I32, &[3]).unwrap();
488 let s = format!("{da}");
489 assert!(s.contains("[0, 0, 0]"));
490 }
491
492 #[test]
493 fn dynarray_is_empty() {
494 let da = DynArray::zeros(DType::F32, &[0]).unwrap();
495 assert!(da.is_empty());
496 }
497
498 #[cfg(feature = "f16")]
501 #[test]
502 fn dynarray_f16_zeros_shape_and_dtype() {
503 let da = DynArray::zeros(DType::F16, &[2, 3]).unwrap();
504 assert_eq!(da.dtype(), DType::F16);
505 assert_eq!(da.shape(), &[2, 3]);
506 assert_eq!(da.size(), 6);
507 assert_eq!(da.itemsize(), 2);
508 assert_eq!(da.nbytes(), 12);
509 }
510
511 #[cfg(feature = "f16")]
512 #[test]
513 fn dynarray_f16_from_typed_roundtrips() {
514 use half::f16;
515 let raw = [f16::from_f32(1.0), f16::from_f32(2.5), f16::from_f32(-3.0)];
516 let arr = Array::<f16, IxDyn>::from_vec(IxDyn::new(&[3]), raw.to_vec()).unwrap();
517 let da: DynArray = arr.into();
518 assert_eq!(da.dtype(), DType::F16);
519 assert_eq!(da.shape(), &[3]);
520 }
521
522 #[cfg(feature = "bf16")]
523 #[test]
524 fn dynarray_bf16_zeros_shape_and_dtype() {
525 let da = DynArray::zeros(DType::BF16, &[4]).unwrap();
526 assert_eq!(da.dtype(), DType::BF16);
527 assert_eq!(da.shape(), &[4]);
528 assert_eq!(da.itemsize(), 2);
529 }
530
531 #[cfg(feature = "bf16")]
532 #[test]
533 fn dynarray_bf16_from_typed_roundtrips() {
534 use half::bf16;
535 let raw = [bf16::from_f32(1.0), bf16::from_f32(2.0)];
536 let arr = Array::<bf16, IxDyn>::from_vec(IxDyn::new(&[2]), raw.to_vec()).unwrap();
537 let da: DynArray = arr.into();
538 assert_eq!(da.dtype(), DType::BF16);
539 }
540}