1use crate::array::{Array, DType, DTypeValue};
8use anyhow::{Result, bail};
9
10#[derive(Debug, Clone)]
30pub enum DynArray {
31 F32(Array<f32>),
32 F64(Array<f64>),
33 I32(Array<i32>),
34 I8(Array<i8>),
35 U8(Array<u8>),
36 Bool(Array<bool>),
37}
38
39impl DynArray {
40 pub fn dtype(&self) -> DType {
42 match self {
43 DynArray::F32(_) => DType::F32,
44 DynArray::F64(_) => DType::F64,
45 DynArray::I32(_) => DType::I32,
46 DynArray::I8(_) => DType::I8,
47 DynArray::U8(_) => DType::U8,
48 DynArray::Bool(_) => DType::Bool,
49 }
50 }
51
52 pub fn shape(&self) -> &[usize] {
54 match self {
55 DynArray::F32(a) => &a.shape,
56 DynArray::F64(a) => &a.shape,
57 DynArray::I32(a) => &a.shape,
58 DynArray::I8(a) => &a.shape,
59 DynArray::U8(a) => &a.shape,
60 DynArray::Bool(a) => &a.shape,
61 }
62 }
63
64 pub fn len(&self) -> usize {
66 self.shape().iter().product()
67 }
68
69 pub fn is_empty(&self) -> bool {
71 self.len() == 0
72 }
73
74 pub fn to_f32(&self) -> Array<f32> {
76 match self {
77 DynArray::F32(a) => a.clone(),
78 DynArray::F64(a) => {
79 let data: Vec<f32> = a.data.iter().map(|&x| x as f32).collect();
80 Array::new(a.shape.clone(), data)
81 }
82 DynArray::I32(a) => {
83 let data: Vec<f32> = a.data.iter().map(|&x| x as f32).collect();
84 Array::new(a.shape.clone(), data)
85 }
86 DynArray::I8(a) => {
87 let data: Vec<f32> = a.data.iter().map(|&x| x as f32).collect();
88 Array::new(a.shape.clone(), data)
89 }
90 DynArray::U8(a) => {
91 let data: Vec<f32> = a.data.iter().map(|&x| x as f32).collect();
92 Array::new(a.shape.clone(), data)
93 }
94 DynArray::Bool(a) => {
95 let data: Vec<f32> = a.data.iter().map(|&x| if x { 1.0 } else { 0.0 }).collect();
96 Array::new(a.shape.clone(), data)
97 }
98 }
99 }
100
101 pub fn map_data<F, R>(&self, f: F) -> Result<R>
103 where
104 F: FnOnce(&dyn std::any::Any) -> Result<R>,
105 {
106 match self {
107 DynArray::F32(a) => f(&a.data),
108 DynArray::F64(a) => f(&a.data),
109 DynArray::I32(a) => f(&a.data),
110 DynArray::I8(a) => f(&a.data),
111 DynArray::U8(a) => f(&a.data),
112 DynArray::Bool(a) => f(&a.data),
113 }
114 }
115
116 pub fn into_typed<T: DTypeValue>(self) -> Result<Array<T>> {
122 use std::any::TypeId;
123 use std::mem;
124
125 match self {
128 DynArray::F32(arr) if TypeId::of::<T>() == TypeId::of::<f32>() => {
129 return Ok(unsafe { mem::transmute::<Array<f32>, Array<T>>(arr) });
130 }
131 DynArray::F64(arr) if TypeId::of::<T>() == TypeId::of::<f64>() => {
132 return Ok(unsafe { mem::transmute::<Array<f64>, Array<T>>(arr) });
133 }
134 DynArray::I32(arr) if TypeId::of::<T>() == TypeId::of::<i32>() => {
135 return Ok(unsafe { mem::transmute::<Array<i32>, Array<T>>(arr) });
136 }
137 DynArray::I8(arr) if TypeId::of::<T>() == TypeId::of::<i8>() => {
138 return Ok(unsafe { mem::transmute::<Array<i8>, Array<T>>(arr) });
139 }
140 DynArray::U8(arr) if TypeId::of::<T>() == TypeId::of::<u8>() => {
141 return Ok(unsafe { mem::transmute::<Array<u8>, Array<T>>(arr) });
142 }
143 DynArray::Bool(arr) if TypeId::of::<T>() == TypeId::of::<bool>() => {
144 return Ok(unsafe { mem::transmute::<Array<bool>, Array<T>>(arr) });
145 }
146 _ => {
147 match self {
149 DynArray::F32(arr) => Ok(crate::array::promotion::cast_array(&arr)),
150 DynArray::F64(arr) => Ok(crate::array::promotion::cast_array(&arr)),
151 DynArray::I32(arr) => Ok(crate::array::promotion::cast_array(&arr)),
152 DynArray::I8(arr) => Ok(crate::array::promotion::cast_array(&arr)),
153 DynArray::U8(arr) => Ok(crate::array::promotion::cast_array(&arr)),
154 DynArray::Bool(arr) => Ok(crate::array::promotion::cast_array(&arr)),
155 }
156 }
157 }
158 }
159
160 pub fn from_generic<T: DTypeValue>(arr: Array<T>) -> Self {
164 use std::any::TypeId;
165 use std::mem;
166
167 if TypeId::of::<T>() == TypeId::of::<f32>() {
169 let arr_f32 = unsafe { mem::transmute::<Array<T>, Array<f32>>(arr) };
170 return DynArray::F32(arr_f32);
171 }
172 if TypeId::of::<T>() == TypeId::of::<f64>() {
173 let arr_f64 = unsafe { mem::transmute::<Array<T>, Array<f64>>(arr) };
174 return DynArray::F64(arr_f64);
175 }
176 if TypeId::of::<T>() == TypeId::of::<i32>() {
177 let arr_i32 = unsafe { mem::transmute::<Array<T>, Array<i32>>(arr) };
178 return DynArray::I32(arr_i32);
179 }
180 if TypeId::of::<T>() == TypeId::of::<i8>() {
181 let arr_i8 = unsafe { mem::transmute::<Array<T>, Array<i8>>(arr) };
182 return DynArray::I8(arr_i8);
183 }
184 if TypeId::of::<T>() == TypeId::of::<u8>() {
185 let arr_u8 = unsafe { mem::transmute::<Array<T>, Array<u8>>(arr) };
186 return DynArray::U8(arr_u8);
187 }
188 if TypeId::of::<T>() == TypeId::of::<bool>() {
189 let arr_bool = unsafe { mem::transmute::<Array<T>, Array<bool>>(arr) };
190 return DynArray::Bool(arr_bool);
191 }
192
193 panic!("Unsupported dtype for DynArray::from_generic");
195 }
196}
197
198impl From<Array<f32>> for DynArray {
200 fn from(arr: Array<f32>) -> Self {
201 DynArray::F32(arr)
202 }
203}
204
205impl From<Array<f64>> for DynArray {
206 fn from(arr: Array<f64>) -> Self {
207 DynArray::F64(arr)
208 }
209}
210
211impl From<Array<i32>> for DynArray {
212 fn from(arr: Array<i32>) -> Self {
213 DynArray::I32(arr)
214 }
215}
216
217impl From<Array<i8>> for DynArray {
218 fn from(arr: Array<i8>) -> Self {
219 DynArray::I8(arr)
220 }
221}
222
223impl From<Array<u8>> for DynArray {
224 fn from(arr: Array<u8>) -> Self {
225 DynArray::U8(arr)
226 }
227}
228
229impl From<Array<bool>> for DynArray {
230 fn from(arr: Array<bool>) -> Self {
231 DynArray::Bool(arr)
232 }
233}
234
235impl TryFrom<DynArray> for Array<f32> {
237 type Error = anyhow::Error;
238
239 fn try_from(dyn_arr: DynArray) -> Result<Self> {
240 match dyn_arr {
241 DynArray::F32(a) => Ok(a),
242 _ => bail!("Expected F32, got {:?}", dyn_arr.dtype()),
243 }
244 }
245}
246
247impl TryFrom<DynArray> for Array<f64> {
248 type Error = anyhow::Error;
249
250 fn try_from(dyn_arr: DynArray) -> Result<Self> {
251 match dyn_arr {
252 DynArray::F64(a) => Ok(a),
253 _ => bail!("Expected F64, got {:?}", dyn_arr.dtype()),
254 }
255 }
256}
257
258impl TryFrom<DynArray> for Array<i32> {
259 type Error = anyhow::Error;
260
261 fn try_from(dyn_arr: DynArray) -> Result<Self> {
262 match dyn_arr {
263 DynArray::I32(a) => Ok(a),
264 _ => bail!("Expected I32, got {:?}", dyn_arr.dtype()),
265 }
266 }
267}
268
269impl TryFrom<DynArray> for Array<i8> {
270 type Error = anyhow::Error;
271
272 fn try_from(dyn_arr: DynArray) -> Result<Self> {
273 match dyn_arr {
274 DynArray::I8(a) => Ok(a),
275 _ => bail!("Expected I8, got {:?}", dyn_arr.dtype()),
276 }
277 }
278}
279
280impl TryFrom<DynArray> for Array<u8> {
281 type Error = anyhow::Error;
282
283 fn try_from(dyn_arr: DynArray) -> Result<Self> {
284 match dyn_arr {
285 DynArray::U8(a) => Ok(a),
286 _ => bail!("Expected U8, got {:?}", dyn_arr.dtype()),
287 }
288 }
289}
290
291impl TryFrom<DynArray> for Array<bool> {
292 type Error = anyhow::Error;
293
294 fn try_from(dyn_arr: DynArray) -> Result<Self> {
295 match dyn_arr {
296 DynArray::Bool(a) => Ok(a),
297 _ => bail!("Expected Bool, got {:?}", dyn_arr.dtype()),
298 }
299 }
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn test_dyn_array_creation() {
308 let a = DynArray::F32(Array::new(vec![2], vec![1.0, 2.0]));
309 assert_eq!(a.dtype(), DType::F32);
310 assert_eq!(a.shape(), &[2]);
311 assert_eq!(a.len(), 2);
312 }
313
314 #[test]
315 fn test_dyn_array_conversions() {
316 let arr_f32 = Array::new(vec![3], vec![1.0, 2.0, 3.0]);
317 let dyn_arr: DynArray = arr_f32.clone().into();
318
319 assert_eq!(dyn_arr.dtype(), DType::F32);
320
321 let back: Array<f32> = dyn_arr.try_into().unwrap();
322 assert_eq!(back.data, arr_f32.data);
323 }
324
325 #[test]
326 fn test_dyn_array_type_mismatch() {
327 let dyn_arr = DynArray::I32(Array::new(vec![2], vec![1, 2]));
328
329 let result: Result<Array<f32>> = dyn_arr.try_into();
330 assert!(result.is_err());
331 }
332
333 #[test]
334 fn test_to_f32_conversion() {
335 let i32_arr = DynArray::I32(Array::new(vec![3], vec![1, 2, 3]));
336 let f32_arr = i32_arr.to_f32();
337
338 assert_eq!(f32_arr.data, vec![1.0, 2.0, 3.0]);
339 }
340}