Skip to main content

radiate_utils/datatype/
arithmetic.rs

1use radiate_error::radiate_bail;
2
3use crate::AnyValue;
4use std::ops::{Add, Div, Mul, Rem, Sub};
5use std::ops::{BitAnd, BitOr, Not};
6
7/// Internal helper: perform `lhs <op> rhs` for all numeric AnyValue variants.
8/// On type mismatch, returns `AnyValue::Null`.
9macro_rules! bin_numeric_op {
10    ($lhs:expr, $rhs:expr, $op:tt) => {{
11        use AnyValue::*;
12        match ($lhs, $rhs) {
13            (Int8(a),    Int8(b))    => Int8(a $op b),
14            (Int8(a),   Int16(b))   => Int16(i16::from(a) $op b),
15            (Int8(a),   Int32(b))   => Int32(i32::from(a) $op b),
16            (Int8(a),   Int64(b))   => Int64(i64::from(a) $op b),
17            (Int8(a),   Int128(b))  => Int128(i128::from(a) $op b),
18
19            (Int16(a),   Int8(b))   => Int16(i16::from(a) $op i16::from(b)),
20            (Int16(a),   Int16(b))   => Int16(a $op b),
21            (Int16(a),   Int32(b))   => Int32(i32::from(a) $op b),
22            (Int16(a),   Int64(b))   => Int64(i64::from(a) $op b),
23            (Int16(a),   Int128(b))  => Int128(i128::from(a) $op b),
24
25            (Int32(a),   Int8(b))   => Int32(a $op i32::from(b)),
26            (Int32(a),   Int16(b))   => Int32(a $op i32::from(b)),
27            (Int32(a),   Int32(b))   => Int32(a $op b),
28            (Int32(a),   Int64(b))   => Int64(i64::from(a) $op b),
29            (Int32(a),   Int128(b))  => Int128(i128::from(a) $op b),
30
31            (Int64(a),   Int8(b))   => Int64(a $op i64::from(b)),
32            (Int64(a),   Int16(b))   => Int64(a $op i64::from(b)),
33            (Int64(a),   Int32(b))   => Int64(a $op i64::from(b)),
34            (Int64(a),   Int64(b))   => Int64(a $op b),
35            (Int64(a),   Int128(b))  => Int128(i128::from(a) $op b),
36
37            (Int128(a),   Int8(b))   => Int128(a $op i128::from(b)),
38            (Int128(a),   Int16(b))   => Int128(a $op i128::from(b)),
39            (Int128(a),   Int32(b))   => Int128(a $op i128::from(b)),
40            (Int128(a),   Int64(b))   => Int128(a $op i128::from(b)),
41            (Int128(a),  Int128(b))  => Int128(a $op b),
42
43            (UInt8(a),   UInt8(b))   => UInt8(a $op b),
44            (UInt8(a),  UInt16(b))  => UInt16(u16::from(a) $op b),
45            (UInt8(a),  UInt32(b))  => UInt32(u32::from(a) $op b),
46            (UInt8(a),  UInt64(b))  => UInt64(u64::from(a) $op b),
47            (UInt8(a),  UInt128(b)) => UInt128(u128::from(a) $op b),
48
49            (UInt16(a),  UInt8(b))  => UInt16(a $op u16::from(b)),
50            (UInt16(a),  UInt16(b))  => UInt16(a $op b),
51            (UInt16(a),  UInt32(b))  => UInt32(u32::from(a) $op b),
52            (UInt16(a),  UInt64(b))  => UInt64(u64::from(a) $op b),
53            (UInt16(a),  UInt128(b)) => UInt128(u128::from(a) $op b),
54
55            (UInt32(a),  UInt8(b))  => UInt32(a $op u32::from(b)),
56            (UInt32(a),  UInt16(b))  => UInt32(a $op u32::from(b)),
57            (UInt32(a),  UInt32(b))  => UInt32(a $op b),
58            (UInt32(a),  UInt64(b))  => UInt64(u64::from(a) $op b),
59            (UInt32(a),  UInt128(b)) => UInt128(u128::from(a) $op b),
60
61            (UInt64(a),  UInt8(b))  => UInt64(a $op u64::from(b)),
62            (UInt64(a),  UInt16(b))  => UInt64(a $op u64::from(b)),
63            (UInt64(a),  UInt32(b))  => UInt64(a $op u64::from(b)),
64            (UInt64(a),  UInt64(b))  => UInt64(a $op b),
65            (UInt64(a),  UInt128(b)) => UInt128(u128::from(a) $op b),
66
67            (Float32(a), Float32(b)) => Float32(a $op b),
68            (Float64(a), Float32(b)) => Float64(a $op b as f64),
69
70            (Float64(a), Float64(b)) => Float64(a $op b),
71            (Float32(a), Float64(b)) => Float64(a as f64 $op b),
72            _ => Null,
73        }
74    }};
75}
76
77macro_rules! bin_numeric_div {
78    ($lhs:expr, $rhs:expr) => {{
79        use AnyValue::*;
80        match ($lhs, $rhs) {
81            (Int8(a), Int8(b)) => Int8(if b == 0 { a } else { a / b }),
82            (Int16(a), Int8(b)) => Int16(if b == 0 { a } else { a / i16::from(b) }),
83            (Int32(a), Int8(b)) => Int32(if b == 0 { a } else { a / i32::from(b) }),
84            (Int64(a), Int8(b)) => Int64(if b == 0 { a } else { a / i64::from(b) }),
85            (Int128(a), Int8(b)) => Int128(if b == 0 { a } else { a / i128::from(b) }),
86
87            (Int16(a), Int16(b)) => Int16(if b == 0 { a } else { a / b }),
88            (Int32(a), Int16(b)) => Int32(if b == 0 { a } else { a / i32::from(b) }),
89            (Int64(a), Int16(b)) => Int64(if b == 0 { a } else { a / i64::from(b) }),
90            (Int128(a), Int16(b)) => Int128(if b == 0 { a } else { a / i128::from(b) }),
91
92            (Int32(a), Int32(b)) => Int32(if b == 0 { a } else { a / b }),
93            (Int64(a), Int32(b)) => Int64(if b == 0 { a } else { a / i64::from(b) }),
94            (Int128(a), Int32(b)) => Int128(if b == 0 { a } else { a / i128::from(b) }),
95
96            (Int64(a), Int64(b)) => Int64(if b == 0 { a } else { a / b }),
97            (Int128(a), Int64(b)) => Int128(if b == 0 { a } else { a / i128::from(b) }),
98
99            (Int128(a), Int128(b)) => Int128(if b == 0 { a } else { a / b }),
100
101            (UInt8(a), UInt8(b)) => UInt8(if b == 0 { a } else { a / b }),
102            (UInt8(a), UInt16(b)) => UInt16(if b == 0 { a as u16 } else { (a as u16) / b }),
103            (UInt8(a), UInt32(b)) => UInt32(if b == 0 { a as u32 } else { (a as u32) / b }),
104            (UInt8(a), UInt64(b)) => UInt64(if b == 0 { a as u64 } else { (a as u64) / b }),
105            (UInt8(a), UInt128(b)) => UInt128(if b == 0 { a as u128 } else { (a as u128) / b }),
106
107            (UInt16(a), UInt16(b)) => UInt16(if b == 0 { a } else { a / b }),
108            (UInt16(a), UInt32(b)) => UInt32(if b == 0 { a as u32 } else { (a as u32) / b }),
109            (UInt16(a), UInt64(b)) => UInt64(if b == 0 { a as u64 } else { (a as u64) / b }),
110            (UInt16(a), UInt128(b)) => UInt128(if b == 0 { a as u128 } else { (a as u128) / b }),
111
112            (UInt32(a), UInt32(b)) => UInt32(if b == 0 { a } else { a / b }),
113            (UInt32(a), UInt64(b)) => UInt64(if b == 0 { a as u64 } else { (a as u64) / b }),
114            (UInt32(a), UInt128(b)) => UInt128(if b == 0 { a as u128 } else { (a as u128) / b }),
115
116            (UInt64(a), UInt64(b)) => UInt64(if b == 0 { a } else { a / b }),
117            (UInt64(a), UInt128(b)) => UInt128(if b == 0 { a as u128 } else { (a as u128) / b }),
118
119            (Float32(a), Float32(b)) => {
120                if b == 0.0 {
121                    Null
122                } else {
123                    Float32(a / b)
124                }
125            }
126            (Float64(a), Float64(b)) => {
127                if b == 0.0 {
128                    Null
129                } else {
130                    Float64(a / b)
131                }
132            }
133            (Float32(a), Float64(b)) => {
134                if b == 0.0 {
135                    Null
136                } else {
137                    Float64((a as f64) / b)
138                }
139            }
140            (Float64(a), Float32(b)) => {
141                if b == 0.0 {
142                    Null
143                } else {
144                    Float64(a / (b as f64))
145                }
146            }
147            _ => panic!("Division is only supported for numeric types"),
148        }
149    }};
150}
151
152impl Add for AnyValue<'_> {
153    type Output = Self;
154
155    #[inline(always)]
156    fn add(self, other: Self) -> Self {
157        use AnyValue::*;
158        let is_numeric = self.dtype().is_numeric() && other.dtype().is_numeric();
159        let is_nested = self.is_nested() && other.is_nested();
160
161        if !is_numeric && !is_nested {
162            return self;
163        }
164
165        match (self, other) {
166            (Bool(a), Bool(b)) => Bool(a || b),
167            (Vector(a), Vector(b)) => Vector(a.into_iter().zip(b).map(|(x, y)| x + y).collect()),
168            (Dict(a), Dict(b)) => {
169                if a.len() != b.len() {
170                    return Null;
171                }
172
173                Dict(a
174                    .into_iter()
175                    .zip(b)
176                    .map(|(one, two)| {
177                        if one.0 != two.0 {
178                            return (one.0, one.1, Null);
179                        }
180
181                        (one.0, one.1, one.2 + two.2)
182                    })
183                    .collect())
184            }
185            (lhs, rhs) => bin_numeric_op!(lhs, rhs, +),
186        }
187    }
188}
189
190impl Sub for AnyValue<'_> {
191    type Output = Self;
192
193    #[inline(always)]
194    fn sub(self, other: Self) -> Self {
195        use AnyValue::*;
196
197        let is_numeric = self.dtype().is_numeric() && other.dtype().is_numeric();
198        let is_nested = self.is_nested() && other.is_nested();
199
200        if !is_numeric && !is_nested {
201            return self;
202        }
203
204        match (self, other) {
205            (Bool(a), Bool(b)) => Bool(a ^ b),
206            (Vector(a), Vector(b)) => Vector(a.into_iter().zip(b).map(|(x, y)| x - y).collect()),
207            (Dict(a), Dict(b)) => {
208                if a.len() != b.len() {
209                    return Null;
210                }
211
212                Dict(a
213                    .into_iter()
214                    .zip(b)
215                    .map(|(one, two)| {
216                        if one.0 != two.0 {
217                            return (one.0, one.1, Null);
218                        }
219
220                        (one.0, one.1, one.2 - two.2)
221                    })
222                    .collect())
223            }
224            (lhs, rhs) => bin_numeric_op!(lhs, rhs, -),
225        }
226    }
227}
228
229impl Mul for AnyValue<'_> {
230    type Output = Self;
231
232    #[inline(always)]
233    fn mul(self, other: Self) -> Self {
234        use AnyValue::*;
235
236        let is_numeric = self.dtype().is_numeric() && other.dtype().is_numeric();
237        let is_nested = self.is_nested() && other.is_nested();
238
239        if !is_numeric && !is_nested {
240            return self;
241        }
242
243        match (self, other) {
244            (Bool(a), Bool(b)) => Bool(a && b),
245            (Vector(a), Vector(b)) => Vector(a.into_iter().zip(b).map(|(x, y)| x * y).collect()),
246            (Dict(a), Dict(b)) => {
247                if a.len() != b.len() {
248                    return Null;
249                }
250
251                Dict(a
252                    .into_iter()
253                    .zip(b)
254                    .map(|(one, two)| {
255                        if one.0 != two.0 {
256                            return (one.0, one.1, Null);
257                        }
258
259                        (one.0, one.1, one.2 * two.2)
260                    })
261                    .collect())
262            }
263            (lhs, rhs) => bin_numeric_op!(lhs, rhs, *),
264        }
265    }
266}
267
268impl Div for AnyValue<'_> {
269    type Output = Self;
270
271    #[inline(always)]
272    fn div(self, other: Self) -> Self {
273        use AnyValue::*;
274
275        let is_numeric = self.dtype().is_numeric() && other.dtype().is_numeric();
276        let is_nested = self.is_nested() && other.is_nested();
277
278        if !is_numeric && !is_nested {
279            return self;
280        }
281
282        match (self, other) {
283            (Vector(a), Vector(b)) => Vector(a.into_iter().zip(b).map(|(x, y)| x / y).collect()),
284            (Dict(a), Dict(b)) => {
285                if a.len() != b.len() {
286                    return Null;
287                }
288
289                Dict(a
290                    .into_iter()
291                    .zip(b)
292                    .map(|(one, two)| {
293                        if one.0 != two.0 {
294                            return (one.0, one.1, Null);
295                        }
296
297                        (one.0, one.1, one.2 / two.2)
298                    })
299                    .collect())
300            }
301            (lhs, rhs) => bin_numeric_div!(lhs, rhs),
302        }
303    }
304}
305
306impl Rem for AnyValue<'_> {
307    type Output = Self;
308
309    #[inline(always)]
310    fn rem(self, other: Self) -> Self {
311        use AnyValue::*;
312
313        let is_numeric = self.dtype().is_numeric() && other.dtype().is_numeric();
314
315        if !is_numeric {
316            return self;
317        }
318
319        match (self, other) {
320            (Vector(a), Vector(b)) => Vector(a.into_iter().zip(b).map(|(x, y)| x % y).collect()),
321            (Dict(a), Dict(b)) => {
322                if a.len() != b.len() {
323                    return Null;
324                }
325
326                Dict(a
327                    .into_iter()
328                    .zip(b)
329                    .map(|(one, two)| {
330                        if one.0 != two.0 {
331                            return (one.0, one.1, Null);
332                        }
333
334                        (one.0, one.1, one.2 % two.2)
335                    })
336                    .collect())
337            }
338            (lhs, rhs) => bin_numeric_op!(lhs, rhs, %),
339        }
340    }
341}
342
343impl<'a> BitAnd for AnyValue<'a> {
344    type Output = AnyValue<'static>;
345
346    fn bitand(self, rhs: Self) -> Self::Output {
347        match (self, rhs) {
348            (AnyValue::Bool(a), AnyValue::Bool(b)) => AnyValue::Bool(a & b),
349            _ => AnyValue::Null,
350        }
351    }
352}
353
354impl<'a> BitOr for AnyValue<'a> {
355    type Output = AnyValue<'static>;
356
357    fn bitor(self, rhs: Self) -> Self::Output {
358        match (self, rhs) {
359            (AnyValue::Bool(a), AnyValue::Bool(b)) => AnyValue::Bool(a | b),
360            _ => AnyValue::Null,
361        }
362    }
363}
364
365impl<'a> Not for AnyValue<'a> {
366    type Output = AnyValue<'static>;
367
368    fn not(self) -> Self::Output {
369        match self {
370            AnyValue::Bool(v) => AnyValue::Bool(!v),
371            _ => AnyValue::Null,
372        }
373    }
374}
375
376#[inline]
377pub fn pow_anyvalue(
378    base: &AnyValue<'_>,
379    exp: &AnyValue<'_>,
380) -> Result<AnyValue<'static>, radiate_error::RadiateError> {
381    use AnyValue::*;
382    match (base, exp) {
383        (Int8(a), Int8(b)) => Ok(Int8(a.pow(*b as u32))),
384        (Int16(a), Int8(b)) => Ok(Int16(a.pow(*b as u32))),
385        (Int32(a), Int8(b)) => Ok(Int32(a.pow(*b as u32))),
386        (Int64(a), Int8(b)) => Ok(Int64(a.pow(*b as u32))),
387        (Int128(a), Int8(b)) => Ok(Int128(a.pow(*b as u32))),
388
389        (Int16(a), Int16(b)) => Ok(Int16(a.pow(*b as u32))),
390        (Int32(a), Int16(b)) => Ok(Int32(a.pow(*b as u32))),
391        (Int64(a), Int16(b)) => Ok(Int64(a.pow(*b as u32))),
392        (Int128(a), Int16(b)) => Ok(Int128(a.pow(*b as u32))),
393
394        (Int32(a), Int32(b)) => Ok(Int32(a.pow(*b as u32))),
395        (Int64(a), Int32(b)) => Ok(Int64(a.pow(*b as u32))),
396        (Int128(a), Int32(b)) => Ok(Int128(a.pow(*b as u32))),
397
398        (Int64(a), Int64(b)) => Ok(Int64(a.pow(*b as u32))),
399        (Int128(a), Int64(b)) => Ok(Int128(a.pow(*b as u32))),
400
401        (Int128(a), Int128(b)) => Ok(Int128(a.pow(*b as u32))),
402
403        (UInt8(a), UInt8(b)) => Ok(UInt8(a.pow(*b as u32))),
404        (UInt8(a), UInt16(b)) => Ok(UInt16((u16::from(*a)).pow(u32::from(*b)))),
405        (UInt8(a), UInt32(b)) => Ok(UInt32((u32::from(*a)).pow(*b))),
406
407        (UInt16(a), UInt16(b)) => Ok(UInt16(a.pow(*b as u32))),
408        (UInt16(a), UInt32(b)) => Ok(UInt32((u32::from(*a)).pow(*b))),
409
410        (UInt32(a), UInt32(b)) => Ok(UInt32(a.pow(*b))),
411
412        (UInt64(a), UInt64(b)) => Ok(UInt64(a.pow(*b as u32))),
413
414        (UInt128(a), UInt128(b)) => Ok(UInt128(a.pow(*b as u32))),
415
416        (Float32(a), Float32(b)) => Ok(Float32(a.powf(*b))),
417        (Float32(a), Float64(b)) => Ok(Float64((*a as f64).powf(*b))),
418
419        (Float64(a), Float32(b)) => Ok(Float64(a.powf(*b as f64))),
420        (Float64(a), Float64(b)) => Ok(Float64(a.powf(*b))),
421        _ => {
422            radiate_bail!(Expr: "Exponentiation is only supported for numeric types, got base {:?} and exponent {:?}", base, exp)
423        }
424    }
425}
426
427#[inline]
428#[allow(dead_code)]
429fn mean_anyvalue(one: &AnyValue<'_>, two: &AnyValue<'_>) -> Option<AnyValue<'static>> {
430    use AnyValue::*;
431    if let Some(v) = mean_numeric(one, two) {
432        return Some(v);
433    }
434
435    match (one, two) {
436        (Bool(x), Bool(y)) => Some(Bool(*x && *y)),
437
438        (Vector(xs), Vector(ys)) => super::value::apply_zipped_slice(xs, ys, mean_anyvalue),
439        (Dict(xs), Dict(ys)) => super::value::apply_zipped_struct_slice(xs, ys, mean_anyvalue),
440        _ => None,
441    }
442}
443
444#[inline]
445#[allow(dead_code)]
446fn mean_numeric(a: &AnyValue<'_>, b: &AnyValue<'_>) -> Option<AnyValue<'static>> {
447    use AnyValue::*;
448    let out = match (a, b) {
449        (UInt8(x), UInt8(y)) => UInt8(((u16::from(*x) + u16::from(*y)) / 2) as u8),
450        (UInt16(x), UInt16(y)) => UInt16(((u32::from(*x) + u32::from(*y)) / 2) as u16),
451        (UInt32(x), UInt32(y)) => UInt32(((u64::from(*x) + u64::from(*y)) / 2) as u32),
452        (UInt64(x), UInt64(y)) => UInt64(((u128::from(*x) + u128::from(*y)) / 2) as u64),
453
454        (Int8(x), Int8(y)) => Int8(*x + ((*y as i16 - *x as i16) / 2) as i8),
455        (Int16(x), Int16(y)) => Int16(*x + ((*y as i32 - *x as i32) / 2) as i16),
456        (Int32(x), Int32(y)) => Int32(*x + ((*y as i64 - *x as i64) / 2) as i32),
457        (Int64(x), Int64(y)) => {
458            let dx = (*y as i128) - (*x as i128);
459            Int64(*x + (dx / 2) as i64)
460        }
461        (Int128(x), Int128(y)) => Int128(*x + ((*y - *x) / 2)),
462
463        (Float32(x), Float32(y)) => Float32((*x + *y) / 2.0),
464        (Float64(x), Float64(y)) => Float64((*x + *y) / 2.0),
465
466        _ => return None,
467    };
468
469    Some(out)
470}
471
472#[cfg(test)]
473mod tests {
474    use crate::SmallStr;
475
476    use super::*;
477    use AnyValue::*;
478
479    fn make_vec(xs: Vec<AnyValue<'static>>) -> AnyValue<'static> {
480        AnyValue::Vector(xs)
481    }
482
483    fn make_dict(pairs: Vec<(&'static str, AnyValue<'static>)>) -> AnyValue<'static> {
484        let fields = pairs
485            .into_iter()
486            .map(|(name, val)| (SmallStr::from(name), val.dtype(), val))
487            .collect();
488        AnyValue::Dict(fields)
489    }
490
491    // ---------- Numeric: happy paths (same-type) ----------
492    #[test]
493    fn numeric_add_same_type() {
494        assert_eq!(Bool(true) + Bool(false), Bool(true));
495
496        assert_eq!(UInt8(10) + UInt8(5), UInt8(15));
497        assert_eq!(UInt16(10) + UInt16(5), UInt16(15));
498        assert_eq!(UInt32(10) + UInt32(5), UInt32(15));
499        assert_eq!(UInt64(10) + UInt64(5), UInt64(15));
500
501        assert_eq!(Int8(10) + Int8(5), Int8(15));
502        assert_eq!(Int16(10) + Int16(5), Int16(15));
503        assert_eq!(Int32(10) + Int32(5), Int32(15));
504        assert_eq!(Int64(10) + Int64(5), Int64(15));
505        assert_eq!(Int128(10) + Int128(5), Int128(15));
506
507        assert_eq!(Float32(1.5) + Float32(2.0), Float32(3.5));
508        assert_eq!(Float64(1.5) + Float64(2.0), Float64(3.5));
509    }
510
511    #[test]
512    fn numeric_sub_same_type() {
513        assert_eq!(Bool(true) - Bool(false), Bool(true));
514
515        assert_eq!(UInt8(10) - UInt8(3), UInt8(7));
516        assert_eq!(UInt16(10) - UInt16(3), UInt16(7));
517        assert_eq!(UInt32(10) - UInt32(3), UInt32(7));
518        assert_eq!(UInt64(10) - UInt64(3), UInt64(7));
519
520        assert_eq!(Int8(10) - Int8(4), Int8(6));
521        assert_eq!(Int16(10) - Int16(4), Int16(6));
522        assert_eq!(Int32(10) - Int32(4), Int32(6));
523        assert_eq!(Int64(10) - Int64(4), Int64(6));
524        assert_eq!(Int128(10) - Int128(4), Int128(6));
525
526        assert_eq!(Float32(5.0) - Float32(2.5), Float32(2.5));
527        assert_eq!(Float64(5.0) - Float64(2.5), Float64(2.5));
528    }
529
530    #[test]
531    fn numeric_mul_same_type() {
532        assert_eq!(Bool(true) * Bool(false), Bool(true));
533
534        assert_eq!(UInt8(7) * UInt8(6), UInt8(42));
535        assert_eq!(UInt16(7) * UInt16(6), UInt16(42));
536        assert_eq!(UInt32(7) * UInt32(6), UInt32(42));
537        assert_eq!(UInt64(7) * UInt64(6), UInt64(42));
538
539        assert_eq!(Int8(7) * Int8(6), Int8(42));
540        assert_eq!(Int16(7) * Int16(6), Int16(42));
541        assert_eq!(Int32(7) * Int32(6), Int32(42));
542        assert_eq!(Int64(7) * Int64(6), Int64(42));
543        assert_eq!(Int128(7) * Int128(6), Int128(42));
544
545        assert_eq!(Float32(1.5) * Float32(2.0), Float32(3.0));
546        assert_eq!(Float64(1.5) * Float64(2.0), Float64(3.0));
547    }
548
549    #[test]
550    fn numeric_div_same_type() {
551        assert_eq!(Bool(true) / Bool(false), Bool(true));
552
553        assert_eq!(UInt8(42) / UInt8(6), UInt8(7));
554        assert_eq!(UInt16(42) / UInt16(6), UInt16(7));
555        assert_eq!(UInt32(42) / UInt32(6), UInt32(7));
556        assert_eq!(UInt64(42) / UInt64(6), UInt64(7));
557
558        assert_eq!(Int8(42) / Int8(6), Int8(7));
559        assert_eq!(Int16(42) / Int16(6), Int16(7));
560        assert_eq!(Int32(42) / Int32(6), Int32(7));
561        assert_eq!(Int64(42) / Int64(6), Int64(7));
562        assert_eq!(Int128(42) / Int128(6), Int128(7));
563
564        assert_eq!(Float32(7.5) / Float32(2.5), Float32(3.0));
565        assert_eq!(Float64(7.5) / Float64(2.5), Float64(3.0));
566    }
567
568    #[test]
569    fn int_div_by_zero_yields_null() {
570        assert_eq!(Int32(5) / Int32(0), Int32(5));
571        assert_eq!(UInt64(7) / UInt64(0), UInt64(7));
572    }
573
574    // ---------- Vector elementwise ----------
575    #[test]
576    fn vector_elementwise_add_ok() {
577        let a = make_vec(vec![Int32(1), Int32(2), Int32(3)]);
578        let b = make_vec(vec![Int32(4), Int32(5), Int32(6)]);
579        let out = make_vec(vec![Int32(5), Int32(7), Int32(9)]);
580        assert_eq!(a + b, out);
581    }
582
583    #[test]
584    fn vector_length_mismatch() {
585        let a = make_vec(vec![Int32(1), Int32(2)]);
586        let b = make_vec(vec![Int32(3)]);
587        assert_eq!(a + b, Vector(vec![Int32(4)]));
588    }
589
590    // ---------- Struct fieldwise ----------
591    #[test]
592    fn struct_same_shape_by_order() {
593        // Current code: length check; name mismatch → per-field Null (keeps left field)
594        let a = make_dict(vec![("x", Int32(1)), ("y", Int32(2))]);
595        let b = make_dict(vec![("x", Int32(3)), ("y", Int32(4))]);
596        let out = make_dict(vec![("x", Int32(4)), ("y", Int32(6))]);
597        assert_eq!(a + b, out);
598    }
599
600    #[test]
601    fn struct_length_mismatch_yields_null() {
602        let a = make_dict(vec![("x", Int32(1))]);
603        let b = make_dict(vec![("x", Int32(2)), ("y", Int32(3))]);
604        assert_eq!(a + b, Null);
605    }
606
607    #[test]
608    fn struct_field_name_mismatch_sets_field_null_under_current_rules() {
609        let a = make_dict(vec![("x", Int32(1)), ("y", Int32(2))]);
610        let b = make_dict(vec![("x", Int32(3)), ("z", Int32(9))]);
611        // Current impl: when names differ at a position, that *slot* becomes Null; rest proceed.
612        let expected = make_dict(vec![("x", Int32(4)), ("y", Null)]);
613        assert_eq!(a + b, expected);
614    }
615
616    #[test]
617    fn struct_align_by_name_regardless_of_order() {
618        let a = make_dict(vec![("x", Int32(1)), ("y", Int32(2))]);
619        let b = make_dict(vec![("y", Int32(4)), ("x", Int32(3))]);
620        let out = make_dict(vec![("x", Null), ("y", Null)]);
621        assert_eq!(a + b, out);
622    }
623
624    // ---------- Null interactions ----------
625    #[test]
626    fn null_propagation() {
627        assert_eq!(Null + Int32(5), Null);
628        assert_eq!(Float64(2.0) * Null, Float64(2.0));
629        assert_eq!(Null / Null, Null);
630    }
631
632    // ---------- Mean ----------
633    #[test]
634    fn mean_numeric_pairs() {
635        assert_eq!(mean_anyvalue(&Int32(2), &Int32(4)), Some(Int32(3)));
636        assert_eq!(mean_anyvalue(&UInt8(10), &UInt8(20)), Some(UInt8(15)));
637        assert_eq!(
638            mean_anyvalue(&Float64(1.0), &Float64(3.0)),
639            Some(Float64(2.0))
640        );
641    }
642
643    #[test]
644    fn mean_bool_is_and() {
645        assert_eq!(mean_anyvalue(&Bool(true), &Bool(false)), Some(Bool(false)));
646        assert_eq!(mean_anyvalue(&Bool(true), &Bool(true)), Some(Bool(true)));
647    }
648
649    // ---------- Algebraic sanity checks ----------
650    #[test]
651    fn add_commutative_for_numeric() {
652        assert_eq!(Int64(7) + Int64(5), Int64(5) + Int64(7));
653    }
654
655    #[test]
656    fn mul_commutative_for_numeric() {
657        assert_eq!(Int16(3) * Int16(9), Int16(9) * Int16(3));
658    }
659
660    #[test]
661    fn sub_non_commutative_for_numeric() {
662        assert_ne!(Int32(10) - Int32(4), Int32(4) - Int32(10));
663    }
664}