Skip to main content

ferray_core/
dynarray.rs

1// ferray-core: DynArray — runtime-typed array enum (REQ-30)
2
3use num_complex::Complex;
4
5use crate::array::owned::Array;
6use crate::dimension::IxDyn;
7use crate::dtype::casting::CastKind;
8use crate::dtype::{DType, I256};
9use crate::error::{FerrayError, FerrayResult};
10
11/// A runtime-typed array whose element type is determined at runtime.
12///
13/// This is analogous to a Python `numpy.ndarray` where the dtype is a
14/// runtime property. Each variant wraps an `Array<T, IxDyn>` for the
15/// corresponding element type.
16///
17/// Use this when the element type is not known at compile time (e.g.,
18/// loading from a file, receiving from Python/FFI).
19#[derive(Debug, Clone)]
20#[non_exhaustive]
21pub enum DynArray {
22    /// `bool` elements
23    Bool(Array<bool, IxDyn>),
24    /// `u8` elements
25    U8(Array<u8, IxDyn>),
26    /// `u16` elements
27    U16(Array<u16, IxDyn>),
28    /// `u32` elements
29    U32(Array<u32, IxDyn>),
30    /// `u64` elements
31    U64(Array<u64, IxDyn>),
32    /// `u128` elements
33    U128(Array<u128, IxDyn>),
34    /// `i8` elements
35    I8(Array<i8, IxDyn>),
36    /// `i16` elements
37    I16(Array<i16, IxDyn>),
38    /// `i32` elements
39    I32(Array<i32, IxDyn>),
40    /// `i64` elements
41    I64(Array<i64, IxDyn>),
42    /// `i128` elements
43    I128(Array<i128, IxDyn>),
44    /// `I256` elements — 256-bit two's-complement signed integer,
45    /// used as the promoted type for mixed `u128` + signed-int
46    /// arithmetic (#375, #562).
47    I256(Array<I256, IxDyn>),
48    /// `f32` elements
49    F32(Array<f32, IxDyn>),
50    /// `f64` elements
51    F64(Array<f64, IxDyn>),
52    /// `Complex<f32>` elements
53    Complex32(Array<Complex<f32>, IxDyn>),
54    /// `Complex<f64>` elements
55    Complex64(Array<Complex<f64>, IxDyn>),
56    /// `f16` elements (feature-gated)
57    #[cfg(feature = "f16")]
58    F16(Array<half::f16, IxDyn>),
59    /// `bf16` (bfloat16) elements (feature-gated)
60    #[cfg(feature = "bf16")]
61    BF16(Array<half::bf16, IxDyn>),
62}
63
64/// Dispatch a single expression across every DynArray variant, binding
65/// the inner `Array<T, IxDyn>` to `$binding`. This turns repeated 17-way
66/// match arms into one-line methods (see issue #125); the f16/bf16
67/// variants are conditionally compiled in the same way as the enum.
68macro_rules! dispatch {
69    ($value:expr, $binding:ident => $expr:expr) => {
70        match $value {
71            Self::Bool($binding) => $expr,
72            Self::U8($binding) => $expr,
73            Self::U16($binding) => $expr,
74            Self::U32($binding) => $expr,
75            Self::U64($binding) => $expr,
76            Self::U128($binding) => $expr,
77            Self::I8($binding) => $expr,
78            Self::I16($binding) => $expr,
79            Self::I32($binding) => $expr,
80            Self::I64($binding) => $expr,
81            Self::I128($binding) => $expr,
82            Self::I256($binding) => $expr,
83            Self::F32($binding) => $expr,
84            Self::F64($binding) => $expr,
85            Self::Complex32($binding) => $expr,
86            Self::Complex64($binding) => $expr,
87            #[cfg(feature = "f16")]
88            Self::F16($binding) => $expr,
89            #[cfg(feature = "bf16")]
90            Self::BF16($binding) => $expr,
91        }
92    };
93}
94
95impl DynArray {
96    /// The runtime dtype of the elements in this array.
97    pub fn dtype(&self) -> DType {
98        match self {
99            Self::Bool(_) => DType::Bool,
100            Self::U8(_) => DType::U8,
101            Self::U16(_) => DType::U16,
102            Self::U32(_) => DType::U32,
103            Self::U64(_) => DType::U64,
104            Self::U128(_) => DType::U128,
105            Self::I8(_) => DType::I8,
106            Self::I16(_) => DType::I16,
107            Self::I32(_) => DType::I32,
108            Self::I64(_) => DType::I64,
109            Self::I128(_) => DType::I128,
110            Self::I256(_) => DType::I256,
111            Self::F32(_) => DType::F32,
112            Self::F64(_) => DType::F64,
113            Self::Complex32(_) => DType::Complex32,
114            Self::Complex64(_) => DType::Complex64,
115            #[cfg(feature = "f16")]
116            Self::F16(_) => DType::F16,
117            #[cfg(feature = "bf16")]
118            Self::BF16(_) => DType::BF16,
119        }
120    }
121
122    /// Shape as a slice.
123    pub fn shape(&self) -> &[usize] {
124        dispatch!(self, a => a.shape())
125    }
126
127    /// Number of dimensions.
128    pub fn ndim(&self) -> usize {
129        self.shape().len()
130    }
131
132    /// Total number of elements.
133    pub fn size(&self) -> usize {
134        self.shape().iter().product()
135    }
136
137    /// Whether the array has zero elements.
138    pub fn is_empty(&self) -> bool {
139        self.size() == 0
140    }
141
142    /// Size in bytes of one element.
143    pub fn itemsize(&self) -> usize {
144        self.dtype().size_of()
145    }
146
147    /// Total size in bytes.
148    pub fn nbytes(&self) -> usize {
149        self.size() * self.itemsize()
150    }
151
152    /// Try to extract the inner `Array<f64, IxDyn>`.
153    ///
154    /// # Errors
155    /// Returns `FerrayError::InvalidDtype` if the dtype is not `f64`.
156    pub fn try_into_f64(self) -> FerrayResult<Array<f64, IxDyn>> {
157        match self {
158            Self::F64(a) => Ok(a),
159            other => Err(FerrayError::invalid_dtype(format!(
160                "expected float64, got {}",
161                other.dtype()
162            ))),
163        }
164    }
165
166    /// Try to extract the inner `Array<f32, IxDyn>`.
167    pub fn try_into_f32(self) -> FerrayResult<Array<f32, IxDyn>> {
168        match self {
169            Self::F32(a) => Ok(a),
170            other => Err(FerrayError::invalid_dtype(format!(
171                "expected float32, got {}",
172                other.dtype()
173            ))),
174        }
175    }
176
177    /// Try to extract the inner `Array<i64, IxDyn>`.
178    pub fn try_into_i64(self) -> FerrayResult<Array<i64, IxDyn>> {
179        match self {
180            Self::I64(a) => Ok(a),
181            other => Err(FerrayError::invalid_dtype(format!(
182                "expected int64, got {}",
183                other.dtype()
184            ))),
185        }
186    }
187
188    /// Try to extract the inner `Array<i32, IxDyn>`.
189    pub fn try_into_i32(self) -> FerrayResult<Array<i32, IxDyn>> {
190        match self {
191            Self::I32(a) => Ok(a),
192            other => Err(FerrayError::invalid_dtype(format!(
193                "expected int32, got {}",
194                other.dtype()
195            ))),
196        }
197    }
198
199    /// Try to extract the inner `Array<bool, IxDyn>`.
200    pub fn try_into_bool(self) -> FerrayResult<Array<bool, IxDyn>> {
201        match self {
202            Self::Bool(a) => Ok(a),
203            other => Err(FerrayError::invalid_dtype(format!(
204                "expected bool, got {}",
205                other.dtype()
206            ))),
207        }
208    }
209
210    /// Cast this array to a different element dtype at the requested safety level.
211    ///
212    /// Mirrors NumPy's `arr.astype(dtype, casting=...)`. The conversion routes
213    /// through [`crate::dtype::unsafe_cast::CastTo`] for the underlying typed
214    /// arrays, so it supports the same set of element pairs (every primitive
215    /// numeric, bool, and Complex<f32>/Complex<f64> in any combination).
216    ///
217    /// `f16` / `bf16` are not yet supported in this dispatch and will return
218    /// `FerrayError::InvalidDtype` — track via the umbrella casting issue.
219    ///
220    /// # Errors
221    /// Returns `FerrayError::InvalidDtype` if:
222    /// - The cast is not permitted at the chosen `casting` level, or
223    /// - either source or target dtype is `f16`/`bf16` (not yet wired).
224    pub fn astype(&self, target: DType, casting: CastKind) -> FerrayResult<Self> {
225        // Reject f16/bf16 — see method docs.
226        #[cfg(feature = "f16")]
227        if matches!(self, Self::F16(_)) || target == DType::F16 {
228            return Err(FerrayError::invalid_dtype(
229                "DynArray::astype does not yet support f16",
230            ));
231        }
232        #[cfg(feature = "bf16")]
233        if matches!(self, Self::BF16(_)) || target == DType::BF16 {
234            return Err(FerrayError::invalid_dtype(
235                "DynArray::astype does not yet support bf16",
236            ));
237        }
238        // I256 is accepted as a storage type (from promotion) but
239        // generic cast-through-`CastTo` is not yet wired for it.
240        // Reject both source and target explicitly — users who want
241        // I256 arrays should construct them directly rather than via
242        // astype (#562).
243        if matches!(self, Self::I256(_)) || target == DType::I256 {
244            return Err(FerrayError::invalid_dtype(
245                "DynArray::astype does not yet support I256 — construct I256 arrays directly",
246            ));
247        }
248
249        // Inner macro: dispatch on the *source* variant. The target type `$U`
250        // is fixed by the outer match below.
251        macro_rules! cast_into {
252            ($U:ty) => {
253                match self {
254                    Self::Bool(a) => a.cast::<$U>(casting),
255                    Self::U8(a) => a.cast::<$U>(casting),
256                    Self::U16(a) => a.cast::<$U>(casting),
257                    Self::U32(a) => a.cast::<$U>(casting),
258                    Self::U64(a) => a.cast::<$U>(casting),
259                    Self::U128(a) => a.cast::<$U>(casting),
260                    Self::I8(a) => a.cast::<$U>(casting),
261                    Self::I16(a) => a.cast::<$U>(casting),
262                    Self::I32(a) => a.cast::<$U>(casting),
263                    Self::I64(a) => a.cast::<$U>(casting),
264                    Self::I128(a) => a.cast::<$U>(casting),
265                    Self::F32(a) => a.cast::<$U>(casting),
266                    Self::F64(a) => a.cast::<$U>(casting),
267                    Self::Complex32(a) => a.cast::<$U>(casting),
268                    Self::Complex64(a) => a.cast::<$U>(casting),
269                    Self::I256(_) => unreachable!("I256 source rejected above"),
270                    #[cfg(feature = "f16")]
271                    Self::F16(_) => unreachable!("f16 source rejected above"),
272                    #[cfg(feature = "bf16")]
273                    Self::BF16(_) => unreachable!("bf16 source rejected above"),
274                }
275            };
276        }
277
278        Ok(match target {
279            DType::Bool => Self::Bool(cast_into!(bool)?),
280            DType::U8 => Self::U8(cast_into!(u8)?),
281            DType::U16 => Self::U16(cast_into!(u16)?),
282            DType::U32 => Self::U32(cast_into!(u32)?),
283            DType::U64 => Self::U64(cast_into!(u64)?),
284            DType::U128 => Self::U128(cast_into!(u128)?),
285            DType::I8 => Self::I8(cast_into!(i8)?),
286            DType::I16 => Self::I16(cast_into!(i16)?),
287            DType::I32 => Self::I32(cast_into!(i32)?),
288            DType::I64 => Self::I64(cast_into!(i64)?),
289            DType::I128 => Self::I128(cast_into!(i128)?),
290            DType::F32 => Self::F32(cast_into!(f32)?),
291            DType::F64 => Self::F64(cast_into!(f64)?),
292            DType::Complex32 => Self::Complex32(cast_into!(Complex<f32>)?),
293            DType::Complex64 => Self::Complex64(cast_into!(Complex<f64>)?),
294            DType::I256 => unreachable!("I256 target rejected above"),
295            #[cfg(feature = "f16")]
296            DType::F16 => unreachable!("f16 target rejected above"),
297            #[cfg(feature = "bf16")]
298            DType::BF16 => unreachable!("bf16 target rejected above"),
299        })
300    }
301
302    /// Create a `DynArray` of zeros with the given dtype and shape.
303    pub fn zeros(dtype: DType, shape: &[usize]) -> FerrayResult<Self> {
304        let dim = IxDyn::new(shape);
305        Ok(match dtype {
306            DType::Bool => Self::Bool(Array::zeros(dim)?),
307            DType::U8 => Self::U8(Array::zeros(dim)?),
308            DType::U16 => Self::U16(Array::zeros(dim)?),
309            DType::U32 => Self::U32(Array::zeros(dim)?),
310            DType::U64 => Self::U64(Array::zeros(dim)?),
311            DType::U128 => Self::U128(Array::zeros(dim)?),
312            DType::I8 => Self::I8(Array::zeros(dim)?),
313            DType::I16 => Self::I16(Array::zeros(dim)?),
314            DType::I32 => Self::I32(Array::zeros(dim)?),
315            DType::I64 => Self::I64(Array::zeros(dim)?),
316            DType::I128 => Self::I128(Array::zeros(dim)?),
317            DType::I256 => Self::I256(Array::zeros(dim)?),
318            DType::F32 => Self::F32(Array::zeros(dim)?),
319            DType::F64 => Self::F64(Array::zeros(dim)?),
320            DType::Complex32 => Self::Complex32(Array::zeros(dim)?),
321            DType::Complex64 => Self::Complex64(Array::zeros(dim)?),
322            #[cfg(feature = "f16")]
323            DType::F16 => Self::F16(Array::zeros(dim)?),
324            #[cfg(feature = "bf16")]
325            DType::BF16 => Self::BF16(Array::zeros(dim)?),
326        })
327    }
328}
329
330impl std::fmt::Display for DynArray {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        dispatch!(self, a => write!(f, "{a}"))
333    }
334}
335
336// Conversion from typed arrays to DynArray
337macro_rules! impl_from_array_dyn {
338    ($ty:ty, $variant:ident) => {
339        impl From<Array<$ty, IxDyn>> for DynArray {
340            fn from(a: Array<$ty, IxDyn>) -> Self {
341                Self::$variant(a)
342            }
343        }
344    };
345}
346
347impl_from_array_dyn!(bool, Bool);
348impl_from_array_dyn!(u8, U8);
349impl_from_array_dyn!(u16, U16);
350impl_from_array_dyn!(u32, U32);
351impl_from_array_dyn!(u64, U64);
352impl_from_array_dyn!(u128, U128);
353impl_from_array_dyn!(i8, I8);
354impl_from_array_dyn!(i16, I16);
355impl_from_array_dyn!(i32, I32);
356impl_from_array_dyn!(i64, I64);
357impl_from_array_dyn!(i128, I128);
358impl_from_array_dyn!(I256, I256);
359impl_from_array_dyn!(f32, F32);
360impl_from_array_dyn!(f64, F64);
361impl_from_array_dyn!(Complex<f32>, Complex32);
362impl_from_array_dyn!(Complex<f64>, Complex64);
363#[cfg(feature = "f16")]
364impl_from_array_dyn!(half::f16, F16);
365#[cfg(feature = "bf16")]
366impl_from_array_dyn!(half::bf16, BF16);
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn dynarray_zeros_f64() {
374        let da = DynArray::zeros(DType::F64, &[2, 3]).unwrap();
375        assert_eq!(da.dtype(), DType::F64);
376        assert_eq!(da.shape(), &[2, 3]);
377        assert_eq!(da.ndim(), 2);
378        assert_eq!(da.size(), 6);
379        assert_eq!(da.itemsize(), 8);
380        assert_eq!(da.nbytes(), 48);
381    }
382
383    #[test]
384    fn dynarray_zeros_i32() {
385        let da = DynArray::zeros(DType::I32, &[4]).unwrap();
386        assert_eq!(da.dtype(), DType::I32);
387        assert_eq!(da.shape(), &[4]);
388    }
389
390    #[test]
391    fn dynarray_try_into_f64() {
392        let da = DynArray::zeros(DType::F64, &[3]).unwrap();
393        let arr = da.try_into_f64().unwrap();
394        assert_eq!(arr.shape(), &[3]);
395    }
396
397    #[test]
398    fn dynarray_try_into_wrong_type() {
399        let da = DynArray::zeros(DType::I32, &[3]).unwrap();
400        assert!(da.try_into_f64().is_err());
401    }
402
403    // ----- DynArray::astype tests (issue #361) -----
404
405    #[test]
406    fn dynarray_astype_f64_to_i32_unsafe() {
407        let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![1.5, 2.7, -3.9]).unwrap();
408        let dy = DynArray::F64(arr);
409        let casted = dy.astype(DType::I32, CastKind::Unsafe).unwrap();
410        assert_eq!(casted.dtype(), DType::I32);
411        match casted {
412            DynArray::I32(a) => assert_eq!(a.as_slice().unwrap(), &[1, 2, -3]),
413            _ => panic!("expected I32"),
414        }
415    }
416
417    #[test]
418    fn dynarray_astype_safe_widening() {
419        let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), vec![10, 20, 30]).unwrap();
420        let dy = DynArray::I32(arr);
421        let casted = dy.astype(DType::I64, CastKind::Safe).unwrap();
422        assert_eq!(casted.dtype(), DType::I64);
423        match casted {
424            DynArray::I64(a) => assert_eq!(a.as_slice().unwrap(), &[10i64, 20, 30]),
425            _ => panic!("expected I64"),
426        }
427    }
428
429    #[test]
430    fn dynarray_astype_safe_narrowing_errors() {
431        let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
432        let dy = DynArray::F64(arr);
433        assert!(dy.astype(DType::F32, CastKind::Safe).is_err());
434    }
435
436    #[test]
437    fn dynarray_astype_complex_to_real_unsafe() {
438        let arr = Array::<Complex<f64>, IxDyn>::from_vec(
439            IxDyn::new(&[2]),
440            vec![Complex::new(1.5, 9.0), Complex::new(2.5, -1.0)],
441        )
442        .unwrap();
443        let dy = DynArray::Complex64(arr);
444        let casted = dy.astype(DType::F64, CastKind::Unsafe).unwrap();
445        match casted {
446            DynArray::F64(a) => assert_eq!(a.as_slice().unwrap(), &[1.5, 2.5]),
447            _ => panic!("expected F64"),
448        }
449    }
450
451    #[test]
452    fn dynarray_astype_bool_to_u8_safe() {
453        let arr =
454            Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3]), vec![true, false, true]).unwrap();
455        let dy = DynArray::Bool(arr);
456        let casted = dy.astype(DType::U8, CastKind::Safe).unwrap();
457        match casted {
458            DynArray::U8(a) => assert_eq!(a.as_slice().unwrap(), &[1u8, 0, 1]),
459            _ => panic!("expected U8"),
460        }
461    }
462
463    #[test]
464    fn dynarray_astype_no_kind_requires_identity() {
465        let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
466        let dy = DynArray::F64(arr);
467        assert!(dy.astype(DType::F64, CastKind::No).is_ok());
468        assert!(dy.astype(DType::F32, CastKind::No).is_err());
469    }
470
471    #[test]
472    fn dynarray_from_typed() {
473        let arr = Array::<f64, IxDyn>::zeros(IxDyn::new(&[2, 2])).unwrap();
474        let da: DynArray = arr.into();
475        assert_eq!(da.dtype(), DType::F64);
476    }
477
478    #[test]
479    fn dynarray_display() {
480        let da = DynArray::zeros(DType::I32, &[3]).unwrap();
481        let s = format!("{da}");
482        assert!(s.contains("[0, 0, 0]"));
483    }
484
485    #[test]
486    fn dynarray_is_empty() {
487        let da = DynArray::zeros(DType::F32, &[0]).unwrap();
488        assert!(da.is_empty());
489    }
490
491    // ----- f16 / bf16 DynArray coverage (#139) -----
492
493    #[cfg(feature = "f16")]
494    #[test]
495    fn dynarray_f16_zeros_shape_and_dtype() {
496        let da = DynArray::zeros(DType::F16, &[2, 3]).unwrap();
497        assert_eq!(da.dtype(), DType::F16);
498        assert_eq!(da.shape(), &[2, 3]);
499        assert_eq!(da.size(), 6);
500        assert_eq!(da.itemsize(), 2);
501        assert_eq!(da.nbytes(), 12);
502    }
503
504    #[cfg(feature = "f16")]
505    #[test]
506    fn dynarray_f16_from_typed_roundtrips() {
507        use half::f16;
508        let raw = [f16::from_f32(1.0), f16::from_f32(2.5), f16::from_f32(-3.0)];
509        let arr = Array::<f16, IxDyn>::from_vec(IxDyn::new(&[3]), raw.to_vec()).unwrap();
510        let da: DynArray = arr.into();
511        assert_eq!(da.dtype(), DType::F16);
512        assert_eq!(da.shape(), &[3]);
513    }
514
515    #[cfg(feature = "bf16")]
516    #[test]
517    fn dynarray_bf16_zeros_shape_and_dtype() {
518        let da = DynArray::zeros(DType::BF16, &[4]).unwrap();
519        assert_eq!(da.dtype(), DType::BF16);
520        assert_eq!(da.shape(), &[4]);
521        assert_eq!(da.itemsize(), 2);
522    }
523
524    #[cfg(feature = "bf16")]
525    #[test]
526    fn dynarray_bf16_from_typed_roundtrips() {
527        use half::bf16;
528        let raw = [bf16::from_f32(1.0), bf16::from_f32(2.0)];
529        let arr = Array::<bf16, IxDyn>::from_vec(IxDyn::new(&[2]), raw.to_vec()).unwrap();
530        let da: DynArray = arr.into();
531        assert_eq!(da.dtype(), DType::BF16);
532    }
533}