numrs/array/
dyn_array.rs

1//! Dynamic (type-erased) array that can hold any dtype
2//!
3//! Este módulo proporciona `DynArray`, un enum que puede contener un Array<T>
4//! de cualquier tipo concreto. Esto permite que funciones devuelvan arrays
5//! con dtype determinado en runtime (como NumPy).
6
7use crate::array::{Array, DType, DTypeValue};
8use anyhow::{Result, bail};
9
10/// Array con tipo borrado (type-erased) que puede contener cualquier dtype
11///
12/// Similar a como NumPy maneja arrays internamente: el dtype se conoce en runtime.
13/// 
14/// # Ejemplo
15/// ```ignore
16/// let a = DynArray::F32(Array::new(vec![2], vec![1.0, 2.0]));
17/// let b = DynArray::I32(Array::new(vec![2], vec![3, 4]));
18/// 
19/// // Operaciones devuelven DynArray con tipo promovido
20/// let result = ops::add_dyn(&a, &b)?; // -> DynArray::F32
21/// 
22/// // Pattern matching para extraer el tipo concreto
23/// match result {
24///     DynArray::F32(arr) => println!("f32: {:?}", arr.data),
25///     DynArray::I32(arr) => println!("i32: {:?}", arr.data),
26///     _ => {}
27/// }
28/// ```
29#[derive(Debug, Clone)]
30pub enum DynArray {
31    F32(Array<f32>),
32    F64(Array<f64>),
33    I32(Array<i32>),
34    I8(Array<i8>),
35    U8(Array<u8>),
36    Bool(Array<bool>),
37}
38
39impl DynArray {
40    /// Obtener el DType de este array
41    pub fn dtype(&self) -> DType {
42        match self {
43            DynArray::F32(_) => DType::F32,
44            DynArray::F64(_) => DType::F64,
45            DynArray::I32(_) => DType::I32,
46            DynArray::I8(_) => DType::I8,
47            DynArray::U8(_) => DType::U8,
48            DynArray::Bool(_) => DType::Bool,
49        }
50    }
51    
52    /// Obtener el shape de este array
53    pub fn shape(&self) -> &[usize] {
54        match self {
55            DynArray::F32(a) => &a.shape,
56            DynArray::F64(a) => &a.shape,
57            DynArray::I32(a) => &a.shape,
58            DynArray::I8(a) => &a.shape,
59            DynArray::U8(a) => &a.shape,
60            DynArray::Bool(a) => &a.shape,
61        }
62    }
63    
64    /// Número de elementos
65    pub fn len(&self) -> usize {
66        self.shape().iter().product()
67    }
68    
69    /// Si el array está vacío
70    pub fn is_empty(&self) -> bool {
71        self.len() == 0
72    }
73    
74    /// Convertir a Array<f32> (con conversión si es necesario)
75    pub fn to_f32(&self) -> Array<f32> {
76        match self {
77            DynArray::F32(a) => a.clone(),
78            DynArray::F64(a) => {
79                let data: Vec<f32> = a.data.iter().map(|&x| x as f32).collect();
80                Array::new(a.shape.clone(), data)
81            }
82            DynArray::I32(a) => {
83                let data: Vec<f32> = a.data.iter().map(|&x| x as f32).collect();
84                Array::new(a.shape.clone(), data)
85            }
86            DynArray::I8(a) => {
87                let data: Vec<f32> = a.data.iter().map(|&x| x as f32).collect();
88                Array::new(a.shape.clone(), data)
89            }
90            DynArray::U8(a) => {
91                let data: Vec<f32> = a.data.iter().map(|&x| x as f32).collect();
92                Array::new(a.shape.clone(), data)
93            }
94            DynArray::Bool(a) => {
95                let data: Vec<f32> = a.data.iter().map(|&x| if x { 1.0 } else { 0.0 }).collect();
96                Array::new(a.shape.clone(), data)
97            }
98        }
99    }
100    
101    /// Aplicar una función a los datos internos
102    pub fn map_data<F, R>(&self, f: F) -> Result<R>
103    where
104        F: FnOnce(&dyn std::any::Any) -> Result<R>,
105    {
106        match self {
107            DynArray::F32(a) => f(&a.data),
108            DynArray::F64(a) => f(&a.data),
109            DynArray::I32(a) => f(&a.data),
110            DynArray::I8(a) => f(&a.data),
111            DynArray::U8(a) => f(&a.data),
112            DynArray::Bool(a) => f(&a.data),
113        }
114    }
115    
116    /// Extract the inner Array<T> with zero-copy if types match, or cast if they don't.
117    /// 
118    /// # Automatic Casting
119    /// If the internal dtype differs from T, this function will automatically cast/convert
120    /// the data to T. This enables "automatic narrowing" (F64 -> F32) or normal promotion casting.
121    pub fn into_typed<T: DTypeValue>(self) -> Result<Array<T>> {
122        use std::any::TypeId;
123        use std::mem;
124        
125        // Fast path: types match (zero copy)
126        // We use TypeId check for safety before transmute
127        match self {
128            DynArray::F32(arr) if TypeId::of::<T>() == TypeId::of::<f32>() => {
129                return Ok(unsafe { mem::transmute::<Array<f32>, Array<T>>(arr) });
130            }
131            DynArray::F64(arr) if TypeId::of::<T>() == TypeId::of::<f64>() => {
132                return Ok(unsafe { mem::transmute::<Array<f64>, Array<T>>(arr) });
133            }
134            DynArray::I32(arr) if TypeId::of::<T>() == TypeId::of::<i32>() => {
135                return Ok(unsafe { mem::transmute::<Array<i32>, Array<T>>(arr) });
136            }
137            DynArray::I8(arr) if TypeId::of::<T>() == TypeId::of::<i8>() => {
138                return Ok(unsafe { mem::transmute::<Array<i8>, Array<T>>(arr) });
139            }
140            DynArray::U8(arr) if TypeId::of::<T>() == TypeId::of::<u8>() => {
141                return Ok(unsafe { mem::transmute::<Array<u8>, Array<T>>(arr) });
142            }
143            DynArray::Bool(arr) if TypeId::of::<T>() == TypeId::of::<bool>() => {
144                return Ok(unsafe { mem::transmute::<Array<bool>, Array<T>>(arr) });
145            }
146            _ => {
147                // Slow path: Type mismatch -> Cast required
148                match self {
149                    DynArray::F32(arr) => Ok(crate::array::promotion::cast_array(&arr)),
150                    DynArray::F64(arr) => Ok(crate::array::promotion::cast_array(&arr)),
151                    DynArray::I32(arr) => Ok(crate::array::promotion::cast_array(&arr)),
152                    DynArray::I8(arr) => Ok(crate::array::promotion::cast_array(&arr)),
153                    DynArray::U8(arr) => Ok(crate::array::promotion::cast_array(&arr)),
154                    DynArray::Bool(arr) => Ok(crate::array::promotion::cast_array(&arr)),
155                }
156            }
157        }
158    }
159    
160    /// Envolver un Array<T> genérico en DynArray basándose en su dtype
161    /// 
162    /// Esto usa unsafe para transmute, pero es seguro porque verificamos el dtype
163    pub fn from_generic<T: DTypeValue>(arr: Array<T>) -> Self {
164        use std::any::TypeId;
165        use std::mem;
166        
167        // Verificar el tipo en compiletime si es posible
168        if TypeId::of::<T>() == TypeId::of::<f32>() {
169            let arr_f32 = unsafe { mem::transmute::<Array<T>, Array<f32>>(arr) };
170            return DynArray::F32(arr_f32);
171        }
172        if TypeId::of::<T>() == TypeId::of::<f64>() {
173            let arr_f64 = unsafe { mem::transmute::<Array<T>, Array<f64>>(arr) };
174            return DynArray::F64(arr_f64);
175        }
176        if TypeId::of::<T>() == TypeId::of::<i32>() {
177            let arr_i32 = unsafe { mem::transmute::<Array<T>, Array<i32>>(arr) };
178            return DynArray::I32(arr_i32);
179        }
180        if TypeId::of::<T>() == TypeId::of::<i8>() {
181            let arr_i8 = unsafe { mem::transmute::<Array<T>, Array<i8>>(arr) };
182            return DynArray::I8(arr_i8);
183        }
184        if TypeId::of::<T>() == TypeId::of::<u8>() {
185            let arr_u8 = unsafe { mem::transmute::<Array<T>, Array<u8>>(arr) };
186            return DynArray::U8(arr_u8);
187        }
188        if TypeId::of::<T>() == TypeId::of::<bool>() {
189            let arr_bool = unsafe { mem::transmute::<Array<T>, Array<bool>>(arr) };
190            return DynArray::Bool(arr_bool);
191        }
192        
193        // Fallback: no debería llegar aquí si DTypeValue está bien implementado
194        panic!("Unsupported dtype for DynArray::from_generic");
195    }
196}
197
198// Conversiones convenientes desde Array<T> a DynArray
199impl From<Array<f32>> for DynArray {
200    fn from(arr: Array<f32>) -> Self {
201        DynArray::F32(arr)
202    }
203}
204
205impl From<Array<f64>> for DynArray {
206    fn from(arr: Array<f64>) -> Self {
207        DynArray::F64(arr)
208    }
209}
210
211impl From<Array<i32>> for DynArray {
212    fn from(arr: Array<i32>) -> Self {
213        DynArray::I32(arr)
214    }
215}
216
217impl From<Array<i8>> for DynArray {
218    fn from(arr: Array<i8>) -> Self {
219        DynArray::I8(arr)
220    }
221}
222
223impl From<Array<u8>> for DynArray {
224    fn from(arr: Array<u8>) -> Self {
225        DynArray::U8(arr)
226    }
227}
228
229impl From<Array<bool>> for DynArray {
230    fn from(arr: Array<bool>) -> Self {
231        DynArray::Bool(arr)
232    }
233}
234
235// Conversiones hacia Array<T> con TryFrom (puede fallar si el tipo no coincide)
236impl TryFrom<DynArray> for Array<f32> {
237    type Error = anyhow::Error;
238    
239    fn try_from(dyn_arr: DynArray) -> Result<Self> {
240        match dyn_arr {
241            DynArray::F32(a) => Ok(a),
242            _ => bail!("Expected F32, got {:?}", dyn_arr.dtype()),
243        }
244    }
245}
246
247impl TryFrom<DynArray> for Array<f64> {
248    type Error = anyhow::Error;
249    
250    fn try_from(dyn_arr: DynArray) -> Result<Self> {
251        match dyn_arr {
252            DynArray::F64(a) => Ok(a),
253            _ => bail!("Expected F64, got {:?}", dyn_arr.dtype()),
254        }
255    }
256}
257
258impl TryFrom<DynArray> for Array<i32> {
259    type Error = anyhow::Error;
260    
261    fn try_from(dyn_arr: DynArray) -> Result<Self> {
262        match dyn_arr {
263            DynArray::I32(a) => Ok(a),
264            _ => bail!("Expected I32, got {:?}", dyn_arr.dtype()),
265        }
266    }
267}
268
269impl TryFrom<DynArray> for Array<i8> {
270    type Error = anyhow::Error;
271    
272    fn try_from(dyn_arr: DynArray) -> Result<Self> {
273        match dyn_arr {
274            DynArray::I8(a) => Ok(a),
275            _ => bail!("Expected I8, got {:?}", dyn_arr.dtype()),
276        }
277    }
278}
279
280impl TryFrom<DynArray> for Array<u8> {
281    type Error = anyhow::Error;
282    
283    fn try_from(dyn_arr: DynArray) -> Result<Self> {
284        match dyn_arr {
285            DynArray::U8(a) => Ok(a),
286            _ => bail!("Expected U8, got {:?}", dyn_arr.dtype()),
287        }
288    }
289}
290
291impl TryFrom<DynArray> for Array<bool> {
292    type Error = anyhow::Error;
293    
294    fn try_from(dyn_arr: DynArray) -> Result<Self> {
295        match dyn_arr {
296            DynArray::Bool(a) => Ok(a),
297            _ => bail!("Expected Bool, got {:?}", dyn_arr.dtype()),
298        }
299    }
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn test_dyn_array_creation() {
308        let a = DynArray::F32(Array::new(vec![2], vec![1.0, 2.0]));
309        assert_eq!(a.dtype(), DType::F32);
310        assert_eq!(a.shape(), &[2]);
311        assert_eq!(a.len(), 2);
312    }
313
314    #[test]
315    fn test_dyn_array_conversions() {
316        let arr_f32 = Array::new(vec![3], vec![1.0, 2.0, 3.0]);
317        let dyn_arr: DynArray = arr_f32.clone().into();
318        
319        assert_eq!(dyn_arr.dtype(), DType::F32);
320        
321        let back: Array<f32> = dyn_arr.try_into().unwrap();
322        assert_eq!(back.data, arr_f32.data);
323    }
324
325    #[test]
326    fn test_dyn_array_type_mismatch() {
327        let dyn_arr = DynArray::I32(Array::new(vec![2], vec![1, 2]));
328        
329        let result: Result<Array<f32>> = dyn_arr.try_into();
330        assert!(result.is_err());
331    }
332
333    #[test]
334    fn test_to_f32_conversion() {
335        let i32_arr = DynArray::I32(Array::new(vec![3], vec![1, 2, 3]));
336        let f32_arr = i32_arr.to_f32();
337        
338        assert_eq!(f32_arr.data, vec![1.0, 2.0, 3.0]);
339    }
340}