1use num_complex::Complex;
4
5use crate::array::owned::Array;
6use crate::dimension::IxDyn;
7use crate::dtype::casting::CastKind;
8use crate::dtype::{DType, DateTime64, I256, TimeUnit, Timedelta64};
9use crate::error::{FerrayError, FerrayResult};
10
11#[derive(Debug, Clone)]
20#[non_exhaustive]
21pub enum DynArray {
22 Bool(Array<bool, IxDyn>),
24 U8(Array<u8, IxDyn>),
26 U16(Array<u16, IxDyn>),
28 U32(Array<u32, IxDyn>),
30 U64(Array<u64, IxDyn>),
32 U128(Array<u128, IxDyn>),
34 I8(Array<i8, IxDyn>),
36 I16(Array<i16, IxDyn>),
38 I32(Array<i32, IxDyn>),
40 I64(Array<i64, IxDyn>),
42 I128(Array<i128, IxDyn>),
44 I256(Array<I256, IxDyn>),
48 F32(Array<f32, IxDyn>),
50 F64(Array<f64, IxDyn>),
52 Complex32(Array<Complex<f32>, IxDyn>),
54 Complex64(Array<Complex<f64>, IxDyn>),
56 #[cfg(feature = "f16")]
58 F16(Array<half::f16, IxDyn>),
59 #[cfg(feature = "bf16")]
61 BF16(Array<half::bf16, IxDyn>),
62 DateTime64(Array<DateTime64, IxDyn>, TimeUnit),
67 Timedelta64(Array<Timedelta64, IxDyn>, TimeUnit),
69}
70
71macro_rules! dispatch {
76 ($value:expr, $binding:ident => $expr:expr) => {
77 match $value {
78 Self::Bool($binding) => $expr,
79 Self::U8($binding) => $expr,
80 Self::U16($binding) => $expr,
81 Self::U32($binding) => $expr,
82 Self::U64($binding) => $expr,
83 Self::U128($binding) => $expr,
84 Self::I8($binding) => $expr,
85 Self::I16($binding) => $expr,
86 Self::I32($binding) => $expr,
87 Self::I64($binding) => $expr,
88 Self::I128($binding) => $expr,
89 Self::I256($binding) => $expr,
90 Self::F32($binding) => $expr,
91 Self::F64($binding) => $expr,
92 Self::Complex32($binding) => $expr,
93 Self::Complex64($binding) => $expr,
94 #[cfg(feature = "f16")]
95 Self::F16($binding) => $expr,
96 #[cfg(feature = "bf16")]
97 Self::BF16($binding) => $expr,
98 Self::DateTime64($binding, _) => $expr,
101 Self::Timedelta64($binding, _) => $expr,
102 }
103 };
104}
105
106impl DynArray {
107 #[must_use]
109 pub const fn dtype(&self) -> DType {
110 match self {
111 Self::Bool(_) => DType::Bool,
112 Self::U8(_) => DType::U8,
113 Self::U16(_) => DType::U16,
114 Self::U32(_) => DType::U32,
115 Self::U64(_) => DType::U64,
116 Self::U128(_) => DType::U128,
117 Self::I8(_) => DType::I8,
118 Self::I16(_) => DType::I16,
119 Self::I32(_) => DType::I32,
120 Self::I64(_) => DType::I64,
121 Self::I128(_) => DType::I128,
122 Self::I256(_) => DType::I256,
123 Self::F32(_) => DType::F32,
124 Self::F64(_) => DType::F64,
125 Self::Complex32(_) => DType::Complex32,
126 Self::Complex64(_) => DType::Complex64,
127 #[cfg(feature = "f16")]
128 Self::F16(_) => DType::F16,
129 #[cfg(feature = "bf16")]
130 Self::BF16(_) => DType::BF16,
131 Self::DateTime64(_, u) => DType::DateTime64(*u),
132 Self::Timedelta64(_, u) => DType::Timedelta64(*u),
133 }
134 }
135
136 #[must_use]
138 pub fn shape(&self) -> &[usize] {
139 dispatch!(self, a => a.shape())
140 }
141
142 #[must_use]
144 pub fn ndim(&self) -> usize {
145 self.shape().len()
146 }
147
148 #[must_use]
150 pub fn size(&self) -> usize {
151 self.shape().iter().product()
152 }
153
154 #[must_use]
156 pub fn is_empty(&self) -> bool {
157 self.size() == 0
158 }
159
160 #[must_use]
162 pub fn itemsize(&self) -> usize {
163 self.dtype().size_of()
164 }
165
166 #[must_use]
168 pub fn nbytes(&self) -> usize {
169 self.size() * self.itemsize()
170 }
171
172 pub fn try_into_f64(self) -> FerrayResult<Array<f64, IxDyn>> {
177 match self {
178 Self::F64(a) => Ok(a),
179 other => Err(FerrayError::invalid_dtype(format!(
180 "expected float64, got {}",
181 other.dtype()
182 ))),
183 }
184 }
185
186 pub fn try_into_f32(self) -> FerrayResult<Array<f32, IxDyn>> {
188 match self {
189 Self::F32(a) => Ok(a),
190 other => Err(FerrayError::invalid_dtype(format!(
191 "expected float32, got {}",
192 other.dtype()
193 ))),
194 }
195 }
196
197 pub fn try_into_i64(self) -> FerrayResult<Array<i64, IxDyn>> {
199 match self {
200 Self::I64(a) => Ok(a),
201 other => Err(FerrayError::invalid_dtype(format!(
202 "expected int64, got {}",
203 other.dtype()
204 ))),
205 }
206 }
207
208 pub fn try_into_i32(self) -> FerrayResult<Array<i32, IxDyn>> {
210 match self {
211 Self::I32(a) => Ok(a),
212 other => Err(FerrayError::invalid_dtype(format!(
213 "expected int32, got {}",
214 other.dtype()
215 ))),
216 }
217 }
218
219 pub fn try_into_bool(self) -> FerrayResult<Array<bool, IxDyn>> {
221 match self {
222 Self::Bool(a) => Ok(a),
223 other => Err(FerrayError::invalid_dtype(format!(
224 "expected bool, got {}",
225 other.dtype()
226 ))),
227 }
228 }
229
230 pub fn astype(&self, target: DType, casting: CastKind) -> FerrayResult<Self> {
245 #[cfg(feature = "f16")]
247 if matches!(self, Self::F16(_)) || target == DType::F16 {
248 return Err(FerrayError::invalid_dtype(
249 "DynArray::astype does not yet support f16",
250 ));
251 }
252 #[cfg(feature = "bf16")]
253 if matches!(self, Self::BF16(_)) || target == DType::BF16 {
254 return Err(FerrayError::invalid_dtype(
255 "DynArray::astype does not yet support bf16",
256 ));
257 }
258 if matches!(self, Self::I256(_)) || target == DType::I256 {
264 return Err(FerrayError::invalid_dtype(
265 "DynArray::astype does not yet support I256 — construct I256 arrays directly",
266 ));
267 }
268 if matches!(self, Self::DateTime64(_, _) | Self::Timedelta64(_, _))
273 || matches!(target, DType::DateTime64(_) | DType::Timedelta64(_))
274 {
275 return Err(FerrayError::invalid_dtype(format!(
276 "DynArray::astype: cast involving {target} not supported \
277 — datetime/timedelta dtypes use dedicated arithmetic, not generic casts"
278 )));
279 }
280
281 macro_rules! cast_into {
284 ($U:ty) => {
285 match self {
286 Self::Bool(a) => a.cast::<$U>(casting),
287 Self::U8(a) => a.cast::<$U>(casting),
288 Self::U16(a) => a.cast::<$U>(casting),
289 Self::U32(a) => a.cast::<$U>(casting),
290 Self::U64(a) => a.cast::<$U>(casting),
291 Self::U128(a) => a.cast::<$U>(casting),
292 Self::I8(a) => a.cast::<$U>(casting),
293 Self::I16(a) => a.cast::<$U>(casting),
294 Self::I32(a) => a.cast::<$U>(casting),
295 Self::I64(a) => a.cast::<$U>(casting),
296 Self::I128(a) => a.cast::<$U>(casting),
297 Self::F32(a) => a.cast::<$U>(casting),
298 Self::F64(a) => a.cast::<$U>(casting),
299 Self::Complex32(a) => a.cast::<$U>(casting),
300 Self::Complex64(a) => a.cast::<$U>(casting),
301 Self::I256(_) => unreachable!("I256 source rejected above"),
302 #[cfg(feature = "f16")]
303 Self::F16(_) => unreachable!("f16 source rejected above"),
304 #[cfg(feature = "bf16")]
305 Self::BF16(_) => unreachable!("bf16 source rejected above"),
306 Self::DateTime64(_, _) | Self::Timedelta64(_, _) => {
307 unreachable!("time-dtype source rejected above")
308 }
309 }
310 };
311 }
312
313 Ok(match target {
314 DType::Bool => Self::Bool(cast_into!(bool)?),
315 DType::U8 => Self::U8(cast_into!(u8)?),
316 DType::U16 => Self::U16(cast_into!(u16)?),
317 DType::U32 => Self::U32(cast_into!(u32)?),
318 DType::U64 => Self::U64(cast_into!(u64)?),
319 DType::U128 => Self::U128(cast_into!(u128)?),
320 DType::I8 => Self::I8(cast_into!(i8)?),
321 DType::I16 => Self::I16(cast_into!(i16)?),
322 DType::I32 => Self::I32(cast_into!(i32)?),
323 DType::I64 => Self::I64(cast_into!(i64)?),
324 DType::I128 => Self::I128(cast_into!(i128)?),
325 DType::F32 => Self::F32(cast_into!(f32)?),
326 DType::F64 => Self::F64(cast_into!(f64)?),
327 DType::Complex32 => Self::Complex32(cast_into!(Complex<f32>)?),
328 DType::Complex64 => Self::Complex64(cast_into!(Complex<f64>)?),
329 DType::I256 => unreachable!("I256 target rejected above"),
330 #[cfg(feature = "f16")]
331 DType::F16 => unreachable!("f16 target rejected above"),
332 #[cfg(feature = "bf16")]
333 DType::BF16 => unreachable!("bf16 target rejected above"),
334 DType::DateTime64(_) | DType::Timedelta64(_) => {
335 unreachable!("time-dtype target rejected above")
336 }
337 DType::Struct(_) => {
342 return Err(FerrayError::invalid_dtype(
343 "DynArray cannot represent structured dtype targets yet",
344 ));
345 }
346 DType::FixedAscii(_) | DType::FixedUnicode(_) | DType::RawBytes(_) => {
350 return Err(FerrayError::invalid_dtype(
351 "DynArray cannot represent fixed-width string / void dtype targets yet (#741)",
352 ));
353 }
354 })
355 }
356
357 pub fn zeros(dtype: DType, shape: &[usize]) -> FerrayResult<Self> {
359 let dim = IxDyn::new(shape);
360 Ok(match dtype {
361 DType::Bool => Self::Bool(Array::zeros(dim)?),
362 DType::U8 => Self::U8(Array::zeros(dim)?),
363 DType::U16 => Self::U16(Array::zeros(dim)?),
364 DType::U32 => Self::U32(Array::zeros(dim)?),
365 DType::U64 => Self::U64(Array::zeros(dim)?),
366 DType::U128 => Self::U128(Array::zeros(dim)?),
367 DType::I8 => Self::I8(Array::zeros(dim)?),
368 DType::I16 => Self::I16(Array::zeros(dim)?),
369 DType::I32 => Self::I32(Array::zeros(dim)?),
370 DType::I64 => Self::I64(Array::zeros(dim)?),
371 DType::I128 => Self::I128(Array::zeros(dim)?),
372 DType::I256 => Self::I256(Array::zeros(dim)?),
373 DType::F32 => Self::F32(Array::zeros(dim)?),
374 DType::F64 => Self::F64(Array::zeros(dim)?),
375 DType::Complex32 => Self::Complex32(Array::zeros(dim)?),
376 DType::Complex64 => Self::Complex64(Array::zeros(dim)?),
377 #[cfg(feature = "f16")]
378 DType::F16 => Self::F16(Array::zeros(dim)?),
379 #[cfg(feature = "bf16")]
380 DType::BF16 => Self::BF16(Array::zeros(dim)?),
381 DType::DateTime64(unit) => Self::DateTime64(Array::zeros(dim)?, unit),
385 DType::Timedelta64(unit) => Self::Timedelta64(Array::zeros(dim)?, unit),
386 DType::Struct(_) => {
390 return Err(FerrayError::invalid_dtype(
391 "DynArray::zeros doesn't support structured dtypes yet",
392 ));
393 }
394 DType::FixedAscii(_) | DType::FixedUnicode(_) | DType::RawBytes(_) => {
398 return Err(FerrayError::invalid_dtype(
399 "DynArray::zeros doesn't support fixed-width string / void dtypes yet (#741)",
400 ));
401 }
402 })
403 }
404
405 #[must_use]
407 pub fn from_datetime64(arr: Array<DateTime64, IxDyn>, unit: TimeUnit) -> Self {
408 Self::DateTime64(arr, unit)
409 }
410
411 #[must_use]
413 pub fn from_timedelta64(arr: Array<Timedelta64, IxDyn>, unit: TimeUnit) -> Self {
414 Self::Timedelta64(arr, unit)
415 }
416
417 pub fn try_into_datetime64(self) -> FerrayResult<(Array<DateTime64, IxDyn>, TimeUnit)> {
423 match self {
424 Self::DateTime64(a, u) => Ok((a, u)),
425 other => Err(FerrayError::invalid_dtype(format!(
426 "expected datetime64, got {}",
427 other.dtype()
428 ))),
429 }
430 }
431
432 pub fn try_into_timedelta64(self) -> FerrayResult<(Array<Timedelta64, IxDyn>, TimeUnit)> {
438 match self {
439 Self::Timedelta64(a, u) => Ok((a, u)),
440 other => Err(FerrayError::invalid_dtype(format!(
441 "expected timedelta64, got {}",
442 other.dtype()
443 ))),
444 }
445 }
446}
447
448impl std::fmt::Display for DynArray {
449 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
450 dispatch!(self, a => write!(f, "{a}"))
451 }
452}
453
454macro_rules! impl_from_array_dyn {
456 ($ty:ty, $variant:ident) => {
457 impl From<Array<$ty, IxDyn>> for DynArray {
458 fn from(a: Array<$ty, IxDyn>) -> Self {
459 Self::$variant(a)
460 }
461 }
462 };
463}
464
465impl_from_array_dyn!(bool, Bool);
466impl_from_array_dyn!(u8, U8);
467impl_from_array_dyn!(u16, U16);
468impl_from_array_dyn!(u32, U32);
469impl_from_array_dyn!(u64, U64);
470impl_from_array_dyn!(u128, U128);
471impl_from_array_dyn!(i8, I8);
472impl_from_array_dyn!(i16, I16);
473impl_from_array_dyn!(i32, I32);
474impl_from_array_dyn!(i64, I64);
475impl_from_array_dyn!(i128, I128);
476impl_from_array_dyn!(I256, I256);
477impl_from_array_dyn!(f32, F32);
478impl_from_array_dyn!(f64, F64);
479impl_from_array_dyn!(Complex<f32>, Complex32);
480impl_from_array_dyn!(Complex<f64>, Complex64);
481#[cfg(feature = "f16")]
482impl_from_array_dyn!(half::f16, F16);
483#[cfg(feature = "bf16")]
484impl_from_array_dyn!(half::bf16, BF16);
485
486#[cfg(test)]
487mod tests {
488 use super::*;
489
490 #[test]
491 fn dynarray_zeros_f64() {
492 let da = DynArray::zeros(DType::F64, &[2, 3]).unwrap();
493 assert_eq!(da.dtype(), DType::F64);
494 assert_eq!(da.shape(), &[2, 3]);
495 assert_eq!(da.ndim(), 2);
496 assert_eq!(da.size(), 6);
497 assert_eq!(da.itemsize(), 8);
498 assert_eq!(da.nbytes(), 48);
499 }
500
501 #[test]
502 fn dynarray_zeros_i32() {
503 let da = DynArray::zeros(DType::I32, &[4]).unwrap();
504 assert_eq!(da.dtype(), DType::I32);
505 assert_eq!(da.shape(), &[4]);
506 }
507
508 #[test]
509 fn dynarray_try_into_f64() {
510 let da = DynArray::zeros(DType::F64, &[3]).unwrap();
511 let arr = da.try_into_f64().unwrap();
512 assert_eq!(arr.shape(), &[3]);
513 }
514
515 #[test]
516 fn dynarray_try_into_wrong_type() {
517 let da = DynArray::zeros(DType::I32, &[3]).unwrap();
518 assert!(da.try_into_f64().is_err());
519 }
520
521 #[test]
524 fn dynarray_astype_f64_to_i32_unsafe() {
525 let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![1.5, 2.7, -3.9]).unwrap();
526 let dy = DynArray::F64(arr);
527 let casted = dy.astype(DType::I32, CastKind::Unsafe).unwrap();
528 assert_eq!(casted.dtype(), DType::I32);
529 match casted {
530 DynArray::I32(a) => assert_eq!(a.as_slice().unwrap(), &[1, 2, -3]),
531 _ => panic!("expected I32"),
532 }
533 }
534
535 #[test]
536 fn dynarray_astype_safe_widening() {
537 let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), vec![10, 20, 30]).unwrap();
538 let dy = DynArray::I32(arr);
539 let casted = dy.astype(DType::I64, CastKind::Safe).unwrap();
540 assert_eq!(casted.dtype(), DType::I64);
541 match casted {
542 DynArray::I64(a) => assert_eq!(a.as_slice().unwrap(), &[10i64, 20, 30]),
543 _ => panic!("expected I64"),
544 }
545 }
546
547 #[test]
548 fn dynarray_astype_safe_narrowing_errors() {
549 let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
550 let dy = DynArray::F64(arr);
551 assert!(dy.astype(DType::F32, CastKind::Safe).is_err());
552 }
553
554 #[test]
555 fn dynarray_astype_complex_to_real_unsafe() {
556 let arr = Array::<Complex<f64>, IxDyn>::from_vec(
557 IxDyn::new(&[2]),
558 vec![Complex::new(1.5, 9.0), Complex::new(2.5, -1.0)],
559 )
560 .unwrap();
561 let dy = DynArray::Complex64(arr);
562 let casted = dy.astype(DType::F64, CastKind::Unsafe).unwrap();
563 match casted {
564 DynArray::F64(a) => assert_eq!(a.as_slice().unwrap(), &[1.5, 2.5]),
565 _ => panic!("expected F64"),
566 }
567 }
568
569 #[test]
570 fn dynarray_astype_bool_to_u8_safe() {
571 let arr =
572 Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3]), vec![true, false, true]).unwrap();
573 let dy = DynArray::Bool(arr);
574 let casted = dy.astype(DType::U8, CastKind::Safe).unwrap();
575 match casted {
576 DynArray::U8(a) => assert_eq!(a.as_slice().unwrap(), &[1u8, 0, 1]),
577 _ => panic!("expected U8"),
578 }
579 }
580
581 #[test]
582 fn dynarray_astype_no_kind_requires_identity() {
583 let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
584 let dy = DynArray::F64(arr);
585 assert!(dy.astype(DType::F64, CastKind::No).is_ok());
586 assert!(dy.astype(DType::F32, CastKind::No).is_err());
587 }
588
589 #[test]
590 fn dynarray_from_typed() {
591 let arr = Array::<f64, IxDyn>::zeros(IxDyn::new(&[2, 2])).unwrap();
592 let da: DynArray = arr.into();
593 assert_eq!(da.dtype(), DType::F64);
594 }
595
596 #[test]
597 fn dynarray_display() {
598 let da = DynArray::zeros(DType::I32, &[3]).unwrap();
599 let s = format!("{da}");
600 assert!(s.contains("[0, 0, 0]"));
601 }
602
603 #[test]
604 fn dynarray_is_empty() {
605 let da = DynArray::zeros(DType::F32, &[0]).unwrap();
606 assert!(da.is_empty());
607 }
608
609 #[cfg(feature = "f16")]
612 #[test]
613 fn dynarray_f16_zeros_shape_and_dtype() {
614 let da = DynArray::zeros(DType::F16, &[2, 3]).unwrap();
615 assert_eq!(da.dtype(), DType::F16);
616 assert_eq!(da.shape(), &[2, 3]);
617 assert_eq!(da.size(), 6);
618 assert_eq!(da.itemsize(), 2);
619 assert_eq!(da.nbytes(), 12);
620 }
621
622 #[cfg(feature = "f16")]
623 #[test]
624 fn dynarray_f16_from_typed_roundtrips() {
625 use half::f16;
626 let raw = [f16::from_f32(1.0), f16::from_f32(2.5), f16::from_f32(-3.0)];
627 let arr = Array::<f16, IxDyn>::from_vec(IxDyn::new(&[3]), raw.to_vec()).unwrap();
628 let da: DynArray = arr.into();
629 assert_eq!(da.dtype(), DType::F16);
630 assert_eq!(da.shape(), &[3]);
631 }
632
633 #[cfg(feature = "bf16")]
634 #[test]
635 fn dynarray_bf16_zeros_shape_and_dtype() {
636 let da = DynArray::zeros(DType::BF16, &[4]).unwrap();
637 assert_eq!(da.dtype(), DType::BF16);
638 assert_eq!(da.shape(), &[4]);
639 assert_eq!(da.itemsize(), 2);
640 }
641
642 #[cfg(feature = "bf16")]
643 #[test]
644 fn dynarray_bf16_from_typed_roundtrips() {
645 use half::bf16;
646 let raw = [bf16::from_f32(1.0), bf16::from_f32(2.0)];
647 let arr = Array::<bf16, IxDyn>::from_vec(IxDyn::new(&[2]), raw.to_vec()).unwrap();
648 let da: DynArray = arr.into();
649 assert_eq!(da.dtype(), DType::BF16);
650 }
651}