numrs/array/
promotion.rs

1//! Type promotion rules for operations between different dtypes
2//!
3//! Cuando dos arrays de tipos diferentes se operan, necesitamos determinar
4//! el tipo del resultado siguiendo reglas de promoción (como NumPy).
5//!
6//! Jerarquía de promoción:
7//! bool < u8 < i8 < i32 < f16 < bf16 < f32 < f64
8
9use crate::array::{Array, DType, DTypeValue};
10use anyhow::{Result, bail};
11
12/// Determine el tipo de resultado cuando se combinan dos dtypes
13/// 
14/// Reglas:
15/// - Operaciones entre el mismo tipo retornan ese tipo
16/// - Operaciones entre tipos diferentes promueven al tipo "más grande"
17/// - Float siempre gana sobre int
18/// - Tipos más anchos (más bits) ganan sobre tipos más estrechos
19pub fn promoted_dtype(a: DType, b: DType) -> DType {
20    if a == b {
21        return a;
22    }
23
24    // Lookup table para promoción
25    match (a, b) {
26        // F64 es el tipo más grande - siempre gana
27        (DType::F64, _) | (_, DType::F64) => DType::F64,
28        
29        // F32 gana sobre todo excepto F64
30        (DType::F32, DType::F16) | (DType::F16, DType::F32) => DType::F32,
31        (DType::F32, DType::BF16) | (DType::BF16, DType::F32) => DType::F32,
32        (DType::F32, _) | (_, DType::F32) => DType::F32,
33        
34        // BF16 vs F16 -> F32 (ambos son 16-bit, pero incompatibles)
35        (DType::BF16, DType::F16) | (DType::F16, DType::BF16) => DType::F32,
36        
37        // F16 gana sobre enteros
38        (DType::F16, _) | (_, DType::F16) => DType::F16,
39        
40        // BF16 gana sobre enteros
41        (DType::BF16, _) | (_, DType::BF16) => DType::BF16,
42        
43        // I32 gana sobre tipos más pequeños
44        (DType::I32, _) | (_, DType::I32) => DType::I32,
45        
46        // I8 vs U8 -> I32 (para evitar overflow)
47        (DType::I8, DType::U8) | (DType::U8, DType::I8) => DType::I32,
48        
49        // I8 gana sobre Bool
50        (DType::I8, DType::Bool) | (DType::Bool, DType::I8) => DType::I8,
51        
52        // U8 gana sobre Bool
53        (DType::U8, DType::Bool) | (DType::Bool, DType::U8) => DType::U8,
54        
55        // Casos base (no deberían llegar aquí si la tabla está completa)
56        _ => DType::F32, // Default seguro: F32
57    }
58}
59
60/// Convierte un Array<T> a Array<U> (cast)
61/// 
62/// Esta función hace la conversión real de datos entre tipos.
63pub fn cast_array<T, U>(arr: &Array<T>) -> Array<U>
64where
65    T: DTypeValue,
66    U: DTypeValue,
67{
68    let data: Vec<U> = arr.data.iter().map(|&val| {
69        U::from_f32(val.to_f32())
70    }).collect();
71    
72    Array::new(arr.shape.clone(), data)
73}
74
75/// Promociona dos arrays al tipo común y retorna las versiones convertidas
76/// 
77/// Esta es la función principal que se usa en operaciones.
78/// Si ambos arrays tienen el mismo tipo, no hace nada (zero-cost).
79/// Si son diferentes, convierte ambos al tipo común.
80pub fn promote_arrays<T1, T2>(
81    a: &Array<T1>,
82    b: &Array<T2>,
83) -> Result<(DType, Vec<f32>, Vec<f32>)>
84where
85    T1: DTypeValue,
86    T2: DTypeValue,
87{
88    let dtype_a = a.dtype;
89    let dtype_b = b.dtype;
90    
91    if dtype_a == dtype_b {
92        // Mismo tipo - solo convertir a f32 para procesamiento
93        return Ok((dtype_a, a.data.iter().map(|&x| x.to_f32()).collect(), 
94                            b.data.iter().map(|&x| x.to_f32()).collect()));
95    }
96    
97    // Determinar tipo de resultado
98    let result_dtype = promoted_dtype(dtype_a, dtype_b);
99    
100    // Convertir ambos arrays a f32 (forma intermedia común)
101    let a_f32: Vec<f32> = a.data.iter().map(|&x| x.to_f32()).collect();
102    let b_f32: Vec<f32> = b.data.iter().map(|&x| x.to_f32()).collect();
103    
104    Ok((result_dtype, a_f32, b_f32))
105}
106
107/// Valida que dos arrays puedan operarse juntos
108/// 
109/// Verifica:
110/// - Que los shapes sean compatibles
111/// - Que los dtypes sean compatibles
112pub fn validate_binary_op<T1, T2>(
113    a: &Array<T1>,
114    b: &Array<T2>,
115    op_name: &str,
116) -> Result<()>
117where
118    T1: DTypeValue,
119    T2: DTypeValue,
120{
121    // MATMUL es un caso especial - necesita validación diferente
122    if op_name == "matmul" {
123        return validate_matmul_shapes(a, b);
124    }
125    
126    // Para operaciones elementwise, validar broadcasting
127    if !shapes_are_broadcastable(&a.shape, &b.shape) {
128        bail!(
129            "{}: shape mismatch and not broadcastable: {:?} vs {:?}",
130            op_name,
131            a.shape,
132            b.shape
133        );
134    }
135    
136    // Por ahora, todos los dtypes son compatibles (promoción automática)
137    // En el futuro podríamos rechazar ciertas combinaciones
138    
139    Ok(())
140}
141
142/// Verifica si dos shapes son compatibles para broadcasting según reglas de NumPy:
143/// - Los shapes se comparan desde el final hacia el inicio
144/// - Cada dimensión debe ser igual O una de ellas debe ser 1
145/// Ejemplos:
146/// - [3, 4] y [4] -> broadcastable (se expande [4] a [1, 4])
147/// - [3, 4] y [3, 1] -> broadcastable
148/// - [3, 4] y [5] -> NO broadcastable
149fn shapes_are_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
150    let len1 = shape1.len();
151    let len2 = shape2.len();
152    let max_len = len1.max(len2);
153    
154    for i in 0..max_len {
155        let dim1 = if i < len1 {
156            shape1[len1 - 1 - i]
157        } else {
158            1
159        };
160        
161        let dim2 = if i < len2 {
162            shape2[len2 - 1 - i]
163        } else {
164            1
165        };
166        
167        // Cada dimensión debe ser igual O una debe ser 1
168        if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
169            return false;
170        }
171    }
172    
173    true
174}
175
176/// Valida shapes para matmul, soportando múltiples configuraciones como NumPy:
177/// - 2D @ 2D: [M, K] @ [K, N] -> [M, N]
178/// - 1D @ 2D: [K] @ [K, N] -> [N] (vector-matrix)
179/// - 2D @ 1D: [M, K] @ [K] -> [M] (matrix-vector)
180/// - 1D @ 1D: [K] @ [K] -> [] (dot product)
181/// - nD @ mD: batched matmul con broadcasting
182fn validate_matmul_shapes<T1, T2>(a: &Array<T1>, b: &Array<T2>) -> Result<()>
183where
184    T1: DTypeValue,
185    T2: DTypeValue,
186{
187    let a_ndim = a.shape.len();
188    let b_ndim = b.shape.len();
189    
190    // Casos soportados
191    match (a_ndim, b_ndim) {
192        // 2D @ 2D: Caso estándar de matriz
193        (2, 2) => {
194            let a_cols = a.shape[1];
195            let b_rows = b.shape[0];
196            if a_cols != b_rows {
197                bail!(
198                    "matmul: inner dimensions must match: [{}] @ [{}] incompatible",
199                    a_cols, b_rows
200                );
201            }
202            Ok(())
203        }
204        
205        // 1D @ 2D: vector @ matrix -> vector
206        (1, 2) => {
207            let a_len = a.shape[0];
208            let b_rows = b.shape[0];
209            if a_len != b_rows {
210                bail!(
211                    "matmul: vector-matrix dimensions incompatible: [{}] @ [{}, {}]",
212                    a_len, b_rows, b.shape[1]
213                );
214            }
215            Ok(())
216        }
217        
218        // 2D @ 1D: matrix @ vector -> vector
219        (2, 1) => {
220            let a_cols = a.shape[1];
221            let b_len = b.shape[0];
222            if a_cols != b_len {
223                bail!(
224                    "matmul: matrix-vector dimensions incompatible: [{}, {}] @ [{}]",
225                    a.shape[0], a_cols, b_len
226                );
227            }
228            Ok(())
229        }
230        
231        // 1D @ 1D: dot product
232        (1, 1) => {
233            if a.shape[0] != b.shape[0] {
234                bail!(
235                    "matmul: vectors must have same length: [{}] @ [{}]",
236                    a.shape[0], b.shape[0]
237                );
238            }
239            Ok(())
240        }
241        
242        // TODO: Batched matmul (3D+)
243        _ => {
244            bail!(
245                "matmul: unsupported dimensions: {}D @ {}D (currently only 1D/2D supported)",
246                a_ndim, b_ndim
247            );
248        }
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[test]
257    fn test_same_type_promotion() {
258        assert_eq!(promoted_dtype(DType::F32, DType::F32), DType::F32);
259        assert_eq!(promoted_dtype(DType::I32, DType::I32), DType::I32);
260    }
261
262    #[test]
263    fn test_float_hierarchy() {
264        // F64 > F32 > F16/BF16
265        assert_eq!(promoted_dtype(DType::F32, DType::F64), DType::F64);
266        assert_eq!(promoted_dtype(DType::F16, DType::F32), DType::F32);
267        assert_eq!(promoted_dtype(DType::BF16, DType::F32), DType::F32);
268        assert_eq!(promoted_dtype(DType::F16, DType::F64), DType::F64);
269    }
270
271    #[test]
272    fn test_float_vs_int() {
273        // Float siempre gana
274        assert_eq!(promoted_dtype(DType::F32, DType::I32), DType::F32);
275        assert_eq!(promoted_dtype(DType::F16, DType::I32), DType::F16);
276        assert_eq!(promoted_dtype(DType::I8, DType::F32), DType::F32);
277    }
278
279    #[test]
280    fn test_int_promotion() {
281        // I32 > I8/U8 > Bool
282        assert_eq!(promoted_dtype(DType::I32, DType::I8), DType::I32);
283        assert_eq!(promoted_dtype(DType::I32, DType::U8), DType::I32);
284        assert_eq!(promoted_dtype(DType::I8, DType::Bool), DType::I8);
285        assert_eq!(promoted_dtype(DType::U8, DType::Bool), DType::U8);
286    }
287
288    #[test]
289    fn test_mixed_sign_ints() {
290        // I8 + U8 -> I32 (para seguridad)
291        assert_eq!(promoted_dtype(DType::I8, DType::U8), DType::I32);
292    }
293
294    #[test]
295    fn test_f16_vs_bf16() {
296        // F16 + BF16 -> F32 (incompatibles entre sí)
297        assert_eq!(promoted_dtype(DType::F16, DType::BF16), DType::F32);
298    }
299
300    #[test]
301    fn test_cast_array() {
302        let a = Array::new(vec![3], vec![1.0_f32, 2.0, 3.0]);
303        
304        // F32 -> I32
305        let b: Array<i32> = cast_array(&a);
306        assert_eq!(b.dtype, DType::I32);
307        assert_eq!(b.data, vec![1, 2, 3]);
308        
309        // F32 -> F64
310        let c: Array<f64> = cast_array(&a);
311        assert_eq!(c.dtype, DType::F64);
312        assert_eq!(c.data, vec![1.0_f64, 2.0, 3.0]);
313    }
314
315    #[test]
316    fn test_promote_same_type() -> Result<()> {
317        let a = Array::new(vec![2], vec![1.0_f32, 2.0]);
318        let b = Array::new(vec![2], vec![3.0_f32, 4.0]);
319        
320        let (dtype, a_data, b_data) = promote_arrays(&a, &b)?;
321        
322        assert_eq!(dtype, DType::F32);
323        assert_eq!(a_data, vec![1.0, 2.0]);
324        assert_eq!(b_data, vec![3.0, 4.0]);
325        
326        Ok(())
327    }
328
329    #[test]
330    fn test_promote_different_types() -> Result<()> {
331        let a = Array::new(vec![2], vec![1_i32, 2]);
332        let b = Array::new(vec![2], vec![3.0_f32, 4.0]);
333        
334        let (dtype, a_data, b_data) = promote_arrays(&a, &b)?;
335        
336        assert_eq!(dtype, DType::F32);
337        assert_eq!(a_data, vec![1.0, 2.0]);
338        assert_eq!(b_data, vec![3.0, 4.0]);
339        
340        Ok(())
341    }
342
343    #[test]
344    fn test_validate_binary_op_ok() -> Result<()> {
345        let a = Array::new(vec![2, 3], vec![1.0_f32; 6]);
346        let b = Array::new(vec![2, 3], vec![2.0_f64; 6]);
347        
348        validate_binary_op(&a, &b, "add")?;
349        Ok(())
350    }
351
352    #[test]
353    fn test_validate_binary_op_shape_mismatch() {
354        let a = Array::new(vec![2], vec![1.0_f32, 2.0]);
355        let b = Array::new(vec![3], vec![1.0_f32, 2.0, 3.0]);
356        
357        let result = validate_binary_op(&a, &b, "add");
358        assert!(result.is_err());
359        assert!(result.unwrap_err().to_string().contains("shape mismatch"));
360    }
361}