jax_rs/ops/
unary.rs

1//! Unary operations on arrays.
2
3use crate::trace::{is_tracing, trace_unary, Primitive};
4use crate::{buffer::Buffer, Array, DType, Device};
5
6#[cfg(test)]
7use crate::Shape;
8
9/// Stirling's approximation for lgamma for x >= 7.
10fn lgamma_impl(x: f32) -> f32 {
11    let x64 = x as f64;
12    let c = [
13        76.18009172947146,
14        -86.50532032941677,
15        24.01409824083091,
16        -1.231739572450155,
17        0.1208650973866179e-2,
18        -0.5395239384953e-5,
19    ];
20    let tmp = x64 + 5.5;
21    let tmp = tmp - (x64 + 0.5) * tmp.ln();
22    let mut ser = 1.000000000190015;
23    for (i, &cval) in c.iter().enumerate() {
24        ser += cval / (x64 + (i + 1) as f64);
25    }
26    (-tmp + (2.5066282746310005 * ser / x64).ln()) as f32
27}
28
29/// Apply a unary function element-wise to an array.
30fn unary_op<F>(input: &Array, op: Primitive, f: F) -> Array
31where
32    F: Fn(f32) -> f32,
33{
34    assert_eq!(input.dtype(), DType::Float32, "Only Float32 supported");
35    assert_eq!(input.device(), Device::Cpu, "Only CPU supported for now");
36
37    let data = input.to_vec();
38    let result_data: Vec<f32> = data.iter().map(|&x| f(x)).collect();
39    let buffer = Buffer::from_f32(result_data, Device::Cpu);
40
41    let result = Array::from_buffer(buffer, input.shape().clone());
42
43    // Register with trace context if tracing is active
44    if is_tracing() {
45        trace_unary(result.id(), op, input);
46    }
47
48    result
49}
50
51impl Array {
52    /// Negate the array element-wise.
53    ///
54    /// # Examples
55    ///
56    /// ```
57    /// # use jax_rs::{Array, Shape};
58    /// let a = Array::from_vec(vec![1.0, -2.0, 3.0], Shape::new(vec![3]));
59    /// let b = a.neg();
60    /// assert_eq!(b.to_vec(), vec![-1.0, 2.0, -3.0]);
61    /// ```
62    pub fn neg(&self) -> Array {
63        unary_op(self, Primitive::Neg, |x| -x)
64    }
65
66    /// Absolute value element-wise.
67    pub fn abs(&self) -> Array {
68        unary_op(self, Primitive::Abs, |x| x.abs())
69    }
70
71    /// Sine element-wise.
72    pub fn sin(&self) -> Array {
73        unary_op(self, Primitive::Sin, |x| x.sin())
74    }
75
76    /// Cosine element-wise.
77    pub fn cos(&self) -> Array {
78        unary_op(self, Primitive::Cos, |x| x.cos())
79    }
80
81    /// Tangent element-wise.
82    pub fn tan(&self) -> Array {
83        unary_op(self, Primitive::Tan, |x| x.tan())
84    }
85
86    /// Hyperbolic tangent element-wise.
87    pub fn tanh(&self) -> Array {
88        unary_op(self, Primitive::Tanh, |x| x.tanh())
89    }
90
91    /// Natural exponential (e^x) element-wise.
92    pub fn exp(&self) -> Array {
93        unary_op(self, Primitive::Exp, |x| x.exp())
94    }
95
96    /// Natural logarithm element-wise.
97    pub fn log(&self) -> Array {
98        unary_op(self, Primitive::Log, |x| x.ln())
99    }
100
101    /// Square root element-wise.
102    pub fn sqrt(&self) -> Array {
103        unary_op(self, Primitive::Sqrt, |x| x.sqrt())
104    }
105
106    /// Reciprocal (1/x) element-wise.
107    pub fn reciprocal(&self) -> Array {
108        unary_op(self, Primitive::Reciprocal, |x| 1.0 / x)
109    }
110
111    /// Square (x^2) element-wise.
112    pub fn square(&self) -> Array {
113        unary_op(self, Primitive::Square, |x| x * x)
114    }
115
116    /// Sign function element-wise (-1, 0, or 1).
117    pub fn sign(&self) -> Array {
118        unary_op(self, Primitive::Sign, |x| {
119            if x > 0.0 {
120                1.0
121            } else if x < 0.0 {
122                -1.0
123            } else {
124                0.0
125            }
126        })
127    }
128
129    /// Hyperbolic sine element-wise.
130    ///
131    /// # Examples
132    ///
133    /// ```
134    /// # use jax_rs::{Array, Shape};
135    /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
136    /// let b = a.sinh();
137    /// assert_eq!(b.to_vec()[0], 0.0);
138    /// ```
139    pub fn sinh(&self) -> Array {
140        unary_op(self, Primitive::Sin, |x| x.sinh())
141    }
142
143    /// Hyperbolic cosine element-wise.
144    ///
145    /// # Examples
146    ///
147    /// ```
148    /// # use jax_rs::{Array, Shape};
149    /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
150    /// let b = a.cosh();
151    /// assert_eq!(b.to_vec()[0], 1.0);
152    /// ```
153    pub fn cosh(&self) -> Array {
154        unary_op(self, Primitive::Cos, |x| x.cosh())
155    }
156
157    /// Arcsine element-wise.
158    ///
159    /// # Examples
160    ///
161    /// ```
162    /// # use jax_rs::{Array, Shape};
163    /// let a = Array::from_vec(vec![0.0, 1.0], Shape::new(vec![2]));
164    /// let b = a.asin();
165    /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
166    /// assert!((b.to_vec()[1] - std::f32::consts::FRAC_PI_2).abs() < 1e-6);
167    /// ```
168    pub fn asin(&self) -> Array {
169        unary_op(self, Primitive::Sin, |x| x.asin())
170    }
171
172    /// Arccosine element-wise.
173    ///
174    /// # Examples
175    ///
176    /// ```
177    /// # use jax_rs::{Array, Shape};
178    /// let a = Array::from_vec(vec![1.0], Shape::new(vec![1]));
179    /// let b = a.acos();
180    /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
181    /// ```
182    pub fn acos(&self) -> Array {
183        unary_op(self, Primitive::Cos, |x| x.acos())
184    }
185
186    /// Arctangent element-wise.
187    ///
188    /// # Examples
189    ///
190    /// ```
191    /// # use jax_rs::{Array, Shape};
192    /// let a = Array::from_vec(vec![0.0, 1.0], Shape::new(vec![2]));
193    /// let b = a.atan();
194    /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
195    /// assert!((b.to_vec()[1] - std::f32::consts::FRAC_PI_4).abs() < 1e-6);
196    /// ```
197    pub fn atan(&self) -> Array {
198        unary_op(self, Primitive::Tan, |x| x.atan())
199    }
200
201    /// Inverse hyperbolic sine element-wise.
202    ///
203    /// # Examples
204    ///
205    /// ```
206    /// # use jax_rs::{Array, Shape};
207    /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
208    /// let b = a.asinh();
209    /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
210    /// ```
211    pub fn asinh(&self) -> Array {
212        unary_op(self, Primitive::Sin, |x| x.asinh())
213    }
214
215    /// Inverse hyperbolic cosine element-wise.
216    ///
217    /// # Examples
218    ///
219    /// ```
220    /// # use jax_rs::{Array, Shape};
221    /// let a = Array::from_vec(vec![1.0], Shape::new(vec![1]));
222    /// let b = a.acosh();
223    /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
224    /// ```
225    pub fn acosh(&self) -> Array {
226        unary_op(self, Primitive::Cos, |x| x.acosh())
227    }
228
229    /// Inverse hyperbolic tangent element-wise.
230    ///
231    /// # Examples
232    ///
233    /// ```
234    /// # use jax_rs::{Array, Shape};
235    /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
236    /// let b = a.atanh();
237    /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
238    /// ```
239    pub fn atanh(&self) -> Array {
240        unary_op(self, Primitive::Tanh, |x| x.atanh())
241    }
242
243    /// Ceiling function element-wise.
244    ///
245    /// # Examples
246    ///
247    /// ```
248    /// # use jax_rs::{Array, Shape};
249    /// let a = Array::from_vec(vec![1.2, 2.7, -0.5], Shape::new(vec![3]));
250    /// let b = a.ceil();
251    /// assert_eq!(b.to_vec(), vec![2.0, 3.0, 0.0]);
252    /// ```
253    pub fn ceil(&self) -> Array {
254        unary_op(self, Primitive::Sign, |x| x.ceil())
255    }
256
257    /// Floor function element-wise.
258    ///
259    /// # Examples
260    ///
261    /// ```
262    /// # use jax_rs::{Array, Shape};
263    /// let a = Array::from_vec(vec![1.2, 2.7, -0.5], Shape::new(vec![3]));
264    /// let b = a.floor();
265    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, -1.0]);
266    /// ```
267    pub fn floor(&self) -> Array {
268        unary_op(self, Primitive::Sign, |x| x.floor())
269    }
270
271    /// Round to nearest integer element-wise.
272    ///
273    /// # Examples
274    ///
275    /// ```
276    /// # use jax_rs::{Array, Shape};
277    /// let a = Array::from_vec(vec![1.2, 2.7, -0.5], Shape::new(vec![3]));
278    /// let b = a.round();
279    /// assert_eq!(b.to_vec(), vec![1.0, 3.0, -1.0]);
280    /// ```
281    pub fn round(&self) -> Array {
282        unary_op(self, Primitive::Sign, |x| x.round())
283    }
284
285    /// Truncate to integer element-wise (round toward zero).
286    ///
287    /// # Examples
288    ///
289    /// ```
290    /// # use jax_rs::{Array, Shape};
291    /// let a = Array::from_vec(vec![1.7, 2.3, -1.7], Shape::new(vec![3]));
292    /// let b = a.trunc();
293    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, -1.0]);
294    /// ```
295    pub fn trunc(&self) -> Array {
296        unary_op(self, Primitive::Sign, |x| x.trunc())
297    }
298
299    /// Exponential minus 1 (e^x - 1) element-wise.
300    ///
301    /// More accurate than exp(x) - 1 for small values of x.
302    ///
303    /// # Examples
304    ///
305    /// ```
306    /// # use jax_rs::{Array, Shape};
307    /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
308    /// let b = a.expm1();
309    /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
310    /// ```
311    pub fn expm1(&self) -> Array {
312        unary_op(self, Primitive::Exp, |x| x.exp_m1())
313    }
314
315    /// Natural logarithm of 1 + x element-wise.
316    ///
317    /// More accurate than log(1 + x) for small values of x.
318    ///
319    /// # Examples
320    ///
321    /// ```
322    /// # use jax_rs::{Array, Shape};
323    /// let a = Array::from_vec(vec![0.0], Shape::new(vec![1]));
324    /// let b = a.log1p();
325    /// assert!((b.to_vec()[0] - 0.0).abs() < 1e-6);
326    /// ```
327    pub fn log1p(&self) -> Array {
328        unary_op(self, Primitive::Log, |x| x.ln_1p())
329    }
330
331    /// Safe reciprocal that returns 0 where x == 0.
332    ///
333    /// Returns 1/x where x != 0, and 0 where x == 0.
334    ///
335    /// # Examples
336    ///
337    /// ```
338    /// # use jax_rs::{Array, Shape};
339    /// let a = Array::from_vec(vec![2.0, 0.0, 4.0], Shape::new(vec![3]));
340    /// let b = a.reciprocal_no_nan();
341    /// assert_eq!(b.to_vec(), vec![0.5, 0.0, 0.25]);
342    /// ```
343    pub fn reciprocal_no_nan(&self) -> Array {
344        unary_op(self, Primitive::Reciprocal, |x| {
345            if x == 0.0 {
346                0.0
347            } else {
348                1.0 / x
349            }
350        })
351    }
352
353    /// Convert degrees to radians.
354    ///
355    /// # Examples
356    ///
357    /// ```
358    /// # use jax_rs::{Array, Shape};
359    /// let degrees = Array::from_vec(vec![0.0, 90.0, 180.0], Shape::new(vec![3]));
360    /// let radians = degrees.deg2rad();
361    /// assert!((radians.to_vec()[1] - std::f32::consts::PI / 2.0).abs() < 1e-5);
362    /// ```
363    pub fn deg2rad(&self) -> Array {
364        unary_op(self, Primitive::Mul, |x| x * std::f32::consts::PI / 180.0)
365    }
366
367    /// Convert radians to degrees.
368    ///
369    /// # Examples
370    ///
371    /// ```
372    /// # use jax_rs::{Array, Shape};
373    /// let radians = Array::from_vec(vec![0.0, std::f32::consts::PI / 2.0, std::f32::consts::PI], Shape::new(vec![3]));
374    /// let degrees = radians.rad2deg();
375    /// assert!((degrees.to_vec()[1] - 90.0).abs() < 1e-5);
376    /// ```
377    pub fn rad2deg(&self) -> Array {
378        unary_op(self, Primitive::Mul, |x| x * 180.0 / std::f32::consts::PI)
379    }
380
381    /// Compute the sinc function: sin(x) / x.
382    ///
383    /// # Examples
384    ///
385    /// ```
386    /// # use jax_rs::{Array, Shape};
387    /// let x = Array::from_vec(vec![0.0, 1.0, 2.0], Shape::new(vec![3]));
388    /// let y = x.sinc();
389    /// assert_eq!(y.to_vec()[0], 1.0); // sinc(0) = 1
390    /// ```
391    pub fn sinc(&self) -> Array {
392        unary_op(self, Primitive::Sin, |x| {
393            if x.abs() < 1e-10 {
394                1.0
395            } else {
396                x.sin() / x
397            }
398        })
399    }
400
401    /// Compute the cube root.
402    ///
403    /// # Examples
404    ///
405    /// ```
406    /// # use jax_rs::{Array, Shape};
407    /// let a = Array::from_vec(vec![8.0, 27.0, 64.0], Shape::new(vec![3]));
408    /// let b = a.cbrt();
409    /// assert_eq!(b.to_vec(), vec![2.0, 3.0, 4.0]);
410    /// ```
411    pub fn cbrt(&self) -> Array {
412        unary_op(self, Primitive::Pow, |x| x.cbrt())
413    }
414
415    /// Compute the inverse sine (arcsine) element-wise.
416    ///
417    /// Returns values in the range [-π/2, π/2].
418    ///
419    /// # Examples
420    ///
421    /// ```
422    /// # use jax_rs::{Array, Shape};
423    /// let a = Array::from_vec(vec![0.0, 0.5, 1.0], Shape::new(vec![3]));
424    /// let b = a.arcsin();
425    /// // Result: [0.0, ~0.524, ~1.571] (radians)
426    /// ```
427    pub fn arcsin(&self) -> Array {
428        unary_op(self, Primitive::Sin, |x| x.asin())
429    }
430
431    /// Compute the inverse cosine (arccosine) element-wise.
432    ///
433    /// Returns values in the range [0, π].
434    ///
435    /// # Examples
436    ///
437    /// ```
438    /// # use jax_rs::{Array, Shape};
439    /// let a = Array::from_vec(vec![1.0, 0.5, 0.0], Shape::new(vec![3]));
440    /// let b = a.arccos();
441    /// // Result: [0.0, ~1.047, ~1.571] (radians)
442    /// ```
443    pub fn arccos(&self) -> Array {
444        unary_op(self, Primitive::Cos, |x| x.acos())
445    }
446
447    /// Compute the inverse tangent (arctangent) element-wise.
448    ///
449    /// Returns values in the range [-π/2, π/2].
450    ///
451    /// # Examples
452    ///
453    /// ```
454    /// # use jax_rs::{Array, Shape};
455    /// let a = Array::from_vec(vec![0.0, 1.0, -1.0], Shape::new(vec![3]));
456    /// let b = a.arctan();
457    /// // Result: [0.0, ~0.785, ~-0.785] (radians)
458    /// ```
459    pub fn arctan(&self) -> Array {
460        unary_op(self, Primitive::Tan, |x| x.atan())
461    }
462
463    /// Compute the inverse hyperbolic sine element-wise.
464    ///
465    /// # Examples
466    ///
467    /// ```
468    /// # use jax_rs::{Array, Shape};
469    /// let a = Array::from_vec(vec![0.0, 1.0, 2.0], Shape::new(vec![3]));
470    /// let b = a.arcsinh();
471    /// // Result: [0.0, ~0.881, ~1.444]
472    /// ```
473    pub fn arcsinh(&self) -> Array {
474        unary_op(self, Primitive::Sin, |x| x.asinh())
475    }
476
477    /// Compute the inverse hyperbolic cosine element-wise.
478    ///
479    /// # Examples
480    ///
481    /// ```
482    /// # use jax_rs::{Array, Shape};
483    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
484    /// let b = a.arccosh();
485    /// // Result: [0.0, ~1.317, ~1.763]
486    /// ```
487    pub fn arccosh(&self) -> Array {
488        unary_op(self, Primitive::Cos, |x| x.acosh())
489    }
490
491    /// Compute the inverse hyperbolic tangent element-wise.
492    ///
493    /// # Examples
494    ///
495    /// ```
496    /// # use jax_rs::{Array, Shape};
497    /// let a = Array::from_vec(vec![0.0, 0.5, -0.5], Shape::new(vec![3]));
498    /// let b = a.arctanh();
499    /// // Result: [0.0, ~0.549, ~-0.549]
500    /// ```
501    pub fn arctanh(&self) -> Array {
502        unary_op(self, Primitive::Tan, |x| x.atanh())
503    }
504
505    /// Compute the base-10 logarithm element-wise.
506    ///
507    /// # Examples
508    ///
509    /// ```
510    /// # use jax_rs::{Array, Shape};
511    /// let a = Array::from_vec(vec![1.0, 10.0, 100.0, 1000.0], Shape::new(vec![4]));
512    /// let b = a.log10();
513    /// assert_eq!(b.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
514    /// ```
515    pub fn log10(&self) -> Array {
516        unary_op(self, Primitive::Log, |x| x.log10())
517    }
518
519    /// Compute the base-2 logarithm element-wise.
520    ///
521    /// # Examples
522    ///
523    /// ```
524    /// # use jax_rs::{Array, Shape};
525    /// let a = Array::from_vec(vec![1.0, 2.0, 4.0, 8.0], Shape::new(vec![4]));
526    /// let b = a.log2();
527    /// assert_eq!(b.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
528    /// ```
529    pub fn log2(&self) -> Array {
530        unary_op(self, Primitive::Log, |x| x.log2())
531    }
532
533    /// Round to n decimal places.
534    ///
535    /// # Examples
536    ///
537    /// ```
538    /// # use jax_rs::{Array, Shape};
539    /// let a = Array::from_vec(vec![1.234, 5.678, 9.012], Shape::new(vec![3]));
540    /// let b = a.around(1);
541    /// // Result: [1.2, 5.7, 9.0]
542    /// ```
543    pub fn around(&self, decimals: i32) -> Array {
544        let factor = 10_f32.powi(decimals);
545        unary_op(self, Primitive::Mul, move |x| (x * factor).round() / factor)
546    }
547
548    /// Round toward zero (truncate decimal part).
549    ///
550    /// # Examples
551    ///
552    /// ```
553    /// # use jax_rs::{Array, Shape};
554    /// let a = Array::from_vec(vec![1.7, -2.3, 3.9], Shape::new(vec![3]));
555    /// let b = a.fix();
556    /// assert_eq!(b.to_vec(), vec![1.0, -2.0, 3.0]);
557    /// ```
558    pub fn fix(&self) -> Array {
559        unary_op(self, Primitive::Abs, |x| x.trunc())
560    }
561
562    /// Check if sign bit is set (negative number).
563    ///
564    /// Returns 1.0 for negative numbers, 0.0 for positive.
565    ///
566    /// # Examples
567    ///
568    /// ```
569    /// # use jax_rs::{Array, Shape};
570    /// let a = Array::from_vec(vec![1.0, -2.0, 0.0, -0.0], Shape::new(vec![4]));
571    /// let b = a.signbit();
572    /// // Result: [0.0, 1.0, 0.0, 1.0]
573    /// ```
574    pub fn signbit(&self) -> Array {
575        unary_op(self, Primitive::Sign, |x| if x.is_sign_negative() { 1.0 } else { 0.0 })
576    }
577
578    /// Unary positive (identity operation).
579    ///
580    /// # Examples
581    ///
582    /// ```
583    /// # use jax_rs::{Array, Shape};
584    /// let a = Array::from_vec(vec![1.0, -2.0, 3.0], Shape::new(vec![3]));
585    /// let b = a.positive();
586    /// assert_eq!(b.to_vec(), vec![1.0, -2.0, 3.0]);
587    /// ```
588    pub fn positive(&self) -> Array {
589        self.clone()
590    }
591
592    /// Unary negative (same as neg).
593    ///
594    /// # Examples
595    ///
596    /// ```
597    /// # use jax_rs::{Array, Shape};
598    /// let a = Array::from_vec(vec![1.0, -2.0, 3.0], Shape::new(vec![3]));
599    /// let b = a.negative();
600    /// assert_eq!(b.to_vec(), vec![-1.0, 2.0, -3.0]);
601    /// ```
602    pub fn negative(&self) -> Array {
603        self.neg()
604    }
605
606    /// Inverse (1/x) with safe handling of zeros.
607    ///
608    /// Returns infinity for zero values instead of panicking.
609    ///
610    /// # Examples
611    ///
612    /// ```
613    /// # use jax_rs::{Array, Shape};
614    /// let a = Array::from_vec(vec![1.0, 2.0, 4.0, 0.5], Shape::new(vec![4]));
615    /// let b = a.invert();
616    /// assert_eq!(b.to_vec(), vec![1.0, 0.5, 0.25, 2.0]);
617    /// ```
618    pub fn invert(&self) -> Array {
619        self.reciprocal()
620    }
621
622    /// Convert angles from radians to degrees (alias).
623    ///
624    /// # Examples
625    ///
626    /// ```
627    /// # use jax_rs::{Array, Shape};
628    /// let a = Array::from_vec(vec![0.0, std::f32::consts::PI, std::f32::consts::PI * 2.0], Shape::new(vec![3]));
629    /// let b = a.degrees();
630    /// // Result: [0.0, 180.0, 360.0]
631    /// ```
632    pub fn degrees(&self) -> Array {
633        self.rad2deg()
634    }
635
636    /// Convert angles from degrees to radians (alias).
637    ///
638    /// # Examples
639    ///
640    /// ```
641    /// # use jax_rs::{Array, Shape};
642    /// let a = Array::from_vec(vec![0.0, 180.0, 360.0], Shape::new(vec![3]));
643    /// let b = a.radians();
644    /// // Result: [0.0, π, 2π]
645    /// ```
646    pub fn radians(&self) -> Array {
647        self.deg2rad()
648    }
649
650    /// Return the spacing to the next representable float.
651    ///
652    /// For simplicity, returns a constant small value.
653    ///
654    /// # Examples
655    ///
656    /// ```
657    /// # use jax_rs::{Array, Shape};
658    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
659    /// let b = a.spacing();
660    /// // Returns small epsilon values
661    /// ```
662    pub fn spacing(&self) -> Array {
663        unary_op(self, Primitive::Abs, |x| {
664            let next = f32::from_bits(x.to_bits() + 1);
665            (next - x).abs()
666        })
667    }
668
669    /// Return a copy of the array (alias for clone).
670    ///
671    /// # Examples
672    ///
673    /// ```
674    /// # use jax_rs::{Array, Shape};
675    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
676    /// let b = a.copy();
677    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
678    /// ```
679    pub fn copy(&self) -> Array {
680        self.clone()
681    }
682
683    /// Return element-wise natural logarithm (alias for log).
684    ///
685    /// # Examples
686    ///
687    /// ```
688    /// # use jax_rs::{Array, Shape};
689    /// let a = Array::from_vec(vec![1.0, std::f32::consts::E, std::f32::consts::E * std::f32::consts::E], Shape::new(vec![3]));
690    /// let b = a.ln();
691    /// // Result: [0.0, 1.0, 2.0]
692    /// ```
693    pub fn ln(&self) -> Array {
694        self.log()
695    }
696
697    /// Return element-wise maximum with zero.
698    ///
699    /// # Examples
700    ///
701    /// ```
702    /// # use jax_rs::{Array, Shape};
703    /// let a = Array::from_vec(vec![-1.0, 0.0, 1.0, 2.0], Shape::new(vec![4]));
704    /// let b = a.clip_min(0.0);
705    /// assert_eq!(b.to_vec(), vec![0.0, 0.0, 1.0, 2.0]);
706    /// ```
707    pub fn clip_min(&self, min: f32) -> Array {
708        unary_op(self, Primitive::Max, |x| x.max(min))
709    }
710
711    /// Return element-wise minimum with a maximum bound.
712    ///
713    /// # Examples
714    ///
715    /// ```
716    /// # use jax_rs::{Array, Shape};
717    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
718    /// let b = a.clip_max(2.5);
719    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 2.5, 2.5]);
720    /// ```
721    pub fn clip_max(&self, max: f32) -> Array {
722        unary_op(self, Primitive::Min, |x| x.min(max))
723    }
724
725    /// Return the conjugate of the array (identity for real numbers).
726    ///
727    /// # Examples
728    ///
729    /// ```
730    /// # use jax_rs::{Array, Shape};
731    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
732    /// let b = a.conj();
733    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
734    /// ```
735    pub fn conj(&self) -> Array {
736        self.clone()
737    }
738
739    /// Return the conjugate (alias for conj).
740    ///
741    /// # Examples
742    ///
743    /// ```
744    /// # use jax_rs::{Array, Shape};
745    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
746    /// let b = a.conjugate();
747    /// assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0]);
748    /// ```
749    pub fn conjugate(&self) -> Array {
750        self.clone()
751    }
752
753    /// Return the angle of complex numbers (phase).
754    /// For real numbers, returns 0 for positive, PI for negative.
755    ///
756    /// # Examples
757    ///
758    /// ```
759    /// # use jax_rs::{Array, Shape};
760    /// let a = Array::from_vec(vec![1.0, -1.0, 0.0], Shape::new(vec![3]));
761    /// let angles = a.angle();
762    /// // Positive: 0, Negative: PI, Zero: 0
763    /// ```
764    pub fn angle(&self) -> Array {
765        unary_op(self, Primitive::Sign, |x| {
766            if x > 0.0 {
767                0.0
768            } else if x < 0.0 {
769                std::f32::consts::PI
770            } else {
771                0.0
772            }
773        })
774    }
775
776    /// Return the real part of complex numbers.
777    /// For real arrays, this is the identity function.
778    ///
779    /// # Examples
780    ///
781    /// ```
782    /// # use jax_rs::{Array, Shape};
783    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
784    /// let r = a.real();
785    /// assert_eq!(r.to_vec(), vec![1.0, 2.0, 3.0]);
786    /// ```
787    pub fn real(&self) -> Array {
788        self.clone()
789    }
790
791    /// Return the imaginary part of complex numbers.
792    /// For real arrays, returns zeros.
793    ///
794    /// # Examples
795    ///
796    /// ```
797    /// # use jax_rs::{Array, Shape};
798    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
799    /// let im = a.imag();
800    /// assert_eq!(im.to_vec(), vec![0.0, 0.0, 0.0]);
801    /// ```
802    pub fn imag(&self) -> Array {
803        Array::zeros(self.shape().clone(), DType::Float32)
804    }
805
806    /// Bitwise NOT operation.
807    /// Inverts all bits in the bit representation of Float32 values.
808    ///
809    /// # Examples
810    ///
811    /// ```
812    /// # use jax_rs::{Array, Shape};
813    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
814    /// let b = a.bitwise_not();
815    /// ```
816    pub fn bitwise_not(&self) -> Array {
817        unary_op(self, Primitive::Neg, |x| {
818            let bits = x.to_bits();
819            f32::from_bits(!bits)
820        })
821    }
822
823    /// Return the reciprocal of the square root (1/sqrt(x)).
824    ///
825    /// # Examples
826    ///
827    /// ```
828    /// # use jax_rs::{Array, Shape};
829    /// let a = Array::from_vec(vec![1.0, 4.0, 9.0], Shape::new(vec![3]));
830    /// let b = a.rsqrt();
831    /// assert!((b.to_vec()[0] - 1.0).abs() < 1e-6);
832    /// assert!((b.to_vec()[1] - 0.5).abs() < 1e-6);
833    /// ```
834    pub fn rsqrt(&self) -> Array {
835        unary_op(self, Primitive::Sqrt, |x| 1.0 / x.sqrt())
836    }
837
838    /// Return the fractional and integer parts of an array element-wise.
839    /// Returns a tuple of (fractional_part, integer_part).
840    ///
841    /// # Examples
842    ///
843    /// ```
844    /// # use jax_rs::{Array, Shape};
845    /// let a = Array::from_vec(vec![1.5, 2.7, -3.2], Shape::new(vec![3]));
846    /// let (frac, int) = a.modf();
847    /// assert!((frac.to_vec()[0] - 0.5).abs() < 1e-6);
848    /// assert!((int.to_vec()[0] - 1.0).abs() < 1e-6);
849    /// ```
850    pub fn modf(&self) -> (Array, Array) {
851        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
852
853        let data = self.to_vec();
854        let frac_data: Vec<f32> = data.iter().map(|&x| x.fract()).collect();
855        let int_data: Vec<f32> = data.iter().map(|&x| x.trunc()).collect();
856
857        let frac = Array::from_vec(frac_data, self.shape().clone());
858        let int = Array::from_vec(int_data, self.shape().clone());
859
860        (frac, int)
861    }
862
863    /// Compute x * 2^exp for each element.
864    /// Equivalent to ldexp function from C math library.
865    ///
866    /// # Examples
867    ///
868    /// ```
869    /// # use jax_rs::{Array, Shape};
870    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
871    /// let b = a.ldexp(2);
872    /// assert_eq!(b.to_vec(), vec![4.0, 8.0, 12.0]); // multiply by 2^2 = 4
873    /// ```
874    pub fn ldexp(&self, exp: i32) -> Array {
875        let multiplier = 2_f32.powi(exp);
876        unary_op(self, Primitive::Mul, move |x| x * multiplier)
877    }
878
879    /// Decompose x into mantissa and exponent: x = m * 2^e.
880    /// Returns (mantissa, exponent) where mantissa is in [0.5, 1.0).
881    ///
882    /// # Examples
883    ///
884    /// ```
885    /// # use jax_rs::{Array, Shape};
886    /// let a = Array::from_vec(vec![4.0, 8.0, 0.5], Shape::new(vec![3]));
887    /// let (mantissa, exp) = a.frexp();
888    /// // 4.0 = 0.5 * 2^3, 8.0 = 0.5 * 2^4, 0.5 = 0.5 * 2^0
889    /// ```
890    pub fn frexp(&self) -> (Array, Array) {
891        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
892
893        let data = self.to_vec();
894        let mut mantissa_data = Vec::with_capacity(data.len());
895        let mut exp_data = Vec::with_capacity(data.len());
896
897        for &x in &data {
898            if x == 0.0 {
899                mantissa_data.push(0.0);
900                exp_data.push(0.0);
901            } else {
902                let bits = x.to_bits();
903                let sign = (bits >> 31) & 1;
904                let exponent = ((bits >> 23) & 0xFF) as i32 - 126;
905                // Create mantissa in [0.5, 1.0)
906                let mantissa_bits = (sign << 31) | (126 << 23) | (bits & 0x7FFFFF);
907                let mantissa = f32::from_bits(mantissa_bits);
908                mantissa_data.push(mantissa);
909                exp_data.push(exponent as f32);
910            }
911        }
912
913        let mantissa = Array::from_vec(mantissa_data, self.shape().clone());
914        let exp = Array::from_vec(exp_data, self.shape().clone());
915
916        (mantissa, exp)
917    }
918
919    /// Divide arrays element-wise with safe handling of division by zero.
920    /// Returns 0 when dividing by zero instead of NaN/Inf.
921    ///
922    /// # Examples
923    ///
924    /// ```
925    /// # use jax_rs::{Array, Shape};
926    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
927    /// let b = a.safe_divide_scalar(2.0);
928    /// assert_eq!(b.to_vec(), vec![0.5, 1.0, 1.5]);
929    /// let c = a.safe_divide_scalar(0.0);
930    /// assert_eq!(c.to_vec(), vec![0.0, 0.0, 0.0]); // Returns 0 instead of Inf
931    /// ```
932    pub fn safe_divide_scalar(&self, divisor: f32) -> Array {
933        if divisor == 0.0 {
934            Array::zeros(self.shape().clone(), DType::Float32)
935        } else {
936            unary_op(self, Primitive::Reciprocal, move |x| x / divisor)
937        }
938    }
939
940    /// Compute the modified Bessel function of the first kind, order 0.
941    /// Approximation using polynomial expansion.
942    ///
943    /// # Examples
944    ///
945    /// ```
946    /// # use jax_rs::{Array, Shape};
947    /// let a = Array::from_vec(vec![0.0, 1.0, 2.0], Shape::new(vec![3]));
948    /// let b = a.i0();
949    /// assert!((b.to_vec()[0] - 1.0).abs() < 1e-4);  // i0(0) = 1
950    /// ```
951    pub fn i0(&self) -> Array {
952        unary_op(self, Primitive::Exp, |x| {
953            // Polynomial approximation for I0
954            let ax = x.abs();
955            if ax < 3.75 {
956                let y = (x / 3.75).powi(2);
957                1.0 + y * (3.5156229 + y * (3.0899424 + y * (1.2067492 + y * (0.2659732 + y * (0.0360768 + y * 0.0045813)))))
958            } else {
959                let y = 3.75 / ax;
960                (ax.exp() / ax.sqrt()) * (0.398_942_3 + y * (0.01328592 + y * (0.00225319 + y * (-0.00157565 + y * (0.00916281 + y * (-0.02057706 + y * (0.02635537 + y * (-0.01647633 + y * 0.00392377))))))))
961            }
962        })
963    }
964
965    /// Compute the natural logarithm of the absolute value of the gamma function.
966    ///
967    /// # Examples
968    ///
969    /// ```
970    /// # use jax_rs::{Array, Shape};
971    /// let a = Array::from_vec(vec![1.0, 2.0, 5.0], Shape::new(vec![3]));
972    /// let b = a.lgamma();
973    /// assert!((b.to_vec()[0]).abs() < 1e-6);  // lgamma(1) = 0
974    /// assert!((b.to_vec()[1]).abs() < 1e-6);  // lgamma(2) = 0
975    /// ```
976    pub fn lgamma(&self) -> Array {
977        unary_op(self, Primitive::Log, |x| {
978            // Stirling's approximation for larger values
979            if x <= 0.0 {
980                f32::INFINITY
981            } else if x < 7.0 {
982                // Use recurrence relation for small values
983                let n = (7.0 - x).ceil() as i32;
984                let mut y = x;
985                let mut prod = 1.0;
986                for _ in 0..n {
987                    prod *= y;
988                    y += 1.0;
989                }
990                lgamma_impl(y) - prod.ln()
991            } else {
992                lgamma_impl(x)
993            }
994        })
995    }
996
997}
998
999#[cfg(test)]
1000mod tests {
1001    use super::*;
1002    use approx::assert_abs_diff_eq;
1003
1004    #[test]
1005    fn test_neg() {
1006        let a = Array::from_vec(vec![1.0, -2.0, 3.0], Shape::new(vec![3]));
1007        let b = a.neg();
1008        assert_eq!(b.to_vec(), vec![-1.0, 2.0, -3.0]);
1009    }
1010
1011    #[test]
1012    fn test_abs() {
1013        let a =
1014            Array::from_vec(vec![1.0, -2.0, 3.0, -4.0], Shape::new(vec![4]));
1015        let b = a.abs();
1016        assert_eq!(b.to_vec(), vec![1.0, 2.0, 3.0, 4.0]);
1017    }
1018
1019    #[test]
1020    fn test_sin_cos() {
1021        let a = Array::from_vec(
1022            vec![0.0, std::f32::consts::PI / 2.0],
1023            Shape::new(vec![2]),
1024        );
1025        let sin_a = a.sin();
1026        let cos_a = a.cos();
1027
1028        assert_abs_diff_eq!(sin_a.to_vec()[0], 0.0, epsilon = 1e-6);
1029        assert_abs_diff_eq!(sin_a.to_vec()[1], 1.0, epsilon = 1e-6);
1030        assert_abs_diff_eq!(cos_a.to_vec()[0], 1.0, epsilon = 1e-6);
1031        assert_abs_diff_eq!(cos_a.to_vec()[1], 0.0, epsilon = 1e-6);
1032    }
1033
1034    #[test]
1035    fn test_exp_log() {
1036        let a = Array::from_vec(vec![0.0, 1.0, 2.0], Shape::new(vec![3]));
1037        let exp_a = a.exp();
1038        let log_exp_a = exp_a.log();
1039
1040        assert_abs_diff_eq!(exp_a.to_vec()[0], 1.0, epsilon = 1e-6);
1041        assert_abs_diff_eq!(
1042            exp_a.to_vec()[1],
1043            std::f32::consts::E,
1044            epsilon = 1e-6
1045        );
1046
1047        // log(exp(x)) should equal x
1048        assert_abs_diff_eq!(log_exp_a.to_vec()[0], 0.0, epsilon = 1e-5);
1049        assert_abs_diff_eq!(log_exp_a.to_vec()[1], 1.0, epsilon = 1e-5);
1050        assert_abs_diff_eq!(log_exp_a.to_vec()[2], 2.0, epsilon = 1e-5);
1051    }
1052
1053    #[test]
1054    fn test_sqrt() {
1055        let a = Array::from_vec(vec![0.0, 1.0, 4.0, 9.0], Shape::new(vec![4]));
1056        let b = a.sqrt();
1057        assert_eq!(b.to_vec(), vec![0.0, 1.0, 2.0, 3.0]);
1058    }
1059
1060    #[test]
1061    fn test_tanh() {
1062        let a = Array::from_vec(vec![0.0, 1.0], Shape::new(vec![2]));
1063        let b = a.tanh();
1064        assert_abs_diff_eq!(b.to_vec()[0], 0.0, epsilon = 1e-6);
1065        assert_abs_diff_eq!(b.to_vec()[1], 0.761_594_2, epsilon = 1e-6);
1066    }
1067
1068    #[test]
1069    fn test_reciprocal() {
1070        let a = Array::from_vec(vec![1.0, 2.0, 4.0], Shape::new(vec![3]));
1071        let b = a.reciprocal();
1072        assert_abs_diff_eq!(b.to_vec()[0], 1.0, epsilon = 1e-6);
1073        assert_abs_diff_eq!(b.to_vec()[1], 0.5, epsilon = 1e-6);
1074        assert_abs_diff_eq!(b.to_vec()[2], 0.25, epsilon = 1e-6);
1075    }
1076
1077    #[test]
1078    fn test_reciprocal_no_nan() {
1079        let a = Array::from_vec(vec![2.0, 0.0, 4.0], Shape::new(vec![3]));
1080        let b = a.reciprocal_no_nan();
1081        assert_eq!(b.to_vec(), vec![0.5, 0.0, 0.25]);
1082    }
1083
1084    #[test]
1085    fn test_square() {
1086        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1087        let b = a.square();
1088        assert_eq!(b.to_vec(), vec![1.0, 4.0, 9.0]);
1089    }
1090
1091    #[test]
1092    fn test_sign() {
1093        let a =
1094            Array::from_vec(vec![-2.0, -0.0, 0.0, 3.0], Shape::new(vec![4]));
1095        let b = a.sign();
1096        assert_eq!(b.to_vec(), vec![-1.0, 0.0, 0.0, 1.0]);
1097    }
1098}