arr_rs/core/operations/
search.rs

1use crate::{
2    core::prelude::*,
3    errors::prelude::*,
4    extensions::prelude::*,
5};
6
7/// `ArrayTrait` - Array Search functions
8pub trait ArraySearch<T: ArrayElement> where Self: Sized + Clone {
9
10    /// Returns the indices of the maximum values along an axis.
11    ///
12    /// # Arguments
13    ///
14    /// * `axis` - axis along which to search. if None, array is flattened
15    /// * `keepdims` - if true, the result will broadcast correctly against the input
16    ///
17    /// # Examples
18    ///
19    /// ```
20    /// use arr_rs::prelude::*;
21    ///
22    /// let arr = array!(i32, [[10, 11, 12], [13, 14, 15]]);
23    /// assert_eq!(array!(usize, [5]), arr.argmax(None, None));
24    /// let arr = array!(f64, [[f64::NAN, 4.], [2., 3.]]);
25    /// assert_eq!(array!(usize, [0]), arr.argmax(None, None));
26    /// ```
27    ///
28    /// # Errors
29    ///
30    /// may returns `ArrayError`
31    fn argmax(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError>;
32
33    /// Returns the indices of the minimum values along an axis.
34    ///
35    /// # Arguments
36    ///
37    /// * `axis` - axis along which to search. if None, array is flattened
38    /// * `keepdims` - if true, the result will broadcast correctly against the input
39    ///
40    /// # Examples
41    ///
42    /// ```
43    /// use arr_rs::prelude::*;
44    ///
45    /// let arr = array!(i32, [[10, 11, 12], [13, 14, 15]]);
46    /// assert_eq!(array!(usize, [0]), arr.argmin(None, None));
47    /// assert_eq!(array!(usize, [0, 0, 0]), arr.argmin(Some(0), None));
48    /// assert_eq!(array!(usize, [0, 0]), arr.argmin(Some(1), None));
49    /// let arr = array!(f64, [[f64::NAN, 4.], [2., 3.]]);
50    /// assert_eq!(array!(usize, [0]), arr.argmin(None, None));
51    /// assert_eq!(array!(usize, [0, 1]), arr.argmin(Some(0), None));
52    /// assert_eq!(array!(usize, [0, 0]), arr.argmin(Some(1), None));
53    /// ```
54    ///
55    /// # Errors
56    ///
57    /// may returns `ArrayError`
58    fn argmin(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError>;
59}
60
61impl <T: ArrayElement> ArraySearch<T> for Array<T> {
62
63    fn argmax(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError> {
64        if let Some(axis) = axis {
65            let axis = self.normalize_axis(axis);
66            let result = self.apply_along_axis(axis, |arr| arr.argmax(None, keepdims));
67            if keepdims == Some(true) { result }
68            else { result.reshape(&self.get_shape()?.remove_at(axis)) }
69        } else {
70            if self.is_empty()? { return Err(ArrayError::ParameterError { param: "`array`", message: "cannot be empty" }) }
71            let result = if let Some(i) = self.get_elements()?.iter().position(ArrayElement::is_nan) { Array::single(i) } else {
72                let sorted = self.sort(None, Some("quicksort"))?.get_elements()?;
73                let max_pos = self.get_elements()?.iter().position(|item| item == &sorted[sorted.len() - 1]).unwrap();
74                Array::single(max_pos)
75            };
76
77            if keepdims == Some(true) { result.atleast(self.ndim()?) }
78            else { result }
79        }
80    }
81
82    fn argmin(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError> {
83        if let Some(axis) = axis {
84            let axis = self.normalize_axis(axis);
85            let result = self.apply_along_axis(axis, |arr| arr.argmin(None, keepdims));
86            if keepdims == Some(true) { result }
87            else { result.reshape(&self.get_shape()?.remove_at(axis)) }
88        } else {
89            if self.is_empty()? { return Err(ArrayError::ParameterError { param: "`array`", message: "cannot be empty" }) }
90            let result = if let Some(i) = self.get_elements()?.iter().position(ArrayElement::is_nan) { Array::single(i) } else {
91                let sorted = self.sort(None, Some("quicksort"))?.get_elements()?;
92                let max_pos = self.get_elements()?.iter().position(|item| item == &sorted[0]).unwrap();
93                Array::single(max_pos)
94            };
95
96            if keepdims == Some(true) { result.atleast(self.ndim()?) }
97            else { result }
98        }
99    }
100}
101
102impl <T: ArrayElement> ArraySearch<T> for Result<Array<T>, ArrayError> {
103
104    fn argmax(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError> {
105        self.clone()?.argmax(axis, keepdims)
106    }
107
108    fn argmin(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError> {
109        self.clone()?.argmin(axis, keepdims)
110    }
111}