1use crate::array::{Array, DType, DTypeValue};
10use anyhow::{Result, bail};
11
12pub fn promoted_dtype(a: DType, b: DType) -> DType {
20 if a == b {
21 return a;
22 }
23
24 match (a, b) {
26 (DType::F64, _) | (_, DType::F64) => DType::F64,
28
29 (DType::F32, DType::F16) | (DType::F16, DType::F32) => DType::F32,
31 (DType::F32, DType::BF16) | (DType::BF16, DType::F32) => DType::F32,
32 (DType::F32, _) | (_, DType::F32) => DType::F32,
33
34 (DType::BF16, DType::F16) | (DType::F16, DType::BF16) => DType::F32,
36
37 (DType::F16, _) | (_, DType::F16) => DType::F16,
39
40 (DType::BF16, _) | (_, DType::BF16) => DType::BF16,
42
43 (DType::I32, _) | (_, DType::I32) => DType::I32,
45
46 (DType::I8, DType::U8) | (DType::U8, DType::I8) => DType::I32,
48
49 (DType::I8, DType::Bool) | (DType::Bool, DType::I8) => DType::I8,
51
52 (DType::U8, DType::Bool) | (DType::Bool, DType::U8) => DType::U8,
54
55 _ => DType::F32, }
58}
59
60pub fn cast_array<T, U>(arr: &Array<T>) -> Array<U>
64where
65 T: DTypeValue,
66 U: DTypeValue,
67{
68 let data: Vec<U> = arr.data.iter().map(|&val| {
69 U::from_f32(val.to_f32())
70 }).collect();
71
72 Array::new(arr.shape.clone(), data)
73}
74
75pub fn promote_arrays<T1, T2>(
81 a: &Array<T1>,
82 b: &Array<T2>,
83) -> Result<(DType, Vec<f32>, Vec<f32>)>
84where
85 T1: DTypeValue,
86 T2: DTypeValue,
87{
88 let dtype_a = a.dtype;
89 let dtype_b = b.dtype;
90
91 if dtype_a == dtype_b {
92 return Ok((dtype_a, a.data.iter().map(|&x| x.to_f32()).collect(),
94 b.data.iter().map(|&x| x.to_f32()).collect()));
95 }
96
97 let result_dtype = promoted_dtype(dtype_a, dtype_b);
99
100 let a_f32: Vec<f32> = a.data.iter().map(|&x| x.to_f32()).collect();
102 let b_f32: Vec<f32> = b.data.iter().map(|&x| x.to_f32()).collect();
103
104 Ok((result_dtype, a_f32, b_f32))
105}
106
107pub fn validate_binary_op<T1, T2>(
113 a: &Array<T1>,
114 b: &Array<T2>,
115 op_name: &str,
116) -> Result<()>
117where
118 T1: DTypeValue,
119 T2: DTypeValue,
120{
121 if op_name == "matmul" {
123 return validate_matmul_shapes(a, b);
124 }
125
126 if !shapes_are_broadcastable(&a.shape, &b.shape) {
128 bail!(
129 "{}: shape mismatch and not broadcastable: {:?} vs {:?}",
130 op_name,
131 a.shape,
132 b.shape
133 );
134 }
135
136 Ok(())
140}
141
142fn shapes_are_broadcastable(shape1: &[usize], shape2: &[usize]) -> bool {
150 let len1 = shape1.len();
151 let len2 = shape2.len();
152 let max_len = len1.max(len2);
153
154 for i in 0..max_len {
155 let dim1 = if i < len1 {
156 shape1[len1 - 1 - i]
157 } else {
158 1
159 };
160
161 let dim2 = if i < len2 {
162 shape2[len2 - 1 - i]
163 } else {
164 1
165 };
166
167 if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
169 return false;
170 }
171 }
172
173 true
174}
175
176fn validate_matmul_shapes<T1, T2>(a: &Array<T1>, b: &Array<T2>) -> Result<()>
183where
184 T1: DTypeValue,
185 T2: DTypeValue,
186{
187 let a_ndim = a.shape.len();
188 let b_ndim = b.shape.len();
189
190 match (a_ndim, b_ndim) {
192 (2, 2) => {
194 let a_cols = a.shape[1];
195 let b_rows = b.shape[0];
196 if a_cols != b_rows {
197 bail!(
198 "matmul: inner dimensions must match: [{}] @ [{}] incompatible",
199 a_cols, b_rows
200 );
201 }
202 Ok(())
203 }
204
205 (1, 2) => {
207 let a_len = a.shape[0];
208 let b_rows = b.shape[0];
209 if a_len != b_rows {
210 bail!(
211 "matmul: vector-matrix dimensions incompatible: [{}] @ [{}, {}]",
212 a_len, b_rows, b.shape[1]
213 );
214 }
215 Ok(())
216 }
217
218 (2, 1) => {
220 let a_cols = a.shape[1];
221 let b_len = b.shape[0];
222 if a_cols != b_len {
223 bail!(
224 "matmul: matrix-vector dimensions incompatible: [{}, {}] @ [{}]",
225 a.shape[0], a_cols, b_len
226 );
227 }
228 Ok(())
229 }
230
231 (1, 1) => {
233 if a.shape[0] != b.shape[0] {
234 bail!(
235 "matmul: vectors must have same length: [{}] @ [{}]",
236 a.shape[0], b.shape[0]
237 );
238 }
239 Ok(())
240 }
241
242 _ => {
244 bail!(
245 "matmul: unsupported dimensions: {}D @ {}D (currently only 1D/2D supported)",
246 a_ndim, b_ndim
247 );
248 }
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[test]
257 fn test_same_type_promotion() {
258 assert_eq!(promoted_dtype(DType::F32, DType::F32), DType::F32);
259 assert_eq!(promoted_dtype(DType::I32, DType::I32), DType::I32);
260 }
261
262 #[test]
263 fn test_float_hierarchy() {
264 assert_eq!(promoted_dtype(DType::F32, DType::F64), DType::F64);
266 assert_eq!(promoted_dtype(DType::F16, DType::F32), DType::F32);
267 assert_eq!(promoted_dtype(DType::BF16, DType::F32), DType::F32);
268 assert_eq!(promoted_dtype(DType::F16, DType::F64), DType::F64);
269 }
270
271 #[test]
272 fn test_float_vs_int() {
273 assert_eq!(promoted_dtype(DType::F32, DType::I32), DType::F32);
275 assert_eq!(promoted_dtype(DType::F16, DType::I32), DType::F16);
276 assert_eq!(promoted_dtype(DType::I8, DType::F32), DType::F32);
277 }
278
279 #[test]
280 fn test_int_promotion() {
281 assert_eq!(promoted_dtype(DType::I32, DType::I8), DType::I32);
283 assert_eq!(promoted_dtype(DType::I32, DType::U8), DType::I32);
284 assert_eq!(promoted_dtype(DType::I8, DType::Bool), DType::I8);
285 assert_eq!(promoted_dtype(DType::U8, DType::Bool), DType::U8);
286 }
287
288 #[test]
289 fn test_mixed_sign_ints() {
290 assert_eq!(promoted_dtype(DType::I8, DType::U8), DType::I32);
292 }
293
294 #[test]
295 fn test_f16_vs_bf16() {
296 assert_eq!(promoted_dtype(DType::F16, DType::BF16), DType::F32);
298 }
299
300 #[test]
301 fn test_cast_array() {
302 let a = Array::new(vec![3], vec![1.0_f32, 2.0, 3.0]);
303
304 let b: Array<i32> = cast_array(&a);
306 assert_eq!(b.dtype, DType::I32);
307 assert_eq!(b.data, vec![1, 2, 3]);
308
309 let c: Array<f64> = cast_array(&a);
311 assert_eq!(c.dtype, DType::F64);
312 assert_eq!(c.data, vec![1.0_f64, 2.0, 3.0]);
313 }
314
315 #[test]
316 fn test_promote_same_type() -> Result<()> {
317 let a = Array::new(vec![2], vec![1.0_f32, 2.0]);
318 let b = Array::new(vec![2], vec![3.0_f32, 4.0]);
319
320 let (dtype, a_data, b_data) = promote_arrays(&a, &b)?;
321
322 assert_eq!(dtype, DType::F32);
323 assert_eq!(a_data, vec![1.0, 2.0]);
324 assert_eq!(b_data, vec![3.0, 4.0]);
325
326 Ok(())
327 }
328
329 #[test]
330 fn test_promote_different_types() -> Result<()> {
331 let a = Array::new(vec![2], vec![1_i32, 2]);
332 let b = Array::new(vec![2], vec![3.0_f32, 4.0]);
333
334 let (dtype, a_data, b_data) = promote_arrays(&a, &b)?;
335
336 assert_eq!(dtype, DType::F32);
337 assert_eq!(a_data, vec![1.0, 2.0]);
338 assert_eq!(b_data, vec![3.0, 4.0]);
339
340 Ok(())
341 }
342
343 #[test]
344 fn test_validate_binary_op_ok() -> Result<()> {
345 let a = Array::new(vec![2, 3], vec![1.0_f32; 6]);
346 let b = Array::new(vec![2, 3], vec![2.0_f64; 6]);
347
348 validate_binary_op(&a, &b, "add")?;
349 Ok(())
350 }
351
352 #[test]
353 fn test_validate_binary_op_shape_mismatch() {
354 let a = Array::new(vec![2], vec![1.0_f32, 2.0]);
355 let b = Array::new(vec![3], vec![1.0_f32, 2.0, 3.0]);
356
357 let result = validate_binary_op(&a, &b, "add");
358 assert!(result.is_err());
359 assert!(result.unwrap_err().to_string().contains("shape mismatch"));
360 }
361}