jax_rs/ops/binary.rs
1//! Binary operations on arrays.
2
3use crate::trace::{is_tracing, trace_binary, Primitive};
4use crate::{buffer::Buffer, Array, DType, Device, Shape};
5
6/// Apply a binary function element-wise to two arrays with broadcasting.
7fn binary_op<F>(lhs: &Array, rhs: &Array, op: Primitive, f: F) -> Array
8where
9 F: Fn(f32, f32) -> f32,
10{
11 assert_eq!(lhs.dtype(), DType::Float32, "Only Float32 supported");
12 assert_eq!(rhs.dtype(), DType::Float32, "Only Float32 supported");
13
14 // Check if shapes are broadcast-compatible
15 let result_shape = lhs
16 .shape()
17 .broadcast_with(rhs.shape())
18 .expect("Shapes are not broadcast-compatible");
19
20 // Dispatch based on device
21 let result = match (lhs.device(), rhs.device()) {
22 (Device::WebGpu, Device::WebGpu) => {
23 // GPU path - no broadcasting support yet, shapes must match exactly
24 assert_eq!(
25 lhs.shape(),
26 rhs.shape(),
27 "GPU operations do not support broadcasting yet"
28 );
29
30 // Map primitive to WGSL operator
31 let op_str = match &op {
32 Primitive::Add => "+",
33 Primitive::Sub => "-",
34 Primitive::Mul => "*",
35 Primitive::Div => "/",
36 _ => {
37 // Fallback to CPU for unsupported ops
38 return binary_op_cpu(lhs, rhs, op.clone(), f);
39 }
40 };
41
42 // Create output buffer on GPU
43 let output_buffer = Buffer::zeros(
44 result_shape.size(),
45 DType::Float32,
46 Device::WebGpu,
47 );
48
49 // Execute on GPU
50 crate::backend::ops::gpu_binary_op(
51 lhs.buffer(),
52 rhs.buffer(),
53 &output_buffer,
54 op_str,
55 );
56
57 Array::from_buffer(output_buffer, result_shape)
58 }
59 (Device::Cpu, Device::Cpu) | (Device::Wasm, Device::Wasm) => {
60 // CPU path with broadcasting support
61 binary_op_cpu(lhs, rhs, op.clone(), f)
62 }
63 _ => {
64 panic!("Mixed device operations not supported. Both arrays must be on the same device.");
65 }
66 };
67
68 // Register with trace context if tracing is active
69 if is_tracing() {
70 trace_binary(result.id(), op, lhs, rhs);
71 }
72
73 result
74}
75
76/// CPU implementation of binary operation with broadcasting support.
77fn binary_op_cpu<F>(lhs: &Array, rhs: &Array, _op: Primitive, f: F) -> Array
78where
79 F: Fn(f32, f32) -> f32,
80{
81 let result_shape = lhs
82 .shape()
83 .broadcast_with(rhs.shape())
84 .expect("Shapes are not broadcast-compatible");
85
86 let lhs_data = lhs.to_vec();
87 let rhs_data = rhs.to_vec();
88
89 let result_data = if lhs.shape() == rhs.shape() {
90 // Same shape - simple element-wise operation
91 lhs_data.iter().zip(rhs_data.iter()).map(|(&a, &b)| f(a, b)).collect()
92 } else {
93 // Need broadcasting
94 broadcast_binary(
95 &lhs_data,
96 lhs.shape(),
97 &rhs_data,
98 rhs.shape(),
99 &result_shape,
100 f,
101 )
102 };
103
104 let buffer = Buffer::from_f32(result_data, Device::Cpu);
105 Array::from_buffer(buffer, result_shape)
106}
107
108/// Helper function to perform binary operation with broadcasting.
109fn broadcast_binary<F>(
110 lhs_data: &[f32],
111 lhs_shape: &Shape,
112 rhs_data: &[f32],
113 rhs_shape: &Shape,
114 result_shape: &Shape,
115 f: F,
116) -> Vec<f32>
117where
118 F: Fn(f32, f32) -> f32,
119{
120 let size = result_shape.size();
121 let mut result = Vec::with_capacity(size);
122
123 for i in 0..size {
124 let lhs_idx = broadcast_index(i, result_shape, lhs_shape);
125 let rhs_idx = broadcast_index(i, result_shape, rhs_shape);
126 result.push(f(lhs_data[lhs_idx], rhs_data[rhs_idx]));
127 }
128
129 result
130}
131
132/// Convert a flat index in the result array to an index in the source array,
133/// accounting for broadcasting.
134pub(crate) fn broadcast_index(
135 flat_idx: usize,
136 result_shape: &Shape,
137 src_shape: &Shape,
138) -> usize {
139 let result_dims = result_shape.as_slice();
140 let src_dims = src_shape.as_slice();
141
142 // Convert flat index to multi-dimensional index
143 let mut multi_idx = Vec::with_capacity(result_dims.len());
144 let mut idx = flat_idx;
145 for &dim in result_dims.iter().rev() {
146 multi_idx.push(idx % dim);
147 idx /= dim;
148 }
149 multi_idx.reverse();
150
151 // Map to source index with broadcasting
152 let offset = result_dims.len() - src_dims.len();
153 let mut src_idx = 0;
154 let mut stride = 1;
155
156 for i in (0..src_dims.len()).rev() {
157 let result_i = offset + i;
158 let dim_idx = if src_dims[i] == 1 {
159 0 // Broadcast dimension
160 } else {
161 multi_idx[result_i]
162 };
163 src_idx += dim_idx * stride;
164 stride *= src_dims[i];
165 }
166
167 src_idx
168}
169
170impl Array {
171 /// Add two arrays element-wise with broadcasting.
172 ///
173 /// # Examples
174 ///
175 /// ```
176 /// # use jax_rs::{Array, Shape};
177 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
178 /// let b = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
179 /// let c = a.add(&b);
180 /// assert_eq!(c.to_vec(), vec![11.0, 22.0, 33.0]);
181 /// ```
182 pub fn add(&self, other: &Array) -> Array {
183 binary_op(self, other, Primitive::Add, |a, b| a + b)
184 }
185
186 /// Subtract two arrays element-wise with broadcasting.
187 pub fn sub(&self, other: &Array) -> Array {
188 binary_op(self, other, Primitive::Sub, |a, b| a - b)
189 }
190
191 /// Multiply two arrays element-wise with broadcasting.
192 pub fn mul(&self, other: &Array) -> Array {
193 binary_op(self, other, Primitive::Mul, |a, b| a * b)
194 }
195
196 /// Divide two arrays element-wise with broadcasting.
197 pub fn div(&self, other: &Array) -> Array {
198 binary_op(self, other, Primitive::Div, |a, b| a / b)
199 }
200
201 /// Raise elements to a power element-wise with broadcasting.
202 pub fn pow(&self, other: &Array) -> Array {
203 binary_op(self, other, Primitive::Pow, |a, b| a.powf(b))
204 }
205
206 /// Element-wise minimum.
207 pub fn minimum(&self, other: &Array) -> Array {
208 binary_op(self, other, Primitive::Min, |a, b| a.min(b))
209 }
210
211 /// Element-wise maximum.
212 pub fn maximum(&self, other: &Array) -> Array {
213 binary_op(self, other, Primitive::Max, |a, b| a.max(b))
214 }
215
216 /// Safe division that returns 0 where division by zero would occur.
217 ///
218 /// Returns x / y where y != 0, and 0 where y == 0.
219 ///
220 /// # Examples
221 ///
222 /// ```
223 /// # use jax_rs::{Array, Shape};
224 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
225 /// let b = Array::from_vec(vec![2.0, 0.0, 3.0], Shape::new(vec![3]));
226 /// let c = a.divide_no_nan(&b);
227 /// assert_eq!(c.to_vec(), vec![0.5, 0.0, 1.0]);
228 /// ```
229 pub fn divide_no_nan(&self, other: &Array) -> Array {
230 binary_op(self, other, Primitive::Div, |a, b| {
231 if b == 0.0 {
232 0.0
233 } else {
234 a / b
235 }
236 })
237 }
238
239 /// Squared difference: (a - b)^2.
240 ///
241 /// Useful for computing mean squared error and similar metrics.
242 ///
243 /// # Examples
244 ///
245 /// ```
246 /// # use jax_rs::{Array, Shape};
247 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
248 /// let b = Array::from_vec(vec![2.0, 2.0, 1.0], Shape::new(vec![3]));
249 /// let c = a.squared_difference(&b);
250 /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 4.0]);
251 /// ```
252 pub fn squared_difference(&self, other: &Array) -> Array {
253 binary_op(self, other, Primitive::Sub, |a, b| {
254 let diff = a - b;
255 diff * diff
256 })
257 }
258
259 /// Element-wise modulo operation.
260 ///
261 /// # Examples
262 ///
263 /// ```
264 /// # use jax_rs::{Array, Shape};
265 /// let a = Array::from_vec(vec![5.0, 7.0, 9.0], Shape::new(vec![3]));
266 /// let b = Array::from_vec(vec![3.0, 3.0, 3.0], Shape::new(vec![3]));
267 /// let c = a.mod_op(&b);
268 /// assert_eq!(c.to_vec(), vec![2.0, 1.0, 0.0]);
269 /// ```
270 pub fn mod_op(&self, other: &Array) -> Array {
271 binary_op(self, other, Primitive::Div, |a, b| a % b)
272 }
273
274 /// Element-wise arctangent of a/b.
275 ///
276 /// Correctly handles signs to determine quadrant.
277 ///
278 /// # Examples
279 ///
280 /// ```
281 /// # use jax_rs::{Array, Shape};
282 /// let y = Array::from_vec(vec![1.0, -1.0], Shape::new(vec![2]));
283 /// let x = Array::from_vec(vec![1.0, 1.0], Shape::new(vec![2]));
284 /// let angle = y.atan2(&x);
285 /// # // We just check it compiles and runs
286 /// ```
287 pub fn atan2(&self, other: &Array) -> Array {
288 binary_op(self, other, Primitive::Div, |a, b| a.atan2(b))
289 }
290
291 /// Element-wise hypot: sqrt(a^2 + b^2).
292 ///
293 /// Computes the hypotenuse in a numerically stable way.
294 ///
295 /// # Examples
296 ///
297 /// ```
298 /// # use jax_rs::{Array, Shape};
299 /// let a = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
300 /// let b = Array::from_vec(vec![4.0, 3.0], Shape::new(vec![2]));
301 /// let c = a.hypot(&b);
302 /// assert_eq!(c.to_vec(), vec![5.0, 5.0]);
303 /// ```
304 pub fn hypot(&self, other: &Array) -> Array {
305 binary_op(self, other, Primitive::Add, |a, b| a.hypot(b))
306 }
307
308 /// Element-wise copysign: magnitude of a with sign of b.
309 ///
310 /// # Examples
311 ///
312 /// ```
313 /// # use jax_rs::{Array, Shape};
314 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
315 /// let b = Array::from_vec(vec![-1.0, 1.0, -1.0], Shape::new(vec![3]));
316 /// let c = a.copysign(&b);
317 /// assert_eq!(c.to_vec(), vec![-1.0, 2.0, -3.0]);
318 /// ```
319 pub fn copysign(&self, other: &Array) -> Array {
320 binary_op(self, other, Primitive::Mul, |a, b| a.copysign(b))
321 }
322
323 /// Element-wise next representable float in direction of b.
324 ///
325 /// # Examples
326 ///
327 /// ```
328 /// # use jax_rs::{Array, Shape};
329 /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
330 /// let b = Array::from_vec(vec![2.0, 1.0], Shape::new(vec![2]));
331 /// let c = a.next_after(&b);
332 /// # // Just verify it compiles
333 /// ```
334 pub fn next_after(&self, other: &Array) -> Array {
335 binary_op(self, other, Primitive::Add, |a, b| {
336 if a < b {
337 // Next float towards positive infinity
338 f32::from_bits(a.to_bits() + 1)
339 } else if a > b {
340 // Next float towards negative infinity
341 f32::from_bits(a.to_bits() - 1)
342 } else {
343 b
344 }
345 })
346 }
347
348 /// Logarithm of sum of exponentials (numerically stable).
349 ///
350 /// Computes log(exp(x) + exp(y)) in a numerically stable way.
351 ///
352 /// # Examples
353 ///
354 /// ```
355 /// # use jax_rs::{Array, Shape};
356 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
357 /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
358 /// let c = a.logaddexp(&b);
359 /// // Result: log(exp(1)+exp(2)), log(exp(2)+exp(3)), log(exp(3)+exp(4))
360 /// ```
361 pub fn logaddexp(&self, other: &Array) -> Array {
362 binary_op(self, other, Primitive::Add, |a, b| {
363 let max = a.max(b);
364 max + ((a - max).exp() + (b - max).exp()).ln()
365 })
366 }
367
368 /// Base-2 logarithm of sum of exponentials.
369 ///
370 /// Computes log2(2^x + 2^y) in a numerically stable way.
371 ///
372 /// # Examples
373 ///
374 /// ```
375 /// # use jax_rs::{Array, Shape};
376 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
377 /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
378 /// let c = a.logaddexp2(&b);
379 /// ```
380 pub fn logaddexp2(&self, other: &Array) -> Array {
381 binary_op(self, other, Primitive::Add, |a, b| {
382 let max = a.max(b);
383 max + ((a - max).exp2() + (b - max).exp2()).log2()
384 })
385 }
386
387 /// Heaviside step function.
388 ///
389 /// Returns 0 where x < 0, h0 where x == 0, and 1 where x > 0.
390 ///
391 /// # Examples
392 ///
393 /// ```
394 /// # use jax_rs::{Array, Shape};
395 /// let x = Array::from_vec(vec![-1.0, 0.0, 1.0], Shape::new(vec![3]));
396 /// let h0 = Array::from_vec(vec![0.5, 0.5, 0.5], Shape::new(vec![3]));
397 /// let h = x.heaviside(&h0);
398 /// assert_eq!(h.to_vec(), vec![0.0, 0.5, 1.0]);
399 /// ```
400 pub fn heaviside(&self, h0: &Array) -> Array {
401 binary_op(self, h0, Primitive::Max, |x, h0_val| {
402 if x < 0.0 {
403 0.0
404 } else if x == 0.0 {
405 h0_val
406 } else {
407 1.0
408 }
409 })
410 }
411
412 /// Floor division (division rounding toward negative infinity).
413 ///
414 /// # Examples
415 ///
416 /// ```
417 /// # use jax_rs::{Array, Shape};
418 /// let a = Array::from_vec(vec![7.0, 7.0, -7.0], Shape::new(vec![3]));
419 /// let b = Array::from_vec(vec![3.0, -3.0, 3.0], Shape::new(vec![3]));
420 /// let c = a.floor_divide(&b);
421 /// assert_eq!(c.to_vec(), vec![2.0, -3.0, -3.0]);
422 /// ```
423 pub fn floor_divide(&self, other: &Array) -> Array {
424 binary_op(self, other, Primitive::Div, |a, b| (a / b).floor())
425 }
426
427 /// Fused multiply-add: a * b + c.
428 ///
429 /// # Examples
430 ///
431 /// ```
432 /// # use jax_rs::{Array, Shape};
433 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
434 /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
435 /// let c = Array::from_vec(vec![1.0, 1.0, 1.0], Shape::new(vec![3]));
436 /// let result = a.fma(&b, &c);
437 /// assert_eq!(result.to_vec(), vec![3.0, 7.0, 13.0]); // [1*2+1, 2*3+1, 3*4+1]
438 /// ```
439 pub fn fma(&self, b: &Array, c: &Array) -> Array {
440 let product = self.mul(b);
441 product.add(c)
442 }
443
444 /// Greatest common divisor element-wise.
445 ///
446 /// # Examples
447 ///
448 /// ```
449 /// # use jax_rs::{Array, Shape};
450 /// let a = Array::from_vec(vec![12.0, 15.0, 24.0], Shape::new(vec![3]));
451 /// let b = Array::from_vec(vec![8.0, 10.0, 18.0], Shape::new(vec![3]));
452 /// let c = a.gcd(&b);
453 /// assert_eq!(c.to_vec(), vec![4.0, 5.0, 6.0]);
454 /// ```
455 pub fn gcd(&self, other: &Array) -> Array {
456 binary_op(self, other, Primitive::Min, |mut a, mut b| {
457 a = a.abs();
458 b = b.abs();
459 while b > 0.5 {
460 let temp = b;
461 b = a % b;
462 a = temp;
463 }
464 a
465 })
466 }
467
468 /// Least common multiple element-wise.
469 ///
470 /// # Examples
471 ///
472 /// ```
473 /// # use jax_rs::{Array, Shape};
474 /// let a = Array::from_vec(vec![12.0, 15.0, 24.0], Shape::new(vec![3]));
475 /// let b = Array::from_vec(vec![8.0, 10.0, 18.0], Shape::new(vec![3]));
476 /// let c = a.lcm(&b);
477 /// assert_eq!(c.to_vec(), vec![24.0, 30.0, 72.0]);
478 /// ```
479 pub fn lcm(&self, other: &Array) -> Array {
480 binary_op(self, other, Primitive::Mul, |mut a, mut b| {
481 a = a.abs();
482 b = b.abs();
483 if a < 0.5 || b < 0.5 {
484 return 0.0;
485 }
486 let mut gcd_val = a;
487 let mut temp = b;
488 while temp > 0.5 {
489 let r = gcd_val % temp;
490 gcd_val = temp;
491 temp = r;
492 }
493 (a * b) / gcd_val
494 })
495 }
496
497 /// Bitwise AND operation.
498 /// Operates on the bit representation of Float32 values.
499 ///
500 /// # Examples
501 ///
502 /// ```
503 /// # use jax_rs::{Array, Shape};
504 /// let a = Array::from_vec(vec![15.0, 31.0, 63.0], Shape::new(vec![3]));
505 /// let b = Array::from_vec(vec![7.0, 15.0, 31.0], Shape::new(vec![3]));
506 /// let c = a.bitwise_and(&b);
507 /// ```
508 pub fn bitwise_and(&self, other: &Array) -> Array {
509 binary_op(self, other, Primitive::Min, |a, b| {
510 let a_bits = a.to_bits();
511 let b_bits = b.to_bits();
512 f32::from_bits(a_bits & b_bits)
513 })
514 }
515
516 /// Bitwise OR operation.
517 /// Operates on the bit representation of Float32 values.
518 ///
519 /// # Examples
520 ///
521 /// ```
522 /// # use jax_rs::{Array, Shape};
523 /// let a = Array::from_vec(vec![8.0, 16.0, 32.0], Shape::new(vec![3]));
524 /// let b = Array::from_vec(vec![4.0, 8.0, 16.0], Shape::new(vec![3]));
525 /// let c = a.bitwise_or(&b);
526 /// ```
527 pub fn bitwise_or(&self, other: &Array) -> Array {
528 binary_op(self, other, Primitive::Max, |a, b| {
529 let a_bits = a.to_bits();
530 let b_bits = b.to_bits();
531 f32::from_bits(a_bits | b_bits)
532 })
533 }
534
535 /// Bitwise XOR operation.
536 /// Operates on the bit representation of Float32 values.
537 ///
538 /// # Examples
539 ///
540 /// ```
541 /// # use jax_rs::{Array, Shape};
542 /// let a = Array::from_vec(vec![12.0, 15.0, 18.0], Shape::new(vec![3]));
543 /// let b = Array::from_vec(vec![10.0, 5.0, 20.0], Shape::new(vec![3]));
544 /// let c = a.bitwise_xor(&b);
545 /// ```
546 pub fn bitwise_xor(&self, other: &Array) -> Array {
547 binary_op(self, other, Primitive::Add, |a, b| {
548 let a_bits = a.to_bits();
549 let b_bits = b.to_bits();
550 f32::from_bits(a_bits ^ b_bits)
551 })
552 }
553
554 /// Left bit shift operation.
555 /// Shifts the bit representation of Float32 values left.
556 ///
557 /// # Examples
558 ///
559 /// ```
560 /// # use jax_rs::{Array, Shape};
561 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
562 /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
563 /// let c = a.left_shift(&b);
564 /// ```
565 pub fn left_shift(&self, other: &Array) -> Array {
566 binary_op(self, other, Primitive::Mul, |a, b| {
567 let a_bits = a.to_bits();
568 let shift = b as u32;
569 f32::from_bits(a_bits << shift)
570 })
571 }
572
573 /// Right bit shift operation.
574 /// Shifts the bit representation of Float32 values right.
575 ///
576 /// # Examples
577 ///
578 /// ```
579 /// # use jax_rs::{Array, Shape};
580 /// let a = Array::from_vec(vec![4.0, 8.0, 16.0], Shape::new(vec![3]));
581 /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
582 /// let c = a.right_shift(&b);
583 /// ```
584 pub fn right_shift(&self, other: &Array) -> Array {
585 binary_op(self, other, Primitive::Div, |a, b| {
586 let a_bits = a.to_bits();
587 let shift = b as u32;
588 f32::from_bits(a_bits >> shift)
589 })
590 }
591
592 /// Element-wise maximum, ignoring NaNs.
593 ///
594 /// # Examples
595 ///
596 /// ```
597 /// # use jax_rs::{Array, Shape};
598 /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0], Shape::new(vec![3]));
599 /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
600 /// let c = a.fmax(&b);
601 /// assert_eq!(c.to_vec()[0], 2.0);
602 /// assert_eq!(c.to_vec()[1], 2.0);
603 /// assert_eq!(c.to_vec()[2], 3.0);
604 /// ```
605 pub fn fmax(&self, other: &Array) -> Array {
606 binary_op(self, other, Primitive::Max, |a, b| {
607 if a.is_nan() { b }
608 else if b.is_nan() { a }
609 else { a.max(b) }
610 })
611 }
612
613 /// Element-wise minimum, ignoring NaNs.
614 ///
615 /// # Examples
616 ///
617 /// ```
618 /// # use jax_rs::{Array, Shape};
619 /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0], Shape::new(vec![3]));
620 /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
621 /// let c = a.fmin(&b);
622 /// assert_eq!(c.to_vec()[0], 1.0);
623 /// assert_eq!(c.to_vec()[1], 2.0);
624 /// assert_eq!(c.to_vec()[2], 2.0);
625 /// ```
626 pub fn fmin(&self, other: &Array) -> Array {
627 binary_op(self, other, Primitive::Min, |a, b| {
628 if a.is_nan() { b }
629 else if b.is_nan() { a }
630 else { a.min(b) }
631 })
632 }
633
634 /// Element-wise arc tangent of x1/x2 choosing the quadrant correctly.
635 ///
636 /// The quadrant (i.e., branch) is chosen so that arctan2(x1, x2) is
637 /// the signed angle in radians between the ray ending at the origin
638 /// and passing through the point (1,0), and the ray ending at the
639 /// origin and passing through the point (x2, x1).
640 ///
641 /// # Examples
642 ///
643 /// ```
644 /// # use jax_rs::{Array, Shape};
645 /// let y = Array::from_vec(vec![1.0, -1.0, 1.0, -1.0], Shape::new(vec![4]));
646 /// let x = Array::from_vec(vec![1.0, 1.0, -1.0, -1.0], Shape::new(vec![4]));
647 /// let angles = y.arctan2(&x);
648 /// // First quadrant: pi/4, Second: -pi/4, Third: 3pi/4, Fourth: -3pi/4
649 /// ```
650 pub fn arctan2(&self, other: &Array) -> Array {
651 binary_op(self, other, Primitive::Div, |y, x| y.atan2(x))
652 }
653
654 /// Element-wise remainder of division (fmod).
655 ///
656 /// # Examples
657 ///
658 /// ```
659 /// # use jax_rs::{Array, Shape};
660 /// let a = Array::from_vec(vec![5.0, 7.0, 10.0], Shape::new(vec![3]));
661 /// let b = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
662 /// let c = a.fmod(&b);
663 /// assert_eq!(c.to_vec(), vec![1.0, 1.0, 2.0]);
664 /// ```
665 pub fn fmod(&self, other: &Array) -> Array {
666 binary_op(self, other, Primitive::Div, |a, b| a % b)
667 }
668
669 /// Return the next floating-point value after x1 towards x2.
670 ///
671 /// # Examples
672 ///
673 /// ```
674 /// # use jax_rs::{Array, Shape};
675 /// let a = Array::from_vec(vec![1.0, 2.0], Shape::new(vec![2]));
676 /// let b = Array::from_vec(vec![2.0, 1.0], Shape::new(vec![2]));
677 /// let c = a.nextafter(&b);
678 /// // First element goes up slightly, second goes down
679 /// ```
680 pub fn nextafter(&self, other: &Array) -> Array {
681 binary_op(self, other, Primitive::Add, |x1, x2| {
682 if x1 == x2 {
683 x2
684 } else if x2 > x1 {
685 // Next float toward positive infinity
686 let bits = x1.to_bits();
687 if x1 >= 0.0 {
688 f32::from_bits(bits + 1)
689 } else {
690 f32::from_bits(bits - 1)
691 }
692 } else {
693 // Next float toward negative infinity
694 let bits = x1.to_bits();
695 if x1 > 0.0 {
696 f32::from_bits(bits - 1)
697 } else if x1 == 0.0 {
698 f32::from_bits(1 | (1 << 31)) // Negative zero direction
699 } else {
700 f32::from_bits(bits + 1)
701 }
702 }
703 })
704 }
705
706 /// Compute the safe element-wise division, returning 0 where denominator is 0.
707 ///
708 /// # Examples
709 ///
710 /// ```
711 /// # use jax_rs::{Array, Shape};
712 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
713 /// let b = Array::from_vec(vec![1.0, 0.0, 3.0], Shape::new(vec![3]));
714 /// let c = a.safe_divide(&b);
715 /// assert_eq!(c.to_vec(), vec![1.0, 0.0, 1.0]);
716 /// ```
717 pub fn safe_divide(&self, other: &Array) -> Array {
718 binary_op(self, other, Primitive::Div, |a, b| {
719 if b == 0.0 { 0.0 } else { a / b }
720 })
721 }
722
723 /// Compute element-wise true division.
724 ///
725 /// # Examples
726 ///
727 /// ```
728 /// # use jax_rs::{Array, Shape};
729 /// let a = Array::from_vec(vec![5.0, 7.0, 9.0], Shape::new(vec![3]));
730 /// let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
731 /// let c = a.true_divide(&b);
732 /// assert_eq!(c.to_vec(), vec![2.5, 3.5, 4.5]);
733 /// ```
734 pub fn true_divide(&self, other: &Array) -> Array {
735 self.div(other)
736 }
737
738 /// Compute element-wise remainder, with the same sign as divisor.
739 ///
740 /// # Examples
741 ///
742 /// ```
743 /// # use jax_rs::{Array, Shape};
744 /// let a = Array::from_vec(vec![7.0, -7.0, 7.0], Shape::new(vec![3]));
745 /// let b = Array::from_vec(vec![3.0, 3.0, -3.0], Shape::new(vec![3]));
746 /// let c = a.remainder(&b);
747 /// // Python-style modulo: result has same sign as divisor
748 /// ```
749 pub fn remainder(&self, other: &Array) -> Array {
750 binary_op(self, other, Primitive::Div, |a, b| {
751 let r = a % b;
752 if (r > 0.0 && b < 0.0) || (r < 0.0 && b > 0.0) {
753 r + b
754 } else {
755 r
756 }
757 })
758 }
759
760 /// Compute element-wise difference raised to a power.
761 ///
762 /// # Examples
763 ///
764 /// ```
765 /// # use jax_rs::{Array, Shape};
766 /// let a = Array::from_vec(vec![3.0, 5.0, 7.0], Shape::new(vec![3]));
767 /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
768 /// let c = a.diff_pow(&b, 2.0); // (a - b)^2
769 /// assert_eq!(c.to_vec(), vec![4.0, 9.0, 16.0]);
770 /// ```
771 pub fn diff_pow(&self, other: &Array, power: f32) -> Array {
772 binary_op(self, other, Primitive::Sub, move |a, b| (a - b).powf(power))
773 }
774
775 /// Compute element-wise squared difference.
776 ///
777 /// # Examples
778 ///
779 /// ```
780 /// # use jax_rs::{Array, Shape};
781 /// let a = Array::from_vec(vec![3.0, 5.0, 7.0], Shape::new(vec![3]));
782 /// let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
783 /// let c = a.squared_diff(&b); // (a - b)^2
784 /// assert_eq!(c.to_vec(), vec![4.0, 9.0, 16.0]);
785 /// ```
786 pub fn squared_diff(&self, other: &Array) -> Array {
787 binary_op(self, other, Primitive::Sub, |a, b| {
788 let d = a - b;
789 d * d
790 })
791 }
792
793 /// Compute element-wise average of two arrays.
794 ///
795 /// # Examples
796 ///
797 /// ```
798 /// # use jax_rs::{Array, Shape};
799 /// let a = Array::from_vec(vec![2.0, 4.0, 6.0], Shape::new(vec![3]));
800 /// let b = Array::from_vec(vec![4.0, 6.0, 8.0], Shape::new(vec![3]));
801 /// let c = a.average_with(&b);
802 /// assert_eq!(c.to_vec(), vec![3.0, 5.0, 7.0]);
803 /// ```
804 pub fn average_with(&self, other: &Array) -> Array {
805 binary_op(self, other, Primitive::Add, |a, b| (a + b) / 2.0)
806 }
807}
808
809#[cfg(test)]
810mod tests {
811 use super::*;
812
813 #[test]
814 fn test_add() {
815 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
816 let b = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
817 let c = a.add(&b);
818 assert_eq!(c.to_vec(), vec![11.0, 22.0, 33.0]);
819 }
820
821 #[test]
822 fn test_sub() {
823 let a = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
824 let b = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
825 let c = a.sub(&b);
826 assert_eq!(c.to_vec(), vec![9.0, 18.0, 27.0]);
827 }
828
829 #[test]
830 fn test_mul() {
831 let a = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
832 let b = Array::from_vec(vec![5.0, 6.0, 7.0], Shape::new(vec![3]));
833 let c = a.mul(&b);
834 assert_eq!(c.to_vec(), vec![10.0, 18.0, 28.0]);
835 }
836
837 #[test]
838 fn test_div() {
839 let a = Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![3]));
840 let b = Array::from_vec(vec![2.0, 4.0, 5.0], Shape::new(vec![3]));
841 let c = a.div(&b);
842 assert_eq!(c.to_vec(), vec![5.0, 5.0, 6.0]);
843 }
844
845 #[test]
846 fn test_pow() {
847 let a = Array::from_vec(vec![2.0, 3.0, 4.0], Shape::new(vec![3]));
848 let b = Array::from_vec(vec![2.0, 2.0, 2.0], Shape::new(vec![3]));
849 let c = a.pow(&b);
850 assert_eq!(c.to_vec(), vec![4.0, 9.0, 16.0]);
851 }
852
853 #[test]
854 fn test_broadcast_scalar() {
855 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
856 let b = Array::from_vec(vec![10.0], Shape::new(vec![1]));
857 let c = a.add(&b);
858 assert_eq!(c.to_vec(), vec![11.0, 12.0, 13.0]);
859 }
860
861 #[test]
862 fn test_broadcast_2d() {
863 // [2, 3] + [1, 3] -> [2, 3]
864 let a = Array::from_vec(
865 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
866 Shape::new(vec![2, 3]),
867 );
868 let b =
869 Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![1, 3]));
870 let c = a.add(&b);
871 assert_eq!(c.shape().as_slice(), &[2, 3]);
872 assert_eq!(c.to_vec(), vec![11.0, 22.0, 33.0, 14.0, 25.0, 36.0]);
873 }
874
875 #[test]
876 fn test_broadcast_row_col() {
877 // [3, 1] + [1, 3] -> [3, 3]
878 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3, 1]));
879 let b =
880 Array::from_vec(vec![10.0, 20.0, 30.0], Shape::new(vec![1, 3]));
881 let c = a.add(&b);
882 assert_eq!(c.shape().as_slice(), &[3, 3]);
883 assert_eq!(
884 c.to_vec(),
885 vec![11.0, 21.0, 31.0, 12.0, 22.0, 32.0, 13.0, 23.0, 33.0]
886 );
887 }
888
889 #[test]
890 fn test_minimum_maximum() {
891 let a = Array::from_vec(vec![1.0, 5.0, 3.0], Shape::new(vec![3]));
892 let b = Array::from_vec(vec![2.0, 4.0, 6.0], Shape::new(vec![3]));
893
894 let min_ab = a.minimum(&b);
895 assert_eq!(min_ab.to_vec(), vec![1.0, 4.0, 3.0]);
896
897 let max_ab = a.maximum(&b);
898 assert_eq!(max_ab.to_vec(), vec![2.0, 5.0, 6.0]);
899 }
900
901 #[test]
902 fn test_divide_no_nan() {
903 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
904 let b = Array::from_vec(vec![2.0, 0.0, 3.0], Shape::new(vec![3]));
905 let c = a.divide_no_nan(&b);
906 assert_eq!(c.to_vec(), vec![0.5, 0.0, 1.0]);
907 }
908
909 #[test]
910 fn test_squared_difference() {
911 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
912 let b = Array::from_vec(vec![2.0, 2.0, 1.0], Shape::new(vec![3]));
913 let c = a.squared_difference(&b);
914 assert_eq!(c.to_vec(), vec![1.0, 0.0, 4.0]);
915 }
916
917 #[test]
918 fn test_mod_op() {
919 let a = Array::from_vec(vec![5.0, 7.0, 9.0], Shape::new(vec![3]));
920 let b = Array::from_vec(vec![3.0, 3.0, 3.0], Shape::new(vec![3]));
921 let c = a.mod_op(&b);
922 assert_eq!(c.to_vec(), vec![2.0, 1.0, 0.0]);
923 }
924
925 #[test]
926 fn test_atan2() {
927 let y = Array::from_vec(vec![1.0, 1.0, -1.0, -1.0], Shape::new(vec![4]));
928 let x = Array::from_vec(vec![1.0, -1.0, 1.0, -1.0], Shape::new(vec![4]));
929 let angle = y.atan2(&x);
930 let result = angle.to_vec();
931 // Just verify it produces reasonable results
932 assert!(result[0] > 0.0 && result[0] < 1.6); // ~π/4
933 assert!(result[1] > 2.0 && result[1] < 3.2); // ~3π/4
934 }
935
936 #[test]
937 fn test_hypot() {
938 let a = Array::from_vec(vec![3.0, 4.0], Shape::new(vec![2]));
939 let b = Array::from_vec(vec![4.0, 3.0], Shape::new(vec![2]));
940 let c = a.hypot(&b);
941 assert_eq!(c.to_vec(), vec![5.0, 5.0]);
942 }
943
944 #[test]
945 fn test_copysign() {
946 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
947 let b = Array::from_vec(vec![-1.0, 1.0, -1.0], Shape::new(vec![3]));
948 let c = a.copysign(&b);
949 assert_eq!(c.to_vec(), vec![-1.0, 2.0, -3.0]);
950 }
951
952 #[test]
953 fn test_next_after() {
954 let a = Array::from_vec(vec![1.0], Shape::new(vec![1]));
955 let b = Array::from_vec(vec![2.0], Shape::new(vec![1]));
956 let c = a.next_after(&b);
957 // Should be slightly larger than 1.0
958 assert!(c.to_vec()[0] > 1.0);
959 assert!(c.to_vec()[0] < 1.0 + 1e-6);
960 }
961
962 #[test]
963 fn test_broadcast_index() {
964 // Test broadcast_index function
965 let result_shape = Shape::new(vec![2, 3]);
966 let src_shape = Shape::new(vec![1, 3]);
967
968 // For result shape [2,3], indices 0-5 map to positions:
969 // 0: [0,0] -> [0,0] in [1,3] -> flat 0
970 // 1: [0,1] -> [0,1] in [1,3] -> flat 1
971 // 2: [0,2] -> [0,2] in [1,3] -> flat 2
972 // 3: [1,0] -> [0,0] in [1,3] -> flat 0 (broadcast)
973 // 4: [1,1] -> [0,1] in [1,3] -> flat 1 (broadcast)
974 // 5: [1,2] -> [0,2] in [1,3] -> flat 2 (broadcast)
975 assert_eq!(broadcast_index(0, &result_shape, &src_shape), 0);
976 assert_eq!(broadcast_index(1, &result_shape, &src_shape), 1);
977 assert_eq!(broadcast_index(2, &result_shape, &src_shape), 2);
978 assert_eq!(broadcast_index(3, &result_shape, &src_shape), 0);
979 assert_eq!(broadcast_index(4, &result_shape, &src_shape), 1);
980 assert_eq!(broadcast_index(5, &result_shape, &src_shape), 2);
981 }
982}