Skip to main content

ferray_core/
dynarray.rs

1// ferray-core: DynArray — runtime-typed array enum (REQ-30)
2
3use num_complex::Complex;
4
5use crate::array::owned::Array;
6use crate::dimension::IxDyn;
7use crate::dtype::casting::CastKind;
8use crate::dtype::{DType, I256};
9use crate::error::{FerrayError, FerrayResult};
10
11/// A runtime-typed array whose element type is determined at runtime.
12///
13/// This is analogous to a Python `numpy.ndarray` where the dtype is a
14/// runtime property. Each variant wraps an `Array<T, IxDyn>` for the
15/// corresponding element type.
16///
17/// Use this when the element type is not known at compile time (e.g.,
18/// loading from a file, receiving from Python/FFI).
19#[derive(Debug, Clone)]
20#[non_exhaustive]
21pub enum DynArray {
22    /// `bool` elements
23    Bool(Array<bool, IxDyn>),
24    /// `u8` elements
25    U8(Array<u8, IxDyn>),
26    /// `u16` elements
27    U16(Array<u16, IxDyn>),
28    /// `u32` elements
29    U32(Array<u32, IxDyn>),
30    /// `u64` elements
31    U64(Array<u64, IxDyn>),
32    /// `u128` elements
33    U128(Array<u128, IxDyn>),
34    /// `i8` elements
35    I8(Array<i8, IxDyn>),
36    /// `i16` elements
37    I16(Array<i16, IxDyn>),
38    /// `i32` elements
39    I32(Array<i32, IxDyn>),
40    /// `i64` elements
41    I64(Array<i64, IxDyn>),
42    /// `i128` elements
43    I128(Array<i128, IxDyn>),
44    /// `I256` elements — 256-bit two's-complement signed integer,
45    /// used as the promoted type for mixed `u128` + signed-int
46    /// arithmetic (#375, #562).
47    I256(Array<I256, IxDyn>),
48    /// `f32` elements
49    F32(Array<f32, IxDyn>),
50    /// `f64` elements
51    F64(Array<f64, IxDyn>),
52    /// `Complex<f32>` elements
53    Complex32(Array<Complex<f32>, IxDyn>),
54    /// `Complex<f64>` elements
55    Complex64(Array<Complex<f64>, IxDyn>),
56    /// `f16` elements (feature-gated)
57    #[cfg(feature = "f16")]
58    F16(Array<half::f16, IxDyn>),
59    /// `bf16` (bfloat16) elements (feature-gated)
60    #[cfg(feature = "bf16")]
61    BF16(Array<half::bf16, IxDyn>),
62}
63
64/// Dispatch a single expression across every `DynArray` variant, binding
65/// the inner `Array<T, IxDyn>` to `$binding`. This turns repeated 17-way
66/// match arms into one-line methods (see issue #125); the f16/bf16
67/// variants are conditionally compiled in the same way as the enum.
68macro_rules! dispatch {
69    ($value:expr, $binding:ident => $expr:expr) => {
70        match $value {
71            Self::Bool($binding) => $expr,
72            Self::U8($binding) => $expr,
73            Self::U16($binding) => $expr,
74            Self::U32($binding) => $expr,
75            Self::U64($binding) => $expr,
76            Self::U128($binding) => $expr,
77            Self::I8($binding) => $expr,
78            Self::I16($binding) => $expr,
79            Self::I32($binding) => $expr,
80            Self::I64($binding) => $expr,
81            Self::I128($binding) => $expr,
82            Self::I256($binding) => $expr,
83            Self::F32($binding) => $expr,
84            Self::F64($binding) => $expr,
85            Self::Complex32($binding) => $expr,
86            Self::Complex64($binding) => $expr,
87            #[cfg(feature = "f16")]
88            Self::F16($binding) => $expr,
89            #[cfg(feature = "bf16")]
90            Self::BF16($binding) => $expr,
91        }
92    };
93}
94
95impl DynArray {
96    /// The runtime dtype of the elements in this array.
97    #[must_use]
98    pub const fn dtype(&self) -> DType {
99        match self {
100            Self::Bool(_) => DType::Bool,
101            Self::U8(_) => DType::U8,
102            Self::U16(_) => DType::U16,
103            Self::U32(_) => DType::U32,
104            Self::U64(_) => DType::U64,
105            Self::U128(_) => DType::U128,
106            Self::I8(_) => DType::I8,
107            Self::I16(_) => DType::I16,
108            Self::I32(_) => DType::I32,
109            Self::I64(_) => DType::I64,
110            Self::I128(_) => DType::I128,
111            Self::I256(_) => DType::I256,
112            Self::F32(_) => DType::F32,
113            Self::F64(_) => DType::F64,
114            Self::Complex32(_) => DType::Complex32,
115            Self::Complex64(_) => DType::Complex64,
116            #[cfg(feature = "f16")]
117            Self::F16(_) => DType::F16,
118            #[cfg(feature = "bf16")]
119            Self::BF16(_) => DType::BF16,
120        }
121    }
122
123    /// Shape as a slice.
124    #[must_use]
125    pub fn shape(&self) -> &[usize] {
126        dispatch!(self, a => a.shape())
127    }
128
129    /// Number of dimensions.
130    #[must_use]
131    pub fn ndim(&self) -> usize {
132        self.shape().len()
133    }
134
135    /// Total number of elements.
136    #[must_use]
137    pub fn size(&self) -> usize {
138        self.shape().iter().product()
139    }
140
141    /// Whether the array has zero elements.
142    #[must_use]
143    pub fn is_empty(&self) -> bool {
144        self.size() == 0
145    }
146
147    /// Size in bytes of one element.
148    #[must_use]
149    pub const fn itemsize(&self) -> usize {
150        self.dtype().size_of()
151    }
152
153    /// Total size in bytes.
154    #[must_use]
155    pub fn nbytes(&self) -> usize {
156        self.size() * self.itemsize()
157    }
158
159    /// Try to extract the inner `Array<f64, IxDyn>`.
160    ///
161    /// # Errors
162    /// Returns `FerrayError::InvalidDtype` if the dtype is not `f64`.
163    pub fn try_into_f64(self) -> FerrayResult<Array<f64, IxDyn>> {
164        match self {
165            Self::F64(a) => Ok(a),
166            other => Err(FerrayError::invalid_dtype(format!(
167                "expected float64, got {}",
168                other.dtype()
169            ))),
170        }
171    }
172
173    /// Try to extract the inner `Array<f32, IxDyn>`.
174    pub fn try_into_f32(self) -> FerrayResult<Array<f32, IxDyn>> {
175        match self {
176            Self::F32(a) => Ok(a),
177            other => Err(FerrayError::invalid_dtype(format!(
178                "expected float32, got {}",
179                other.dtype()
180            ))),
181        }
182    }
183
184    /// Try to extract the inner `Array<i64, IxDyn>`.
185    pub fn try_into_i64(self) -> FerrayResult<Array<i64, IxDyn>> {
186        match self {
187            Self::I64(a) => Ok(a),
188            other => Err(FerrayError::invalid_dtype(format!(
189                "expected int64, got {}",
190                other.dtype()
191            ))),
192        }
193    }
194
195    /// Try to extract the inner `Array<i32, IxDyn>`.
196    pub fn try_into_i32(self) -> FerrayResult<Array<i32, IxDyn>> {
197        match self {
198            Self::I32(a) => Ok(a),
199            other => Err(FerrayError::invalid_dtype(format!(
200                "expected int32, got {}",
201                other.dtype()
202            ))),
203        }
204    }
205
206    /// Try to extract the inner `Array<bool, IxDyn>`.
207    pub fn try_into_bool(self) -> FerrayResult<Array<bool, IxDyn>> {
208        match self {
209            Self::Bool(a) => Ok(a),
210            other => Err(FerrayError::invalid_dtype(format!(
211                "expected bool, got {}",
212                other.dtype()
213            ))),
214        }
215    }
216
217    /// Cast this array to a different element dtype at the requested safety level.
218    ///
219    /// Mirrors `NumPy`'s `arr.astype(dtype, casting=...)`. The conversion routes
220    /// through [`crate::dtype::unsafe_cast::CastTo`] for the underlying typed
221    /// arrays, so it supports the same set of element pairs (every primitive
222    /// numeric, bool, and Complex<f32>/Complex<f64> in any combination).
223    ///
224    /// `f16` / `bf16` are not yet supported in this dispatch and will return
225    /// `FerrayError::InvalidDtype` — track via the umbrella casting issue.
226    ///
227    /// # Errors
228    /// Returns `FerrayError::InvalidDtype` if:
229    /// - The cast is not permitted at the chosen `casting` level, or
230    /// - either source or target dtype is `f16`/`bf16` (not yet wired).
231    pub fn astype(&self, target: DType, casting: CastKind) -> FerrayResult<Self> {
232        // Reject f16/bf16 — see method docs.
233        #[cfg(feature = "f16")]
234        if matches!(self, Self::F16(_)) || target == DType::F16 {
235            return Err(FerrayError::invalid_dtype(
236                "DynArray::astype does not yet support f16",
237            ));
238        }
239        #[cfg(feature = "bf16")]
240        if matches!(self, Self::BF16(_)) || target == DType::BF16 {
241            return Err(FerrayError::invalid_dtype(
242                "DynArray::astype does not yet support bf16",
243            ));
244        }
245        // I256 is accepted as a storage type (from promotion) but
246        // generic cast-through-`CastTo` is not yet wired for it.
247        // Reject both source and target explicitly — users who want
248        // I256 arrays should construct them directly rather than via
249        // astype (#562).
250        if matches!(self, Self::I256(_)) || target == DType::I256 {
251            return Err(FerrayError::invalid_dtype(
252                "DynArray::astype does not yet support I256 — construct I256 arrays directly",
253            ));
254        }
255
256        // Inner macro: dispatch on the *source* variant. The target type `$U`
257        // is fixed by the outer match below.
258        macro_rules! cast_into {
259            ($U:ty) => {
260                match self {
261                    Self::Bool(a) => a.cast::<$U>(casting),
262                    Self::U8(a) => a.cast::<$U>(casting),
263                    Self::U16(a) => a.cast::<$U>(casting),
264                    Self::U32(a) => a.cast::<$U>(casting),
265                    Self::U64(a) => a.cast::<$U>(casting),
266                    Self::U128(a) => a.cast::<$U>(casting),
267                    Self::I8(a) => a.cast::<$U>(casting),
268                    Self::I16(a) => a.cast::<$U>(casting),
269                    Self::I32(a) => a.cast::<$U>(casting),
270                    Self::I64(a) => a.cast::<$U>(casting),
271                    Self::I128(a) => a.cast::<$U>(casting),
272                    Self::F32(a) => a.cast::<$U>(casting),
273                    Self::F64(a) => a.cast::<$U>(casting),
274                    Self::Complex32(a) => a.cast::<$U>(casting),
275                    Self::Complex64(a) => a.cast::<$U>(casting),
276                    Self::I256(_) => unreachable!("I256 source rejected above"),
277                    #[cfg(feature = "f16")]
278                    Self::F16(_) => unreachable!("f16 source rejected above"),
279                    #[cfg(feature = "bf16")]
280                    Self::BF16(_) => unreachable!("bf16 source rejected above"),
281                }
282            };
283        }
284
285        Ok(match target {
286            DType::Bool => Self::Bool(cast_into!(bool)?),
287            DType::U8 => Self::U8(cast_into!(u8)?),
288            DType::U16 => Self::U16(cast_into!(u16)?),
289            DType::U32 => Self::U32(cast_into!(u32)?),
290            DType::U64 => Self::U64(cast_into!(u64)?),
291            DType::U128 => Self::U128(cast_into!(u128)?),
292            DType::I8 => Self::I8(cast_into!(i8)?),
293            DType::I16 => Self::I16(cast_into!(i16)?),
294            DType::I32 => Self::I32(cast_into!(i32)?),
295            DType::I64 => Self::I64(cast_into!(i64)?),
296            DType::I128 => Self::I128(cast_into!(i128)?),
297            DType::F32 => Self::F32(cast_into!(f32)?),
298            DType::F64 => Self::F64(cast_into!(f64)?),
299            DType::Complex32 => Self::Complex32(cast_into!(Complex<f32>)?),
300            DType::Complex64 => Self::Complex64(cast_into!(Complex<f64>)?),
301            DType::I256 => unreachable!("I256 target rejected above"),
302            #[cfg(feature = "f16")]
303            DType::F16 => unreachable!("f16 target rejected above"),
304            #[cfg(feature = "bf16")]
305            DType::BF16 => unreachable!("bf16 target rejected above"),
306        })
307    }
308
309    /// Create a `DynArray` of zeros with the given dtype and shape.
310    pub fn zeros(dtype: DType, shape: &[usize]) -> FerrayResult<Self> {
311        let dim = IxDyn::new(shape);
312        Ok(match dtype {
313            DType::Bool => Self::Bool(Array::zeros(dim)?),
314            DType::U8 => Self::U8(Array::zeros(dim)?),
315            DType::U16 => Self::U16(Array::zeros(dim)?),
316            DType::U32 => Self::U32(Array::zeros(dim)?),
317            DType::U64 => Self::U64(Array::zeros(dim)?),
318            DType::U128 => Self::U128(Array::zeros(dim)?),
319            DType::I8 => Self::I8(Array::zeros(dim)?),
320            DType::I16 => Self::I16(Array::zeros(dim)?),
321            DType::I32 => Self::I32(Array::zeros(dim)?),
322            DType::I64 => Self::I64(Array::zeros(dim)?),
323            DType::I128 => Self::I128(Array::zeros(dim)?),
324            DType::I256 => Self::I256(Array::zeros(dim)?),
325            DType::F32 => Self::F32(Array::zeros(dim)?),
326            DType::F64 => Self::F64(Array::zeros(dim)?),
327            DType::Complex32 => Self::Complex32(Array::zeros(dim)?),
328            DType::Complex64 => Self::Complex64(Array::zeros(dim)?),
329            #[cfg(feature = "f16")]
330            DType::F16 => Self::F16(Array::zeros(dim)?),
331            #[cfg(feature = "bf16")]
332            DType::BF16 => Self::BF16(Array::zeros(dim)?),
333        })
334    }
335}
336
337impl std::fmt::Display for DynArray {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        dispatch!(self, a => write!(f, "{a}"))
340    }
341}
342
343// Conversion from typed arrays to DynArray
344macro_rules! impl_from_array_dyn {
345    ($ty:ty, $variant:ident) => {
346        impl From<Array<$ty, IxDyn>> for DynArray {
347            fn from(a: Array<$ty, IxDyn>) -> Self {
348                Self::$variant(a)
349            }
350        }
351    };
352}
353
354impl_from_array_dyn!(bool, Bool);
355impl_from_array_dyn!(u8, U8);
356impl_from_array_dyn!(u16, U16);
357impl_from_array_dyn!(u32, U32);
358impl_from_array_dyn!(u64, U64);
359impl_from_array_dyn!(u128, U128);
360impl_from_array_dyn!(i8, I8);
361impl_from_array_dyn!(i16, I16);
362impl_from_array_dyn!(i32, I32);
363impl_from_array_dyn!(i64, I64);
364impl_from_array_dyn!(i128, I128);
365impl_from_array_dyn!(I256, I256);
366impl_from_array_dyn!(f32, F32);
367impl_from_array_dyn!(f64, F64);
368impl_from_array_dyn!(Complex<f32>, Complex32);
369impl_from_array_dyn!(Complex<f64>, Complex64);
370#[cfg(feature = "f16")]
371impl_from_array_dyn!(half::f16, F16);
372#[cfg(feature = "bf16")]
373impl_from_array_dyn!(half::bf16, BF16);
374
375#[cfg(test)]
376mod tests {
377    use super::*;
378
379    #[test]
380    fn dynarray_zeros_f64() {
381        let da = DynArray::zeros(DType::F64, &[2, 3]).unwrap();
382        assert_eq!(da.dtype(), DType::F64);
383        assert_eq!(da.shape(), &[2, 3]);
384        assert_eq!(da.ndim(), 2);
385        assert_eq!(da.size(), 6);
386        assert_eq!(da.itemsize(), 8);
387        assert_eq!(da.nbytes(), 48);
388    }
389
390    #[test]
391    fn dynarray_zeros_i32() {
392        let da = DynArray::zeros(DType::I32, &[4]).unwrap();
393        assert_eq!(da.dtype(), DType::I32);
394        assert_eq!(da.shape(), &[4]);
395    }
396
397    #[test]
398    fn dynarray_try_into_f64() {
399        let da = DynArray::zeros(DType::F64, &[3]).unwrap();
400        let arr = da.try_into_f64().unwrap();
401        assert_eq!(arr.shape(), &[3]);
402    }
403
404    #[test]
405    fn dynarray_try_into_wrong_type() {
406        let da = DynArray::zeros(DType::I32, &[3]).unwrap();
407        assert!(da.try_into_f64().is_err());
408    }
409
410    // ----- DynArray::astype tests (issue #361) -----
411
412    #[test]
413    fn dynarray_astype_f64_to_i32_unsafe() {
414        let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[3]), vec![1.5, 2.7, -3.9]).unwrap();
415        let dy = DynArray::F64(arr);
416        let casted = dy.astype(DType::I32, CastKind::Unsafe).unwrap();
417        assert_eq!(casted.dtype(), DType::I32);
418        match casted {
419            DynArray::I32(a) => assert_eq!(a.as_slice().unwrap(), &[1, 2, -3]),
420            _ => panic!("expected I32"),
421        }
422    }
423
424    #[test]
425    fn dynarray_astype_safe_widening() {
426        let arr = Array::<i32, IxDyn>::from_vec(IxDyn::new(&[3]), vec![10, 20, 30]).unwrap();
427        let dy = DynArray::I32(arr);
428        let casted = dy.astype(DType::I64, CastKind::Safe).unwrap();
429        assert_eq!(casted.dtype(), DType::I64);
430        match casted {
431            DynArray::I64(a) => assert_eq!(a.as_slice().unwrap(), &[10i64, 20, 30]),
432            _ => panic!("expected I64"),
433        }
434    }
435
436    #[test]
437    fn dynarray_astype_safe_narrowing_errors() {
438        let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
439        let dy = DynArray::F64(arr);
440        assert!(dy.astype(DType::F32, CastKind::Safe).is_err());
441    }
442
443    #[test]
444    fn dynarray_astype_complex_to_real_unsafe() {
445        let arr = Array::<Complex<f64>, IxDyn>::from_vec(
446            IxDyn::new(&[2]),
447            vec![Complex::new(1.5, 9.0), Complex::new(2.5, -1.0)],
448        )
449        .unwrap();
450        let dy = DynArray::Complex64(arr);
451        let casted = dy.astype(DType::F64, CastKind::Unsafe).unwrap();
452        match casted {
453            DynArray::F64(a) => assert_eq!(a.as_slice().unwrap(), &[1.5, 2.5]),
454            _ => panic!("expected F64"),
455        }
456    }
457
458    #[test]
459    fn dynarray_astype_bool_to_u8_safe() {
460        let arr =
461            Array::<bool, IxDyn>::from_vec(IxDyn::new(&[3]), vec![true, false, true]).unwrap();
462        let dy = DynArray::Bool(arr);
463        let casted = dy.astype(DType::U8, CastKind::Safe).unwrap();
464        match casted {
465            DynArray::U8(a) => assert_eq!(a.as_slice().unwrap(), &[1u8, 0, 1]),
466            _ => panic!("expected U8"),
467        }
468    }
469
470    #[test]
471    fn dynarray_astype_no_kind_requires_identity() {
472        let arr = Array::<f64, IxDyn>::from_vec(IxDyn::new(&[2]), vec![1.0, 2.0]).unwrap();
473        let dy = DynArray::F64(arr);
474        assert!(dy.astype(DType::F64, CastKind::No).is_ok());
475        assert!(dy.astype(DType::F32, CastKind::No).is_err());
476    }
477
478    #[test]
479    fn dynarray_from_typed() {
480        let arr = Array::<f64, IxDyn>::zeros(IxDyn::new(&[2, 2])).unwrap();
481        let da: DynArray = arr.into();
482        assert_eq!(da.dtype(), DType::F64);
483    }
484
485    #[test]
486    fn dynarray_display() {
487        let da = DynArray::zeros(DType::I32, &[3]).unwrap();
488        let s = format!("{da}");
489        assert!(s.contains("[0, 0, 0]"));
490    }
491
492    #[test]
493    fn dynarray_is_empty() {
494        let da = DynArray::zeros(DType::F32, &[0]).unwrap();
495        assert!(da.is_empty());
496    }
497
498    // ----- f16 / bf16 DynArray coverage (#139) -----
499
500    #[cfg(feature = "f16")]
501    #[test]
502    fn dynarray_f16_zeros_shape_and_dtype() {
503        let da = DynArray::zeros(DType::F16, &[2, 3]).unwrap();
504        assert_eq!(da.dtype(), DType::F16);
505        assert_eq!(da.shape(), &[2, 3]);
506        assert_eq!(da.size(), 6);
507        assert_eq!(da.itemsize(), 2);
508        assert_eq!(da.nbytes(), 12);
509    }
510
511    #[cfg(feature = "f16")]
512    #[test]
513    fn dynarray_f16_from_typed_roundtrips() {
514        use half::f16;
515        let raw = [f16::from_f32(1.0), f16::from_f32(2.5), f16::from_f32(-3.0)];
516        let arr = Array::<f16, IxDyn>::from_vec(IxDyn::new(&[3]), raw.to_vec()).unwrap();
517        let da: DynArray = arr.into();
518        assert_eq!(da.dtype(), DType::F16);
519        assert_eq!(da.shape(), &[3]);
520    }
521
522    #[cfg(feature = "bf16")]
523    #[test]
524    fn dynarray_bf16_zeros_shape_and_dtype() {
525        let da = DynArray::zeros(DType::BF16, &[4]).unwrap();
526        assert_eq!(da.dtype(), DType::BF16);
527        assert_eq!(da.shape(), &[4]);
528        assert_eq!(da.itemsize(), 2);
529    }
530
531    #[cfg(feature = "bf16")]
532    #[test]
533    fn dynarray_bf16_from_typed_roundtrips() {
534        use half::bf16;
535        let raw = [bf16::from_f32(1.0), bf16::from_f32(2.0)];
536        let arr = Array::<bf16, IxDyn>::from_vec(IxDyn::new(&[2]), raw.to_vec()).unwrap();
537        let da: DynArray = arr.into();
538        assert_eq!(da.dtype(), DType::BF16);
539    }
540}