arr_rs/core/operations/
search.rs1use crate::{
2 core::prelude::*,
3 errors::prelude::*,
4 extensions::prelude::*,
5};
6
7pub trait ArraySearch<T: ArrayElement> where Self: Sized + Clone {
9
10 fn argmax(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError>;
32
33 fn argmin(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError>;
59}
60
61impl <T: ArrayElement> ArraySearch<T> for Array<T> {
62
63 fn argmax(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError> {
64 if let Some(axis) = axis {
65 let axis = self.normalize_axis(axis);
66 let result = self.apply_along_axis(axis, |arr| arr.argmax(None, keepdims));
67 if keepdims == Some(true) { result }
68 else { result.reshape(&self.get_shape()?.remove_at(axis)) }
69 } else {
70 if self.is_empty()? { return Err(ArrayError::ParameterError { param: "`array`", message: "cannot be empty" }) }
71 let result = if let Some(i) = self.get_elements()?.iter().position(ArrayElement::is_nan) { Array::single(i) } else {
72 let sorted = self.sort(None, Some("quicksort"))?.get_elements()?;
73 let max_pos = self.get_elements()?.iter().position(|item| item == &sorted[sorted.len() - 1]).unwrap();
74 Array::single(max_pos)
75 };
76
77 if keepdims == Some(true) { result.atleast(self.ndim()?) }
78 else { result }
79 }
80 }
81
82 fn argmin(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError> {
83 if let Some(axis) = axis {
84 let axis = self.normalize_axis(axis);
85 let result = self.apply_along_axis(axis, |arr| arr.argmin(None, keepdims));
86 if keepdims == Some(true) { result }
87 else { result.reshape(&self.get_shape()?.remove_at(axis)) }
88 } else {
89 if self.is_empty()? { return Err(ArrayError::ParameterError { param: "`array`", message: "cannot be empty" }) }
90 let result = if let Some(i) = self.get_elements()?.iter().position(ArrayElement::is_nan) { Array::single(i) } else {
91 let sorted = self.sort(None, Some("quicksort"))?.get_elements()?;
92 let max_pos = self.get_elements()?.iter().position(|item| item == &sorted[0]).unwrap();
93 Array::single(max_pos)
94 };
95
96 if keepdims == Some(true) { result.atleast(self.ndim()?) }
97 else { result }
98 }
99 }
100}
101
102impl <T: ArrayElement> ArraySearch<T> for Result<Array<T>, ArrayError> {
103
104 fn argmax(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError> {
105 self.clone()?.argmax(axis, keepdims)
106 }
107
108 fn argmin(&self, axis: Option<isize>, keepdims: Option<bool>) -> Result<Array<usize>, ArrayError> {
109 self.clone()?.argmin(axis, keepdims)
110 }
111}