jax_rs/ops/
binary.rs

1//! Binary operations on arrays.
2
3use crate::trace::{is_tracing, trace_binary, Primitive};
4use crate::{buffer::Buffer, Array, DType, Device, Shape};
5
6/// Apply a binary function element-wise to two arrays with broadcasting.
7fn binary_op<F>(lhs: &Array, rhs: &Array, op: Primitive, f: F) -> Array
8where
9    F: Fn(f32, f32) -> f32,
10{
11    assert_eq!(lhs.dtype(), DType::Float32, "Only Float32 supported");
12    assert_eq!(rhs.dtype(), DType::Float32, "Only Float32 supported");
13
14    // Check if shapes are broadcast-compatible
15    let result_shape = lhs
16        .shape()
17        .broadcast_with(rhs.shape())
18        .expect("Shapes are not broadcast-compatible");
19
20    // Dispatch based on device
21    let result = match (lhs.device(), rhs.device()) {
22        (Device::WebGpu, Device::WebGpu) => {
23            // GPU path - no broadcasting support yet, shapes must match exactly
24            assert_eq!(
25                lhs.shape(),
26                rhs.shape(),
27                "GPU operations do not support broadcasting yet"
28            );
29
30            // Map primitive to WGSL operator
31            let op_str = match &op {
32                Primitive::Add => "+",
33                Primitive::Sub => "-",
34                Primitive::Mul => "*",
35                Primitive::Div => "/",
36                _ => {
37                    // Fallback to CPU for unsupported ops
38                    return binary_op_cpu(lhs, rhs, op.clone(), f);
39                }
40            };
41
42            // Create output buffer on GPU
43            let output_buffer = Buffer::zeros(
44                result_shape.size(),
45                DType::Float32,
46                Device::WebGpu,
47            );
48
49            // Execute on GPU
50            crate::backend::ops::gpu_binary_op(
51                lhs.buffer(),
52                rhs.buffer(),
53                &output_buffer,
54                op_str,
55            );
56
57            Array::from_buffer(output_buffer, result_shape)
58        }
59        (Device::Cpu, Device::Cpu) | (Device::Wasm, Device::Wasm) => {
60            // CPU path with broadcasting support
61            binary_op_cpu(lhs, rhs, op.clone(), f)
62        }
63        _ => {
64            panic!("Mixed device operations not supported. Both arrays must be on the same device.");
65        }
66    };
67
68    // Register with trace context if tracing is active
69    if is_tracing() {
70        trace_binary(result.id(), op, lhs, rhs);
71    }
72
73    result
74}
75
76/// CPU implementation of binary operation with broadcasting support.
77fn binary_op_cpu<F>(lhs: &Array, rhs: &Array, _op: Primitive, f: F) -> Array
78where
79    F: Fn(f32, f32) -> f32,
80{
81    let result_shape = lhs
82        .shape()
83        .broadcast_with(rhs.shape())
84        .expect("Shapes are not broadcast-compatible");
85
86    let lhs_data = lhs.to_vec();
87    let rhs_data = rhs.to_vec();
88
89    let result_data = if lhs.shape() == rhs.shape() {
90        // Same shape - simple element-wise operation
91        lhs_data.iter().zip(rhs_data.iter()).map(|(&a, &b)| f(a, b)).collect()
92    } else {
93        // Need broadcasting
94        broadcast_binary(
95            &lhs_data,
96            lhs.shape(),
97            &rhs_data,
98            rhs.shape(),
99            &result_shape,
100            f,
101        )
102    };
103
104    let buffer = Buffer::from_f32(result_data, Device::Cpu);
105    Array::from_buffer(buffer, result_shape)
106}
107
108/// Helper function to perform binary operation with broadcasting.
109fn broadcast_binary<F>(
110    lhs_data: &[f32],
111    lhs_shape: &Shape,
112    rhs_data: &[f32],
113    rhs_shape: &Shape,
114    result_shape: &Shape,
115    f: F,
116) -> Vec<f32>
117where
118    F: Fn(f32, f32) -> f32,
119{
120    let size = result_shape.size();
121    let mut result = Vec::with_capacity(size);
122
123    for i in 0..size {
124        let lhs_idx = broadcast_index(i, result_shape, lhs_shape);
125        let rhs_idx = broadcast_index(i, result_shape, rhs_shape);
126        result.push(f(lhs_data[lhs_idx], rhs_data[rhs_idx]));
127    }
128
129    result
130}
131
132/// Convert a flat index in the result array to an index in the source array,
133/// accounting for broadcasting.
134pub(crate) fn broadcast_index(
135    flat_idx: usize,
136    result_shape: &Shape,
137    src_shape: &Shape,
138) -> usize {
139    let result_dims = result_shape.as_slice();
140    let src_dims = src_shape.as_slice();
141
142    // Convert flat index to multi-dimensional index
143    let mut multi_idx = Vec::with_capacity(result_dims.len());
144    let mut idx = flat_idx;
145    for &dim in result_dims.iter().rev() {
146        multi_idx.push(idx % dim);
147        idx /= dim;
148    }
149    multi_idx.reverse();
150
151    // Map to source index with broadcasting
152    let offset = result_dims.len() - src_dims.len();
153    let mut src_idx = 0;
154    let mut stride = 1;
155
156    for i in (0..src_dims.len()).rev() {
157        let result_i = offset + i;
158        let dim_idx = if src_dims[i] == 1 {
159            0 // Broadcast dimension
160        } else {
161            multi_idx[result_i]
162        };
163        src_idx += dim_idx * stride;
164        stride *= src_dims[i];
165    }
166
167    src_idx
168}
169
170impl Array {
171    /// Add two arrays element-wise with broadcasting.
172    ///
173    /// # Examples
174    ///
175    /// ```
176    /// # use jax_rs::{Array, Shape};
177    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
178    /// let b = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
179    /// let c = a.add(&b);
180    /// assert_eq!(c.to_vec(), vec![11.0, 22.0, 33.0]);
181    /// ```
182    pub fn add(&self, other: &Array) -> Array {
183        binary_op(self, other, Primitive::Add, |a, b| a + b)
184    }
185
186    /// Subtract two arrays element-wise with broadcasting.
187    pub fn sub(&self, other: &Array) -> Array {
188        binary_op(self, other, Primitive::Sub, |a, b| a - b)
189    }
190
191    /// Multiply two arrays element-wise with broadcasting.
192    pub fn mul(&self, other: &Array) -> Array {
193        binary_op(self, other, Primitive::Mul, |a, b| a * b)
194    }
195
196    /// Divide two arrays element-wise with broadcasting.
197    pub fn div(&self, other: &Array) -> Array {
198        binary_op(self, other, Primitive::Div, |a, b| a / b)
199    }
200
201    /// Raise elements to a power element-wise with broadcasting.
202    pub fn pow(&self, other: &Array) -> Array {
203        binary_op(self, other, Primitive::Pow, |a, b| a.powf(b))
204    }
205
206    /// Element-wise minimum.
207    pub fn minimum(&self, other: &Array) -> Array {
208        binary_op(self, other, Primitive::Min, |a, b| a.min(b))
209    }
210
211    /// Element-wise maximum.
212    pub fn maximum(&self, other: &Array) -> Array {
213        binary_op(self, other, Primitive::Max, |a, b| a.max(b))
214    }
215
216    /// Safe division that returns 0 where division by zero would occur.
217    ///
218    /// Returns x / y where y != 0, and 0 where y == 0.
219    ///
220    /// # Examples
221    ///
222    /// ```
223    /// # use jax_rs::{Array, Shape};
224    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
225    /// let b = Array::from_vec(vec![2.0, 0.0, 3.0], Shape::new(vec![3]));
226    /// let c = a.divide_no_nan(&b);
227    /// assert_eq!(c.to_vec(), vec![0.5, 0.0, 1.0]);
228    /// ```
229    pub fn divide_no_nan(&self, other: &Array) -> Array {
230        binary_op(self, other, Primitive::Div, |a, b| {
231            if b == 0.0 {
232                0.0
233            } else {
234                a / b
235            }
236        })
237    }
238
239    /// Squared difference: (a - b)^2.
240    ///
241    /// Useful for computing mean squared error and similar metrics.
242    ///
243    /// # Examples
244    ///
245    /// ```
246    /// # use jax_rs::{Array, Shape};
247    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
248    /// let b = Array::from_vec(vec![2.0, 2.0, 1.0], Shape::new(vec![3]));
249    /// let c = a.squared_difference(&b);
250    /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 4.0]);
251    /// ```
252    pub fn squared_difference(&self, other: &Array) -> Array {
253        binary_op(self, other, Primitive::Sub, |a, b| {
254            let diff = a - b;
255            diff * diff
256        })
257    }
258
259    /// Element-wise modulo operation.
260    ///
261    /// # Examples
262    ///
263    /// ```
264    /// # use jax_rs::{Array, Shape};
265    /// let a = Array::from_vec(vec![5.0, 7.0, 9.0], Shape::new(vec![3]));
266    /// let b = Array::from_vec(vec![3.0, 3.0, 3.0], Shape::new(vec![3]));
267    /// let c = a.mod_op(&b);
268    /// assert_eq!(c.to_vec(), vec![2.0, 1.0, 0.0]);
269    /// ```
270    pub fn mod_op(&self, other: &Array) -> Array {
271        binary_op(self, other, Primitive::Div, |a, b| a % b)
272    }
273
274    /// Element-wise arctangent of a/b.
275    ///
276    /// Correctly handles signs to determine quadrant.
277    ///
278    /// # Examples
279    ///
280    /// ```
281    /// # use jax_rs::{Array, Shape};
282    /// let y = Array::from_vec(vec![1.0, -1.0], Shape::new(vec![2]));
283    /// let x = Array::from_vec(vec![1.0, 1.0], Shape::new(vec![2]));
284    /// let angle = y.atan2(&x);
285    /// # // We just check it compiles and runs
286    /// ```
287    pub fn atan2(&self, other: &Array) -> Array {
288        binary_op(self, other, Primitive::Div, |a, b| a.atan2(b))
289    }
290
291    /// Element-wise hypot: sqrt(a^2 + b^2).
292    ///
293    /// Computes the hypotenuse in a numerically stable way.
294    ///
295    /// # Examples
296    ///
297    /// ```
298    /// # use jax_rs::{Array, Shape};
299    /// let a = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
300    /// let b = Array::from_vec(vec![4.0, 3.0], Shape::new(vec![2]));
301    /// let c = a.hypot(&b);
302    /// assert_eq!(c.to_vec(), vec![5.0, 5.0]);
303    /// ```
304    pub fn hypot(&self, other: &Array) -> Array {
305        binary_op(self, other, Primitive::Add, |a, b| a.hypot(b))
306    }
307
308    /// Element-wise copysign: magnitude of a with sign of b.
309    ///
310    /// # Examples
311    ///
312    /// ```
313    /// # use jax_rs::{Array, Shape};
314    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
315    /// let b = Array::from_vec(vec![-1.0, 1.0, -1.0], Shape::new(vec![3]));
316    /// let c = a.copysign(&b);
317    /// assert_eq!(c.to_vec(), vec![-1.0, 2.0, -3.0]);
318    /// ```
319    pub fn copysign(&self, other: &Array) -> Array {
320        binary_op(self, other, Primitive::Mul, |a, b| a.copysign(b))
321    }
322
323    /// Element-wise next representable float in direction of b.
324    ///
325    /// # Examples
326    ///
327    /// ```
328    /// # use jax_rs::{Array, Shape};
329    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
330    /// let b = Array::from_vec(vec![2.0, 1.0], Shape::new(vec![2]));
331    /// let c = a.next_after(&b);
332    /// # // Just verify it compiles
333    /// ```
334    pub fn next_after(&self, other: &Array) -> Array {
335        binary_op(self, other, Primitive::Add, |a, b| {
336            if a < b {
337                // Next float towards positive infinity
338                f32::from_bits(a.to_bits() + 1)
339            } else if a > b {
340                // Next float towards negative infinity
341                f32::from_bits(a.to_bits() - 1)
342            } else {
343                b
344            }
345        })
346    }
347
348    /// Logarithm of sum of exponentials (numerically stable).
349    ///
350    /// Computes log(exp(x) + exp(y)) in a numerically stable way.
351    ///
352    /// # Examples
353    ///
354    /// ```
355    /// # use jax_rs::{Array, Shape};
356    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
357    /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
358    /// let c = a.logaddexp(&b);
359    /// // Result: log(exp(1)+exp(2)), log(exp(2)+exp(3)), log(exp(3)+exp(4))
360    /// ```
361    pub fn logaddexp(&self, other: &Array) -> Array {
362        binary_op(self, other, Primitive::Add, |a, b| {
363            let max = a.max(b);
364            max + ((a - max).exp() + (b - max).exp()).ln()
365        })
366    }
367
368    /// Base-2 logarithm of sum of exponentials.
369    ///
370    /// Computes log2(2^x + 2^y) in a numerically stable way.
371    ///
372    /// # Examples
373    ///
374    /// ```
375    /// # use jax_rs::{Array, Shape};
376    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
377    /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
378    /// let c = a.logaddexp2(&b);
379    /// ```
380    pub fn logaddexp2(&self, other: &Array) -> Array {
381        binary_op(self, other, Primitive::Add, |a, b| {
382            let max = a.max(b);
383            max + ((a - max).exp2() + (b - max).exp2()).log2()
384        })
385    }
386
387    /// Heaviside step function.
388    ///
389    /// Returns 0 where x < 0, h0 where x == 0, and 1 where x > 0.
390    ///
391    /// # Examples
392    ///
393    /// ```
394    /// # use jax_rs::{Array, Shape};
395    /// let x = Array::from_vec(vec![-1.0, 0.0, 1.0], Shape::new(vec![3]));
396    /// let h0 = Array::from_vec(vec![0.5, 0.5, 0.5], Shape::new(vec![3]));
397    /// let h = x.heaviside(&h0);
398    /// assert_eq!(h.to_vec(), vec![0.0, 0.5, 1.0]);
399    /// ```
400    pub fn heaviside(&self, h0: &Array) -> Array {
401        binary_op(self, h0, Primitive::Max, |x, h0_val| {
402            if x < 0.0 {
403                0.0
404            } else if x == 0.0 {
405                h0_val
406            } else {
407                1.0
408            }
409        })
410    }
411
412    /// Floor division (division rounding toward negative infinity).
413    ///
414    /// # Examples
415    ///
416    /// ```
417    /// # use jax_rs::{Array, Shape};
418    /// let a = Array::from_vec(vec![7.0, 7.0, -7.0], Shape::new(vec![3]));
419    /// let b = Array::from_vec(vec![3.0, -3.0, 3.0], Shape::new(vec![3]));
420    /// let c = a.floor_divide(&b);
421    /// assert_eq!(c.to_vec(), vec![2.0, -3.0, -3.0]);
422    /// ```
423    pub fn floor_divide(&self, other: &Array) -> Array {
424        binary_op(self, other, Primitive::Div, |a, b| (a / b).floor())
425    }
426
427    /// Fused multiply-add: a * b + c.
428    ///
429    /// # Examples
430    ///
431    /// ```
432    /// # use jax_rs::{Array, Shape};
433    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
434    /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
435    /// let c = Array::from_vec(vec![1.0, 1.0, 1.0], Shape::new(vec![3]));
436    /// let result = a.fma(&b, &c);
437    /// assert_eq!(result.to_vec(), vec![3.0, 7.0, 13.0]); // [1*2+1, 2*3+1, 3*4+1]
438    /// ```
439    pub fn fma(&self, b: &Array, c: &Array) -> Array {
440        let product = self.mul(b);
441        product.add(c)
442    }
443
444    /// Greatest common divisor element-wise.
445    ///
446    /// # Examples
447    ///
448    /// ```
449    /// # use jax_rs::{Array, Shape};
450    /// let a = Array::from_vec(vec![12.0, 15.0, 24.0], Shape::new(vec![3]));
451    /// let b = Array::from_vec(vec![8.0, 10.0, 18.0], Shape::new(vec![3]));
452    /// let c = a.gcd(&b);
453    /// assert_eq!(c.to_vec(), vec![4.0, 5.0, 6.0]);
454    /// ```
455    pub fn gcd(&self, other: &Array) -> Array {
456        binary_op(self, other, Primitive::Min, |mut a, mut b| {
457            a = a.abs();
458            b = b.abs();
459            while b > 0.5 {
460                let temp = b;
461                b = a % b;
462                a = temp;
463            }
464            a
465        })
466    }
467
468    /// Least common multiple element-wise.
469    ///
470    /// # Examples
471    ///
472    /// ```
473    /// # use jax_rs::{Array, Shape};
474    /// let a = Array::from_vec(vec![12.0, 15.0, 24.0], Shape::new(vec![3]));
475    /// let b = Array::from_vec(vec![8.0, 10.0, 18.0], Shape::new(vec![3]));
476    /// let c = a.lcm(&b);
477    /// assert_eq!(c.to_vec(), vec![24.0, 30.0, 72.0]);
478    /// ```
479    pub fn lcm(&self, other: &Array) -> Array {
480        binary_op(self, other, Primitive::Mul, |mut a, mut b| {
481            a = a.abs();
482            b = b.abs();
483            if a < 0.5 || b < 0.5 {
484                return 0.0;
485            }
486            let mut gcd_val = a;
487            let mut temp = b;
488            while temp > 0.5 {
489                let r = gcd_val % temp;
490                gcd_val = temp;
491                temp = r;
492            }
493            (a * b) / gcd_val
494        })
495    }
496
497    /// Bitwise AND operation.
498    /// Operates on the bit representation of Float32 values.
499    ///
500    /// # Examples
501    ///
502    /// ```
503    /// # use jax_rs::{Array, Shape};
504    /// let a = Array::from_vec(vec![15.0, 31.0, 63.0], Shape::new(vec![3]));
505    /// let b = Array::from_vec(vec![7.0, 15.0, 31.0], Shape::new(vec![3]));
506    /// let c = a.bitwise_and(&b);
507    /// ```
508    pub fn bitwise_and(&self, other: &Array) -> Array {
509        binary_op(self, other, Primitive::Min, |a, b| {
510            let a_bits = a.to_bits();
511            let b_bits = b.to_bits();
512            f32::from_bits(a_bits & b_bits)
513        })
514    }
515
516    /// Bitwise OR operation.
517    /// Operates on the bit representation of Float32 values.
518    ///
519    /// # Examples
520    ///
521    /// ```
522    /// # use jax_rs::{Array, Shape};
523    /// let a = Array::from_vec(vec![8.0, 16.0, 32.0], Shape::new(vec![3]));
524    /// let b = Array::from_vec(vec![4.0, 8.0, 16.0], Shape::new(vec![3]));
525    /// let c = a.bitwise_or(&b);
526    /// ```
527    pub fn bitwise_or(&self, other: &Array) -> Array {
528        binary_op(self, other, Primitive::Max, |a, b| {
529            let a_bits = a.to_bits();
530            let b_bits = b.to_bits();
531            f32::from_bits(a_bits | b_bits)
532        })
533    }
534
535    /// Bitwise XOR operation.
536    /// Operates on the bit representation of Float32 values.
537    ///
538    /// # Examples
539    ///
540    /// ```
541    /// # use jax_rs::{Array, Shape};
542    /// let a = Array::from_vec(vec![12.0, 15.0, 18.0], Shape::new(vec![3]));
543    /// let b = Array::from_vec(vec![10.0, 5.0, 20.0], Shape::new(vec![3]));
544    /// let c = a.bitwise_xor(&b);
545    /// ```
546    pub fn bitwise_xor(&self, other: &Array) -> Array {
547        binary_op(self, other, Primitive::Add, |a, b| {
548            let a_bits = a.to_bits();
549            let b_bits = b.to_bits();
550            f32::from_bits(a_bits ^ b_bits)
551        })
552    }
553
554    /// Left bit shift operation.
555    /// Shifts the bit representation of Float32 values left.
556    ///
557    /// # Examples
558    ///
559    /// ```
560    /// # use jax_rs::{Array, Shape};
561    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
562    /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
563    /// let c = a.left_shift(&b);
564    /// ```
565    pub fn left_shift(&self, other: &Array) -> Array {
566        binary_op(self, other, Primitive::Mul, |a, b| {
567            let a_bits = a.to_bits();
568            let shift = b as u32;
569            f32::from_bits(a_bits << shift)
570        })
571    }
572
573    /// Right bit shift operation.
574    /// Shifts the bit representation of Float32 values right.
575    ///
576    /// # Examples
577    ///
578    /// ```
579    /// # use jax_rs::{Array, Shape};
580    /// let a = Array::from_vec(vec![4.0, 8.0, 16.0], Shape::new(vec![3]));
581    /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
582    /// let c = a.right_shift(&b);
583    /// ```
584    pub fn right_shift(&self, other: &Array) -> Array {
585        binary_op(self, other, Primitive::Div, |a, b| {
586            let a_bits = a.to_bits();
587            let shift = b as u32;
588            f32::from_bits(a_bits >> shift)
589        })
590    }
591
592    /// Element-wise maximum, ignoring NaNs.
593    ///
594    /// # Examples
595    ///
596    /// ```
597    /// # use jax_rs::{Array, Shape};
598    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0], Shape::new(vec![3]));
599    /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
600    /// let c = a.fmax(&b);
601    /// assert_eq!(c.to_vec()[0], 2.0);
602    /// assert_eq!(c.to_vec()[1], 2.0);
603    /// assert_eq!(c.to_vec()[2], 3.0);
604    /// ```
605    pub fn fmax(&self, other: &Array) -> Array {
606        binary_op(self, other, Primitive::Max, |a, b| {
607            if a.is_nan() { b }
608            else if b.is_nan() { a }
609            else { a.max(b) }
610        })
611    }
612
613    /// Element-wise minimum, ignoring NaNs.
614    ///
615    /// # Examples
616    ///
617    /// ```
618    /// # use jax_rs::{Array, Shape};
619    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0], Shape::new(vec![3]));
620    /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
621    /// let c = a.fmin(&b);
622    /// assert_eq!(c.to_vec()[0], 1.0);
623    /// assert_eq!(c.to_vec()[1], 2.0);
624    /// assert_eq!(c.to_vec()[2], 2.0);
625    /// ```
626    pub fn fmin(&self, other: &Array) -> Array {
627        binary_op(self, other, Primitive::Min, |a, b| {
628            if a.is_nan() { b }
629            else if b.is_nan() { a }
630            else { a.min(b) }
631        })
632    }
633
634    /// Element-wise arc tangent of x1/x2 choosing the quadrant correctly.
635    ///
636    /// The quadrant (i.e., branch) is chosen so that arctan2(x1, x2) is
637    /// the signed angle in radians between the ray ending at the origin
638    /// and passing through the point (1,0), and the ray ending at the
639    /// origin and passing through the point (x2, x1).
640    ///
641    /// # Examples
642    ///
643    /// ```
644    /// # use jax_rs::{Array, Shape};
645    /// let y = Array::from_vec(vec![1.0, -1.0, 1.0, -1.0], Shape::new(vec![4]));
646    /// let x = Array::from_vec(vec![1.0, 1.0, -1.0, -1.0], Shape::new(vec![4]));
647    /// let angles = y.arctan2(&x);
648    /// // First quadrant: pi/4, Second: -pi/4, Third: 3pi/4, Fourth: -3pi/4
649    /// ```
650    pub fn arctan2(&self, other: &Array) -> Array {
651        binary_op(self, other, Primitive::Div, |y, x| y.atan2(x))
652    }
653
654    /// Element-wise remainder of division (fmod).
655    ///
656    /// # Examples
657    ///
658    /// ```
659    /// # use jax_rs::{Array, Shape};
660    /// let a = Array::from_vec(vec![5.0, 7.0, 10.0], Shape::new(vec![3]));
661    /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
662    /// let c = a.fmod(&b);
663    /// assert_eq!(c.to_vec(), vec![1.0, 1.0, 2.0]);
664    /// ```
665    pub fn fmod(&self, other: &Array) -> Array {
666        binary_op(self, other, Primitive::Div, |a, b| a % b)
667    }
668
669    /// Return the next floating-point value after x1 towards x2.
670    ///
671    /// # Examples
672    ///
673    /// ```
674    /// # use jax_rs::{Array, Shape};
675    /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
676    /// let b = Array::from_vec(vec![2.0, 1.0], Shape::new(vec![2]));
677    /// let c = a.nextafter(&b);
678    /// // First element goes up slightly, second goes down
679    /// ```
680    pub fn nextafter(&self, other: &Array) -> Array {
681        binary_op(self, other, Primitive::Add, |x1, x2| {
682            if x1 == x2 {
683                x2
684            } else if x2 > x1 {
685                // Next float toward positive infinity
686                let bits = x1.to_bits();
687                if x1 >= 0.0 {
688                    f32::from_bits(bits + 1)
689                } else {
690                    f32::from_bits(bits - 1)
691                }
692            } else {
693                // Next float toward negative infinity
694                let bits = x1.to_bits();
695                if x1 > 0.0 {
696                    f32::from_bits(bits - 1)
697                } else if x1 == 0.0 {
698                    f32::from_bits(1 | (1 << 31)) // Negative zero direction
699                } else {
700                    f32::from_bits(bits + 1)
701                }
702            }
703        })
704    }
705
706    /// Compute the safe element-wise division, returning 0 where denominator is 0.
707    ///
708    /// # Examples
709    ///
710    /// ```
711    /// # use jax_rs::{Array, Shape};
712    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
713    /// let b = Array::from_vec(vec![1.0, 0.0, 3.0], Shape::new(vec![3]));
714    /// let c = a.safe_divide(&b);
715    /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 1.0]);
716    /// ```
717    pub fn safe_divide(&self, other: &Array) -> Array {
718        binary_op(self, other, Primitive::Div, |a, b| {
719            if b == 0.0 { 0.0 } else { a / b }
720        })
721    }
722
723    /// Compute element-wise true division.
724    ///
725    /// # Examples
726    ///
727    /// ```
728    /// # use jax_rs::{Array, Shape};
729    /// let a = Array::from_vec(vec![5.0, 7.0, 9.0], Shape::new(vec![3]));
730    /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
731    /// let c = a.true_divide(&b);
732    /// assert_eq!(c.to_vec(), vec![2.5, 3.5, 4.5]);
733    /// ```
734    pub fn true_divide(&self, other: &Array) -> Array {
735        self.div(other)
736    }
737
738    /// Compute element-wise remainder, with the same sign as divisor.
739    ///
740    /// # Examples
741    ///
742    /// ```
743    /// # use jax_rs::{Array, Shape};
744    /// let a = Array::from_vec(vec![7.0, -7.0, 7.0], Shape::new(vec![3]));
745    /// let b = Array::from_vec(vec![3.0, 3.0, -3.0], Shape::new(vec![3]));
746    /// let c = a.remainder(&b);
747    /// // Python-style modulo: result has same sign as divisor
748    /// ```
749    pub fn remainder(&self, other: &Array) -> Array {
750        binary_op(self, other, Primitive::Div, |a, b| {
751            let r = a % b;
752            if (r > 0.0 && b < 0.0) || (r < 0.0 && b > 0.0) {
753                r + b
754            } else {
755                r
756            }
757        })
758    }
759
760    /// Compute element-wise difference raised to a power.
761    ///
762    /// # Examples
763    ///
764    /// ```
765    /// # use jax_rs::{Array, Shape};
766    /// let a = Array::from_vec(vec![3.0, 5.0, 7.0], Shape::new(vec![3]));
767    /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
768    /// let c = a.diff_pow(&b, 2.0);  // (a - b)^2
769    /// assert_eq!(c.to_vec(), vec![4.0, 9.0, 16.0]);
770    /// ```
771    pub fn diff_pow(&self, other: &Array, power: f32) -> Array {
772        binary_op(self, other, Primitive::Sub, move |a, b| (a - b).powf(power))
773    }
774
775    /// Compute element-wise squared difference.
776    ///
777    /// # Examples
778    ///
779    /// ```
780    /// # use jax_rs::{Array, Shape};
781    /// let a = Array::from_vec(vec![3.0, 5.0, 7.0], Shape::new(vec![3]));
782    /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
783    /// let c = a.squared_diff(&b);  // (a - b)^2
784    /// assert_eq!(c.to_vec(), vec![4.0, 9.0, 16.0]);
785    /// ```
786    pub fn squared_diff(&self, other: &Array) -> Array {
787        binary_op(self, other, Primitive::Sub, |a, b| {
788            let d = a - b;
789            d * d
790        })
791    }
792
793    /// Compute element-wise average of two arrays.
794    ///
795    /// # Examples
796    ///
797    /// ```
798    /// # use jax_rs::{Array, Shape};
799    /// let a = Array::from_vec(vec![2.0, 4.0, 6.0], Shape::new(vec![3]));
800    /// let b = Array::from_vec(vec![4.0, 6.0, 8.0], Shape::new(vec![3]));
801    /// let c = a.average_with(&b);
802    /// assert_eq!(c.to_vec(), vec![3.0, 5.0, 7.0]);
803    /// ```
804    pub fn average_with(&self, other: &Array) -> Array {
805        binary_op(self, other, Primitive::Add, |a, b| (a + b) / 2.0)
806    }
807}
808
809#[cfg(test)]
810mod tests {
811    use super::*;
812
813    #[test]
814    fn test_add() {
815        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
816        let b = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
817        let c = a.add(&b);
818        assert_eq!(c.to_vec(), vec![11.0, 22.0, 33.0]);
819    }
820
821    #[test]
822    fn test_sub() {
823        let a = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
824        let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
825        let c = a.sub(&b);
826        assert_eq!(c.to_vec(), vec![9.0, 18.0, 27.0]);
827    }
828
829    #[test]
830    fn test_mul() {
831        let a = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
832        let b = Array::from_vec(vec![5.0, 6.0, 7.0], Shape::new(vec![3]));
833        let c = a.mul(&b);
834        assert_eq!(c.to_vec(), vec![10.0, 18.0, 28.0]);
835    }
836
837    #[test]
838    fn test_div() {
839        let a = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
840        let b = Array::from_vec(vec![2.0, 4.0, 5.0], Shape::new(vec![3]));
841        let c = a.div(&b);
842        assert_eq!(c.to_vec(), vec![5.0, 5.0, 6.0]);
843    }
844
845    #[test]
846    fn test_pow() {
847        let a = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
848        let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
849        let c = a.pow(&b);
850        assert_eq!(c.to_vec(), vec![4.0, 9.0, 16.0]);
851    }
852
853    #[test]
854    fn test_broadcast_scalar() {
855        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
856        let b = Array::from_vec(vec![10.0], Shape::new(vec![1]));
857        let c = a.add(&b);
858        assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
859    }
860
861    #[test]
862    fn test_broadcast_2d() {
863        // [2, 3] + [1, 3] -> [2, 3]
864        let a = Array::from_vec(
865            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
866            Shape::new(vec![2, 3]),
867        );
868        let b =
869            Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![1, 3]));
870        let c = a.add(&b);
871        assert_eq!(c.shape().as_slice(), &[2, 3]);
872        assert_eq!(c.to_vec(), vec![11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
873    }
874
875    #[test]
876    fn test_broadcast_row_col() {
877        // [3, 1] + [1, 3] -> [3, 3]
878        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3, 1]));
879        let b =
880            Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![1, 3]));
881        let c = a.add(&b);
882        assert_eq!(c.shape().as_slice(), &[3, 3]);
883        assert_eq!(
884            c.to_vec(),
885            vec![11.0, 21.0, 31.0, 12.0, 22.0, 32.0, 13.0, 23.0, 33.0]
886        );
887    }
888
889    #[test]
890    fn test_minimum_maximum() {
891        let a = Array::from_vec(vec![1.0, 5.0, 3.0], Shape::new(vec![3]));
892        let b = Array::from_vec(vec![2.0, 4.0, 6.0], Shape::new(vec![3]));
893
894        let min_ab = a.minimum(&b);
895        assert_eq!(min_ab.to_vec(), vec![1.0, 4.0, 3.0]);
896
897        let max_ab = a.maximum(&b);
898        assert_eq!(max_ab.to_vec(), vec![2.0, 5.0, 6.0]);
899    }
900
901    #[test]
902    fn test_divide_no_nan() {
903        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
904        let b = Array::from_vec(vec![2.0, 0.0, 3.0], Shape::new(vec![3]));
905        let c = a.divide_no_nan(&b);
906        assert_eq!(c.to_vec(), vec![0.5, 0.0, 1.0]);
907    }
908
909    #[test]
910    fn test_squared_difference() {
911        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
912        let b = Array::from_vec(vec![2.0, 2.0, 1.0], Shape::new(vec![3]));
913        let c = a.squared_difference(&b);
914        assert_eq!(c.to_vec(), vec![1.0, 0.0, 4.0]);
915    }
916
917    #[test]
918    fn test_mod_op() {
919        let a = Array::from_vec(vec![5.0, 7.0, 9.0], Shape::new(vec![3]));
920        let b = Array::from_vec(vec![3.0, 3.0, 3.0], Shape::new(vec![3]));
921        let c = a.mod_op(&b);
922        assert_eq!(c.to_vec(), vec![2.0, 1.0, 0.0]);
923    }
924
925    #[test]
926    fn test_atan2() {
927        let y = Array::from_vec(vec![1.0, 1.0, -1.0, -1.0], Shape::new(vec![4]));
928        let x = Array::from_vec(vec![1.0, -1.0, 1.0, -1.0], Shape::new(vec![4]));
929        let angle = y.atan2(&x);
930        let result = angle.to_vec();
931        // Just verify it produces reasonable results
932        assert!(result[0] > 0.0 && result[0] < 1.6); // ~π/4
933        assert!(result[1] > 2.0 && result[1] < 3.2); // ~3π/4
934    }
935
936    #[test]
937    fn test_hypot() {
938        let a = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
939        let b = Array::from_vec(vec![4.0, 3.0], Shape::new(vec![2]));
940        let c = a.hypot(&b);
941        assert_eq!(c.to_vec(), vec![5.0, 5.0]);
942    }
943
944    #[test]
945    fn test_copysign() {
946        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
947        let b = Array::from_vec(vec![-1.0, 1.0, -1.0], Shape::new(vec![3]));
948        let c = a.copysign(&b);
949        assert_eq!(c.to_vec(), vec![-1.0, 2.0, -3.0]);
950    }
951
952    #[test]
953    fn test_next_after() {
954        let a = Array::from_vec(vec![1.0], Shape::new(vec![1]));
955        let b = Array::from_vec(vec![2.0], Shape::new(vec![1]));
956        let c = a.next_after(&b);
957        // Should be slightly larger than 1.0
958        assert!(c.to_vec()[0] > 1.0);
959        assert!(c.to_vec()[0] < 1.0 + 1e-6);
960    }
961
962    #[test]
963    fn test_broadcast_index() {
964        // Test broadcast_index function
965        let result_shape = Shape::new(vec![2, 3]);
966        let src_shape = Shape::new(vec![1, 3]);
967
968        // For result shape [2,3], indices 0-5 map to positions:
969        // 0: [0,0] -> [0,0] in [1,3] -> flat 0
970        // 1: [0,1] -> [0,1] in [1,3] -> flat 1
971        // 2: [0,2] -> [0,2] in [1,3] -> flat 2
972        // 3: [1,0] -> [0,0] in [1,3] -> flat 0 (broadcast)
973        // 4: [1,1] -> [0,1] in [1,3] -> flat 1 (broadcast)
974        // 5: [1,2] -> [0,2] in [1,3] -> flat 2 (broadcast)
975        assert_eq!(broadcast_index(0, &result_shape, &src_shape), 0);
976        assert_eq!(broadcast_index(1, &result_shape, &src_shape), 1);
977        assert_eq!(broadcast_index(2, &result_shape, &src_shape), 2);
978        assert_eq!(broadcast_index(3, &result_shape, &src_shape), 0);
979        assert_eq!(broadcast_index(4, &result_shape, &src_shape), 1);
980        assert_eq!(broadcast_index(5, &result_shape, &src_shape), 2);
981    }
982}