Skip to main content

ferray_core/
dynarray.rs

1// ferray-core: DynArray — runtime-typed array enum (REQ-30)
2
3use num_complex::Complex;
4
5use crate::array::owned::Array;
6use crate::dimension::IxDyn;
7use crate::dtype::DType;
8use crate::error::{FerrayError, FerrayResult};
9
10/// A runtime-typed array whose element type is determined at runtime.
11///
12/// This is analogous to a Python `numpy.ndarray` where the dtype is a
13/// runtime property. Each variant wraps an `Array<T, IxDyn>` for the
14/// corresponding element type.
15///
16/// Use this when the element type is not known at compile time (e.g.,
17/// loading from a file, receiving from Python/FFI).
18#[derive(Debug, Clone)]
19#[non_exhaustive]
20pub enum DynArray {
21    /// `bool` elements
22    Bool(Array<bool, IxDyn>),
23    /// `u8` elements
24    U8(Array<u8, IxDyn>),
25    /// `u16` elements
26    U16(Array<u16, IxDyn>),
27    /// `u32` elements
28    U32(Array<u32, IxDyn>),
29    /// `u64` elements
30    U64(Array<u64, IxDyn>),
31    /// `u128` elements
32    U128(Array<u128, IxDyn>),
33    /// `i8` elements
34    I8(Array<i8, IxDyn>),
35    /// `i16` elements
36    I16(Array<i16, IxDyn>),
37    /// `i32` elements
38    I32(Array<i32, IxDyn>),
39    /// `i64` elements
40    I64(Array<i64, IxDyn>),
41    /// `i128` elements
42    I128(Array<i128, IxDyn>),
43    /// `f32` elements
44    F32(Array<f32, IxDyn>),
45    /// `f64` elements
46    F64(Array<f64, IxDyn>),
47    /// `Complex<f32>` elements
48    Complex32(Array<Complex<f32>, IxDyn>),
49    /// `Complex<f64>` elements
50    Complex64(Array<Complex<f64>, IxDyn>),
51    /// `f16` elements (feature-gated)
52    #[cfg(feature = "f16")]
53    F16(Array<half::f16, IxDyn>),
54    /// `bf16` (bfloat16) elements (feature-gated)
55    #[cfg(feature = "bf16")]
56    BF16(Array<half::bf16, IxDyn>),
57}
58
59impl DynArray {
60    /// The runtime dtype of the elements in this array.
61    pub fn dtype(&self) -> DType {
62        match self {
63            Self::Bool(_) => DType::Bool,
64            Self::U8(_) => DType::U8,
65            Self::U16(_) => DType::U16,
66            Self::U32(_) => DType::U32,
67            Self::U64(_) => DType::U64,
68            Self::U128(_) => DType::U128,
69            Self::I8(_) => DType::I8,
70            Self::I16(_) => DType::I16,
71            Self::I32(_) => DType::I32,
72            Self::I64(_) => DType::I64,
73            Self::I128(_) => DType::I128,
74            Self::F32(_) => DType::F32,
75            Self::F64(_) => DType::F64,
76            Self::Complex32(_) => DType::Complex32,
77            Self::Complex64(_) => DType::Complex64,
78            #[cfg(feature = "f16")]
79            Self::F16(_) => DType::F16,
80            #[cfg(feature = "bf16")]
81            Self::BF16(_) => DType::BF16,
82        }
83    }
84
85    /// Shape as a slice.
86    pub fn shape(&self) -> &[usize] {
87        match self {
88            Self::Bool(a) => a.shape(),
89            Self::U8(a) => a.shape(),
90            Self::U16(a) => a.shape(),
91            Self::U32(a) => a.shape(),
92            Self::U64(a) => a.shape(),
93            Self::U128(a) => a.shape(),
94            Self::I8(a) => a.shape(),
95            Self::I16(a) => a.shape(),
96            Self::I32(a) => a.shape(),
97            Self::I64(a) => a.shape(),
98            Self::I128(a) => a.shape(),
99            Self::F32(a) => a.shape(),
100            Self::F64(a) => a.shape(),
101            Self::Complex32(a) => a.shape(),
102            Self::Complex64(a) => a.shape(),
103            #[cfg(feature = "f16")]
104            Self::F16(a) => a.shape(),
105            #[cfg(feature = "bf16")]
106            Self::BF16(a) => a.shape(),
107        }
108    }
109
110    /// Number of dimensions.
111    pub fn ndim(&self) -> usize {
112        self.shape().len()
113    }
114
115    /// Total number of elements.
116    pub fn size(&self) -> usize {
117        self.shape().iter().product()
118    }
119
120    /// Whether the array has zero elements.
121    pub fn is_empty(&self) -> bool {
122        self.size() == 0
123    }
124
125    /// Size in bytes of one element.
126    pub fn itemsize(&self) -> usize {
127        self.dtype().size_of()
128    }
129
130    /// Total size in bytes.
131    pub fn nbytes(&self) -> usize {
132        self.size() * self.itemsize()
133    }
134
135    /// Try to extract the inner `Array<f64, IxDyn>`.
136    ///
137    /// # Errors
138    /// Returns `FerrayError::InvalidDtype` if the dtype is not `f64`.
139    pub fn try_into_f64(self) -> FerrayResult<Array<f64, IxDyn>> {
140        match self {
141            Self::F64(a) => Ok(a),
142            other => Err(FerrayError::invalid_dtype(format!(
143                "expected float64, got {}",
144                other.dtype()
145            ))),
146        }
147    }
148
149    /// Try to extract the inner `Array<f32, IxDyn>`.
150    pub fn try_into_f32(self) -> FerrayResult<Array<f32, IxDyn>> {
151        match self {
152            Self::F32(a) => Ok(a),
153            other => Err(FerrayError::invalid_dtype(format!(
154                "expected float32, got {}",
155                other.dtype()
156            ))),
157        }
158    }
159
160    /// Try to extract the inner `Array<i64, IxDyn>`.
161    pub fn try_into_i64(self) -> FerrayResult<Array<i64, IxDyn>> {
162        match self {
163            Self::I64(a) => Ok(a),
164            other => Err(FerrayError::invalid_dtype(format!(
165                "expected int64, got {}",
166                other.dtype()
167            ))),
168        }
169    }
170
171    /// Try to extract the inner `Array<i32, IxDyn>`.
172    pub fn try_into_i32(self) -> FerrayResult<Array<i32, IxDyn>> {
173        match self {
174            Self::I32(a) => Ok(a),
175            other => Err(FerrayError::invalid_dtype(format!(
176                "expected int32, got {}",
177                other.dtype()
178            ))),
179        }
180    }
181
182    /// Try to extract the inner `Array<bool, IxDyn>`.
183    pub fn try_into_bool(self) -> FerrayResult<Array<bool, IxDyn>> {
184        match self {
185            Self::Bool(a) => Ok(a),
186            other => Err(FerrayError::invalid_dtype(format!(
187                "expected bool, got {}",
188                other.dtype()
189            ))),
190        }
191    }
192
193    /// Create a `DynArray` of zeros with the given dtype and shape.
194    pub fn zeros(dtype: DType, shape: &[usize]) -> FerrayResult<Self> {
195        let dim = IxDyn::new(shape);
196        Ok(match dtype {
197            DType::Bool => Self::Bool(Array::zeros(dim)?),
198            DType::U8 => Self::U8(Array::zeros(dim)?),
199            DType::U16 => Self::U16(Array::zeros(dim)?),
200            DType::U32 => Self::U32(Array::zeros(dim)?),
201            DType::U64 => Self::U64(Array::zeros(dim)?),
202            DType::U128 => Self::U128(Array::zeros(dim)?),
203            DType::I8 => Self::I8(Array::zeros(dim)?),
204            DType::I16 => Self::I16(Array::zeros(dim)?),
205            DType::I32 => Self::I32(Array::zeros(dim)?),
206            DType::I64 => Self::I64(Array::zeros(dim)?),
207            DType::I128 => Self::I128(Array::zeros(dim)?),
208            DType::F32 => Self::F32(Array::zeros(dim)?),
209            DType::F64 => Self::F64(Array::zeros(dim)?),
210            DType::Complex32 => Self::Complex32(Array::zeros(dim)?),
211            DType::Complex64 => Self::Complex64(Array::zeros(dim)?),
212            #[cfg(feature = "f16")]
213            DType::F16 => Self::F16(Array::zeros(dim)?),
214            #[cfg(feature = "bf16")]
215            DType::BF16 => Self::BF16(Array::zeros(dim)?),
216        })
217    }
218}
219
220impl std::fmt::Display for DynArray {
221    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
222        match self {
223            Self::Bool(a) => write!(f, "{a}"),
224            Self::U8(a) => write!(f, "{a}"),
225            Self::U16(a) => write!(f, "{a}"),
226            Self::U32(a) => write!(f, "{a}"),
227            Self::U64(a) => write!(f, "{a}"),
228            Self::U128(a) => write!(f, "{a}"),
229            Self::I8(a) => write!(f, "{a}"),
230            Self::I16(a) => write!(f, "{a}"),
231            Self::I32(a) => write!(f, "{a}"),
232            Self::I64(a) => write!(f, "{a}"),
233            Self::I128(a) => write!(f, "{a}"),
234            Self::F32(a) => write!(f, "{a}"),
235            Self::F64(a) => write!(f, "{a}"),
236            Self::Complex32(a) => write!(f, "{a}"),
237            Self::Complex64(a) => write!(f, "{a}"),
238            #[cfg(feature = "f16")]
239            Self::F16(a) => write!(f, "{a}"),
240            #[cfg(feature = "bf16")]
241            Self::BF16(a) => write!(f, "{a}"),
242        }
243    }
244}
245
246// Conversion from typed arrays to DynArray
247macro_rules! impl_from_array_dyn {
248    ($ty:ty, $variant:ident) => {
249        impl From<Array<$ty, IxDyn>> for DynArray {
250            fn from(a: Array<$ty, IxDyn>) -> Self {
251                Self::$variant(a)
252            }
253        }
254    };
255}
256
257impl_from_array_dyn!(bool, Bool);
258impl_from_array_dyn!(u8, U8);
259impl_from_array_dyn!(u16, U16);
260impl_from_array_dyn!(u32, U32);
261impl_from_array_dyn!(u64, U64);
262impl_from_array_dyn!(u128, U128);
263impl_from_array_dyn!(i8, I8);
264impl_from_array_dyn!(i16, I16);
265impl_from_array_dyn!(i32, I32);
266impl_from_array_dyn!(i64, I64);
267impl_from_array_dyn!(i128, I128);
268impl_from_array_dyn!(f32, F32);
269impl_from_array_dyn!(f64, F64);
270impl_from_array_dyn!(Complex<f32>, Complex32);
271impl_from_array_dyn!(Complex<f64>, Complex64);
272#[cfg(feature = "f16")]
273impl_from_array_dyn!(half::f16, F16);
274#[cfg(feature = "bf16")]
275impl_from_array_dyn!(half::bf16, BF16);
276
277#[cfg(test)]
278mod tests {
279    use super::*;
280
281    #[test]
282    fn dynarray_zeros_f64() {
283        let da = DynArray::zeros(DType::F64, &[2, 3]).unwrap();
284        assert_eq!(da.dtype(), DType::F64);
285        assert_eq!(da.shape(), &[2, 3]);
286        assert_eq!(da.ndim(), 2);
287        assert_eq!(da.size(), 6);
288        assert_eq!(da.itemsize(), 8);
289        assert_eq!(da.nbytes(), 48);
290    }
291
292    #[test]
293    fn dynarray_zeros_i32() {
294        let da = DynArray::zeros(DType::I32, &[4]).unwrap();
295        assert_eq!(da.dtype(), DType::I32);
296        assert_eq!(da.shape(), &[4]);
297    }
298
299    #[test]
300    fn dynarray_try_into_f64() {
301        let da = DynArray::zeros(DType::F64, &[3]).unwrap();
302        let arr = da.try_into_f64().unwrap();
303        assert_eq!(arr.shape(), &[3]);
304    }
305
306    #[test]
307    fn dynarray_try_into_wrong_type() {
308        let da = DynArray::zeros(DType::I32, &[3]).unwrap();
309        assert!(da.try_into_f64().is_err());
310    }
311
312    #[test]
313    fn dynarray_from_typed() {
314        let arr = Array::<f64, IxDyn>::zeros(IxDyn::new(&[2, 2])).unwrap();
315        let da: DynArray = arr.into();
316        assert_eq!(da.dtype(), DType::F64);
317    }
318
319    #[test]
320    fn dynarray_display() {
321        let da = DynArray::zeros(DType::I32, &[3]).unwrap();
322        let s = format!("{da}");
323        assert!(s.contains("[0, 0, 0]"));
324    }
325
326    #[test]
327    fn dynarray_is_empty() {
328        let da = DynArray::zeros(DType::F32, &[0]).unwrap();
329        assert!(da.is_empty());
330    }
331}