jax_rs/ops/
reduce.rs

1//! Reduction operations on arrays.
2
3use crate::trace::{is_tracing, trace_reduce, Primitive};
4use crate::{buffer::Buffer, Array, DType, Device, Shape};
5
6/// Reduce over all elements with a binary operation.
7fn reduce_all<F>(input: &Array, init: f32, f: F) -> f32
8where
9    F: Fn(f32, f32) -> f32,
10{
11    assert_eq!(input.dtype(), DType::Float32, "Only Float32 supported");
12
13    // CPU path - simple fold
14    let data = input.to_vec();
15    data.iter().fold(init, |acc, &x| f(acc, x))
16}
17
18/// GPU-aware reduce_all that takes an operation string.
19fn reduce_all_gpu_aware(input: &Array, op: &str) -> f32 {
20    assert_eq!(input.dtype(), DType::Float32, "Only Float32 supported");
21
22    match input.device() {
23        Device::WebGpu => {
24            // GPU path
25            let output_buffer = Buffer::zeros(1, DType::Float32, Device::WebGpu);
26
27            crate::backend::ops::gpu_reduce_all(
28                input.buffer(),
29                &output_buffer,
30                op,
31            );
32
33            // Read result back from GPU
34            output_buffer.to_f32_vec()[0]
35        }
36        Device::Cpu | Device::Wasm => {
37            // CPU fallback - use appropriate init and operation
38            let (init, f): (f32, Box<dyn Fn(f32, f32) -> f32>) = match op {
39                "sum" => (0.0, Box::new(|acc, x| acc + x)),
40                "max" => (f32::NEG_INFINITY, Box::new(|acc, x| acc.max(x))),
41                "min" => (f32::INFINITY, Box::new(|acc, x| acc.min(x))),
42                "prod" => (1.0, Box::new(|acc, x| acc * x)),
43                _ => panic!("Unknown reduction op: {}", op),
44            };
45
46            let data = input.to_vec();
47            data.iter().fold(init, |acc, &x| f(acc, x))
48        }
49    }
50}
51
52/// Reduce along a specific axis.
53fn reduce_axis<F>(
54    input: &Array,
55    axis: usize,
56    op: Primitive,
57    init: f32,
58    f: F,
59) -> Array
60where
61    F: Fn(f32, f32) -> f32,
62{
63    assert_eq!(input.dtype(), DType::Float32, "Only Float32 supported");
64    assert_eq!(input.device(), Device::Cpu, "Only CPU supported for now");
65    assert!(axis < input.ndim(), "Axis out of bounds");
66
67    let shape = input.shape();
68    let dims = shape.as_slice();
69
70    // Result shape has the reduced axis removed
71    let mut result_dims: Vec<usize> = dims.to_vec();
72    result_dims.remove(axis);
73    let result_shape = if result_dims.is_empty() {
74        Shape::scalar()
75    } else {
76        Shape::new(result_dims.clone())
77    };
78
79    let input_data = input.to_vec();
80    let result_size = result_shape.size();
81    let mut result_data = vec![init; result_size];
82
83    // Compute strides for input
84    let mut strides = vec![1; dims.len()];
85    for i in (0..dims.len() - 1).rev() {
86        strides[i] = strides[i + 1] * dims[i + 1];
87    }
88
89    // Iterate over result indices
90    for (result_idx, item) in result_data.iter_mut().enumerate() {
91        // Convert flat result index to multi-dimensional
92        let mut result_multi = vec![0; result_dims.len()];
93        let mut idx = result_idx;
94        for i in (0..result_dims.len()).rev() {
95            result_multi[i] = idx % result_shape.as_slice()[i];
96            idx /= result_shape.as_slice()[i];
97        }
98
99        // Insert the reduced axis and iterate over it
100        let mut acc = init;
101        for axis_idx in 0..dims[axis] {
102            let mut input_multi = Vec::with_capacity(dims.len());
103            let mut result_i = 0;
104            for i in 0..dims.len() {
105                if i == axis {
106                    input_multi.push(axis_idx);
107                } else {
108                    input_multi.push(result_multi[result_i]);
109                    result_i += 1;
110                }
111            }
112
113            // Convert multi-dimensional index to flat
114            let flat_idx: usize = input_multi
115                .iter()
116                .zip(strides.iter())
117                .map(|(idx, stride)| idx * stride)
118                .sum();
119
120            acc = f(acc, input_data[flat_idx]);
121        }
122
123        *item = acc;
124    }
125
126    let buffer = Buffer::from_f32(result_data, Device::Cpu);
127    let result = Array::from_buffer(buffer, result_shape.clone());
128
129    // Register with trace context if tracing is active
130    if is_tracing() {
131        trace_reduce(result.id(), op, input, result_shape);
132    }
133
134    result
135}
136
137impl Array {
138    /// Sum of all elements.
139    ///
140    /// # Examples
141    ///
142    /// ```
143    /// # use jax_rs::{Array, Shape};
144    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
145    /// let sum = a.sum_all();
146    /// assert_eq!(sum, 10.0);
147    /// ```
148    pub fn sum_all(&self) -> f32 {
149        reduce_all_gpu_aware(self, "sum")
150    }
151
152    /// Sum of all elements, returned as a scalar Array.
153    ///
154    /// This is a convenience method for autodiff that wraps `sum_all()`.
155    pub fn sum_all_array(&self) -> Array {
156        let val = self.sum_all();
157        let result = Array::from_vec(vec![val], crate::Shape::scalar());
158
159        // Register with trace context if tracing is active
160        if is_tracing() {
161            trace_reduce(
162                result.id(),
163                Primitive::SumAll,
164                self,
165                crate::Shape::scalar(),
166            );
167        }
168
169        result
170    }
171
172    /// Sum along a specific axis.
173    ///
174    /// # Examples
175    ///
176    /// ```
177    /// # use jax_rs::{Array, Shape};
178    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
179    /// let sum_axis0 = a.sum(0);
180    /// assert_eq!(sum_axis0.to_vec(), vec![5.0, 7.0, 9.0]);
181    /// let sum_axis1 = a.sum(1);
182    /// assert_eq!(sum_axis1.to_vec(), vec![6.0, 15.0]);
183    /// ```
184    pub fn sum(&self, axis: usize) -> Array {
185        reduce_axis(self, axis, Primitive::Sum { axis }, 0.0, |acc, x| acc + x)
186    }
187
188    /// Mean of all elements.
189    pub fn mean_all(&self) -> f32 {
190        self.sum_all() / (self.size() as f32)
191    }
192
193    /// Mean of all elements, returning a scalar array.
194    pub fn mean_all_array(&self) -> Array {
195        let val = self.mean_all();
196        let result = Array::from_vec(vec![val], crate::Shape::scalar());
197
198        // Register with trace context if tracing is active
199        if is_tracing() {
200            trace_reduce(
201                result.id(),
202                Primitive::MeanAll,
203                self,
204                crate::Shape::scalar(),
205            );
206        }
207
208        result
209    }
210
211    /// Mean along a specific axis.
212    pub fn mean(&self, axis: usize) -> Array {
213        reduce_axis(self, axis, Primitive::Mean { axis }, 0.0, |acc, x| {
214            acc + x / (self.shape().as_slice()[axis] as f32)
215        })
216    }
217
218    /// Maximum of all elements.
219    pub fn max_all(&self) -> f32 {
220        reduce_all_gpu_aware(self, "max")
221    }
222
223    /// Maximum along a specific axis.
224    pub fn max(&self, axis: usize) -> Array {
225        reduce_axis(
226            self,
227            axis,
228            Primitive::MaxAxis { axis },
229            f32::NEG_INFINITY,
230            |acc, x| acc.max(x),
231        )
232    }
233
234    /// Minimum of all elements.
235    pub fn min_all(&self) -> f32 {
236        reduce_all_gpu_aware(self, "min")
237    }
238
239    /// Minimum along a specific axis.
240    pub fn min(&self, axis: usize) -> Array {
241        reduce_axis(
242            self,
243            axis,
244            Primitive::MinAxis { axis },
245            f32::INFINITY,
246            |acc, x| acc.min(x),
247        )
248    }
249
250    /// Product of all elements.
251    ///
252    /// # Examples
253    ///
254    /// ```
255    /// # use jax_rs::{Array, Shape};
256    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
257    /// let prod = a.prod_all();
258    /// assert_eq!(prod, 24.0);
259    /// ```
260    pub fn prod_all(&self) -> f32 {
261        reduce_all(self, 1.0, |acc, x| acc * x)
262    }
263
264    /// Product along a specific axis.
265    pub fn prod(&self, axis: usize) -> Array {
266        reduce_axis(self, axis, Primitive::ProdAxis { axis }, 1.0, |acc, x| {
267            acc * x
268        })
269    }
270
271    /// Index of minimum element.
272    ///
273    /// # Examples
274    ///
275    /// ```
276    /// # use jax_rs::{Array, Shape};
277    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
278    /// let idx = a.argmin();
279    /// assert_eq!(idx, 1);
280    /// ```
281    pub fn argmin(&self) -> usize {
282        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
283        let data = self.to_vec();
284        data.iter()
285            .enumerate()
286            .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
287            .map(|(idx, _)| idx)
288            .unwrap()
289    }
290
291    /// Index of maximum element.
292    ///
293    /// # Examples
294    ///
295    /// ```
296    /// # use jax_rs::{Array, Shape};
297    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
298    /// let idx = a.argmax();
299    /// assert_eq!(idx, 2);
300    /// ```
301    pub fn argmax(&self) -> usize {
302        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
303        let data = self.to_vec();
304        data.iter()
305            .enumerate()
306            .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
307            .map(|(idx, _)| idx)
308            .unwrap()
309    }
310
311    /// Variance of all elements.
312    ///
313    /// # Examples
314    ///
315    /// ```
316    /// # use jax_rs::{Array, Shape};
317    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
318    /// let var = a.var();
319    /// assert!((var - 1.25).abs() < 1e-6);
320    /// ```
321    pub fn var(&self) -> f32 {
322        let mean = self.mean_all();
323        let data = self.to_vec();
324        let sum_sq_diff: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum();
325        sum_sq_diff / data.len() as f32
326    }
327
328    /// Standard deviation of all elements.
329    ///
330    /// # Examples
331    ///
332    /// ```
333    /// # use jax_rs::{Array, Shape};
334    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
335    /// let std = a.std();
336    /// assert!((std - 1.118).abs() < 0.01);
337    /// ```
338    pub fn std(&self) -> f32 {
339        self.var().sqrt()
340    }
341
342    /// Variance along a specific axis.
343    ///
344    /// # Examples
345    ///
346    /// ```
347    /// # use jax_rs::{Array, Shape};
348    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
349    /// let var_axis0 = a.var_axis(0);
350    /// assert_eq!(var_axis0.shape().as_slice(), &[2]);
351    /// ```
352    pub fn var_axis(&self, axis: usize) -> Array {
353        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
354        assert!(axis < self.ndim(), "Axis out of bounds");
355
356        let mean = self.mean(axis);
357        let mean_data = mean.to_vec();
358
359        let shape = self.shape().as_slice();
360        let data = self.to_vec();
361
362        // Result shape has the reduced axis removed
363        let mut result_dims: Vec<usize> = shape.to_vec();
364        result_dims.remove(axis);
365        let result_shape = if result_dims.is_empty() {
366            Shape::scalar()
367        } else {
368            Shape::new(result_dims.clone())
369        };
370
371        let result_size = result_shape.size();
372        let mut result_data = vec![0.0; result_size];
373
374        // Compute strides for input
375        let mut strides = vec![1; shape.len()];
376        for i in (0..shape.len() - 1).rev() {
377            strides[i] = strides[i + 1] * shape[i + 1];
378        }
379
380        // Iterate over result indices
381        for (result_idx, item) in result_data.iter_mut().enumerate() {
382            // Convert flat result index to multi-dimensional
383            let mut result_multi = vec![0; result_dims.len()];
384            let mut idx = result_idx;
385            for i in (0..result_dims.len()).rev() {
386                result_multi[i] = idx % result_shape.as_slice()[i];
387                idx /= result_shape.as_slice()[i];
388            }
389
390            let mean_val = mean_data[result_idx];
391            let mut sum_sq = 0.0;
392
393            // Iterate over the reduced axis
394            for axis_idx in 0..shape[axis] {
395                let mut input_multi = Vec::with_capacity(shape.len());
396                let mut result_i = 0;
397                for i in 0..shape.len() {
398                    if i == axis {
399                        input_multi.push(axis_idx);
400                    } else {
401                        input_multi.push(result_multi[result_i]);
402                        result_i += 1;
403                    }
404                }
405
406                // Convert multi-dimensional index to flat
407                let flat_idx: usize = input_multi
408                    .iter()
409                    .zip(strides.iter())
410                    .map(|(idx, stride)| idx * stride)
411                    .sum();
412
413                let diff = data[flat_idx] - mean_val;
414                sum_sq += diff * diff;
415            }
416
417            *item = sum_sq / shape[axis] as f32;
418        }
419
420        Array::from_vec(result_data, result_shape)
421    }
422
423    /// Standard deviation along a specific axis.
424    ///
425    /// # Examples
426    ///
427    /// ```
428    /// # use jax_rs::{Array, Shape};
429    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![2, 2]));
430    /// let std_axis0 = a.std_axis(0);
431    /// assert_eq!(std_axis0.shape().as_slice(), &[2]);
432    /// ```
433    pub fn std_axis(&self, axis: usize) -> Array {
434        let var = self.var_axis(axis);
435        let data = var.to_vec();
436        let result: Vec<f32> = data.iter().map(|&x| x.sqrt()).collect();
437        Array::from_vec(result, var.shape().clone())
438    }
439
440    /// Median of all elements.
441    ///
442    /// # Examples
443    ///
444    /// ```
445    /// # use jax_rs::{Array, Shape};
446    /// let a = Array::from_vec(vec![1.0, 3.0, 2.0], Shape::new(vec![3]));
447    /// let med = a.median();
448    /// assert_eq!(med, 2.0);
449    /// ```
450    pub fn median(&self) -> f32 {
451        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
452        let mut data = self.to_vec();
453        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
454
455        let len = data.len();
456        if len.is_multiple_of(2) {
457            (data[len / 2 - 1] + data[len / 2]) / 2.0
458        } else {
459            data[len / 2]
460        }
461    }
462
463    /// Percentile of all elements.
464    ///
465    /// # Arguments
466    ///
467    /// * `q` - Percentile to compute (0-100)
468    ///
469    /// # Examples
470    ///
471    /// ```
472    /// # use jax_rs::{Array, Shape};
473    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
474    /// let p50 = a.percentile(50.0);
475    /// assert_eq!(p50, 3.0);
476    /// ```
477    pub fn percentile(&self, q: f32) -> f32 {
478        assert!((0.0..=100.0).contains(&q), "Percentile must be between 0 and 100");
479        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
480
481        let mut data = self.to_vec();
482        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
483
484        let len = data.len();
485        if len == 1 {
486            return data[0];
487        }
488
489        let index = (q / 100.0) * (len - 1) as f32;
490        let lower = index.floor() as usize;
491        let upper = index.ceil() as usize;
492
493        if lower == upper {
494            data[lower]
495        } else {
496            let weight = index - lower as f32;
497            data[lower] * (1.0 - weight) + data[upper] * weight
498        }
499    }
500
501    /// Cumulative sum of array elements.
502    ///
503    /// Returns an array of the same shape with cumulative sums.
504    ///
505    /// # Examples
506    ///
507    /// ```
508    /// # use jax_rs::{Array, Shape};
509    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
510    /// let cumsum = a.cumsum();
511    /// assert_eq!(cumsum.to_vec(), vec![1.0, 3.0, 6.0, 10.0]);
512    /// ```
513    pub fn cumsum(&self) -> Array {
514        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
515        let data = self.to_vec();
516        let mut result = Vec::with_capacity(data.len());
517        let mut sum = 0.0;
518
519        for &val in data.iter() {
520            sum += val;
521            result.push(sum);
522        }
523
524        Array::from_vec(result, self.shape().clone())
525    }
526
527    /// Cumulative product of array elements.
528    ///
529    /// Returns an array of the same shape with cumulative products.
530    ///
531    /// # Examples
532    ///
533    /// ```
534    /// # use jax_rs::{Array, Shape};
535    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
536    /// let cumprod = a.cumprod();
537    /// assert_eq!(cumprod.to_vec(), vec![1.0, 2.0, 6.0, 24.0]);
538    /// ```
539    pub fn cumprod(&self) -> Array {
540        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
541        let data = self.to_vec();
542        let mut result = Vec::with_capacity(data.len());
543        let mut prod = 1.0;
544
545        for &val in data.iter() {
546            prod *= val;
547            result.push(prod);
548        }
549
550        Array::from_vec(result, self.shape().clone())
551    }
552
553    /// Cumulative maximum of array elements.
554    ///
555    /// Returns an array of the same shape with cumulative maximums.
556    ///
557    /// # Examples
558    ///
559    /// ```
560    /// # use jax_rs::{Array, Shape};
561    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
562    /// let cummax = a.cummax();
563    /// assert_eq!(cummax.to_vec(), vec![3.0, 3.0, 4.0, 4.0]);
564    /// ```
565    pub fn cummax(&self) -> Array {
566        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
567        let data = self.to_vec();
568        let mut result = Vec::with_capacity(data.len());
569        let mut max = f32::NEG_INFINITY;
570
571        for &val in data.iter() {
572            max = max.max(val);
573            result.push(max);
574        }
575
576        Array::from_vec(result, self.shape().clone())
577    }
578
579    /// Cumulative minimum of array elements.
580    ///
581    /// Returns an array of the same shape with cumulative minimums.
582    ///
583    /// # Examples
584    ///
585    /// ```
586    /// # use jax_rs::{Array, Shape};
587    /// let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
588    /// let cummin = a.cummin();
589    /// assert_eq!(cummin.to_vec(), vec![3.0, 1.0, 1.0, 1.0]);
590    /// ```
591    pub fn cummin(&self) -> Array {
592        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
593        let data = self.to_vec();
594        let mut result = Vec::with_capacity(data.len());
595        let mut min = f32::INFINITY;
596
597        for &val in data.iter() {
598            min = min.min(val);
599            result.push(min);
600        }
601
602        Array::from_vec(result, self.shape().clone())
603    }
604
605    /// Calculate the discrete difference along the array.
606    ///
607    /// Computes the difference between consecutive elements.
608    ///
609    /// # Examples
610    ///
611    /// ```
612    /// # use jax_rs::{Array, Shape};
613    /// let a = Array::from_vec(vec![1.0, 3.0, 6.0, 10.0], Shape::new(vec![4]));
614    /// let diff = a.diff();
615    /// assert_eq!(diff.to_vec(), vec![2.0, 3.0, 4.0]);
616    /// ```
617    pub fn diff(&self) -> Array {
618        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
619        let data = self.to_vec();
620        assert!(!data.is_empty(), "Array must have at least 1 element");
621
622        if data.len() == 1 {
623            return Array::from_vec(vec![], Shape::new(vec![0]));
624        }
625
626        let mut result = Vec::with_capacity(data.len() - 1);
627        for i in 1..data.len() {
628            result.push(data[i] - data[i - 1]);
629        }
630
631        let len = result.len();
632        Array::from_vec(result, Shape::new(vec![len]))
633    }
634
635    /// Calculate the n-th discrete difference.
636    ///
637    /// Recursively applies diff n times.
638    ///
639    /// # Examples
640    ///
641    /// ```
642    /// # use jax_rs::{Array, Shape};
643    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
644    /// let diff2 = a.diff_n(2);
645    /// assert_eq!(diff2.to_vec(), vec![0.0, 0.0, 0.0]);
646    /// ```
647    pub fn diff_n(&self, n: usize) -> Array {
648        if n == 0 {
649            return self.clone();
650        }
651
652        let mut result = self.diff();
653        for _ in 1..n {
654            result = result.diff();
655        }
656        result
657    }
658
659    /// Sum of array elements, ignoring NaN values.
660    ///
661    /// # Examples
662    ///
663    /// ```
664    /// # use jax_rs::{Array, Shape};
665    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 4.0], Shape::new(vec![4]));
666    /// let sum = a.nansum();
667    /// assert_eq!(sum, 8.0);
668    /// ```
669    pub fn nansum(&self) -> f32 {
670        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
671        let data = self.to_vec();
672        data.iter().filter(|x| !x.is_nan()).sum()
673    }
674
675    /// Mean of array elements, ignoring NaN values.
676    ///
677    /// # Examples
678    ///
679    /// ```
680    /// # use jax_rs::{Array, Shape};
681    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 4.0], Shape::new(vec![4]));
682    /// let mean = a.nanmean();
683    /// assert_eq!(mean, 8.0 / 3.0);
684    /// ```
685    pub fn nanmean(&self) -> f32 {
686        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
687        let data = self.to_vec();
688        let valid: Vec<f32> = data.iter().copied().filter(|x| !x.is_nan()).collect();
689        if valid.is_empty() {
690            return f32::NAN;
691        }
692        valid.iter().sum::<f32>() / valid.len() as f32
693    }
694
695    /// Maximum of array elements, ignoring NaN values.
696    ///
697    /// # Examples
698    ///
699    /// ```
700    /// # use jax_rs::{Array, Shape};
701    /// let a = Array::from_vec(vec![1.0, f32::NAN, 4.0, 2.0], Shape::new(vec![4]));
702    /// let max = a.nanmax();
703    /// assert_eq!(max, 4.0);
704    /// ```
705    pub fn nanmax(&self) -> f32 {
706        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
707        let data = self.to_vec();
708        data.iter()
709            .copied()
710            .filter(|x| !x.is_nan())
711            .fold(f32::NEG_INFINITY, f32::max)
712    }
713
714    /// Minimum of array elements, ignoring NaN values.
715    ///
716    /// # Examples
717    ///
718    /// ```
719    /// # use jax_rs::{Array, Shape};
720    /// let a = Array::from_vec(vec![1.0, f32::NAN, 4.0, 2.0], Shape::new(vec![4]));
721    /// let min = a.nanmin();
722    /// assert_eq!(min, 1.0);
723    /// ```
724    pub fn nanmin(&self) -> f32 {
725        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
726        let data = self.to_vec();
727        data.iter()
728            .copied()
729            .filter(|x| !x.is_nan())
730            .fold(f32::INFINITY, f32::min)
731    }
732
733    /// Standard deviation of array elements, ignoring NaN values.
734    ///
735    /// # Examples
736    ///
737    /// ```
738    /// # use jax_rs::{Array, Shape};
739    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 5.0], Shape::new(vec![4]));
740    /// let std = a.nanstd();
741    /// assert!((std - 2.0).abs() < 1e-5);
742    /// ```
743    pub fn nanstd(&self) -> f32 {
744        self.nanvar().sqrt()
745    }
746
747    /// Variance of array elements, ignoring NaN values.
748    ///
749    /// # Examples
750    ///
751    /// ```
752    /// # use jax_rs::{Array, Shape};
753    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 5.0], Shape::new(vec![4]));
754    /// let var = a.nanvar();
755    /// assert!((var - 4.0).abs() < 1e-5);
756    /// ```
757    pub fn nanvar(&self) -> f32 {
758        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
759        let data = self.to_vec();
760        let valid: Vec<f32> = data.iter().copied().filter(|x| !x.is_nan()).collect();
761
762        if valid.is_empty() || valid.len() == 1 {
763            return f32::NAN;
764        }
765
766        let mean = valid.iter().sum::<f32>() / valid.len() as f32;
767        let variance = valid
768            .iter()
769            .map(|x| {
770                let diff = x - mean;
771                diff * diff
772            })
773            .sum::<f32>()
774            / (valid.len() - 1) as f32; // Bessel's correction (sample variance)
775
776        variance
777    }
778
779    /// Median of array elements, ignoring NaN values.
780    ///
781    /// # Examples
782    ///
783    /// ```
784    /// # use jax_rs::{Array, Shape};
785    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 5.0, 2.0], Shape::new(vec![5]));
786    /// let median = a.nanmedian();
787    /// assert_eq!(median, 2.5);
788    /// ```
789    pub fn nanmedian(&self) -> f32 {
790        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
791        let data = self.to_vec();
792        let mut valid: Vec<f32> = data.iter().copied().filter(|x| !x.is_nan()).collect();
793
794        if valid.is_empty() {
795            return f32::NAN;
796        }
797
798        valid.sort_by(|a, b| a.partial_cmp(b).unwrap());
799        let len = valid.len();
800
801        if len.is_multiple_of(2) {
802            (valid[len / 2 - 1] + valid[len / 2]) / 2.0
803        } else {
804            valid[len / 2]
805        }
806    }
807
808    /// Peak-to-peak (maximum - minimum) value.
809    ///
810    /// # Examples
811    ///
812    /// ```
813    /// # use jax_rs::{Array, Shape};
814    /// let a = Array::from_vec(vec![1.0, 5.0, 2.0, 8.0], Shape::new(vec![4]));
815    /// assert_eq!(a.ptp(), 7.0);
816    /// ```
817    pub fn ptp(&self) -> f32 {
818        let max = self.max_all();
819        let min = self.min_all();
820        max - min
821    }
822
823    /// Peak-to-peak (maximum - minimum) along an axis.
824    ///
825    /// # Examples
826    ///
827    /// ```
828    /// # use jax_rs::{Array, Shape};
829    /// let a = Array::from_vec(vec![1.0, 5.0, 2.0, 8.0], Shape::new(vec![2, 2]));
830    /// let ptp = a.ptp_axis(0);
831    /// assert_eq!(ptp.to_vec(), vec![1.0, 3.0]);
832    /// ```
833    pub fn ptp_axis(&self, axis: usize) -> Array {
834        let max = self.max(axis);
835        let min = self.min(axis);
836        max.sub(&min)
837    }
838
839    /// Compute the q-th quantile of the data.
840    ///
841    /// # Arguments
842    ///
843    /// * `q` - Quantile to compute (between 0.0 and 1.0)
844    ///
845    /// # Examples
846    ///
847    /// ```
848    /// # use jax_rs::{Array, Shape};
849    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
850    /// let q = a.quantile(0.5); // Median
851    /// assert!((q - 3.0).abs() < 1e-6);
852    /// ```
853    pub fn quantile(&self, q: f32) -> f32 {
854        assert!(
855            (0.0..=1.0).contains(&q),
856            "Quantile must be between 0 and 1"
857        );
858
859        let mut data = self.to_vec();
860        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
861
862        let n = data.len();
863        if n == 0 {
864            return f32::NAN;
865        }
866
867        let index = q * (n - 1) as f32;
868        let lower = index.floor() as usize;
869        let upper = index.ceil() as usize;
870
871        if lower == upper {
872            data[lower]
873        } else {
874            let weight = index - lower as f32;
875            data[lower] * (1.0 - weight) + data[upper] * weight
876        }
877    }
878
879    /// Compute the q-th quantile along an axis.
880    ///
881    /// # Examples
882    ///
883    /// ```
884    /// # use jax_rs::{Array, Shape};
885    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
886    /// let q = a.quantile_axis(0.5, 0);
887    /// assert_eq!(q.shape().as_slice(), &[3]);
888    /// ```
889    pub fn quantile_axis(&self, q: f32, axis: usize) -> Array {
890        assert!(
891            (0.0..=1.0).contains(&q),
892            "Quantile must be between 0 and 1"
893        );
894        assert!(axis < self.ndim(), "Axis out of bounds");
895
896        let shape = self.shape();
897        let dims = shape.as_slice();
898        let axis_size = dims[axis];
899
900        // Compute output shape
901        let mut output_dims = dims.to_vec();
902        output_dims.remove(axis);
903        let output_shape = Shape::new(output_dims);
904        let output_size = output_shape.size();
905
906        let data = self.to_vec();
907        let mut result = Vec::with_capacity(output_size);
908
909        // For each position in the output
910        for output_idx in 0..output_size {
911            // Collect values along the axis
912            let mut values = Vec::with_capacity(axis_size);
913
914            for axis_idx in 0..axis_size {
915                // Compute input index
916                let mut input_idx = 0;
917                let mut remaining = output_idx;
918                let mut stride = 1;
919
920                for (dim_idx, &dim_size) in dims.iter().enumerate().rev() {
921                    if dim_idx == axis {
922                        input_idx += axis_idx * stride;
923                        stride *= dim_size;
924                    } else {
925                        let out_dim_size = if dim_idx < axis {
926                            dims[dim_idx]
927                        } else {
928                            dims[dim_idx]
929                        };
930                        let coord = remaining % out_dim_size;
931                        input_idx += coord * stride;
932                        remaining /= out_dim_size;
933                        stride *= dim_size;
934                    }
935                }
936
937                values.push(data[input_idx]);
938            }
939
940            // Compute quantile of collected values
941            values.sort_by(|a, b| a.partial_cmp(b).unwrap());
942            let n = values.len();
943            let index = q * (n - 1) as f32;
944            let lower = index.floor() as usize;
945            let upper = index.ceil() as usize;
946
947            let quantile = if lower == upper {
948                values[lower]
949            } else {
950                let weight = index - lower as f32;
951                values[lower] * (1.0 - weight) + values[upper] * weight
952            };
953
954            result.push(quantile);
955        }
956
957        Array::from_vec(result, output_shape)
958    }
959
960    /// Integrate along the array using the composite trapezoidal rule.
961    ///
962    /// # Examples
963    ///
964    /// ```
965    /// # use jax_rs::{Array, Shape};
966    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
967    /// let integral = a.trapz();
968    /// assert_eq!(integral, 4.0); // (1+2)/2 + (2+3)/2 = 1.5 + 2.5 = 4.0
969    /// ```
970    pub fn trapz(&self) -> f32 {
971        let data = self.to_vec();
972        if data.len() < 2 {
973            return 0.0;
974        }
975
976        let mut sum = 0.0;
977        for i in 0..data.len() - 1 {
978            sum += (data[i] + data[i + 1]) / 2.0;
979        }
980        sum
981    }
982
983    /// Integrate along an axis using the composite trapezoidal rule.
984    ///
985    /// # Examples
986    ///
987    /// ```
988    /// # use jax_rs::{Array, Shape};
989    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], Shape::new(vec![2, 3]));
990    /// let integral = a.trapz_axis(1);
991    /// assert_eq!(integral.shape().as_slice(), &[2]);
992    /// ```
993    pub fn trapz_axis(&self, axis: usize) -> Array {
994        assert!(axis < self.ndim(), "Axis out of bounds");
995
996        let shape = self.shape();
997        let dims = shape.as_slice();
998        let axis_size = dims[axis];
999
1000        if axis_size < 2 {
1001            // Can't integrate with less than 2 points
1002            let mut output_dims = dims.to_vec();
1003            output_dims.remove(axis);
1004            let output_shape = Shape::new(output_dims);
1005            return Array::zeros(output_shape, self.dtype());
1006        }
1007
1008        // Compute output shape
1009        let mut output_dims = dims.to_vec();
1010        output_dims.remove(axis);
1011        let output_shape = Shape::new(output_dims);
1012        let output_size = output_shape.size();
1013
1014        let data = self.to_vec();
1015        let mut result = Vec::with_capacity(output_size);
1016
1017        // For each position in the output
1018        for output_idx in 0..output_size {
1019            let mut sum = 0.0;
1020
1021            // Integrate along the axis
1022            for i in 0..axis_size - 1 {
1023                // Get values at i and i+1
1024                let idx1 = self.compute_axis_index(output_idx, axis, i, &output_shape);
1025                let idx2 = self.compute_axis_index(output_idx, axis, i + 1, &output_shape);
1026
1027                sum += (data[idx1] + data[idx2]) / 2.0;
1028            }
1029
1030            result.push(sum);
1031        }
1032
1033        Array::from_vec(result, output_shape)
1034    }
1035
1036    /// Compute the gradient (numerical derivative) of an array.
1037    ///
1038    /// For an array [a, b, c, d], returns [b-a, (c-a)/2, (d-b)/2, d-c].
1039    /// Uses forward differences at the start, backward differences at the end,
1040    /// and central differences in the middle.
1041    ///
1042    /// # Examples
1043    ///
1044    /// ```
1045    /// # use jax_rs::{Array, Shape};
1046    /// let a = Array::from_vec(vec![1.0, 2.0, 4.0, 7.0], Shape::new(vec![4]));
1047    /// let grad = a.gradient();
1048    /// // [2-1, (4-1)/2, (7-2)/2, 7-4] = [1.0, 1.5, 2.5, 3.0]
1049    /// assert_eq!(grad.to_vec(), vec![1.0, 1.5, 2.5, 3.0]);
1050    /// ```
1051    pub fn gradient(&self) -> Array {
1052        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1053        let data = self.to_vec();
1054        let n = data.len();
1055
1056        if n == 0 {
1057            return self.clone();
1058        }
1059
1060        if n == 1 {
1061            return Array::from_vec(vec![0.0], self.shape().clone());
1062        }
1063
1064        let mut result = Vec::with_capacity(n);
1065
1066        // Forward difference at start
1067        result.push(data[1] - data[0]);
1068
1069        // Central differences in the middle
1070        for i in 1..n - 1 {
1071            result.push((data[i + 1] - data[i - 1]) / 2.0);
1072        }
1073
1074        // Backward difference at end
1075        result.push(data[n - 1] - data[n - 2]);
1076
1077        Array::from_vec(result, self.shape().clone())
1078    }
1079
1080    /// Compute differences between consecutive elements (edge differences).
1081    ///
1082    /// This is equivalent to diff but specifically meant for edge detection.
1083    /// Returns an array of length n-1 where result[i] = array[i+1] - array[i].
1084    ///
1085    /// # Examples
1086    ///
1087    /// ```
1088    /// # use jax_rs::{Array, Shape};
1089    /// let a = Array::from_vec(vec![1.0, 3.0, 6.0, 10.0], Shape::new(vec![4]));
1090    /// let edges = a.ediff1d();
1091    /// assert_eq!(edges.to_vec(), vec![2.0, 3.0, 4.0]);
1092    /// ```
1093    pub fn ediff1d(&self) -> Array {
1094        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1095        let data = self.to_vec();
1096
1097        if data.is_empty() {
1098            return Array::zeros(Shape::new(vec![0]), DType::Float32);
1099        }
1100
1101        if data.len() == 1 {
1102            return Array::zeros(Shape::new(vec![0]), DType::Float32);
1103        }
1104
1105        let mut result = Vec::with_capacity(data.len() - 1);
1106        for i in 0..data.len() - 1 {
1107            result.push(data[i + 1] - data[i]);
1108        }
1109
1110        let len = result.len();
1111        Array::from_vec(result, Shape::new(vec![len]))
1112    }
1113
1114    /// Find the index of the maximum value, ignoring NaN.
1115    ///
1116    /// # Examples
1117    ///
1118    /// ```
1119    /// # use jax_rs::{Array, Shape};
1120    /// let a = Array::from_vec(vec![1.0, f32::NAN, 5.0, 3.0], Shape::new(vec![4]));
1121    /// let idx = a.nanargmax();
1122    /// assert_eq!(idx, 2); // Index of 5.0
1123    /// ```
1124    pub fn nanargmax(&self) -> usize {
1125        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1126        let data = self.to_vec();
1127
1128        let mut max_val = f32::NEG_INFINITY;
1129        let mut max_idx = 0;
1130
1131        for (i, &val) in data.iter().enumerate() {
1132            if !val.is_nan() && val > max_val {
1133                max_val = val;
1134                max_idx = i;
1135            }
1136        }
1137
1138        max_idx
1139    }
1140
1141    /// Find the index of the minimum value, ignoring NaN.
1142    ///
1143    /// # Examples
1144    ///
1145    /// ```
1146    /// # use jax_rs::{Array, Shape};
1147    /// let a = Array::from_vec(vec![5.0, f32::NAN, 1.0, 3.0], Shape::new(vec![4]));
1148    /// let idx = a.nanargmin();
1149    /// assert_eq!(idx, 2); // Index of 1.0
1150    /// ```
1151    pub fn nanargmin(&self) -> usize {
1152        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1153        let data = self.to_vec();
1154
1155        let mut min_val = f32::INFINITY;
1156        let mut min_idx = 0;
1157
1158        for (i, &val) in data.iter().enumerate() {
1159            if !val.is_nan() && val < min_val {
1160                min_val = val;
1161                min_idx = i;
1162            }
1163        }
1164
1165        min_idx
1166    }
1167
1168    /// Compute the weighted average of an array.
1169    ///
1170    /// # Examples
1171    ///
1172    /// ```
1173    /// # use jax_rs::{Array, Shape};
1174    /// let values = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1175    /// let weights = Array::from_vec(vec![1.0, 1.0, 1.0, 1.0], Shape::new(vec![4]));
1176    /// let avg = values.average(&weights);
1177    /// assert_eq!(avg, 2.5);
1178    /// ```
1179    pub fn average(&self, weights: &Array) -> f32 {
1180        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1181        assert_eq!(weights.dtype(), DType::Float32, "Only Float32 supported");
1182        assert_eq!(
1183            self.size(),
1184            weights.size(),
1185            "Values and weights must have same size"
1186        );
1187
1188        let data = self.to_vec();
1189        let weight_data = weights.to_vec();
1190
1191        let weighted_sum: f32 = data
1192            .iter()
1193            .zip(weight_data.iter())
1194            .map(|(v, w)| v * w)
1195            .sum();
1196
1197        let weight_sum: f32 = weight_data.iter().sum();
1198
1199        weighted_sum / weight_sum
1200    }
1201
1202    /// Compute covariance between two arrays.
1203    ///
1204    /// # Examples
1205    ///
1206    /// ```
1207    /// # use jax_rs::{Array, Shape};
1208    /// let x = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1209    /// let y = Array::from_vec(vec![2.0, 4.0, 6.0, 8.0], Shape::new(vec![4]));
1210    /// let cov = x.cov(&y);
1211    /// // Covariance should be positive since both increase together
1212    /// ```
1213    pub fn cov(&self, other: &Array) -> f32 {
1214        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1215        assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1216        assert_eq!(
1217            self.size(),
1218            other.size(),
1219            "Arrays must have same size for covariance"
1220        );
1221
1222        let x = self.to_vec();
1223        let y = other.to_vec();
1224        let n = x.len() as f32;
1225
1226        if n == 0.0 {
1227            return 0.0;
1228        }
1229
1230        let x_mean: f32 = x.iter().sum::<f32>() / n;
1231        let y_mean: f32 = y.iter().sum::<f32>() / n;
1232
1233        let cov: f32 = x
1234            .iter()
1235            .zip(y.iter())
1236            .map(|(&xi, &yi)| (xi - x_mean) * (yi - y_mean))
1237            .sum();
1238
1239        cov / (n - 1.0)
1240    }
1241
1242    // Helper function to compute flat index given output index and axis position
1243    fn compute_axis_index(
1244        &self,
1245        output_idx: usize,
1246        axis: usize,
1247        axis_pos: usize,
1248        output_shape: &Shape,
1249    ) -> usize {
1250        let dims = self.shape().as_slice();
1251        let output_dims = output_shape.as_slice();
1252
1253        let mut input_idx = 0;
1254        let mut remaining = output_idx;
1255        let mut stride = 1;
1256
1257        for (dim_idx, &dim_size) in dims.iter().enumerate().rev() {
1258            if dim_idx == axis {
1259                input_idx += axis_pos * stride;
1260            } else {
1261                let out_dim_idx = if dim_idx > axis {
1262                    dim_idx - 1
1263                } else {
1264                    dim_idx
1265                };
1266                let coord = remaining % output_dims[out_dim_idx];
1267                input_idx += coord * stride;
1268                remaining /= output_dims[out_dim_idx];
1269            }
1270            stride *= dim_size;
1271        }
1272
1273        input_idx
1274    }
1275
1276    /// Product of array elements, ignoring NaN values.
1277    ///
1278    /// # Examples
1279    ///
1280    /// ```
1281    /// # use jax_rs::{Array, Shape};
1282    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 4.0], Shape::new(vec![4]));
1283    /// let prod = a.nanprod();
1284    /// assert_eq!(prod, 12.0); // 1 * 3 * 4
1285    /// ```
1286    pub fn nanprod(&self) -> f32 {
1287        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1288        let data = self.to_vec();
1289        data.iter()
1290            .filter(|x| !x.is_nan())
1291            .fold(1.0, |acc, &x| acc * x)
1292    }
1293
1294    /// Cumulative sum of array elements, ignoring NaN values.
1295    ///
1296    /// # Examples
1297    ///
1298    /// ```
1299    /// # use jax_rs::{Array, Shape};
1300    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 4.0], Shape::new(vec![4]));
1301    /// let cumsum = a.nancumsum();
1302    /// // Treats NaN as 0
1303    /// assert_eq!(cumsum.to_vec(), vec![1.0, 1.0, 4.0, 8.0]);
1304    /// ```
1305    pub fn nancumsum(&self) -> Array {
1306        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1307        let data = self.to_vec();
1308        let mut result = Vec::with_capacity(data.len());
1309        let mut sum = 0.0;
1310
1311        for &val in data.iter() {
1312            if !val.is_nan() {
1313                sum += val;
1314            }
1315            result.push(sum);
1316        }
1317
1318        Array::from_vec(result, self.shape().clone())
1319    }
1320
1321    /// Cumulative product of array elements, ignoring NaN values.
1322    ///
1323    /// # Examples
1324    ///
1325    /// ```
1326    /// # use jax_rs::{Array, Shape};
1327    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 4.0], Shape::new(vec![4]));
1328    /// let cumprod = a.nancumprod();
1329    /// // Treats NaN as 1
1330    /// assert_eq!(cumprod.to_vec(), vec![1.0, 1.0, 3.0, 12.0]);
1331    /// ```
1332    pub fn nancumprod(&self) -> Array {
1333        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1334        let data = self.to_vec();
1335        let mut result = Vec::with_capacity(data.len());
1336        let mut prod = 1.0;
1337
1338        for &val in data.iter() {
1339            if !val.is_nan() {
1340                prod *= val;
1341            }
1342            result.push(prod);
1343        }
1344
1345        Array::from_vec(result, self.shape().clone())
1346    }
1347
1348    /// Compute the arithmetic-geometric mean.
1349    ///
1350    /// # Examples
1351    ///
1352    /// ```
1353    /// # use jax_rs::{Array, Shape};
1354    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1355    /// let agm = a.agmean();
1356    /// // AGM is between arithmetic and geometric means
1357    /// ```
1358    pub fn agmean(&self) -> f32 {
1359        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1360        let data = self.to_vec();
1361        if data.is_empty() {
1362            return f32::NAN;
1363        }
1364
1365        let arith = data.iter().sum::<f32>() / data.len() as f32;
1366        let geom = data.iter().fold(1.0, |acc, &x| acc * x).powf(1.0 / data.len() as f32);
1367
1368        // AGM iteration
1369        let mut a = arith;
1370        let mut g = geom;
1371
1372        for _ in 0..20 {
1373            let new_a = (a + g) / 2.0;
1374            let new_g = (a * g).sqrt();
1375            if (new_a - a).abs() < 1e-10 {
1376                break;
1377            }
1378            a = new_a;
1379            g = new_g;
1380        }
1381
1382        a
1383    }
1384
1385    /// Root mean square of array elements.
1386    ///
1387    /// # Examples
1388    ///
1389    /// ```
1390    /// # use jax_rs::{Array, Shape};
1391    /// let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1392    /// let rms = a.rms();
1393    /// // sqrt((1 + 4 + 9) / 3) = sqrt(14/3)
1394    /// assert!((rms - (14.0_f32 / 3.0).sqrt()).abs() < 1e-6);
1395    /// ```
1396    pub fn rms(&self) -> f32 {
1397        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1398        let data = self.to_vec();
1399        if data.is_empty() {
1400            return 0.0;
1401        }
1402
1403        let sum_sq: f32 = data.iter().map(|&x| x * x).sum();
1404        (sum_sq / data.len() as f32).sqrt()
1405    }
1406
1407    /// Harmonic mean of array elements.
1408    ///
1409    /// # Examples
1410    ///
1411    /// ```
1412    /// # use jax_rs::{Array, Shape};
1413    /// let a = Array::from_vec(vec![1.0, 2.0, 4.0], Shape::new(vec![3]));
1414    /// let hm = a.harmonic_mean();
1415    /// // 3 / (1/1 + 1/2 + 1/4) = 3 / 1.75
1416    /// ```
1417    pub fn harmonic_mean(&self) -> f32 {
1418        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1419        let data = self.to_vec();
1420        if data.is_empty() {
1421            return f32::NAN;
1422        }
1423
1424        let sum_inv: f32 = data.iter().map(|&x| 1.0 / x).sum();
1425        data.len() as f32 / sum_inv
1426    }
1427
1428    /// Geometric mean of array elements.
1429    ///
1430    /// # Examples
1431    ///
1432    /// ```
1433    /// # use jax_rs::{Array, Shape};
1434    /// let a = Array::from_vec(vec![1.0, 2.0, 4.0, 8.0], Shape::new(vec![4]));
1435    /// let gm = a.geometric_mean();
1436    /// // (1 * 2 * 4 * 8)^(1/4) = 64^(1/4) = 2.83...
1437    /// ```
1438    pub fn geometric_mean(&self) -> f32 {
1439        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1440        let data = self.to_vec();
1441        if data.is_empty() {
1442            return f32::NAN;
1443        }
1444
1445        let product: f32 = data.iter().fold(1.0, |acc, &x| acc * x);
1446        product.powf(1.0 / data.len() as f32)
1447    }
1448
1449    /// Compute percentile of array elements, ignoring NaN values.
1450    ///
1451    /// # Examples
1452    ///
1453    /// ```
1454    /// # use jax_rs::{Array, Shape};
1455    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 5.0, 7.0], Shape::new(vec![5]));
1456    /// let p = a.nanpercentile(50.0);
1457    /// assert!((p - 4.0).abs() < 1e-6);  // median of [1, 3, 5, 7] = 4
1458    /// ```
1459    pub fn nanpercentile(&self, q: f32) -> f32 {
1460        assert!((0.0..=100.0).contains(&q), "Percentile must be in [0, 100]");
1461        assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1462
1463        let mut data: Vec<f32> = self.to_vec().into_iter().filter(|x| !x.is_nan()).collect();
1464        if data.is_empty() {
1465            return f32::NAN;
1466        }
1467
1468        data.sort_by(|a, b| a.partial_cmp(b).unwrap());
1469
1470        let n = data.len();
1471        let idx = q / 100.0 * (n - 1) as f32;
1472        let lo = idx.floor() as usize;
1473        let hi = idx.ceil() as usize;
1474        let frac = idx - lo as f32;
1475
1476        if lo == hi {
1477            data[lo]
1478        } else {
1479            data[lo] * (1.0 - frac) + data[hi] * frac
1480        }
1481    }
1482
1483    /// Compute quantile of array elements, ignoring NaN values.
1484    ///
1485    /// # Examples
1486    ///
1487    /// ```
1488    /// # use jax_rs::{Array, Shape};
1489    /// let a = Array::from_vec(vec![1.0, f32::NAN, 3.0, 5.0, 7.0], Shape::new(vec![5]));
1490    /// let q = a.nanquantile(0.5);
1491    /// assert!((q - 4.0).abs() < 1e-6);  // median of [1, 3, 5, 7] = 4
1492    /// ```
1493    pub fn nanquantile(&self, q: f32) -> f32 {
1494        assert!((0.0..=1.0).contains(&q), "Quantile must be in [0, 1]");
1495        self.nanpercentile(q * 100.0)
1496    }
1497}
1498
1499#[cfg(test)]
1500mod tests {
1501    use super::*;
1502    use approx::assert_abs_diff_eq;
1503
1504    #[test]
1505    fn test_sum_all() {
1506        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1507        assert_eq!(a.sum_all(), 10.0);
1508    }
1509
1510    #[test]
1511    fn test_sum_axis() {
1512        // 2x3 array: [[1, 2, 3], [4, 5, 6]]
1513        let a = Array::from_vec(
1514            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1515            Shape::new(vec![2, 3]),
1516        );
1517
1518        // Sum along axis 0 (collapse rows): [5, 7, 9]
1519        let sum_axis0 = a.sum(0);
1520        assert_eq!(sum_axis0.shape().as_slice(), &[3]);
1521        assert_eq!(sum_axis0.to_vec(), vec![5.0, 7.0, 9.0]);
1522
1523        // Sum along axis 1 (collapse columns): [6, 15]
1524        let sum_axis1 = a.sum(1);
1525        assert_eq!(sum_axis1.shape().as_slice(), &[2]);
1526        assert_eq!(sum_axis1.to_vec(), vec![6.0, 15.0]);
1527    }
1528
1529    #[test]
1530    fn test_mean_all() {
1531        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1532        assert_abs_diff_eq!(a.mean_all(), 2.5, epsilon = 1e-6);
1533    }
1534
1535    #[test]
1536    fn test_mean_axis() {
1537        let a = Array::from_vec(
1538            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1539            Shape::new(vec![2, 3]),
1540        );
1541
1542        let mean_axis0 = a.mean(0);
1543        assert_eq!(mean_axis0.to_vec(), vec![2.5, 3.5, 4.5]);
1544
1545        let mean_axis1 = a.mean(1);
1546        assert_eq!(mean_axis1.to_vec(), vec![2.0, 5.0]);
1547    }
1548
1549    #[test]
1550    fn test_max_all() {
1551        let a = Array::from_vec(vec![1.0, 5.0, 3.0, 2.0], Shape::new(vec![4]));
1552        assert_eq!(a.max_all(), 5.0);
1553    }
1554
1555    #[test]
1556    fn test_max_axis() {
1557        let a = Array::from_vec(
1558            vec![1.0, 5.0, 3.0, 2.0, 4.0, 6.0],
1559            Shape::new(vec![2, 3]),
1560        );
1561
1562        let max_axis0 = a.max(0);
1563        assert_eq!(max_axis0.to_vec(), vec![2.0, 5.0, 6.0]);
1564
1565        let max_axis1 = a.max(1);
1566        assert_eq!(max_axis1.to_vec(), vec![5.0, 6.0]);
1567    }
1568
1569    #[test]
1570    fn test_min_all() {
1571        let a = Array::from_vec(vec![3.0, 1.0, 5.0, 2.0], Shape::new(vec![4]));
1572        assert_eq!(a.min_all(), 1.0);
1573    }
1574
1575    #[test]
1576    fn test_min_axis() {
1577        let a = Array::from_vec(
1578            vec![3.0, 5.0, 2.0, 1.0, 4.0, 6.0],
1579            Shape::new(vec![2, 3]),
1580        );
1581
1582        let min_axis0 = a.min(0);
1583        assert_eq!(min_axis0.to_vec(), vec![1.0, 4.0, 2.0]);
1584
1585        let min_axis1 = a.min(1);
1586        assert_eq!(min_axis1.to_vec(), vec![2.0, 1.0]);
1587    }
1588
1589    #[test]
1590    fn test_reduce_3d() {
1591        // 2x2x2 array
1592        let a = Array::from_vec(
1593            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
1594            Shape::new(vec![2, 2, 2]),
1595        );
1596
1597        // Sum along middle axis
1598        let sum_axis1 = a.sum(1);
1599        assert_eq!(sum_axis1.shape().as_slice(), &[2, 2]);
1600        assert_eq!(sum_axis1.to_vec(), vec![4.0, 6.0, 12.0, 14.0]);
1601    }
1602
1603    #[test]
1604    fn test_cumsum() {
1605        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1606        let cumsum = a.cumsum();
1607        assert_eq!(cumsum.to_vec(), vec![1.0, 3.0, 6.0, 10.0]);
1608    }
1609
1610    #[test]
1611    fn test_cumprod() {
1612        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1613        let cumprod = a.cumprod();
1614        assert_eq!(cumprod.to_vec(), vec![1.0, 2.0, 6.0, 24.0]);
1615    }
1616
1617    #[test]
1618    fn test_cummax() {
1619        let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
1620        let cummax = a.cummax();
1621        assert_eq!(cummax.to_vec(), vec![3.0, 3.0, 4.0, 4.0]);
1622    }
1623
1624    #[test]
1625    fn test_cummin() {
1626        let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
1627        let cummin = a.cummin();
1628        assert_eq!(cummin.to_vec(), vec![3.0, 1.0, 1.0, 1.0]);
1629    }
1630
1631    #[test]
1632    fn test_diff() {
1633        let a = Array::from_vec(vec![1.0, 3.0, 6.0, 10.0], Shape::new(vec![4]));
1634        let diff = a.diff();
1635        assert_eq!(diff.to_vec(), vec![2.0, 3.0, 4.0]);
1636    }
1637
1638    #[test]
1639    fn test_diff_n() {
1640        // Linear sequence - second derivative should be 0
1641        let a = Array::from_vec(
1642            vec![1.0, 2.0, 3.0, 4.0, 5.0],
1643            Shape::new(vec![5]),
1644        );
1645        let diff2 = a.diff_n(2);
1646        assert_eq!(diff2.to_vec(), vec![0.0, 0.0, 0.0]);
1647
1648        // diff_n(0) should return the original array
1649        let diff0 = a.diff_n(0);
1650        assert_eq!(diff0.to_vec(), a.to_vec());
1651    }
1652
1653    #[test]
1654    fn test_nansum() {
1655        let a = Array::from_vec(
1656            vec![1.0, f32::NAN, 3.0, 4.0],
1657            Shape::new(vec![4]),
1658        );
1659        let sum = a.nansum();
1660        assert_eq!(sum, 8.0);
1661    }
1662
1663    #[test]
1664    fn test_nanmean() {
1665        let a = Array::from_vec(
1666            vec![1.0, f32::NAN, 3.0, 4.0],
1667            Shape::new(vec![4]),
1668        );
1669        let mean = a.nanmean();
1670        assert_abs_diff_eq!(mean, 8.0 / 3.0, epsilon = 1e-6);
1671    }
1672
1673    #[test]
1674    fn test_nanmax() {
1675        let a = Array::from_vec(
1676            vec![1.0, f32::NAN, 4.0, 2.0],
1677            Shape::new(vec![4]),
1678        );
1679        let max = a.nanmax();
1680        assert_eq!(max, 4.0);
1681    }
1682
1683    #[test]
1684    fn test_nanmin() {
1685        let a = Array::from_vec(
1686            vec![1.0, f32::NAN, 4.0, 2.0],
1687            Shape::new(vec![4]),
1688        );
1689        let min = a.nanmin();
1690        assert_eq!(min, 1.0);
1691    }
1692
1693    #[test]
1694    fn test_nanstd() {
1695        let a = Array::from_vec(
1696            vec![1.0, f32::NAN, 3.0, 5.0],
1697            Shape::new(vec![4]),
1698        );
1699        let std = a.nanstd();
1700        assert_abs_diff_eq!(std, 2.0, epsilon = 1e-5);
1701    }
1702
1703    #[test]
1704    fn test_nanvar() {
1705        let a = Array::from_vec(
1706            vec![1.0, f32::NAN, 3.0, 5.0],
1707            Shape::new(vec![4]),
1708        );
1709        let var = a.nanvar();
1710        assert_abs_diff_eq!(var, 4.0, epsilon = 1e-5);
1711    }
1712
1713    #[test]
1714    fn test_nanmedian() {
1715        let a = Array::from_vec(
1716            vec![1.0, f32::NAN, 3.0, 5.0, 2.0],
1717            Shape::new(vec![5]),
1718        );
1719        let median = a.nanmedian();
1720        assert_eq!(median, 2.5);
1721    }
1722
1723    #[test]
1724    fn test_ptp() {
1725        let a = Array::from_vec(vec![1.0, 5.0, 2.0, 8.0], Shape::new(vec![4]));
1726        assert_eq!(a.ptp(), 7.0);
1727    }
1728
1729    #[test]
1730    fn test_ptp_axis() {
1731        let a = Array::from_vec(vec![1.0, 5.0, 2.0, 8.0], Shape::new(vec![2, 2]));
1732        let ptp = a.ptp_axis(0);
1733        assert_eq!(ptp.to_vec(), vec![1.0, 3.0]);
1734    }
1735
1736    #[test]
1737    fn test_quantile() {
1738        let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1739        assert_abs_diff_eq!(a.quantile(0.0), 1.0, epsilon = 1e-6);
1740        assert_abs_diff_eq!(a.quantile(0.5), 3.0, epsilon = 1e-6);
1741        assert_abs_diff_eq!(a.quantile(1.0), 5.0, epsilon = 1e-6);
1742        assert_abs_diff_eq!(a.quantile(0.25), 2.0, epsilon = 1e-6);
1743    }
1744
1745    #[test]
1746    fn test_quantile_axis() {
1747        let a = Array::from_vec(
1748            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1749            Shape::new(vec![2, 3]),
1750        );
1751        let q = a.quantile_axis(0.5, 0);
1752        assert_eq!(q.shape().as_slice(), &[3]);
1753        assert_abs_diff_eq!(q.to_vec()[0], 2.5, epsilon = 1e-6);
1754        assert_abs_diff_eq!(q.to_vec()[1], 3.5, epsilon = 1e-6);
1755        assert_abs_diff_eq!(q.to_vec()[2], 4.5, epsilon = 1e-6);
1756    }
1757
1758    #[test]
1759    fn test_trapz() {
1760        let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1761        assert_eq!(a.trapz(), 4.0);
1762
1763        let b = Array::from_vec(vec![0.0, 1.0, 0.0], Shape::new(vec![3]));
1764        assert_eq!(b.trapz(), 1.0);
1765    }
1766
1767    #[test]
1768    fn test_trapz_axis() {
1769        let a = Array::from_vec(
1770            vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1771            Shape::new(vec![2, 3]),
1772        );
1773        let integral = a.trapz_axis(1);
1774        assert_eq!(integral.shape().as_slice(), &[2]);
1775        assert_eq!(integral.to_vec(), vec![4.0, 10.0]);
1776    }
1777
1778    #[test]
1779    fn test_gradient() {
1780        let a = Array::from_vec(vec![1.0, 2.0, 4.0, 7.0], Shape::new(vec![4]));
1781        let grad = a.gradient();
1782        // [2-1, (4-1)/2, (7-2)/2, 7-4] = [1.0, 1.5, 2.5, 3.0]
1783        assert_eq!(grad.to_vec(), vec![1.0, 1.5, 2.5, 3.0]);
1784    }
1785
1786    #[test]
1787    fn test_gradient_constant() {
1788        let a = Array::from_vec(vec![5.0, 5.0, 5.0, 5.0], Shape::new(vec![4]));
1789        let grad = a.gradient();
1790        // All gradients should be 0 for constant function
1791        assert_eq!(grad.to_vec(), vec![0.0, 0.0, 0.0, 0.0]);
1792    }
1793
1794    #[test]
1795    fn test_gradient_linear() {
1796        let a = Array::from_vec(vec![0.0, 1.0, 2.0, 3.0], Shape::new(vec![4]));
1797        let grad = a.gradient();
1798        // All gradients should be 1.0 for linear function
1799        assert_eq!(grad.to_vec(), vec![1.0, 1.0, 1.0, 1.0]);
1800    }
1801
1802    #[test]
1803    fn test_ediff1d() {
1804        let a = Array::from_vec(vec![1.0, 3.0, 6.0, 10.0], Shape::new(vec![4]));
1805        let edges = a.ediff1d();
1806        assert_eq!(edges.to_vec(), vec![2.0, 3.0, 4.0]);
1807    }
1808
1809    #[test]
1810    fn test_ediff1d_single() {
1811        let a = Array::from_vec(vec![5.0], Shape::new(vec![1]));
1812        let edges = a.ediff1d();
1813        assert_eq!(edges.shape().as_slice(), &[0]);
1814        assert_eq!(edges.size(), 0);
1815    }
1816}