1use crate::trace::{is_tracing, trace_reduce, Primitive};
4use crate::{buffer::Buffer, Array, DType, Device, Shape};
5
6fn reduce_all<F>(input: &Array, init: f32, f: F) -> f32
8where
9 F: Fn(f32, f32) -> f32,
10{
11 assert_eq!(input.dtype(), DType::Float32, "Only Float32 supported");
12
13 let data = input.to_vec();
15 data.iter().fold(init, |acc, &x| f(acc, x))
16}
17
18fn reduce_all_gpu_aware(input: &Array, op: &str) -> f32 {
20 assert_eq!(input.dtype(), DType::Float32, "Only Float32 supported");
21
22 match input.device() {
23 Device::WebGpu => {
24 let output_buffer = Buffer::zeros(1, DType::Float32, Device::WebGpu);
26
27 crate::backend::ops::gpu_reduce_all(
28 input.buffer(),
29 &output_buffer,
30 op,
31 );
32
33 output_buffer.to_f32_vec()[0]
35 }
36 Device::Cpu | Device::Wasm => {
37 let (init, f): (f32, Box<dyn Fn(f32, f32) -> f32>) = match op {
39 "sum" => (0.0, Box::new(|acc, x| acc + x)),
40 "max" => (f32::NEG_INFINITY, Box::new(|acc, x| acc.max(x))),
41 "min" => (f32::INFINITY, Box::new(|acc, x| acc.min(x))),
42 "prod" => (1.0, Box::new(|acc, x| acc * x)),
43 _ => panic!("Unknown reduction op: {}", op),
44 };
45
46 let data = input.to_vec();
47 data.iter().fold(init, |acc, &x| f(acc, x))
48 }
49 }
50}
51
52fn reduce_axis<F>(
54 input: &Array,
55 axis: usize,
56 op: Primitive,
57 init: f32,
58 f: F,
59) -> Array
60where
61 F: Fn(f32, f32) -> f32,
62{
63 assert_eq!(input.dtype(), DType::Float32, "Only Float32 supported");
64 assert_eq!(input.device(), Device::Cpu, "Only CPU supported for now");
65 assert!(axis < input.ndim(), "Axis out of bounds");
66
67 let shape = input.shape();
68 let dims = shape.as_slice();
69
70 let mut result_dims: Vec<usize> = dims.to_vec();
72 result_dims.remove(axis);
73 let result_shape = if result_dims.is_empty() {
74 Shape::scalar()
75 } else {
76 Shape::new(result_dims.clone())
77 };
78
79 let input_data = input.to_vec();
80 let result_size = result_shape.size();
81 let mut result_data = vec![init; result_size];
82
83 let mut strides = vec![1; dims.len()];
85 for i in (0..dims.len() - 1).rev() {
86 strides[i] = strides[i + 1] * dims[i + 1];
87 }
88
89 for (result_idx, item) in result_data.iter_mut().enumerate() {
91 let mut result_multi = vec![0; result_dims.len()];
93 let mut idx = result_idx;
94 for i in (0..result_dims.len()).rev() {
95 result_multi[i] = idx % result_shape.as_slice()[i];
96 idx /= result_shape.as_slice()[i];
97 }
98
99 let mut acc = init;
101 for axis_idx in 0..dims[axis] {
102 let mut input_multi = Vec::with_capacity(dims.len());
103 let mut result_i = 0;
104 for i in 0..dims.len() {
105 if i == axis {
106 input_multi.push(axis_idx);
107 } else {
108 input_multi.push(result_multi[result_i]);
109 result_i += 1;
110 }
111 }
112
113 let flat_idx: usize = input_multi
115 .iter()
116 .zip(strides.iter())
117 .map(|(idx, stride)| idx * stride)
118 .sum();
119
120 acc = f(acc, input_data[flat_idx]);
121 }
122
123 *item = acc;
124 }
125
126 let buffer = Buffer::from_f32(result_data, Device::Cpu);
127 let result = Array::from_buffer(buffer, result_shape.clone());
128
129 if is_tracing() {
131 trace_reduce(result.id(), op, input, result_shape);
132 }
133
134 result
135}
136
137impl Array {
138 pub fn sum_all(&self) -> f32 {
149 reduce_all_gpu_aware(self, "sum")
150 }
151
152 pub fn sum_all_array(&self) -> Array {
156 let val = self.sum_all();
157 let result = Array::from_vec(vec![val], crate::Shape::scalar());
158
159 if is_tracing() {
161 trace_reduce(
162 result.id(),
163 Primitive::SumAll,
164 self,
165 crate::Shape::scalar(),
166 );
167 }
168
169 result
170 }
171
172 pub fn sum(&self, axis: usize) -> Array {
185 reduce_axis(self, axis, Primitive::Sum { axis }, 0.0, |acc, x| acc + x)
186 }
187
188 pub fn mean_all(&self) -> f32 {
190 self.sum_all() / (self.size() as f32)
191 }
192
193 pub fn mean_all_array(&self) -> Array {
195 let val = self.mean_all();
196 let result = Array::from_vec(vec![val], crate::Shape::scalar());
197
198 if is_tracing() {
200 trace_reduce(
201 result.id(),
202 Primitive::MeanAll,
203 self,
204 crate::Shape::scalar(),
205 );
206 }
207
208 result
209 }
210
211 pub fn mean(&self, axis: usize) -> Array {
213 reduce_axis(self, axis, Primitive::Mean { axis }, 0.0, |acc, x| {
214 acc + x / (self.shape().as_slice()[axis] as f32)
215 })
216 }
217
218 pub fn max_all(&self) -> f32 {
220 reduce_all_gpu_aware(self, "max")
221 }
222
223 pub fn max(&self, axis: usize) -> Array {
225 reduce_axis(
226 self,
227 axis,
228 Primitive::MaxAxis { axis },
229 f32::NEG_INFINITY,
230 |acc, x| acc.max(x),
231 )
232 }
233
234 pub fn min_all(&self) -> f32 {
236 reduce_all_gpu_aware(self, "min")
237 }
238
239 pub fn min(&self, axis: usize) -> Array {
241 reduce_axis(
242 self,
243 axis,
244 Primitive::MinAxis { axis },
245 f32::INFINITY,
246 |acc, x| acc.min(x),
247 )
248 }
249
250 pub fn prod_all(&self) -> f32 {
261 reduce_all(self, 1.0, |acc, x| acc * x)
262 }
263
264 pub fn prod(&self, axis: usize) -> Array {
266 reduce_axis(self, axis, Primitive::ProdAxis { axis }, 1.0, |acc, x| {
267 acc * x
268 })
269 }
270
271 pub fn argmin(&self) -> usize {
282 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
283 let data = self.to_vec();
284 data.iter()
285 .enumerate()
286 .min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
287 .map(|(idx, _)| idx)
288 .unwrap()
289 }
290
291 pub fn argmax(&self) -> usize {
302 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
303 let data = self.to_vec();
304 data.iter()
305 .enumerate()
306 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
307 .map(|(idx, _)| idx)
308 .unwrap()
309 }
310
311 pub fn var(&self) -> f32 {
322 let mean = self.mean_all();
323 let data = self.to_vec();
324 let sum_sq_diff: f32 = data.iter().map(|&x| (x - mean).powi(2)).sum();
325 sum_sq_diff / data.len() as f32
326 }
327
328 pub fn std(&self) -> f32 {
339 self.var().sqrt()
340 }
341
342 pub fn var_axis(&self, axis: usize) -> Array {
353 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
354 assert!(axis < self.ndim(), "Axis out of bounds");
355
356 let mean = self.mean(axis);
357 let mean_data = mean.to_vec();
358
359 let shape = self.shape().as_slice();
360 let data = self.to_vec();
361
362 let mut result_dims: Vec<usize> = shape.to_vec();
364 result_dims.remove(axis);
365 let result_shape = if result_dims.is_empty() {
366 Shape::scalar()
367 } else {
368 Shape::new(result_dims.clone())
369 };
370
371 let result_size = result_shape.size();
372 let mut result_data = vec![0.0; result_size];
373
374 let mut strides = vec![1; shape.len()];
376 for i in (0..shape.len() - 1).rev() {
377 strides[i] = strides[i + 1] * shape[i + 1];
378 }
379
380 for (result_idx, item) in result_data.iter_mut().enumerate() {
382 let mut result_multi = vec![0; result_dims.len()];
384 let mut idx = result_idx;
385 for i in (0..result_dims.len()).rev() {
386 result_multi[i] = idx % result_shape.as_slice()[i];
387 idx /= result_shape.as_slice()[i];
388 }
389
390 let mean_val = mean_data[result_idx];
391 let mut sum_sq = 0.0;
392
393 for axis_idx in 0..shape[axis] {
395 let mut input_multi = Vec::with_capacity(shape.len());
396 let mut result_i = 0;
397 for i in 0..shape.len() {
398 if i == axis {
399 input_multi.push(axis_idx);
400 } else {
401 input_multi.push(result_multi[result_i]);
402 result_i += 1;
403 }
404 }
405
406 let flat_idx: usize = input_multi
408 .iter()
409 .zip(strides.iter())
410 .map(|(idx, stride)| idx * stride)
411 .sum();
412
413 let diff = data[flat_idx] - mean_val;
414 sum_sq += diff * diff;
415 }
416
417 *item = sum_sq / shape[axis] as f32;
418 }
419
420 Array::from_vec(result_data, result_shape)
421 }
422
423 pub fn std_axis(&self, axis: usize) -> Array {
434 let var = self.var_axis(axis);
435 let data = var.to_vec();
436 let result: Vec<f32> = data.iter().map(|&x| x.sqrt()).collect();
437 Array::from_vec(result, var.shape().clone())
438 }
439
440 pub fn median(&self) -> f32 {
451 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
452 let mut data = self.to_vec();
453 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
454
455 let len = data.len();
456 if len.is_multiple_of(2) {
457 (data[len / 2 - 1] + data[len / 2]) / 2.0
458 } else {
459 data[len / 2]
460 }
461 }
462
463 pub fn percentile(&self, q: f32) -> f32 {
478 assert!((0.0..=100.0).contains(&q), "Percentile must be between 0 and 100");
479 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
480
481 let mut data = self.to_vec();
482 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
483
484 let len = data.len();
485 if len == 1 {
486 return data[0];
487 }
488
489 let index = (q / 100.0) * (len - 1) as f32;
490 let lower = index.floor() as usize;
491 let upper = index.ceil() as usize;
492
493 if lower == upper {
494 data[lower]
495 } else {
496 let weight = index - lower as f32;
497 data[lower] * (1.0 - weight) + data[upper] * weight
498 }
499 }
500
501 pub fn cumsum(&self) -> Array {
514 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
515 let data = self.to_vec();
516 let mut result = Vec::with_capacity(data.len());
517 let mut sum = 0.0;
518
519 for &val in data.iter() {
520 sum += val;
521 result.push(sum);
522 }
523
524 Array::from_vec(result, self.shape().clone())
525 }
526
527 pub fn cumprod(&self) -> Array {
540 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
541 let data = self.to_vec();
542 let mut result = Vec::with_capacity(data.len());
543 let mut prod = 1.0;
544
545 for &val in data.iter() {
546 prod *= val;
547 result.push(prod);
548 }
549
550 Array::from_vec(result, self.shape().clone())
551 }
552
553 pub fn cummax(&self) -> Array {
566 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
567 let data = self.to_vec();
568 let mut result = Vec::with_capacity(data.len());
569 let mut max = f32::NEG_INFINITY;
570
571 for &val in data.iter() {
572 max = max.max(val);
573 result.push(max);
574 }
575
576 Array::from_vec(result, self.shape().clone())
577 }
578
579 pub fn cummin(&self) -> Array {
592 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
593 let data = self.to_vec();
594 let mut result = Vec::with_capacity(data.len());
595 let mut min = f32::INFINITY;
596
597 for &val in data.iter() {
598 min = min.min(val);
599 result.push(min);
600 }
601
602 Array::from_vec(result, self.shape().clone())
603 }
604
605 pub fn diff(&self) -> Array {
618 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
619 let data = self.to_vec();
620 assert!(!data.is_empty(), "Array must have at least 1 element");
621
622 if data.len() == 1 {
623 return Array::from_vec(vec![], Shape::new(vec![0]));
624 }
625
626 let mut result = Vec::with_capacity(data.len() - 1);
627 for i in 1..data.len() {
628 result.push(data[i] - data[i - 1]);
629 }
630
631 let len = result.len();
632 Array::from_vec(result, Shape::new(vec![len]))
633 }
634
635 pub fn diff_n(&self, n: usize) -> Array {
648 if n == 0 {
649 return self.clone();
650 }
651
652 let mut result = self.diff();
653 for _ in 1..n {
654 result = result.diff();
655 }
656 result
657 }
658
659 pub fn nansum(&self) -> f32 {
670 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
671 let data = self.to_vec();
672 data.iter().filter(|x| !x.is_nan()).sum()
673 }
674
675 pub fn nanmean(&self) -> f32 {
686 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
687 let data = self.to_vec();
688 let valid: Vec<f32> = data.iter().copied().filter(|x| !x.is_nan()).collect();
689 if valid.is_empty() {
690 return f32::NAN;
691 }
692 valid.iter().sum::<f32>() / valid.len() as f32
693 }
694
695 pub fn nanmax(&self) -> f32 {
706 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
707 let data = self.to_vec();
708 data.iter()
709 .copied()
710 .filter(|x| !x.is_nan())
711 .fold(f32::NEG_INFINITY, f32::max)
712 }
713
714 pub fn nanmin(&self) -> f32 {
725 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
726 let data = self.to_vec();
727 data.iter()
728 .copied()
729 .filter(|x| !x.is_nan())
730 .fold(f32::INFINITY, f32::min)
731 }
732
733 pub fn nanstd(&self) -> f32 {
744 self.nanvar().sqrt()
745 }
746
747 pub fn nanvar(&self) -> f32 {
758 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
759 let data = self.to_vec();
760 let valid: Vec<f32> = data.iter().copied().filter(|x| !x.is_nan()).collect();
761
762 if valid.is_empty() || valid.len() == 1 {
763 return f32::NAN;
764 }
765
766 let mean = valid.iter().sum::<f32>() / valid.len() as f32;
767 let variance = valid
768 .iter()
769 .map(|x| {
770 let diff = x - mean;
771 diff * diff
772 })
773 .sum::<f32>()
774 / (valid.len() - 1) as f32; variance
777 }
778
779 pub fn nanmedian(&self) -> f32 {
790 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
791 let data = self.to_vec();
792 let mut valid: Vec<f32> = data.iter().copied().filter(|x| !x.is_nan()).collect();
793
794 if valid.is_empty() {
795 return f32::NAN;
796 }
797
798 valid.sort_by(|a, b| a.partial_cmp(b).unwrap());
799 let len = valid.len();
800
801 if len.is_multiple_of(2) {
802 (valid[len / 2 - 1] + valid[len / 2]) / 2.0
803 } else {
804 valid[len / 2]
805 }
806 }
807
808 pub fn ptp(&self) -> f32 {
818 let max = self.max_all();
819 let min = self.min_all();
820 max - min
821 }
822
823 pub fn ptp_axis(&self, axis: usize) -> Array {
834 let max = self.max(axis);
835 let min = self.min(axis);
836 max.sub(&min)
837 }
838
839 pub fn quantile(&self, q: f32) -> f32 {
854 assert!(
855 (0.0..=1.0).contains(&q),
856 "Quantile must be between 0 and 1"
857 );
858
859 let mut data = self.to_vec();
860 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
861
862 let n = data.len();
863 if n == 0 {
864 return f32::NAN;
865 }
866
867 let index = q * (n - 1) as f32;
868 let lower = index.floor() as usize;
869 let upper = index.ceil() as usize;
870
871 if lower == upper {
872 data[lower]
873 } else {
874 let weight = index - lower as f32;
875 data[lower] * (1.0 - weight) + data[upper] * weight
876 }
877 }
878
879 pub fn quantile_axis(&self, q: f32, axis: usize) -> Array {
890 assert!(
891 (0.0..=1.0).contains(&q),
892 "Quantile must be between 0 and 1"
893 );
894 assert!(axis < self.ndim(), "Axis out of bounds");
895
896 let shape = self.shape();
897 let dims = shape.as_slice();
898 let axis_size = dims[axis];
899
900 let mut output_dims = dims.to_vec();
902 output_dims.remove(axis);
903 let output_shape = Shape::new(output_dims);
904 let output_size = output_shape.size();
905
906 let data = self.to_vec();
907 let mut result = Vec::with_capacity(output_size);
908
909 for output_idx in 0..output_size {
911 let mut values = Vec::with_capacity(axis_size);
913
914 for axis_idx in 0..axis_size {
915 let mut input_idx = 0;
917 let mut remaining = output_idx;
918 let mut stride = 1;
919
920 for (dim_idx, &dim_size) in dims.iter().enumerate().rev() {
921 if dim_idx == axis {
922 input_idx += axis_idx * stride;
923 stride *= dim_size;
924 } else {
925 let out_dim_size = if dim_idx < axis {
926 dims[dim_idx]
927 } else {
928 dims[dim_idx]
929 };
930 let coord = remaining % out_dim_size;
931 input_idx += coord * stride;
932 remaining /= out_dim_size;
933 stride *= dim_size;
934 }
935 }
936
937 values.push(data[input_idx]);
938 }
939
940 values.sort_by(|a, b| a.partial_cmp(b).unwrap());
942 let n = values.len();
943 let index = q * (n - 1) as f32;
944 let lower = index.floor() as usize;
945 let upper = index.ceil() as usize;
946
947 let quantile = if lower == upper {
948 values[lower]
949 } else {
950 let weight = index - lower as f32;
951 values[lower] * (1.0 - weight) + values[upper] * weight
952 };
953
954 result.push(quantile);
955 }
956
957 Array::from_vec(result, output_shape)
958 }
959
960 pub fn trapz(&self) -> f32 {
971 let data = self.to_vec();
972 if data.len() < 2 {
973 return 0.0;
974 }
975
976 let mut sum = 0.0;
977 for i in 0..data.len() - 1 {
978 sum += (data[i] + data[i + 1]) / 2.0;
979 }
980 sum
981 }
982
983 pub fn trapz_axis(&self, axis: usize) -> Array {
994 assert!(axis < self.ndim(), "Axis out of bounds");
995
996 let shape = self.shape();
997 let dims = shape.as_slice();
998 let axis_size = dims[axis];
999
1000 if axis_size < 2 {
1001 let mut output_dims = dims.to_vec();
1003 output_dims.remove(axis);
1004 let output_shape = Shape::new(output_dims);
1005 return Array::zeros(output_shape, self.dtype());
1006 }
1007
1008 let mut output_dims = dims.to_vec();
1010 output_dims.remove(axis);
1011 let output_shape = Shape::new(output_dims);
1012 let output_size = output_shape.size();
1013
1014 let data = self.to_vec();
1015 let mut result = Vec::with_capacity(output_size);
1016
1017 for output_idx in 0..output_size {
1019 let mut sum = 0.0;
1020
1021 for i in 0..axis_size - 1 {
1023 let idx1 = self.compute_axis_index(output_idx, axis, i, &output_shape);
1025 let idx2 = self.compute_axis_index(output_idx, axis, i + 1, &output_shape);
1026
1027 sum += (data[idx1] + data[idx2]) / 2.0;
1028 }
1029
1030 result.push(sum);
1031 }
1032
1033 Array::from_vec(result, output_shape)
1034 }
1035
1036 pub fn gradient(&self) -> Array {
1052 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1053 let data = self.to_vec();
1054 let n = data.len();
1055
1056 if n == 0 {
1057 return self.clone();
1058 }
1059
1060 if n == 1 {
1061 return Array::from_vec(vec![0.0], self.shape().clone());
1062 }
1063
1064 let mut result = Vec::with_capacity(n);
1065
1066 result.push(data[1] - data[0]);
1068
1069 for i in 1..n - 1 {
1071 result.push((data[i + 1] - data[i - 1]) / 2.0);
1072 }
1073
1074 result.push(data[n - 1] - data[n - 2]);
1076
1077 Array::from_vec(result, self.shape().clone())
1078 }
1079
1080 pub fn ediff1d(&self) -> Array {
1094 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1095 let data = self.to_vec();
1096
1097 if data.is_empty() {
1098 return Array::zeros(Shape::new(vec![0]), DType::Float32);
1099 }
1100
1101 if data.len() == 1 {
1102 return Array::zeros(Shape::new(vec![0]), DType::Float32);
1103 }
1104
1105 let mut result = Vec::with_capacity(data.len() - 1);
1106 for i in 0..data.len() - 1 {
1107 result.push(data[i + 1] - data[i]);
1108 }
1109
1110 let len = result.len();
1111 Array::from_vec(result, Shape::new(vec![len]))
1112 }
1113
1114 pub fn nanargmax(&self) -> usize {
1125 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1126 let data = self.to_vec();
1127
1128 let mut max_val = f32::NEG_INFINITY;
1129 let mut max_idx = 0;
1130
1131 for (i, &val) in data.iter().enumerate() {
1132 if !val.is_nan() && val > max_val {
1133 max_val = val;
1134 max_idx = i;
1135 }
1136 }
1137
1138 max_idx
1139 }
1140
1141 pub fn nanargmin(&self) -> usize {
1152 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1153 let data = self.to_vec();
1154
1155 let mut min_val = f32::INFINITY;
1156 let mut min_idx = 0;
1157
1158 for (i, &val) in data.iter().enumerate() {
1159 if !val.is_nan() && val < min_val {
1160 min_val = val;
1161 min_idx = i;
1162 }
1163 }
1164
1165 min_idx
1166 }
1167
1168 pub fn average(&self, weights: &Array) -> f32 {
1180 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1181 assert_eq!(weights.dtype(), DType::Float32, "Only Float32 supported");
1182 assert_eq!(
1183 self.size(),
1184 weights.size(),
1185 "Values and weights must have same size"
1186 );
1187
1188 let data = self.to_vec();
1189 let weight_data = weights.to_vec();
1190
1191 let weighted_sum: f32 = data
1192 .iter()
1193 .zip(weight_data.iter())
1194 .map(|(v, w)| v * w)
1195 .sum();
1196
1197 let weight_sum: f32 = weight_data.iter().sum();
1198
1199 weighted_sum / weight_sum
1200 }
1201
1202 pub fn cov(&self, other: &Array) -> f32 {
1214 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1215 assert_eq!(other.dtype(), DType::Float32, "Only Float32 supported");
1216 assert_eq!(
1217 self.size(),
1218 other.size(),
1219 "Arrays must have same size for covariance"
1220 );
1221
1222 let x = self.to_vec();
1223 let y = other.to_vec();
1224 let n = x.len() as f32;
1225
1226 if n == 0.0 {
1227 return 0.0;
1228 }
1229
1230 let x_mean: f32 = x.iter().sum::<f32>() / n;
1231 let y_mean: f32 = y.iter().sum::<f32>() / n;
1232
1233 let cov: f32 = x
1234 .iter()
1235 .zip(y.iter())
1236 .map(|(&xi, &yi)| (xi - x_mean) * (yi - y_mean))
1237 .sum();
1238
1239 cov / (n - 1.0)
1240 }
1241
1242 fn compute_axis_index(
1244 &self,
1245 output_idx: usize,
1246 axis: usize,
1247 axis_pos: usize,
1248 output_shape: &Shape,
1249 ) -> usize {
1250 let dims = self.shape().as_slice();
1251 let output_dims = output_shape.as_slice();
1252
1253 let mut input_idx = 0;
1254 let mut remaining = output_idx;
1255 let mut stride = 1;
1256
1257 for (dim_idx, &dim_size) in dims.iter().enumerate().rev() {
1258 if dim_idx == axis {
1259 input_idx += axis_pos * stride;
1260 } else {
1261 let out_dim_idx = if dim_idx > axis {
1262 dim_idx - 1
1263 } else {
1264 dim_idx
1265 };
1266 let coord = remaining % output_dims[out_dim_idx];
1267 input_idx += coord * stride;
1268 remaining /= output_dims[out_dim_idx];
1269 }
1270 stride *= dim_size;
1271 }
1272
1273 input_idx
1274 }
1275
1276 pub fn nanprod(&self) -> f32 {
1287 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1288 let data = self.to_vec();
1289 data.iter()
1290 .filter(|x| !x.is_nan())
1291 .fold(1.0, |acc, &x| acc * x)
1292 }
1293
1294 pub fn nancumsum(&self) -> Array {
1306 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1307 let data = self.to_vec();
1308 let mut result = Vec::with_capacity(data.len());
1309 let mut sum = 0.0;
1310
1311 for &val in data.iter() {
1312 if !val.is_nan() {
1313 sum += val;
1314 }
1315 result.push(sum);
1316 }
1317
1318 Array::from_vec(result, self.shape().clone())
1319 }
1320
1321 pub fn nancumprod(&self) -> Array {
1333 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1334 let data = self.to_vec();
1335 let mut result = Vec::with_capacity(data.len());
1336 let mut prod = 1.0;
1337
1338 for &val in data.iter() {
1339 if !val.is_nan() {
1340 prod *= val;
1341 }
1342 result.push(prod);
1343 }
1344
1345 Array::from_vec(result, self.shape().clone())
1346 }
1347
1348 pub fn agmean(&self) -> f32 {
1359 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1360 let data = self.to_vec();
1361 if data.is_empty() {
1362 return f32::NAN;
1363 }
1364
1365 let arith = data.iter().sum::<f32>() / data.len() as f32;
1366 let geom = data.iter().fold(1.0, |acc, &x| acc * x).powf(1.0 / data.len() as f32);
1367
1368 let mut a = arith;
1370 let mut g = geom;
1371
1372 for _ in 0..20 {
1373 let new_a = (a + g) / 2.0;
1374 let new_g = (a * g).sqrt();
1375 if (new_a - a).abs() < 1e-10 {
1376 break;
1377 }
1378 a = new_a;
1379 g = new_g;
1380 }
1381
1382 a
1383 }
1384
1385 pub fn rms(&self) -> f32 {
1397 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1398 let data = self.to_vec();
1399 if data.is_empty() {
1400 return 0.0;
1401 }
1402
1403 let sum_sq: f32 = data.iter().map(|&x| x * x).sum();
1404 (sum_sq / data.len() as f32).sqrt()
1405 }
1406
1407 pub fn harmonic_mean(&self) -> f32 {
1418 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1419 let data = self.to_vec();
1420 if data.is_empty() {
1421 return f32::NAN;
1422 }
1423
1424 let sum_inv: f32 = data.iter().map(|&x| 1.0 / x).sum();
1425 data.len() as f32 / sum_inv
1426 }
1427
1428 pub fn geometric_mean(&self) -> f32 {
1439 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1440 let data = self.to_vec();
1441 if data.is_empty() {
1442 return f32::NAN;
1443 }
1444
1445 let product: f32 = data.iter().fold(1.0, |acc, &x| acc * x);
1446 product.powf(1.0 / data.len() as f32)
1447 }
1448
1449 pub fn nanpercentile(&self, q: f32) -> f32 {
1460 assert!((0.0..=100.0).contains(&q), "Percentile must be in [0, 100]");
1461 assert_eq!(self.dtype(), DType::Float32, "Only Float32 supported");
1462
1463 let mut data: Vec<f32> = self.to_vec().into_iter().filter(|x| !x.is_nan()).collect();
1464 if data.is_empty() {
1465 return f32::NAN;
1466 }
1467
1468 data.sort_by(|a, b| a.partial_cmp(b).unwrap());
1469
1470 let n = data.len();
1471 let idx = q / 100.0 * (n - 1) as f32;
1472 let lo = idx.floor() as usize;
1473 let hi = idx.ceil() as usize;
1474 let frac = idx - lo as f32;
1475
1476 if lo == hi {
1477 data[lo]
1478 } else {
1479 data[lo] * (1.0 - frac) + data[hi] * frac
1480 }
1481 }
1482
1483 pub fn nanquantile(&self, q: f32) -> f32 {
1494 assert!((0.0..=1.0).contains(&q), "Quantile must be in [0, 1]");
1495 self.nanpercentile(q * 100.0)
1496 }
1497}
1498
1499#[cfg(test)]
1500mod tests {
1501 use super::*;
1502 use approx::assert_abs_diff_eq;
1503
1504 #[test]
1505 fn test_sum_all() {
1506 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1507 assert_eq!(a.sum_all(), 10.0);
1508 }
1509
1510 #[test]
1511 fn test_sum_axis() {
1512 let a = Array::from_vec(
1514 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1515 Shape::new(vec![2, 3]),
1516 );
1517
1518 let sum_axis0 = a.sum(0);
1520 assert_eq!(sum_axis0.shape().as_slice(), &[3]);
1521 assert_eq!(sum_axis0.to_vec(), vec![5.0, 7.0, 9.0]);
1522
1523 let sum_axis1 = a.sum(1);
1525 assert_eq!(sum_axis1.shape().as_slice(), &[2]);
1526 assert_eq!(sum_axis1.to_vec(), vec![6.0, 15.0]);
1527 }
1528
1529 #[test]
1530 fn test_mean_all() {
1531 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1532 assert_abs_diff_eq!(a.mean_all(), 2.5, epsilon = 1e-6);
1533 }
1534
1535 #[test]
1536 fn test_mean_axis() {
1537 let a = Array::from_vec(
1538 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1539 Shape::new(vec![2, 3]),
1540 );
1541
1542 let mean_axis0 = a.mean(0);
1543 assert_eq!(mean_axis0.to_vec(), vec![2.5, 3.5, 4.5]);
1544
1545 let mean_axis1 = a.mean(1);
1546 assert_eq!(mean_axis1.to_vec(), vec![2.0, 5.0]);
1547 }
1548
1549 #[test]
1550 fn test_max_all() {
1551 let a = Array::from_vec(vec![1.0, 5.0, 3.0, 2.0], Shape::new(vec![4]));
1552 assert_eq!(a.max_all(), 5.0);
1553 }
1554
1555 #[test]
1556 fn test_max_axis() {
1557 let a = Array::from_vec(
1558 vec![1.0, 5.0, 3.0, 2.0, 4.0, 6.0],
1559 Shape::new(vec![2, 3]),
1560 );
1561
1562 let max_axis0 = a.max(0);
1563 assert_eq!(max_axis0.to_vec(), vec![2.0, 5.0, 6.0]);
1564
1565 let max_axis1 = a.max(1);
1566 assert_eq!(max_axis1.to_vec(), vec![5.0, 6.0]);
1567 }
1568
1569 #[test]
1570 fn test_min_all() {
1571 let a = Array::from_vec(vec![3.0, 1.0, 5.0, 2.0], Shape::new(vec![4]));
1572 assert_eq!(a.min_all(), 1.0);
1573 }
1574
1575 #[test]
1576 fn test_min_axis() {
1577 let a = Array::from_vec(
1578 vec![3.0, 5.0, 2.0, 1.0, 4.0, 6.0],
1579 Shape::new(vec![2, 3]),
1580 );
1581
1582 let min_axis0 = a.min(0);
1583 assert_eq!(min_axis0.to_vec(), vec![1.0, 4.0, 2.0]);
1584
1585 let min_axis1 = a.min(1);
1586 assert_eq!(min_axis1.to_vec(), vec![2.0, 1.0]);
1587 }
1588
1589 #[test]
1590 fn test_reduce_3d() {
1591 let a = Array::from_vec(
1593 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0],
1594 Shape::new(vec![2, 2, 2]),
1595 );
1596
1597 let sum_axis1 = a.sum(1);
1599 assert_eq!(sum_axis1.shape().as_slice(), &[2, 2]);
1600 assert_eq!(sum_axis1.to_vec(), vec![4.0, 6.0, 12.0, 14.0]);
1601 }
1602
1603 #[test]
1604 fn test_cumsum() {
1605 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1606 let cumsum = a.cumsum();
1607 assert_eq!(cumsum.to_vec(), vec![1.0, 3.0, 6.0, 10.0]);
1608 }
1609
1610 #[test]
1611 fn test_cumprod() {
1612 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0], Shape::new(vec![4]));
1613 let cumprod = a.cumprod();
1614 assert_eq!(cumprod.to_vec(), vec![1.0, 2.0, 6.0, 24.0]);
1615 }
1616
1617 #[test]
1618 fn test_cummax() {
1619 let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
1620 let cummax = a.cummax();
1621 assert_eq!(cummax.to_vec(), vec![3.0, 3.0, 4.0, 4.0]);
1622 }
1623
1624 #[test]
1625 fn test_cummin() {
1626 let a = Array::from_vec(vec![3.0, 1.0, 4.0, 2.0], Shape::new(vec![4]));
1627 let cummin = a.cummin();
1628 assert_eq!(cummin.to_vec(), vec![3.0, 1.0, 1.0, 1.0]);
1629 }
1630
1631 #[test]
1632 fn test_diff() {
1633 let a = Array::from_vec(vec![1.0, 3.0, 6.0, 10.0], Shape::new(vec![4]));
1634 let diff = a.diff();
1635 assert_eq!(diff.to_vec(), vec![2.0, 3.0, 4.0]);
1636 }
1637
1638 #[test]
1639 fn test_diff_n() {
1640 let a = Array::from_vec(
1642 vec![1.0, 2.0, 3.0, 4.0, 5.0],
1643 Shape::new(vec![5]),
1644 );
1645 let diff2 = a.diff_n(2);
1646 assert_eq!(diff2.to_vec(), vec![0.0, 0.0, 0.0]);
1647
1648 let diff0 = a.diff_n(0);
1650 assert_eq!(diff0.to_vec(), a.to_vec());
1651 }
1652
1653 #[test]
1654 fn test_nansum() {
1655 let a = Array::from_vec(
1656 vec![1.0, f32::NAN, 3.0, 4.0],
1657 Shape::new(vec![4]),
1658 );
1659 let sum = a.nansum();
1660 assert_eq!(sum, 8.0);
1661 }
1662
1663 #[test]
1664 fn test_nanmean() {
1665 let a = Array::from_vec(
1666 vec![1.0, f32::NAN, 3.0, 4.0],
1667 Shape::new(vec![4]),
1668 );
1669 let mean = a.nanmean();
1670 assert_abs_diff_eq!(mean, 8.0 / 3.0, epsilon = 1e-6);
1671 }
1672
1673 #[test]
1674 fn test_nanmax() {
1675 let a = Array::from_vec(
1676 vec![1.0, f32::NAN, 4.0, 2.0],
1677 Shape::new(vec![4]),
1678 );
1679 let max = a.nanmax();
1680 assert_eq!(max, 4.0);
1681 }
1682
1683 #[test]
1684 fn test_nanmin() {
1685 let a = Array::from_vec(
1686 vec![1.0, f32::NAN, 4.0, 2.0],
1687 Shape::new(vec![4]),
1688 );
1689 let min = a.nanmin();
1690 assert_eq!(min, 1.0);
1691 }
1692
1693 #[test]
1694 fn test_nanstd() {
1695 let a = Array::from_vec(
1696 vec![1.0, f32::NAN, 3.0, 5.0],
1697 Shape::new(vec![4]),
1698 );
1699 let std = a.nanstd();
1700 assert_abs_diff_eq!(std, 2.0, epsilon = 1e-5);
1701 }
1702
1703 #[test]
1704 fn test_nanvar() {
1705 let a = Array::from_vec(
1706 vec![1.0, f32::NAN, 3.0, 5.0],
1707 Shape::new(vec![4]),
1708 );
1709 let var = a.nanvar();
1710 assert_abs_diff_eq!(var, 4.0, epsilon = 1e-5);
1711 }
1712
1713 #[test]
1714 fn test_nanmedian() {
1715 let a = Array::from_vec(
1716 vec![1.0, f32::NAN, 3.0, 5.0, 2.0],
1717 Shape::new(vec![5]),
1718 );
1719 let median = a.nanmedian();
1720 assert_eq!(median, 2.5);
1721 }
1722
1723 #[test]
1724 fn test_ptp() {
1725 let a = Array::from_vec(vec![1.0, 5.0, 2.0, 8.0], Shape::new(vec![4]));
1726 assert_eq!(a.ptp(), 7.0);
1727 }
1728
1729 #[test]
1730 fn test_ptp_axis() {
1731 let a = Array::from_vec(vec![1.0, 5.0, 2.0, 8.0], Shape::new(vec![2, 2]));
1732 let ptp = a.ptp_axis(0);
1733 assert_eq!(ptp.to_vec(), vec![1.0, 3.0]);
1734 }
1735
1736 #[test]
1737 fn test_quantile() {
1738 let a = Array::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0], Shape::new(vec![5]));
1739 assert_abs_diff_eq!(a.quantile(0.0), 1.0, epsilon = 1e-6);
1740 assert_abs_diff_eq!(a.quantile(0.5), 3.0, epsilon = 1e-6);
1741 assert_abs_diff_eq!(a.quantile(1.0), 5.0, epsilon = 1e-6);
1742 assert_abs_diff_eq!(a.quantile(0.25), 2.0, epsilon = 1e-6);
1743 }
1744
1745 #[test]
1746 fn test_quantile_axis() {
1747 let a = Array::from_vec(
1748 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1749 Shape::new(vec![2, 3]),
1750 );
1751 let q = a.quantile_axis(0.5, 0);
1752 assert_eq!(q.shape().as_slice(), &[3]);
1753 assert_abs_diff_eq!(q.to_vec()[0], 2.5, epsilon = 1e-6);
1754 assert_abs_diff_eq!(q.to_vec()[1], 3.5, epsilon = 1e-6);
1755 assert_abs_diff_eq!(q.to_vec()[2], 4.5, epsilon = 1e-6);
1756 }
1757
1758 #[test]
1759 fn test_trapz() {
1760 let a = Array::from_vec(vec![1.0, 2.0, 3.0], Shape::new(vec![3]));
1761 assert_eq!(a.trapz(), 4.0);
1762
1763 let b = Array::from_vec(vec![0.0, 1.0, 0.0], Shape::new(vec![3]));
1764 assert_eq!(b.trapz(), 1.0);
1765 }
1766
1767 #[test]
1768 fn test_trapz_axis() {
1769 let a = Array::from_vec(
1770 vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0],
1771 Shape::new(vec![2, 3]),
1772 );
1773 let integral = a.trapz_axis(1);
1774 assert_eq!(integral.shape().as_slice(), &[2]);
1775 assert_eq!(integral.to_vec(), vec![4.0, 10.0]);
1776 }
1777
1778 #[test]
1779 fn test_gradient() {
1780 let a = Array::from_vec(vec![1.0, 2.0, 4.0, 7.0], Shape::new(vec![4]));
1781 let grad = a.gradient();
1782 assert_eq!(grad.to_vec(), vec![1.0, 1.5, 2.5, 3.0]);
1784 }
1785
1786 #[test]
1787 fn test_gradient_constant() {
1788 let a = Array::from_vec(vec![5.0, 5.0, 5.0, 5.0], Shape::new(vec![4]));
1789 let grad = a.gradient();
1790 assert_eq!(grad.to_vec(), vec![0.0, 0.0, 0.0, 0.0]);
1792 }
1793
1794 #[test]
1795 fn test_gradient_linear() {
1796 let a = Array::from_vec(vec![0.0, 1.0, 2.0, 3.0], Shape::new(vec![4]));
1797 let grad = a.gradient();
1798 assert_eq!(grad.to_vec(), vec![1.0, 1.0, 1.0, 1.0]);
1800 }
1801
1802 #[test]
1803 fn test_ediff1d() {
1804 let a = Array::from_vec(vec![1.0, 3.0, 6.0, 10.0], Shape::new(vec![4]));
1805 let edges = a.ediff1d();
1806 assert_eq!(edges.to_vec(), vec![2.0, 3.0, 4.0]);
1807 }
1808
1809 #[test]
1810 fn test_ediff1d_single() {
1811 let a = Array::from_vec(vec![5.0], Shape::new(vec![1]));
1812 let edges = a.ediff1d();
1813 assert_eq!(edges.shape().as_slice(), &[0]);
1814 assert_eq!(edges.size(), 0);
1815 }
1816}