nd_array/ndarray/array/
calc.rs

1use std::{
2    borrow::Cow,
3    ops::{Add, Div, Mul, Sub},
4};
5
6use num_traits::{FromPrimitive, One, Zero};
7
8use crate::Array;
9
10impl<'a, T: Clone + Ord, const D: usize> Array<'a, T, D> {
11    pub fn max(&self) -> Option<T> {
12        self.flat().max().cloned()
13    }
14
15    pub fn arg_max(&self) -> Vec<usize> {
16        let mut positions = vec![];
17
18        if let Some(max) = self.max() {
19            for (index, value) in self.flat().enumerate() {
20                if value == &max {
21                    positions.push(index)
22                }
23            }
24        }
25
26        positions
27    }
28
29    pub fn max_across(&self, axis: usize) -> Vec<Option<T>> {
30        self.axis_view(axis).map(|view| view.max()).collect()
31    }
32
33    pub fn arg_max_across(&self, axis: usize) -> Vec<Option<usize>> {
34        self.axis_view(axis)
35            .map(|view| view.arg_max().get(0).copied())
36            .collect()
37    }
38
39    pub fn min(&self) -> Option<T> {
40        self.flat().min().cloned()
41    }
42
43    pub fn arg_min(&self) -> Vec<usize> {
44        let mut positions = vec![];
45
46        if let Some(min) = self.min() {
47            for (index, value) in self.flat().enumerate() {
48                if value == &min {
49                    positions.push(index)
50                }
51            }
52        }
53
54        positions
55    }
56
57    pub fn min_across(&self, axis: usize) -> Vec<Option<T>> {
58        self.axis_view(axis).map(|view| view.min()).collect()
59    }
60
61    pub fn arg_min_across(&self, axis: usize) -> Vec<Option<usize>> {
62        self.axis_view(axis)
63            .map(|view| view.arg_min().get(0).copied())
64            .collect()
65    }
66
67    pub fn clip(&self, min: &T, max: &T) -> Array<'a, T, D> {
68        let vec: Vec<T> = self
69            .vec
70            .iter()
71            .map(|val| val.clamp(min, max).clone())
72            .collect();
73
74        let shape = self.shape.clone();
75        let strides = self.strides.clone();
76        let idx_maps = self.idx_maps.clone();
77
78        Array {
79            vec: Cow::from(vec),
80            shape,
81            strides,
82            idx_maps,
83        }
84    }
85}
86
87impl<'a, T, const D: usize> Array<'a, T, D>
88where
89    T: Clone + Ord + Sub<Output = T>,
90{
91    pub fn ptp(&self) -> Option<T> {
92        self.max().and_then(|max| self.min().map(|min| max - min))
93    }
94
95    pub fn ptp_across(&self, axis: usize) -> Vec<Option<T>> {
96        self.axis_view(axis).map(|view| view.ptp()).collect()
97    }
98}
99
100impl<'a, T, const D: usize> Array<'a, T, D>
101where
102    T: Clone + Add<Output = T> + Zero,
103{
104    pub fn sum(&self) -> T {
105        self.flat().fold(T::zero(), |acc, val| acc + val.clone())
106    }
107
108    pub fn sum_across(&self, axis: usize) -> Vec<T> {
109        self.axis_view(axis).map(|view| view.sum()).collect()
110    }
111}
112
113impl<'a, T, const D: usize> Array<'a, T, D>
114where
115    T: Clone + Mul<Output = T> + One,
116{
117    pub fn prod(&self) -> T {
118        self.flat().fold(T::one(), |acc, val| acc * val.clone())
119    }
120
121    pub fn prod_across(&self, axis: usize) -> Vec<T> {
122        self.axis_view(axis).map(|view| view.prod()).collect()
123    }
124}
125
126impl<'a, T, const D: usize> Array<'a, T, D>
127where
128    T: Clone + Add<Output = T> + FromPrimitive + Div<T, Output = T> + Zero,
129{
130    pub fn mean(&self) -> T {
131        self.sum() / T::from_usize(self.shape().iter().product()).unwrap()
132    }
133
134    pub fn mean_across(&self, axis: usize) -> Vec<T> {
135        self.axis_view(axis).map(|view| view.mean()).collect()
136    }
137}
138
139impl<'a, T, const D: usize> Array<'a, T, D>
140where
141    T: Clone + Sub<Output = T> + FromPrimitive + Div<T, Output = T> + Mul<Output = T> + Zero,
142{
143    pub fn var(&self) -> T {
144        let mean = self.mean();
145
146        self.flat().fold(T::zero(), |acc, val| {
147            acc + (val.clone() - mean.clone()) * (val.clone() - mean.clone())
148        }) / T::from_usize(self.shape().iter().product()).unwrap()
149    }
150
151    pub fn var_across(&self, axis: usize) -> Vec<T> {
152        self.axis_view(axis).map(|view| view.var()).collect()
153    }
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn max() {
162        // 2-D array:
163        // 0 1
164        // 2 3
165        let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
166
167        assert_eq!(array.max().unwrap(), 3);
168    }
169
170    #[test]
171    fn arg_max() {
172        // 2-D array:
173        // 0 1
174        // 2 3
175        let array = Array::init(vec![0, 1, 2 , 3], [2, 2]);
176
177        assert_eq!(array.arg_max()[0], 3);
178    }
179
180    #[test]
181    fn max_across() {
182        // 2-D array:
183        // 0 1
184        // 2 3
185        let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
186
187        assert_eq!(array.max_across(1), vec![Some(2), Some(3)]);
188        assert_eq!(array.max_across(0), vec![Some(1), Some(3)]);
189    }
190
191    #[test]
192    fn arg_max_across() {
193        // 2-D array:
194        // 0 1
195        // 2 3
196        let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
197
198        assert_eq!(array.arg_max_across(1), vec![Some(1), Some(1)]);
199        assert_eq!(array.arg_max_across(0), vec![Some(1), Some(1)]);
200    }
201
202    #[test]
203    fn min() {
204        // 2-D array:
205        // 0 1
206        // 2 3
207        let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
208
209        assert_eq!(array.min().unwrap(), 0);
210    }
211
212    #[test]
213    fn arg_min() {
214        // 2-D array:
215        // 0 1
216        // 2 3
217        let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
218
219        assert_eq!(array.arg_min()[0], 0);
220    }
221
222    #[test]
223    fn min_across() {
224        // 2-D array:
225        // 0 1
226        // 2 3
227        let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
228
229        assert_eq!(array.min_across(1), vec![Some(0), Some(1)]);
230        assert_eq!(array.min_across(0), vec![Some(0), Some(2)]);
231    }
232
233    #[test]
234    fn arg_min_across() {
235        // 2-D array:
236        // 0 1
237        // 2 3
238        let array = Array::init(vec![0, 1, 2, 3], [2, 2]);
239
240        assert_eq!(array.arg_min_across(1), vec![Some(0), Some(0)]);
241        assert_eq!(array.arg_min_across(0), vec![Some(0), Some(0)]);
242    }
243
244    #[test]
245    fn clip() {
246        let array = Array::arange(0..10);
247
248        let clipped = array.clip(&1, &8);
249
250        assert_eq!(
251            clipped.flat().copied().collect::<Vec<i32>>(),
252            vec![1, 1, 2, 3, 4, 5, 6, 7, 8, 8]
253        );
254    }
255
256    #[test]
257    fn ptp() {
258        let array = Array::init(vec![4, 9, 2, 10, 6, 9, 7, 12], [2, 4]);
259
260        assert_eq!(array.ptp().unwrap(), 10)
261    }
262
263    #[test]
264    fn ptp_across() {
265        let array = Array::init(vec![4, 9, 2, 10, 6, 9, 7, 12], [2, 4]);
266
267        assert_eq!(array.ptp_across(0), vec![Some(8), Some(6)]);
268        assert_eq!(
269            array.ptp_across(1),
270            vec![Some(2), Some(0), Some(5), Some(2)]
271        )
272    }
273
274    #[test]
275    fn sum() {
276        // 1 2
277        // 3 4
278        let array = Array::arange(1..5).reshape([2, 2]);
279
280        assert_eq!(array.sum(), 10);
281    }
282
283    #[test]
284    fn sum_across() {
285        // 1 2
286        // 3 4
287        let array = Array::arange(1..5).reshape([2, 2]);
288
289        assert_eq!(array.sum_across(0), vec![3, 7]);
290        assert_eq!(array.sum_across(1), vec![4, 6]);
291    }
292
293    #[test]
294    fn prod() {
295        // 1 2
296        // 3 4
297        let array = Array::arange(1..5).reshape([2, 2]);
298
299        assert_eq!(array.prod(), 24);
300    }
301
302    #[test]
303    fn prod_across() {
304        // 1 2
305        // 3 4
306        let array = Array::arange(1..5).reshape([2, 2]);
307
308        assert_eq!(array.prod_across(0), vec![2, 12]);
309        assert_eq!(array.prod_across(1), vec![3, 8]);
310    }
311
312    #[test]
313    fn mean() {
314        // 1 2
315        // 3 4
316        let array = Array::arange(1..5).reshape([2, 2]);
317
318        assert_eq!(array.mean(), 2);
319    }
320
321    #[test]
322    fn mean_across() {
323        // 1 2
324        // 3 4
325        let array = Array::arange(1..5).reshape([2, 2]);
326
327        assert_eq!(array.mean_across(0), vec![1, 3]);
328        assert_eq!(array.mean_across(1), vec![2, 3]);
329    }
330
331    #[test]
332    fn var() {
333        // 1 2
334        // 3 4
335        let array = Array::init(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
336
337        assert_eq!(array.var(), 1.25);
338    }
339
340    #[test]
341    fn var_across() {
342        // 1 2
343        // 3 4
344        let array = Array::init(vec![1.0, 2.0, 3.0, 4.0], [2, 2]);
345
346        assert_eq!(array.var_across(0), vec![0.25, 0.25]);
347        assert_eq!(array.var_across(1), vec![1.0, 1.0]);
348    }
349}