jax_rs/ops/unary.rs
1//! Unary operations on arrays.
2
3use crate::trace::{is_tracing, trace_unary, Primitive};
4use crate::{buffer::Buffer, Array, DType, Device};
5
6#[cfg(test)]
7use crate::Shape;
8
9/// Stirling's approximation for lgamma for x >= 7.
10fn lgamma_impl(x: f32) -> f32 {
11 let x64 = x as f64;
12 let c = [
13 76.18009172947146,
14 -86.50532032941677,
15 24.01409824083091,
16 -1.231739572450155,
17 0.1208650973866179e-2,
18 -0.5395239384953e-5,
19 ];
20 let tmp = x64 + 5.5;
21 let tmp = tmp - (x64 + 0.5) * tmp.ln();
22 let mut ser = 1.000000000190015;
23 for (i, &cval) in c.iter().enumerate() {
24 ser += cval / (x64 + (i + 1) as f64);
25 }
26 (-tmp + (2.5066282746310005 * ser / x64).ln()) as f32
27}
28
29/// Apply a unary function element-wise to an array.
30fn unary_op<F>(input: &Array, op: Primitive, f: F) -> Array
31where
32 F: Fn(f32) -> f32,
33{
34 assert_eq!(input.dtype(), DType::Float32, "Only Float32 supported");
35 assert_eq!(input.device(), Device::Cpu, "Only CPU supported for now");
36
37 let data = input.to_vec();
38 let result_data: Vec<f32> = data.iter().map(|&x| f(x)).collect();
39 let buffer = Buffer::from_f32(result_data, Device::Cpu);
40
41 let result = Array::from_buffer(buffer, input.shape().clone());
42
43 // Register with trace context if tracing is active
44 if is_tracing() {
45 trace_unary(result.id(), op, input);
46 }
47
48 result
49}
50
51impl Array {
52 /// Negate the array element-wise.
53 ///
54 /// # Examples
55 ///
56 /// ```
57 /// # use jax_rs::{Array, Shape};
58 /// let a = Array::from_vec(vec![1.0, -2.0, 3.0], Shape::new(vec![3]));
59 /// let b = a.neg();
60 /// assert_eq!(b.to_vec(), vec![-1.0, 2.0, -3.0]);
61 /// ```
62 pub fn neg(&self) -> Array {
63 unary_op(self, Primitive::Neg, |x| -x)
64 }
65
66 /// Absolute value element-wise.
67 pub fn abs(&self) -> Array {
68 unary_op(self, Primitive::Abs, |x| x.abs())
69 }
70
71 /// Sine element-wise.
72 pub fn sin(&self) -> Array {
73 unary_op(self, Primitive::Sin, |x| x.sin())
74 }
75
76 /// Cosine element-wise.
77 pub fn cos(&self) -> Array {
78 unary_op(self, Primitive::Cos, |x| x.cos())
79 }
80
81 /// Tangent element-wise.
82 pub fn tan(&self) -> Array {
83 unary_op(self, Primitive::Tan, |x| x.tan())
84 }
85
86 /// Hyperbolic tangent element-wise.
87 pub fn tanh(&self) -> Array {
88 unary_op(self, Primitive::Tanh, |x| x.tanh())
89 }
90
91 /// Natural exponential (e^x) element-wise.
92 pub fn exp(&self) -> Array {
93 unary_op(self, Primitive::Exp, |x| x.exp())
94 }
95
96 /// Natural logarithm element-wise.
97 pub fn log(&self) -> Array {
98 unary_op(self, Primitive::Log, |x| x.ln())
99 }
100
101 /// Square root element-wise.
102 pub fn sqrt(&self) -> Array {
103 unary_op(self, Primitive::Sqrt, |x| x.sqrt())
104 }
105
106 /// Reciprocal (1/x) element-wise.
107 pub fn reciprocal(&self) -> Array {
108 unary_op(self, Primitive::Reciprocal, |x| 1.0 / x)
109 }
110
111 /// Square (x^2) element-wise.
112 pub fn square(&self) -> Array {
113 unary_op(self, Primitive::Square, |x| x * x)
114 }
115
116 /// Sign function element-wise (-1, 0, or 1).
117 pub fn sign(&self) -> Array {
118 unary_op(self, Primitive::Sign, |x| {
119 if x > 0.0 {
120 1.0
121 } else if x < 0.0 {
122 -1.0
123 } else {
124 0.0
125 }
126 })
127 }
128
129 /// Hyperbolic sine element-wise.
130 ///
131 /// # Examples
132 ///
133 /// ```
134 /// # use jax_rs::{Array, Shape};
135 /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
136 /// let b = a.sinh();
137 /// assert_eq!(b.to_vec()[0], 0.0);
138 /// ```
139 pub fn sinh(&self) -> Array {
140 unary_op(self, Primitive::Sin, |x| x.sinh())
141 }
142
143 /// Hyperbolic cosine element-wise.
144 ///
145 /// # Examples
146 ///
147 /// ```
148 /// # use jax_rs::{Array, Shape};
149 /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
150 /// let b = a.cosh();
151 /// assert_eq!(b.to_vec()[0], 1.0);
152 /// ```
153 pub fn cosh(&self) -> Array {
154 unary_op(self, Primitive::Cos, |x| x.cosh())
155 }
156
157 /// Arcsine element-wise.
158 ///
159 /// # Examples
160 ///
161 /// ```
162 /// # use jax_rs::{Array, Shape};
163 /// let a = Array::from_vec(vec![0.0, 1.0], Shape::new(vec![2]));
164 /// let b = a.asin();
165 /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
166 /// assert!((b.to_vec()[1] - std::f32::consts::FRAC_PI_2).abs() < 1e-6);
167 /// ```
168 pub fn asin(&self) -> Array {
169 unary_op(self, Primitive::Sin, |x| x.asin())
170 }
171
172 /// Arccosine element-wise.
173 ///
174 /// # Examples
175 ///
176 /// ```
177 /// # use jax_rs::{Array, Shape};
178 /// let a = Array::from_vec(vec![1.0], Shape::new(vec![1]));
179 /// let b = a.acos();
180 /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
181 /// ```
182 pub fn acos(&self) -> Array {
183 unary_op(self, Primitive::Cos, |x| x.acos())
184 }
185
186 /// Arctangent element-wise.
187 ///
188 /// # Examples
189 ///
190 /// ```
191 /// # use jax_rs::{Array, Shape};
192 /// let a = Array::from_vec(vec![0.0, 1.0], Shape::new(vec![2]));
193 /// let b = a.atan();
194 /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
195 /// assert!((b.to_vec()[1] - std::f32::consts::FRAC_PI_4).abs() < 1e-6);
196 /// ```
197 pub fn atan(&self) -> Array {
198 unary_op(self, Primitive::Tan, |x| x.atan())
199 }
200
201 /// Inverse hyperbolic sine element-wise.
202 ///
203 /// # Examples
204 ///
205 /// ```
206 /// # use jax_rs::{Array, Shape};
207 /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
208 /// let b = a.asinh();
209 /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
210 /// ```
211 pub fn asinh(&self) -> Array {
212 unary_op(self, Primitive::Sin, |x| x.asinh())
213 }
214
215 /// Inverse hyperbolic cosine element-wise.
216 ///
217 /// # Examples
218 ///
219 /// ```
220 /// # use jax_rs::{Array, Shape};
221 /// let a = Array::from_vec(vec![1.0], Shape::new(vec![1]));
222 /// let b = a.acosh();
223 /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
224 /// ```
225 pub fn acosh(&self) -> Array {
226 unary_op(self, Primitive::Cos, |x| x.acosh())
227 }
228
229 /// Inverse hyperbolic tangent element-wise.
230 ///
231 /// # Examples
232 ///
233 /// ```
234 /// # use jax_rs::{Array, Shape};
235 /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
236 /// let b = a.atanh();
237 /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
238 /// ```
239 pub fn atanh(&self) -> Array {
240 unary_op(self, Primitive::Tanh, |x| x.atanh())
241 }
242
243 /// Ceiling function element-wise.
244 ///
245 /// # Examples
246 ///
247 /// ```
248 /// # use jax_rs::{Array, Shape};
249 /// let a = Array::from_vec(vec![1.2, 2.7, -0.5], Shape::new(vec![3]));
250 /// let b = a.ceil();
251 /// assert_eq!(b.to_vec(), vec![2.0, 3.0, 0.0]);
252 /// ```
253 pub fn ceil(&self) -> Array {
254 unary_op(self, Primitive::Sign, |x| x.ceil())
255 }
256
257 /// Floor function element-wise.
258 ///
259 /// # Examples
260 ///
261 /// ```
262 /// # use jax_rs::{Array, Shape};
263 /// let a = Array::from_vec(vec![1.2, 2.7, -0.5], Shape::new(vec![3]));
264 /// let b = a.floor();
265 /// assert_eq!(b.to_vec(), vec![1.0, 2.0, -1.0]);
266 /// ```
267 pub fn floor(&self) -> Array {
268 unary_op(self, Primitive::Sign, |x| x.floor())
269 }
270
271 /// Round to nearest integer element-wise.
272 ///
273 /// # Examples
274 ///
275 /// ```
276 /// # use jax_rs::{Array, Shape};
277 /// let a = Array::from_vec(vec![1.2, 2.7, -0.5], Shape::new(vec![3]));
278 /// let b = a.round();
279 /// assert_eq!(b.to_vec(), vec![1.0, 3.0, -1.0]);
280 /// ```
281 pub fn round(&self) -> Array {
282 unary_op(self, Primitive::Sign, |x| x.round())
283 }
284
285 /// Truncate to integer element-wise (round toward zero).
286 ///
287 /// # Examples
288 ///
289 /// ```
290 /// # use jax_rs::{Array, Shape};
291 /// let a = Array::from_vec(vec![1.7, 2.3, -1.7], Shape::new(vec![3]));
292 /// let b = a.trunc();
293 /// assert_eq!(b.to_vec(), vec![1.0, 2.0, -1.0]);
294 /// ```
295 pub fn trunc(&self) -> Array {
296 unary_op(self, Primitive::Sign, |x| x.trunc())
297 }
298
299 /// Exponential minus 1 (e^x - 1) element-wise.
300 ///
301 /// More accurate than exp(x) - 1 for small values of x.
302 ///
303 /// # Examples
304 ///
305 /// ```
306 /// # use jax_rs::{Array, Shape};
307 /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
308 /// let b = a.expm1();
309 /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
310 /// ```
311 pub fn expm1(&self) -> Array {
312 unary_op(self, Primitive::Exp, |x| x.exp_m1())
313 }
314
315 /// Natural logarithm of 1 + x element-wise.
316 ///
317 /// More accurate than log(1 + x) for small values of x.
318 ///
319 /// # Examples
320 ///
321 /// ```
322 /// # use jax_rs::{Array, Shape};
323 /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
324 /// let b = a.log1p();
325 /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
326 /// ```
327 pub fn log1p(&self) -> Array {
328 unary_op(self, Primitive::Log, |x| x.ln_1p())
329 }
330
331 /// Safe reciprocal that returns 0 where x == 0.
332 ///
333 /// Returns 1/x where x != 0, and 0 where x == 0.
334 ///
335 /// # Examples
336 ///
337 /// ```
338 /// # use jax_rs::{Array, Shape};
339 /// let a = Array::from_vec(vec![2.0, 0.0, 4.0], Shape::new(vec![3]));
340 /// let b = a.reciprocal_no_nan();
341 /// assert_eq!(b.to_vec(), vec![0.5, 0.0, 0.25]);
342 /// ```
343 pub fn reciprocal_no_nan(&self) -> Array {
344 unary_op(self, Primitive::Reciprocal, |x| {
345 if x == 0.0 {
346 0.0
347 } else {
348 1.0 / x
349 }
350 })
351 }
352
353 /// Convert degrees to radians.
354 ///
355 /// # Examples
356 ///
357 /// ```
358 /// # use jax_rs::{Array, Shape};
359 /// let degrees = Array::from_vec(vec![0.0, 90.0, 180.0], Shape::new(vec![3]));
360 /// let radians = degrees.deg2rad();
361 /// assert!((radians.to_vec()[1] - std::f32::consts::PI / 2.0).abs() < 1e-5);
362 /// ```
363 pub fn deg2rad(&self) -> Array {
364 unary_op(self, Primitive::Mul, |x| x * std::f32::consts::PI / 180.0)
365 }
366
367 /// Convert radians to degrees.
368 ///
369 /// # Examples
370 ///
371 /// ```
372 /// # use jax_rs::{Array, Shape};
373 /// let radians = Array::from_vec(vec![0.0, std::f32::consts::PI / 2.0, std::f32::consts::PI], Shape::new(vec![3]));
374 /// let degrees = radians.rad2deg();
375 /// assert!((degrees.to_vec()[1] - 90.0).abs() < 1e-5);
376 /// ```
377 pub fn rad2deg(&self) -> Array {
378 unary_op(self, Primitive::Mul, |x| x * 180.0 / std::f32::consts::PI)
379 }
380
381 /// Compute the sinc function: sin(x) / x.
382 ///
383 /// # Examples
384 ///
385 /// ```
386 /// # use jax_rs::{Array, Shape};
387 /// let x = Array::from_vec(vec![0.0, 1.0, 2.0], Shape::new(vec![3]));
388 /// let y = x.sinc();
389 /// assert_eq!(y.to_vec()[0], 1.0); // sinc(0) = 1
390 /// ```
391 pub fn sinc(&self) -> Array {
392 unary_op(self, Primitive::Sin, |x| {
393 if x.abs() < 1e-10 {
394 1.0
395 } else {
396 x.sin() / x
397 }
398 })
399 }
400
401 /// Compute the cube root.
402 ///
403 /// # Examples
404 ///
405 /// ```
406 /// # use jax_rs::{Array, Shape};
407 /// let a = Array::from_vec(vec![8.0, 27.0, 64.0], Shape::new(vec![3]));
408 /// let b = a.cbrt();
409 /// assert_eq!(b.to_vec(), vec![2.0, 3.0, 4.0]);
410 /// ```
411 pub fn cbrt(&self) -> Array {
412 unary_op(self, Primitive::Pow, |x| x.cbrt())
413 }
414
415 /// Compute the inverse sine (arcsine) element-wise.
416 ///
417 /// Returns values in the range [-π/2, π/2].
418 ///
419 /// # Examples
420 ///
421 /// ```
422 /// # use jax_rs::{Array, Shape};
423 /// let a = Array::from_vec(vec![0.0, 0.5, 1.0], Shape::new(vec![3]));
424 /// let b = a.arcsin();
425 /// // Result: [0.0, ~0.524, ~1.571] (radians)
426 /// ```
427 pub fn arcsin(&self) -> Array {
428 unary_op(self, Primitive::Sin, |x| x.asin())
429 }
430
431 /// Compute the inverse cosine (arccosine) element-wise.
432 ///
433 /// Returns values in the range [0, π].
434 ///
435 /// # Examples
436 ///
437 /// ```
438 /// # use jax_rs::{Array, Shape};
439 /// let a = Array::from_vec(vec![1.0, 0.5, 0.0], Shape::new(vec![3]));
440 /// let b = a.arccos();
441 /// // Result: [0.0, ~1.047, ~1.571] (radians)
442 /// ```
443 pub fn arccos(&self) -> Array {
444 unary_op(self, Primitive::Cos, |x| x.acos())
445 }
446
447 /// Compute the inverse tangent (arctangent) element-wise.
448 ///
449 /// Returns values in the range [-π/2, π/2].
450 ///
451 /// # Examples
452 ///
453 /// ```
454 /// # use jax_rs::{Array, Shape};
455 /// let a = Array::from_vec(vec![0.0, 1.0, -1.0], Shape::new(vec![3]));
456 /// let b = a.arctan();
457 /// // Result: [0.0, ~0.785, ~-0.785] (radians)
458 /// ```
459 pub fn arctan(&self) -> Array {
460 unary_op(self, Primitive::Tan, |x| x.atan())
461 }
462
463 /// Compute the inverse hyperbolic sine element-wise.
464 ///
465 /// # Examples
466 ///
467 /// ```
468 /// # use jax_rs::{Array, Shape};
469 /// let a = Array::from_vec(vec![0.0, 1.0, 2.0], Shape::new(vec![3]));
470 /// let b = a.arcsinh();
471 /// // Result: [0.0, ~0.881, ~1.444]
472 /// ```
473 pub fn arcsinh(&self) -> Array {
474 unary_op(self, Primitive::Sin, |x| x.asinh())
475 }
476
477 /// Compute the inverse hyperbolic cosine element-wise.
478 ///
479 /// # Examples
480 ///
481 /// ```
482 /// # use jax_rs::{Array, Shape};
483 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
484 /// let b = a.arccosh();
485 /// // Result: [0.0, ~1.317, ~1.763]
486 /// ```
487 pub fn arccosh(&self) -> Array {
488 unary_op(self, Primitive::Cos, |x| x.acosh())
489 }
490
491 /// Compute the inverse hyperbolic tangent element-wise.
492 ///
493 /// # Examples
494 ///
495 /// ```
496 /// # use jax_rs::{Array, Shape};
497 /// let a = Array::from_vec(vec![0.0, 0.5, -0.5], Shape::new(vec![3]));
498 /// let b = a.arctanh();
499 /// // Result: [0.0, ~0.549, ~-0.549]
500 /// ```
501 pub fn arctanh(&self) -> Array {
502 unary_op(self, Primitive::Tan, |x| x.atanh())
503 }
504
505 /// Compute the base-10 logarithm element-wise.
506 ///
507 /// # Examples
508 ///
509 /// ```
510 /// # use jax_rs::{Array, Shape};
511 /// let a = Array::from_vec(vec![1.0, 10.0, 100.0, 1000.0], Shape::new(vec![4]));
512 /// let b = a.log10();
513 /// assert_eq!(b.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
514 /// ```
515 pub fn log10(&self) -> Array {
516 unary_op(self, Primitive::Log, |x| x.log10())
517 }
518
519 /// Compute the base-2 logarithm element-wise.
520 ///
521 /// # Examples
522 ///
523 /// ```
524 /// # use jax_rs::{Array, Shape};
525 /// let a = Array::from_vec(vec![1.0, 2.0, 4.0, 8.0], Shape::new(vec![4]));
526 /// let b = a.log2();
527 /// assert_eq!(b.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
528 /// ```
529 pub fn log2(&self) -> Array {
530 unary_op(self, Primitive::Log, |x| x.log2())
531 }
532
533 /// Round to n decimal places.
534 ///
535 /// # Examples
536 ///
537 /// ```
538 /// # use jax_rs::{Array, Shape};
539 /// let a = Array::from_vec(vec![1.234, 5.678, 9.012], Shape::new(vec![3]));
540 /// let b = a.around(1);
541 /// // Result: [1.2, 5.7, 9.0]
542 /// ```
543 pub fn around(&self, decimals: i32) -> Array {
544 let factor = 10_f32.powi(decimals);
545 unary_op(self, Primitive::Mul, move |x| (x * factor).round() / factor)
546 }
547
548 /// Round toward zero (truncate decimal part).
549 ///
550 /// # Examples
551 ///
552 /// ```
553 /// # use jax_rs::{Array, Shape};
554 /// let a = Array::from_vec(vec![1.7, -2.3, 3.9], Shape::new(vec![3]));
555 /// let b = a.fix();
556 /// assert_eq!(b.to_vec(), vec![1.0, -2.0, 3.0]);
557 /// ```
558 pub fn fix(&self) -> Array {
559 unary_op(self, Primitive::Abs, |x| x.trunc())
560 }
561
562 /// Check if sign bit is set (negative number).
563 ///
564 /// Returns 1.0 for negative numbers, 0.0 for positive.
565 ///
566 /// # Examples
567 ///
568 /// ```
569 /// # use jax_rs::{Array, Shape};
570 /// let a = Array::from_vec(vec![1.0, -2.0, 0.0, -0.0], Shape::new(vec![4]));
571 /// let b = a.signbit();
572 /// // Result: [0.0, 1.0, 0.0, 1.0]
573 /// ```
574 pub fn signbit(&self) -> Array {
575 unary_op(self, Primitive::Sign, |x| if x.is_sign_negative() { 1.0 } else { 0.0 })
576 }
577
578 /// Unary positive (identity operation).
579 ///
580 /// # Examples
581 ///
582 /// ```
583 /// # use jax_rs::{Array, Shape};
584 /// let a = Array::from_vec(vec![1.0, -2.0, 3.0], Shape::new(vec![3]));
585 /// let b = a.positive();
586 /// assert_eq!(b.to_vec(), vec![1.0, -2.0, 3.0]);
587 /// ```
588 pub fn positive(&self) -> Array {
589 self.clone()
590 }
591
592 /// Unary negative (same as neg).
593 ///
594 /// # Examples
595 ///
596 /// ```
597 /// # use jax_rs::{Array, Shape};
598 /// let a = Array::from_vec(vec![1.0, -2.0, 3.0], Shape::new(vec![3]));
599 /// let b = a.negative();
600 /// assert_eq!(b.to_vec(), vec![-1.0, 2.0, -3.0]);
601 /// ```
602 pub fn negative(&self) -> Array {
603 self.neg()
604 }
605
606 /// Inverse (1/x) with safe handling of zeros.
607 ///
608 /// Returns infinity for zero values instead of panicking.
609 ///
610 /// # Examples
611 ///
612 /// ```
613 /// # use jax_rs::{Array, Shape};
614 /// let a = Array::from_vec(vec![1.0, 2.0, 4.0, 0.5], Shape::new(vec![4]));
615 /// let b = a.invert();
616 /// assert_eq!(b.to_vec(), vec![1.0, 0.5, 0.25, 2.0]);
617 /// ```
618 pub fn invert(&self) -> Array {
619 self.reciprocal()
620 }
621
622 /// Convert angles from radians to degrees (alias).
623 ///
624 /// # Examples
625 ///
626 /// ```
627 /// # use jax_rs::{Array, Shape};
628 /// let a = Array::from_vec(vec![0.0, std::f32::consts::PI, std::f32::consts::PI * 2.0], Shape::new(vec![3]));
629 /// let b = a.degrees();
630 /// // Result: [0.0, 180.0, 360.0]
631 /// ```
632 pub fn degrees(&self) -> Array {
633 self.rad2deg()
634 }
635
636 /// Convert angles from degrees to radians (alias).
637 ///
638 /// # Examples
639 ///
640 /// ```
641 /// # use jax_rs::{Array, Shape};
642 /// let a = Array::from_vec(vec![0.0, 180.0, 360.0], Shape::new(vec![3]));
643 /// let b = a.radians();
644 /// // Result: [0.0, π, 2π]
645 /// ```
646 pub fn radians(&self) -> Array {
647 self.deg2rad()
648 }
649
650 /// Return the spacing to the next representable float.
651 ///
652 /// For simplicity, returns a constant small value.
653 ///
654 /// # Examples
655 ///
656 /// ```
657 /// # use jax_rs::{Array, Shape};
658 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
659 /// let b = a.spacing();
660 /// // Returns small epsilon values
661 /// ```
662 pub fn spacing(&self) -> Array {
663 unary_op(self, Primitive::Abs, |x| {
664 let next = f32::from_bits(x.to_bits() + 1);
665 (next - x).abs()
666 })
667 }
668
669 /// Return a copy of the array (alias for clone).
670 ///
671 /// # Examples
672 ///
673 /// ```
674 /// # use jax_rs::{Array, Shape};
675 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
676 /// let b = a.copy();
677 /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
678 /// ```
679 pub fn copy(&self) -> Array {
680 self.clone()
681 }
682
683 /// Return element-wise natural logarithm (alias for log).
684 ///
685 /// # Examples
686 ///
687 /// ```
688 /// # use jax_rs::{Array, Shape};
689 /// let a = Array::from_vec(vec![1.0, std::f32::consts::E, std::f32::consts::E * std::f32::consts::E], Shape::new(vec![3]));
690 /// let b = a.ln();
691 /// // Result: [0.0, 1.0, 2.0]
692 /// ```
693 pub fn ln(&self) -> Array {
694 self.log()
695 }
696
697 /// Return element-wise maximum with zero.
698 ///
699 /// # Examples
700 ///
701 /// ```
702 /// # use jax_rs::{Array, Shape};
703 /// let a = Array::from_vec(vec![-1.0, 0.0, 1.0, 2.0], Shape::new(vec![4]));
704 /// let b = a.clip_min(0.0);
705 /// assert_eq!(b.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
706 /// ```
707 pub fn clip_min(&self, min: f32) -> Array {
708 unary_op(self, Primitive::Max, |x| x.max(min))
709 }
710
711 /// Return element-wise minimum with a maximum bound.
712 ///
713 /// # Examples
714 ///
715 /// ```
716 /// # use jax_rs::{Array, Shape};
717 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
718 /// let b = a.clip_max(2.5);
719 /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 2.5, 2.5]);
720 /// ```
721 pub fn clip_max(&self, max: f32) -> Array {
722 unary_op(self, Primitive::Min, |x| x.min(max))
723 }
724
725 /// Return the conjugate of the array (identity for real numbers).
726 ///
727 /// # Examples
728 ///
729 /// ```
730 /// # use jax_rs::{Array, Shape};
731 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
732 /// let b = a.conj();
733 /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
734 /// ```
735 pub fn conj(&self) -> Array {
736 self.clone()
737 }
738
739 /// Return the conjugate (alias for conj).
740 ///
741 /// # Examples
742 ///
743 /// ```
744 /// # use jax_rs::{Array, Shape};
745 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
746 /// let b = a.conjugate();
747 /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
748 /// ```
749 pub fn conjugate(&self) -> Array {
750 self.clone()
751 }
752
753 /// Return the angle of complex numbers (phase).
754 /// For real numbers, returns 0 for positive, PI for negative.
755 ///
756 /// # Examples
757 ///
758 /// ```
759 /// # use jax_rs::{Array, Shape};
760 /// let a = Array::from_vec(vec![1.0, -1.0, 0.0], Shape::new(vec![3]));
761 /// let angles = a.angle();
762 /// // Positive: 0, Negative: PI, Zero: 0
763 /// ```
764 pub fn angle(&self) -> Array {
765 unary_op(self, Primitive::Sign, |x| {
766 if x > 0.0 {
767 0.0
768 } else if x < 0.0 {
769 std::f32::consts::PI
770 } else {
771 0.0
772 }
773 })
774 }
775
776 /// Return the real part of complex numbers.
777 /// For real arrays, this is the identity function.
778 ///
779 /// # Examples
780 ///
781 /// ```
782 /// # use jax_rs::{Array, Shape};
783 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
784 /// let r = a.real();
785 /// assert_eq!(r.to_vec(), vec![1.0, 2.0, 3.0]);
786 /// ```
787 pub fn real(&self) -> Array {
788 self.clone()
789 }
790
791 /// Return the imaginary part of complex numbers.
792 /// For real arrays, returns zeros.
793 ///
794 /// # Examples
795 ///
796 /// ```
797 /// # use jax_rs::{Array, Shape};
798 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
799 /// let im = a.imag();
800 /// assert_eq!(im.to_vec(), vec![0.0, 0.0, 0.0]);
801 /// ```
802 pub fn imag(&self) -> Array {
803 Array::zeros(self.shape().clone(), DType::Float32)
804 }
805
806 /// Bitwise NOT operation.
807 /// Inverts all bits in the bit representation of Float32 values.
808 ///
809 /// # Examples
810 ///
811 /// ```
812 /// # use jax_rs::{Array, Shape};
813 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
814 /// let b = a.bitwise_not();
815 /// ```
816 pub fn bitwise_not(&self) -> Array {
817 unary_op(self, Primitive::Neg, |x| {
818 let bits = x.to_bits();
819 f32::from_bits(!bits)
820 })
821 }
822
823 /// Return the reciprocal of the square root (1/sqrt(x)).
824 ///
825 /// # Examples
826 ///
827 /// ```
828 /// # use jax_rs::{Array, Shape};
829 /// let a = Array::from_vec(vec![1.0, 4.0, 9.0], Shape::new(vec![3]));
830 /// let b = a.rsqrt();
831 /// assert!((b.to_vec()[0] - 1.0).abs() < 1e-6);
832 /// assert!((b.to_vec()[1] - 0.5).abs() < 1e-6);
833 /// ```
834 pub fn rsqrt(&self) -> Array {
835 unary_op(self, Primitive::Sqrt, |x| 1.0 / x.sqrt())
836 }
837
838 /// Return the fractional and integer parts of an array element-wise.
839 /// Returns a tuple of (fractional_part, integer_part).
840 ///
841 /// # Examples
842 ///
843 /// ```
844 /// # use jax_rs::{Array, Shape};
845 /// let a = Array::from_vec(vec![1.5, 2.7, -3.2], Shape::new(vec![3]));
846 /// let (frac, int) = a.modf();
847 /// assert!((frac.to_vec()[0] - 0.5).abs() < 1e-6);
848 /// assert!((int.to_vec()[0] - 1.0).abs() < 1e-6);
849 /// ```
850 pub fn modf(&self) -> (Array, Array) {
851 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
852
853 let data = self.to_vec();
854 let frac_data: Vec<f32> = data.iter().map(|&x| x.fract()).collect();
855 let int_data: Vec<f32> = data.iter().map(|&x| x.trunc()).collect();
856
857 let frac = Array::from_vec(frac_data, self.shape().clone());
858 let int = Array::from_vec(int_data, self.shape().clone());
859
860 (frac, int)
861 }
862
863 /// Compute x * 2^exp for each element.
864 /// Equivalent to ldexp function from C math library.
865 ///
866 /// # Examples
867 ///
868 /// ```
869 /// # use jax_rs::{Array, Shape};
870 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
871 /// let b = a.ldexp(2);
872 /// assert_eq!(b.to_vec(), vec![4.0, 8.0, 12.0]); // multiply by 2^2 = 4
873 /// ```
874 pub fn ldexp(&self, exp: i32) -> Array {
875 let multiplier = 2_f32.powi(exp);
876 unary_op(self, Primitive::Mul, move |x| x * multiplier)
877 }
878
879 /// Decompose x into mantissa and exponent: x = m * 2^e.
880 /// Returns (mantissa, exponent) where mantissa is in [0.5, 1.0).
881 ///
882 /// # Examples
883 ///
884 /// ```
885 /// # use jax_rs::{Array, Shape};
886 /// let a = Array::from_vec(vec![4.0, 8.0, 0.5], Shape::new(vec![3]));
887 /// let (mantissa, exp) = a.frexp();
888 /// // 4.0 = 0.5 * 2^3, 8.0 = 0.5 * 2^4, 0.5 = 0.5 * 2^0
889 /// ```
890 pub fn frexp(&self) -> (Array, Array) {
891 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
892
893 let data = self.to_vec();
894 let mut mantissa_data = Vec::with_capacity(data.len());
895 let mut exp_data = Vec::with_capacity(data.len());
896
897 for &x in &data {
898 if x == 0.0 {
899 mantissa_data.push(0.0);
900 exp_data.push(0.0);
901 } else {
902 let bits = x.to_bits();
903 let sign = (bits >> 31) & 1;
904 let exponent = ((bits >> 23) & 0xFF) as i32 - 126;
905 // Create mantissa in [0.5, 1.0)
906 let mantissa_bits = (sign << 31) | (126 << 23) | (bits & 0x7FFFFF);
907 let mantissa = f32::from_bits(mantissa_bits);
908 mantissa_data.push(mantissa);
909 exp_data.push(exponent as f32);
910 }
911 }
912
913 let mantissa = Array::from_vec(mantissa_data, self.shape().clone());
914 let exp = Array::from_vec(exp_data, self.shape().clone());
915
916 (mantissa, exp)
917 }
918
919 /// Divide arrays element-wise with safe handling of division by zero.
920 /// Returns 0 when dividing by zero instead of NaN/Inf.
921 ///
922 /// # Examples
923 ///
924 /// ```
925 /// # use jax_rs::{Array, Shape};
926 /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
927 /// let b = a.safe_divide_scalar(2.0);
928 /// assert_eq!(b.to_vec(), vec![0.5, 1.0, 1.5]);
929 /// let c = a.safe_divide_scalar(0.0);
930 /// assert_eq!(c.to_vec(), vec![0.0, 0.0, 0.0]); // Returns 0 instead of Inf
931 /// ```
932 pub fn safe_divide_scalar(&self, divisor: f32) -> Array {
933 if divisor == 0.0 {
934 Array::zeros(self.shape().clone(), DType::Float32)
935 } else {
936 unary_op(self, Primitive::Reciprocal, move |x| x / divisor)
937 }
938 }
939
940 /// Compute the modified Bessel function of the first kind, order 0.
941 /// Approximation using polynomial expansion.
942 ///
943 /// # Examples
944 ///
945 /// ```
946 /// # use jax_rs::{Array, Shape};
947 /// let a = Array::from_vec(vec![0.0, 1.0, 2.0], Shape::new(vec![3]));
948 /// let b = a.i0();
949 /// assert!((b.to_vec()[0] - 1.0).abs() < 1e-4); // i0(0) = 1
950 /// ```
951 pub fn i0(&self) -> Array {
952 unary_op(self, Primitive::Exp, |x| {
953 // Polynomial approximation for I0
954 let ax = x.abs();
955 if ax < 3.75 {
956 let y = (x / 3.75).powi(2);
957 1.0 + y * (3.5156229 + y * (3.0899424 + y * (1.2067492 + y * (0.2659732 + y * (0.0360768 + y * 0.0045813)))))
958 } else {
959 let y = 3.75 / ax;
960 (ax.exp() / ax.sqrt()) * (0.398_942_3 + y * (0.01328592 + y * (0.00225319 + y * (-0.00157565 + y * (0.00916281 + y * (-0.02057706 + y * (0.02635537 + y * (-0.01647633 + y * 0.00392377))))))))
961 }
962 })
963 }
964
965 /// Compute the natural logarithm of the absolute value of the gamma function.
966 ///
967 /// # Examples
968 ///
969 /// ```
970 /// # use jax_rs::{Array, Shape};
971 /// let a = Array::from_vec(vec![1.0, 2.0, 5.0], Shape::new(vec![3]));
972 /// let b = a.lgamma();
973 /// assert!((b.to_vec()[0]).abs() < 1e-6); // lgamma(1) = 0
974 /// assert!((b.to_vec()[1]).abs() < 1e-6); // lgamma(2) = 0
975 /// ```
976 pub fn lgamma(&self) -> Array {
977 unary_op(self, Primitive::Log, |x| {
978 // Stirling's approximation for larger values
979 if x <= 0.0 {
980 f32::INFINITY
981 } else if x < 7.0 {
982 // Use recurrence relation for small values
983 let n = (7.0 - x).ceil() as i32;
984 let mut y = x;
985 let mut prod = 1.0;
986 for _ in 0..n {
987 prod *= y;
988 y += 1.0;
989 }
990 lgamma_impl(y) - prod.ln()
991 } else {
992 lgamma_impl(x)
993 }
994 })
995 }
996
997}
998
999#[cfg(test)]
1000mod tests {
1001 use super::*;
1002 use approx::assert_abs_diff_eq;
1003
1004 #[test]
1005 fn test_neg() {
1006 let a = Array::from_vec(vec![1.0, -2.0, 3.0], Shape::new(vec![3]));
1007 let b = a.neg();
1008 assert_eq!(b.to_vec(), vec![-1.0, 2.0, -3.0]);
1009 }
1010
1011 #[test]
1012 fn test_abs() {
1013 let a =
1014 Array::from_vec(vec![1.0, -2.0, 3.0, -4.0], Shape::new(vec![4]));
1015 let b = a.abs();
1016 assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
1017 }
1018
1019 #[test]
1020 fn test_sin_cos() {
1021 let a = Array::from_vec(
1022 vec![0.0, std::f32::consts::PI / 2.0],
1023 Shape::new(vec![2]),
1024 );
1025 let sin_a = a.sin();
1026 let cos_a = a.cos();
1027
1028 assert_abs_diff_eq!(sin_a.to_vec()[0], 0.0, epsilon = 1e-6);
1029 assert_abs_diff_eq!(sin_a.to_vec()[1], 1.0, epsilon = 1e-6);
1030 assert_abs_diff_eq!(cos_a.to_vec()[0], 1.0, epsilon = 1e-6);
1031 assert_abs_diff_eq!(cos_a.to_vec()[1], 0.0, epsilon = 1e-6);
1032 }
1033
1034 #[test]
1035 fn test_exp_log() {
1036 let a = Array::from_vec(vec![0.0, 1.0, 2.0], Shape::new(vec![3]));
1037 let exp_a = a.exp();
1038 let log_exp_a = exp_a.log();
1039
1040 assert_abs_diff_eq!(exp_a.to_vec()[0], 1.0, epsilon = 1e-6);
1041 assert_abs_diff_eq!(
1042 exp_a.to_vec()[1],
1043 std::f32::consts::E,
1044 epsilon = 1e-6
1045 );
1046
1047 // log(exp(x)) should equal x
1048 assert_abs_diff_eq!(log_exp_a.to_vec()[0], 0.0, epsilon = 1e-5);
1049 assert_abs_diff_eq!(log_exp_a.to_vec()[1], 1.0, epsilon = 1e-5);
1050 assert_abs_diff_eq!(log_exp_a.to_vec()[2], 2.0, epsilon = 1e-5);
1051 }
1052
1053 #[test]
1054 fn test_sqrt() {
1055 let a = Array::from_vec(vec![0.0, 1.0, 4.0, 9.0], Shape::new(vec![4]));
1056 let b = a.sqrt();
1057 assert_eq!(b.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
1058 }
1059
1060 #[test]
1061 fn test_tanh() {
1062 let a = Array::from_vec(vec![0.0, 1.0], Shape::new(vec![2]));
1063 let b = a.tanh();
1064 assert_abs_diff_eq!(b.to_vec()[0], 0.0, epsilon = 1e-6);
1065 assert_abs_diff_eq!(b.to_vec()[1], 0.761_594_2, epsilon = 1e-6);
1066 }
1067
1068 #[test]
1069 fn test_reciprocal() {
1070 let a = Array::from_vec(vec![1.0, 2.0, 4.0], Shape::new(vec![3]));
1071 let b = a.reciprocal();
1072 assert_abs_diff_eq!(b.to_vec()[0], 1.0, epsilon = 1e-6);
1073 assert_abs_diff_eq!(b.to_vec()[1], 0.5, epsilon = 1e-6);
1074 assert_abs_diff_eq!(b.to_vec()[2], 0.25, epsilon = 1e-6);
1075 }
1076
1077 #[test]
1078 fn test_reciprocal_no_nan() {
1079 let a = Array::from_vec(vec![2.0, 0.0, 4.0], Shape::new(vec![3]));
1080 let b = a.reciprocal_no_nan();
1081 assert_eq!(b.to_vec(), vec![0.5, 0.0, 0.25]);
1082 }
1083
1084 #[test]
1085 fn test_square() {
1086 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1087 let b = a.square();
1088 assert_eq!(b.to_vec(), vec![1.0, 4.0, 9.0]);
1089 }
1090
1091 #[test]
1092 fn test_sign() {
1093 let a =
1094 Array::from_vec(vec![-2.0, -0.0, 0.0, 3.0], Shape::new(vec![4]));
1095 let b = a.sign();
1096 assert_eq!(b.to_vec(), vec![-1.0, 0.0, 0.0, 1.0]);
1097 }
1098}