Skip to main content

ferray_core/
dynarray.rs

1// ferray-core: DynArray — runtime-typed array enum (REQ-30)
2
3use num_complex::Complex;
4
5use crate::array::owned::Array;
6use crate::dimension::IxDyn;
7use crate::dtype::DType;
8use crate::error::{FerrayError, FerrayResult};
9
10/// A runtime-typed array whose element type is determined at runtime.
11///
12/// This is analogous to a Python `numpy.ndarray` where the dtype is a
13/// runtime property. Each variant wraps an `Array<T, IxDyn>` for the
14/// corresponding element type.
15///
16/// Use this when the element type is not known at compile time (e.g.,
17/// loading from a file, receiving from Python/FFI).
18#[derive(Debug, Clone)]
19#[non_exhaustive]
20pub enum DynArray {
21    /// `bool` elements
22    Bool(Array<bool, IxDyn>),
23    /// `u8` elements
24    U8(Array<u8, IxDyn>),
25    /// `u16` elements
26    U16(Array<u16, IxDyn>),
27    /// `u32` elements
28    U32(Array<u32, IxDyn>),
29    /// `u64` elements
30    U64(Array<u64, IxDyn>),
31    /// `u128` elements
32    U128(Array<u128, IxDyn>),
33    /// `i8` elements
34    I8(Array<i8, IxDyn>),
35    /// `i16` elements
36    I16(Array<i16, IxDyn>),
37    /// `i32` elements
38    I32(Array<i32, IxDyn>),
39    /// `i64` elements
40    I64(Array<i64, IxDyn>),
41    /// `i128` elements
42    I128(Array<i128, IxDyn>),
43    /// `f32` elements
44    F32(Array<f32, IxDyn>),
45    /// `f64` elements
46    F64(Array<f64, IxDyn>),
47    /// `Complex<f32>` elements
48    Complex32(Array<Complex<f32>, IxDyn>),
49    /// `Complex<f64>` elements
50    Complex64(Array<Complex<f64>, IxDyn>),
51    /// `f16` elements (feature-gated)
52    #[cfg(feature = "f16")]
53    F16(Array<half::f16, IxDyn>),
54}
55
56impl DynArray {
57    /// The runtime dtype of the elements in this array.
58    pub fn dtype(&self) -> DType {
59        match self {
60            Self::Bool(_) => DType::Bool,
61            Self::U8(_) => DType::U8,
62            Self::U16(_) => DType::U16,
63            Self::U32(_) => DType::U32,
64            Self::U64(_) => DType::U64,
65            Self::U128(_) => DType::U128,
66            Self::I8(_) => DType::I8,
67            Self::I16(_) => DType::I16,
68            Self::I32(_) => DType::I32,
69            Self::I64(_) => DType::I64,
70            Self::I128(_) => DType::I128,
71            Self::F32(_) => DType::F32,
72            Self::F64(_) => DType::F64,
73            Self::Complex32(_) => DType::Complex32,
74            Self::Complex64(_) => DType::Complex64,
75            #[cfg(feature = "f16")]
76            Self::F16(_) => DType::F16,
77        }
78    }
79
80    /// Shape as a slice.
81    pub fn shape(&self) -> &[usize] {
82        match self {
83            Self::Bool(a) => a.shape(),
84            Self::U8(a) => a.shape(),
85            Self::U16(a) => a.shape(),
86            Self::U32(a) => a.shape(),
87            Self::U64(a) => a.shape(),
88            Self::U128(a) => a.shape(),
89            Self::I8(a) => a.shape(),
90            Self::I16(a) => a.shape(),
91            Self::I32(a) => a.shape(),
92            Self::I64(a) => a.shape(),
93            Self::I128(a) => a.shape(),
94            Self::F32(a) => a.shape(),
95            Self::F64(a) => a.shape(),
96            Self::Complex32(a) => a.shape(),
97            Self::Complex64(a) => a.shape(),
98            #[cfg(feature = "f16")]
99            Self::F16(a) => a.shape(),
100        }
101    }
102
103    /// Number of dimensions.
104    pub fn ndim(&self) -> usize {
105        self.shape().len()
106    }
107
108    /// Total number of elements.
109    pub fn size(&self) -> usize {
110        self.shape().iter().product()
111    }
112
113    /// Whether the array has zero elements.
114    pub fn is_empty(&self) -> bool {
115        self.size() == 0
116    }
117
118    /// Size in bytes of one element.
119    pub fn itemsize(&self) -> usize {
120        self.dtype().size_of()
121    }
122
123    /// Total size in bytes.
124    pub fn nbytes(&self) -> usize {
125        self.size() * self.itemsize()
126    }
127
128    /// Try to extract the inner `Array<f64, IxDyn>`.
129    ///
130    /// # Errors
131    /// Returns `FerrayError::InvalidDtype` if the dtype is not `f64`.
132    pub fn try_into_f64(self) -> FerrayResult<Array<f64, IxDyn>> {
133        match self {
134            Self::F64(a) => Ok(a),
135            other => Err(FerrayError::invalid_dtype(format!(
136                "expected float64, got {}",
137                other.dtype()
138            ))),
139        }
140    }
141
142    /// Try to extract the inner `Array<f32, IxDyn>`.
143    pub fn try_into_f32(self) -> FerrayResult<Array<f32, IxDyn>> {
144        match self {
145            Self::F32(a) => Ok(a),
146            other => Err(FerrayError::invalid_dtype(format!(
147                "expected float32, got {}",
148                other.dtype()
149            ))),
150        }
151    }
152
153    /// Try to extract the inner `Array<i64, IxDyn>`.
154    pub fn try_into_i64(self) -> FerrayResult<Array<i64, IxDyn>> {
155        match self {
156            Self::I64(a) => Ok(a),
157            other => Err(FerrayError::invalid_dtype(format!(
158                "expected int64, got {}",
159                other.dtype()
160            ))),
161        }
162    }
163
164    /// Try to extract the inner `Array<i32, IxDyn>`.
165    pub fn try_into_i32(self) -> FerrayResult<Array<i32, IxDyn>> {
166        match self {
167            Self::I32(a) => Ok(a),
168            other => Err(FerrayError::invalid_dtype(format!(
169                "expected int32, got {}",
170                other.dtype()
171            ))),
172        }
173    }
174
175    /// Try to extract the inner `Array<bool, IxDyn>`.
176    pub fn try_into_bool(self) -> FerrayResult<Array<bool, IxDyn>> {
177        match self {
178            Self::Bool(a) => Ok(a),
179            other => Err(FerrayError::invalid_dtype(format!(
180                "expected bool, got {}",
181                other.dtype()
182            ))),
183        }
184    }
185
186    /// Create a `DynArray` of zeros with the given dtype and shape.
187    pub fn zeros(dtype: DType, shape: &[usize]) -> FerrayResult<Self> {
188        let dim = IxDyn::new(shape);
189        Ok(match dtype {
190            DType::Bool => Self::Bool(Array::zeros(dim)?),
191            DType::U8 => Self::U8(Array::zeros(dim)?),
192            DType::U16 => Self::U16(Array::zeros(dim)?),
193            DType::U32 => Self::U32(Array::zeros(dim)?),
194            DType::U64 => Self::U64(Array::zeros(dim)?),
195            DType::U128 => Self::U128(Array::zeros(dim)?),
196            DType::I8 => Self::I8(Array::zeros(dim)?),
197            DType::I16 => Self::I16(Array::zeros(dim)?),
198            DType::I32 => Self::I32(Array::zeros(dim)?),
199            DType::I64 => Self::I64(Array::zeros(dim)?),
200            DType::I128 => Self::I128(Array::zeros(dim)?),
201            DType::F32 => Self::F32(Array::zeros(dim)?),
202            DType::F64 => Self::F64(Array::zeros(dim)?),
203            DType::Complex32 => Self::Complex32(Array::zeros(dim)?),
204            DType::Complex64 => Self::Complex64(Array::zeros(dim)?),
205            #[cfg(feature = "f16")]
206            DType::F16 => Self::F16(Array::zeros(dim)?),
207        })
208    }
209}
210
211impl std::fmt::Display for DynArray {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        match self {
214            Self::Bool(a) => write!(f, "{a}"),
215            Self::U8(a) => write!(f, "{a}"),
216            Self::U16(a) => write!(f, "{a}"),
217            Self::U32(a) => write!(f, "{a}"),
218            Self::U64(a) => write!(f, "{a}"),
219            Self::U128(a) => write!(f, "{a}"),
220            Self::I8(a) => write!(f, "{a}"),
221            Self::I16(a) => write!(f, "{a}"),
222            Self::I32(a) => write!(f, "{a}"),
223            Self::I64(a) => write!(f, "{a}"),
224            Self::I128(a) => write!(f, "{a}"),
225            Self::F32(a) => write!(f, "{a}"),
226            Self::F64(a) => write!(f, "{a}"),
227            Self::Complex32(a) => write!(f, "{a}"),
228            Self::Complex64(a) => write!(f, "{a}"),
229            #[cfg(feature = "f16")]
230            Self::F16(a) => write!(f, "{a}"),
231        }
232    }
233}
234
235// Conversion from typed arrays to DynArray
236macro_rules! impl_from_array_dyn {
237    ($ty:ty, $variant:ident) => {
238        impl From<Array<$ty, IxDyn>> for DynArray {
239            fn from(a: Array<$ty, IxDyn>) -> Self {
240                Self::$variant(a)
241            }
242        }
243    };
244}
245
246impl_from_array_dyn!(bool, Bool);
247impl_from_array_dyn!(u8, U8);
248impl_from_array_dyn!(u16, U16);
249impl_from_array_dyn!(u32, U32);
250impl_from_array_dyn!(u64, U64);
251impl_from_array_dyn!(u128, U128);
252impl_from_array_dyn!(i8, I8);
253impl_from_array_dyn!(i16, I16);
254impl_from_array_dyn!(i32, I32);
255impl_from_array_dyn!(i64, I64);
256impl_from_array_dyn!(i128, I128);
257impl_from_array_dyn!(f32, F32);
258impl_from_array_dyn!(f64, F64);
259impl_from_array_dyn!(Complex<f32>, Complex32);
260impl_from_array_dyn!(Complex<f64>, Complex64);
261#[cfg(feature = "f16")]
262impl_from_array_dyn!(half::f16, F16);
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn dynarray_zeros_f64() {
270        let da = DynArray::zeros(DType::F64, &[2, 3]).unwrap();
271        assert_eq!(da.dtype(), DType::F64);
272        assert_eq!(da.shape(), &[2, 3]);
273        assert_eq!(da.ndim(), 2);
274        assert_eq!(da.size(), 6);
275        assert_eq!(da.itemsize(), 8);
276        assert_eq!(da.nbytes(), 48);
277    }
278
279    #[test]
280    fn dynarray_zeros_i32() {
281        let da = DynArray::zeros(DType::I32, &[4]).unwrap();
282        assert_eq!(da.dtype(), DType::I32);
283        assert_eq!(da.shape(), &[4]);
284    }
285
286    #[test]
287    fn dynarray_try_into_f64() {
288        let da = DynArray::zeros(DType::F64, &[3]).unwrap();
289        let arr = da.try_into_f64().unwrap();
290        assert_eq!(arr.shape(), &[3]);
291    }
292
293    #[test]
294    fn dynarray_try_into_wrong_type() {
295        let da = DynArray::zeros(DType::I32, &[3]).unwrap();
296        assert!(da.try_into_f64().is_err());
297    }
298
299    #[test]
300    fn dynarray_from_typed() {
301        let arr = Array::<f64, IxDyn>::zeros(IxDyn::new(&[2, 2])).unwrap();
302        let da: DynArray = arr.into();
303        assert_eq!(da.dtype(), DType::F64);
304    }
305
306    #[test]
307    fn dynarray_display() {
308        let da = DynArray::zeros(DType::I32, &[3]).unwrap();
309        let s = format!("{da}");
310        assert!(s.contains("[0, 0, 0]"));
311    }
312
313    #[test]
314    fn dynarray_is_empty() {
315        let da = DynArray::zeros(DType::F32, &[0]).unwrap();
316        assert!(da.is_empty());
317    }
318}