1use std::{
2 borrow::Cow,
3 ops::{Add, Div, Mul, Sub},
4};
5
6use num_traits::{FromPrimitive, One, Zero};
7
8use crate::Array;
9
10impl<'a, T: Clone + Ord, const D: usize> Array<'a, T, D> {
11 pub fn max(&self) -> Option<T> {
12 self.flat().max().cloned()
13 }
14
15 pub fn arg_max(&self) -> Vec<usize> {
16 let mut positions = vec![];
17
18 if let Some(max) = self.max() {
19 for (index, value) in self.flat().enumerate() {
20 if value == &max {
21 positions.push(index)
22 }
23 }
24 }
25
26 positions
27 }
28
29 pub fn max_across(&self, axis: usize) -> Vec<Option<T>> {
30 self.axis_view(axis).map(|view| view.max()).collect()
31 }
32
33 pub fn arg_max_across(&self, axis: usize) -> Vec<Option<usize>> {
34 self.axis_view(axis)
35 .map(|view| view.arg_max().get(0).copied())
36 .collect()
37 }
38
39 pub fn min(&self) -> Option<T> {
40 self.flat().min().cloned()
41 }
42
43 pub fn arg_min(&self) -> Vec<usize> {
44 let mut positions = vec![];
45
46 if let Some(min) = self.min() {
47 for (index, value) in self.flat().enumerate() {
48 if value == &min {
49 positions.push(index)
50 }
51 }
52 }
53
54 positions
55 }
56
57 pub fn min_across(&self, axis: usize) -> Vec<Option<T>> {
58 self.axis_view(axis).map(|view| view.min()).collect()
59 }
60
61 pub fn arg_min_across(&self, axis: usize) -> Vec<Option<usize>> {
62 self.axis_view(axis)
63 .map(|view| view.arg_min().get(0).copied())
64 .collect()
65 }
66
67 pub fn clip(&self, min: &T, max: &T) -> Array<'a, T, D> {
68 let vec: Vec<T> = self
69 .vec
70 .iter()
71 .map(|val| val.clamp(min, max).clone())
72 .collect();
73
74 let shape = self.shape.clone();
75 let strides = self.strides.clone();
76 let idx_maps = self.idx_maps.clone();
77
78 Array {
79 vec: Cow::from(vec),
80 shape,
81 strides,
82 idx_maps,
83 }
84 }
85}
86
87impl<'a, T, const D: usize> Array<'a, T, D>
88where
89 T: Clone + Ord + Sub<Output = T>,
90{
91 pub fn ptp(&self) -> Option<T> {
92 self.max().and_then(|max| self.min().map(|min| max - min))
93 }
94
95 pub fn ptp_across(&self, axis: usize) -> Vec<Option<T>> {
96 self.axis_view(axis).map(|view| view.ptp()).collect()
97 }
98}
99
100impl<'a, T, const D: usize> Array<'a, T, D>
101where
102 T: Clone + Add<Output = T> + Zero,
103{
104 pub fn sum(&self) -> T {
105 self.flat().fold(T::zero(), |acc, val| acc + val.clone())
106 }
107
108 pub fn sum_across(&self, axis: usize) -> Vec<T> {
109 self.axis_view(axis).map(|view| view.sum()).collect()
110 }
111}
112
113impl<'a, T, const D: usize> Array<'a, T, D>
114where
115 T: Clone + Mul<Output = T> + One,
116{
117 pub fn prod(&self) -> T {
118 self.flat().fold(T::one(), |acc, val| acc * val.clone())
119 }
120
121 pub fn prod_across(&self, axis: usize) -> Vec<T> {
122 self.axis_view(axis).map(|view| view.prod()).collect()
123 }
124}
125
126impl<'a, T, const D: usize> Array<'a, T, D>
127where
128 T: Clone + Add<Output = T> + FromPrimitive + Div<T, Output = T> + Zero,
129{
130 pub fn mean(&self) -> T {
131 self.sum() / T::from_usize(self.shape().iter().product()).unwrap()
132 }
133
134 pub fn mean_across(&self, axis: usize) -> Vec<T> {
135 self.axis_view(axis).map(|view| view.mean()).collect()
136 }
137}
138
139impl<'a, T, const D: usize> Array<'a, T, D>
140where
141 T: Clone + Sub<Output = T> + FromPrimitive + Div<T, Output = T> + Mul<Output = T> + Zero,
142{
143 pub fn var(&self) -> T {
144 let mean = self.mean();
145
146 self.flat().fold(T::zero(), |acc, val| {
147 acc + (val.clone() - mean.clone()) * (val.clone() - mean.clone())
148 }) / T::from_usize(self.shape().iter().product()).unwrap()
149 }
150
151 pub fn var_across(&self, axis: usize) -> Vec<T> {
152 self.axis_view(axis).map(|view| view.var()).collect()
153 }
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn max() {
162 let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
166
167 assert_eq!(array.max().unwrap(), 3);
168 }
169
170 #[test]
171 fn arg_max() {
172 let array = Array::init(vec![0, 1, 2 , 3], [2, 2]);
176
177 assert_eq!(array.arg_max()[0], 3);
178 }
179
180 #[test]
181 fn max_across() {
182 let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
186
187 assert_eq!(array.max_across(1), vec![Some(2), Some(3)]);
188 assert_eq!(array.max_across(0), vec![Some(1), Some(3)]);
189 }
190
191 #[test]
192 fn arg_max_across() {
193 let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
197
198 assert_eq!(array.arg_max_across(1), vec![Some(1), Some(1)]);
199 assert_eq!(array.arg_max_across(0), vec![Some(1), Some(1)]);
200 }
201
202 #[test]
203 fn min() {
204 let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
208
209 assert_eq!(array.min().unwrap(), 0);
210 }
211
212 #[test]
213 fn arg_min() {
214 let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
218
219 assert_eq!(array.arg_min()[0], 0);
220 }
221
222 #[test]
223 fn min_across() {
224 let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
228
229 assert_eq!(array.min_across(1), vec![Some(0), Some(1)]);
230 assert_eq!(array.min_across(0), vec![Some(0), Some(2)]);
231 }
232
233 #[test]
234 fn arg_min_across() {
235 let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
239
240 assert_eq!(array.arg_min_across(1), vec![Some(0), Some(0)]);
241 assert_eq!(array.arg_min_across(0), vec![Some(0), Some(0)]);
242 }
243
244 #[test]
245 fn clip() {
246 let array = Array::arange(0..10);
247
248 let clipped = array.clip(&1, &8);
249
250 assert_eq!(
251 clipped.flat().copied().collect::<Vec<i32>>(),
252 vec![1, 1, 2, 3, 4, 5, 6, 7, 8, 8]
253 );
254 }
255
256 #[test]
257 fn ptp() {
258 let array = Array::init(vec![4, 9, 2, 10, 6, 9, 7, 12], [2, 4]);
259
260 assert_eq!(array.ptp().unwrap(), 10)
261 }
262
263 #[test]
264 fn ptp_across() {
265 let array = Array::init(vec![4, 9, 2, 10, 6, 9, 7, 12], [2, 4]);
266
267 assert_eq!(array.ptp_across(0), vec![Some(8), Some(6)]);
268 assert_eq!(
269 array.ptp_across(1),
270 vec![Some(2), Some(0), Some(5), Some(2)]
271 )
272 }
273
274 #[test]
275 fn sum() {
276 let array = Array::arange(1..5).reshape([2, 2]);
279
280 assert_eq!(array.sum(), 10);
281 }
282
283 #[test]
284 fn sum_across() {
285 let array = Array::arange(1..5).reshape([2, 2]);
288
289 assert_eq!(array.sum_across(0), vec![3, 7]);
290 assert_eq!(array.sum_across(1), vec![4, 6]);
291 }
292
293 #[test]
294 fn prod() {
295 let array = Array::arange(1..5).reshape([2, 2]);
298
299 assert_eq!(array.prod(), 24);
300 }
301
302 #[test]
303 fn prod_across() {
304 let array = Array::arange(1..5).reshape([2, 2]);
307
308 assert_eq!(array.prod_across(0), vec![2, 12]);
309 assert_eq!(array.prod_across(1), vec![3, 8]);
310 }
311
312 #[test]
313 fn mean() {
314 let array = Array::arange(1..5).reshape([2, 2]);
317
318 assert_eq!(array.mean(), 2);
319 }
320
321 #[test]
322 fn mean_across() {
323 let array = Array::arange(1..5).reshape([2, 2]);
326
327 assert_eq!(array.mean_across(0), vec![1, 3]);
328 assert_eq!(array.mean_across(1), vec![2, 3]);
329 }
330
331 #[test]
332 fn var() {
333 let array = Array::init(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
336
337 assert_eq!(array.var(), 1.25);
338 }
339
340 #[test]
341 fn var_across() {
342 let array = Array::init(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
345
346 assert_eq!(array.var_across(0), vec![0.25, 0.25]);
347 assert_eq!(array.var_across(1), vec![1.0, 1.0]);
348 }
349}