argminmax/scalar/
generic.rs

1#[cfg(feature = "float")]
2use num_traits::float::FloatCore;
3use num_traits::PrimInt;
4
5use super::super::dtype_strategy::Int;
6/// The DTypeStrategy for which we implement the ScalarArgMinMax trait
7#[cfg(any(feature = "float", feature = "half"))]
8use super::super::dtype_strategy::{FloatIgnoreNaN, FloatReturnNaN};
9
10/// Helper trait to initialize the min and max values & check if we should return
11/// This will be implemented for all:
12/// - ints - Int DTypeStrategy (see 1st impl block below)
13/// - uints - Int DTypeStrategy (see 1st impl block below)
14/// - floats: returning NaNs - FloatReturnNan DTypeStrategy (see 2nd impl block below)
15/// - floats: ignoring NaNs - FloatIgnoreNaN DTypeStrategy (see 3rd impl block below)
16///
17trait SCALARInit<ScalarDType: Copy + PartialOrd> {
18    const _RETURN_AT_NAN: bool;
19
20    /// Initialize the initial value for the min and max values
21
22    fn _init_min(start_value: ScalarDType) -> ScalarDType;
23
24    fn _init_max(start_value: ScalarDType) -> ScalarDType;
25
26    /// Check if we should allow the updating the value(s) with the first non-NaN value
27
28    fn _allow_first_non_nan_update(start_value: ScalarDType) -> bool;
29
30    /// Nan check
31
32    fn _nan_check(v: ScalarDType) -> bool;
33}
34
35/// A trait providing the scalar implementation of the argminmax operations.
36///
37// This trait will be implemented for the different DTypeStrategy
38pub trait ScalarArgMinMax<ScalarDType: Copy + PartialOrd> {
39    /// Get the index of the minimum and maximum values in the slice.
40    ///
41    /// # Arguments
42    /// - `data` - the slice of data.
43    ///
44    /// # Returns
45    /// A tuple of the index of the minimum and maximum values in the slice
46    /// `(min_index, max_index)`.
47    ///
48    fn argminmax(data: &[ScalarDType]) -> (usize, usize);
49
50    /// Get the index of the minimum value in the slice.
51    ///
52    /// # Arguments
53    /// - `data` - the slice of data.
54    ///
55    /// # Returns
56    /// The index of the minimum value in the slice.
57    ///
58    fn argmin(data: &[ScalarDType]) -> usize;
59
60    /// Get the index of the maximum value in the slice.
61    ///
62    /// # Arguments
63    /// - `data` - the slice of data.
64    ///
65    /// # Returns
66    /// The index of the maximum value in the slice.
67    ///
68    fn argmax(data: &[ScalarDType]) -> usize;
69}
70
71/// Type that implements the [ScalarArgMinMax](crate::ScalarArgMinMax) trait.
72///
73/// This struct implements the ScalarArgMinMax trait for the different data types and their [datatype strategies](crate::dtype_strategy).
74///
75// See the impl_scalar! macro below for the implementation of the ScalarArgMinMax trait
76pub struct SCALAR<DTypeStrategy> {
77    pub(crate) _dtype_strategy: std::marker::PhantomData<DTypeStrategy>,
78}
79
80/// ------- Implement the SCALARInit trait for the different DTypeStrategy -------
81
82impl<ScalarDType> SCALARInit<ScalarDType> for SCALAR<Int>
83where
84    ScalarDType: PrimInt,
85{
86    const _RETURN_AT_NAN: bool = false;
87
88    #[inline(always)]
89    fn _init_min(start_value: ScalarDType) -> ScalarDType {
90        start_value
91    }
92
93    #[inline(always)]
94    fn _init_max(start_value: ScalarDType) -> ScalarDType {
95        start_value
96    }
97
98    #[inline(always)]
99    fn _allow_first_non_nan_update(_start_value: ScalarDType) -> bool {
100        false
101    }
102
103    #[inline(always)]
104    fn _nan_check(_v: ScalarDType) -> bool {
105        false
106    }
107}
108
109#[cfg(feature = "float")]
110impl<ScalarDType> SCALARInit<ScalarDType> for SCALAR<FloatReturnNaN>
111where
112    ScalarDType: FloatCore,
113{
114    const _RETURN_AT_NAN: bool = true;
115
116    #[inline(always)]
117    fn _init_min(start_value: ScalarDType) -> ScalarDType {
118        start_value
119    }
120
121    #[inline(always)]
122    fn _init_max(start_value: ScalarDType) -> ScalarDType {
123        start_value
124    }
125
126    #[inline(always)]
127    fn _allow_first_non_nan_update(_start_value: ScalarDType) -> bool {
128        false
129    }
130
131    #[inline(always)]
132    fn _nan_check(v: ScalarDType) -> bool {
133        v.is_nan()
134    }
135}
136
137#[cfg(feature = "float")]
138impl<ScalarDType> SCALARInit<ScalarDType> for SCALAR<FloatIgnoreNaN>
139where
140    ScalarDType: FloatCore,
141{
142    const _RETURN_AT_NAN: bool = false;
143
144    #[inline(always)]
145    fn _init_min(start_value: ScalarDType) -> ScalarDType {
146        if start_value.is_nan() {
147            ScalarDType::infinity()
148        } else {
149            start_value
150        }
151    }
152
153    #[inline(always)]
154    fn _init_max(start_value: ScalarDType) -> ScalarDType {
155        if start_value.is_nan() {
156            ScalarDType::neg_infinity()
157        } else {
158            start_value
159        }
160    }
161
162    #[inline(always)]
163    fn _allow_first_non_nan_update(start_value: ScalarDType) -> bool {
164        start_value.is_nan()
165    }
166
167    #[inline(always)]
168    fn _nan_check(v: ScalarDType) -> bool {
169        v.is_nan()
170    }
171}
172
173/// ------- Implement the ScalarArgMinMax trait for the different DTypeStrategy -------
174
175macro_rules! impl_scalar {
176    ($dtype_strategy:ty, $($dtype:ty),*) => {
177        $(
178            impl ScalarArgMinMax<$dtype> for SCALAR<$dtype_strategy>
179            {
180                #[inline(always)]
181                fn argminmax(arr: &[$dtype]) -> (usize, usize) {
182                    assert!(!arr.is_empty());
183                    let mut low_index: usize = 0;
184                    let mut high_index: usize = 0;
185                    // It is remarkably faster to iterate over the index and use get_unchecked
186                    // than using .iter().enumerate() (with a fold).
187                    let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
188                    let mut low: $dtype = Self::_init_min(start_value);
189                    let mut high: $dtype = Self::_init_max(start_value);
190                    let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
191                    for i in 0..arr.len() {
192                        let v: $dtype = unsafe { *arr.get_unchecked(i) };
193                        if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
194                            // When _RETURN_AT_NAN is true and we encounter a NaN
195                            return (i, i); // -> return the index
196                        }
197                        if first_non_nan_update {
198                            // If we allow the first non-nan update (only for FloatIgnoreNaN)
199                            if !Self::_nan_check(v) {
200                                // Update the low and high
201                                low = v;
202                                low_index = i;
203                                high = v;
204                                high_index = i;
205                                // And disable the first_non_nan_update update
206                                first_non_nan_update = false;
207                            }
208                        } else if v < low {
209                            low = v;
210                            low_index = i;
211                        } else if v > high {
212                            high = v;
213                            high_index = i;
214                        }
215                    }
216                    (low_index, high_index)
217                }
218
219                #[inline(always)]
220                fn argmin(arr: &[$dtype]) -> usize {
221                    assert!(!arr.is_empty());
222                    let mut low_index: usize = 0;
223                    // It is remarkably faster to iterate over the index and use get_unchecked
224                    // than using .iter().enumerate() (with a fold).
225                    let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
226                    let mut low: $dtype = Self::_init_min(start_value);
227                    let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
228                    for i in 0..arr.len() {
229                        let v: $dtype = unsafe { *arr.get_unchecked(i) };
230                        if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
231                            // When _RETURN_AT_NAN is true and we encounter a NaN
232                            return i; // -> return the index
233                        }
234                        if first_non_nan_update {
235                            // If we allow the first non-nan update (only for FloatIgnoreNaN)
236                            if !Self::_nan_check(v) {
237                                // Update the low
238                                low = v;
239                                low_index = i;
240                                // And disable the first_non_nan_update update
241                                first_non_nan_update = false;
242                            }
243                        } else if v < low {
244                            low = v;
245                            low_index = i;
246                        }
247                    }
248                    low_index
249                }
250
251                #[inline(always)]
252                fn argmax(arr: &[$dtype]) -> usize {
253                    assert!(!arr.is_empty());
254                    let mut high_index: usize = 0;
255                    // It is remarkably faster to iterate over the index and use get_unchecked
256                    // than using .iter().enumerate() (with a fold).
257                    let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
258                    let mut high: $dtype = Self::_init_max(start_value);
259                    let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
260                    for i in 0..arr.len() {
261                        let v: $dtype = unsafe { *arr.get_unchecked(i) };
262                        if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
263                            // When _RETURN_AT_NAN is true and we encounter a NaN
264                            return i; // -> return the index
265                        }
266                        if first_non_nan_update {
267                            // If we allow the first non-nan update (only for FloatIgnoreNaN)
268                            if !Self::_nan_check(v) {
269                                // Update the high
270                                high = v;
271                                high_index = i;
272                                // And disable the first_non_nan_update update
273                                first_non_nan_update = false;
274                            }
275                        } else if v > high {
276                            high = v;
277                            high_index = i;
278                        }
279                    }
280                    high_index
281                }
282            }
283        )*
284    };
285}
286
287impl_scalar!(Int, i8, i16, i32, i64, u8, u16, u32, u64);
288#[cfg(feature = "float")]
289impl_scalar!(FloatReturnNaN, f32, f64);
290#[cfg(feature = "float")]
291impl_scalar!(FloatIgnoreNaN, f32, f64);
292
293// --- Optional data types
294
295#[cfg(feature = "half")]
296use super::scalar_f16::{
297    scalar_argmax_f16_ignore_nan, scalar_argmin_f16_ignore_nan, scalar_argminmax_f16_ignore_nan,
298};
299#[cfg(feature = "half")]
300use super::scalar_f16::{
301    scalar_argmax_f16_return_nan, scalar_argmin_f16_return_nan, scalar_argminmax_f16_return_nan,
302};
303
304#[cfg(feature = "half")]
305use half::f16;
306
307#[cfg(feature = "half")]
308impl ScalarArgMinMax<f16> for SCALAR<FloatReturnNaN> {
309    #[inline(always)]
310    fn argminmax(arr: &[f16]) -> (usize, usize) {
311        scalar_argminmax_f16_return_nan(arr)
312    }
313
314    #[inline(always)]
315    fn argmin(arr: &[f16]) -> usize {
316        scalar_argmin_f16_return_nan(arr)
317    }
318
319    #[inline(always)]
320    fn argmax(arr: &[f16]) -> usize {
321        scalar_argmax_f16_return_nan(arr)
322    }
323}
324
325#[cfg(feature = "half")]
326impl ScalarArgMinMax<f16> for SCALAR<FloatIgnoreNaN> {
327    #[inline(always)]
328    fn argminmax(arr: &[f16]) -> (usize, usize) {
329        scalar_argminmax_f16_ignore_nan(arr)
330    }
331
332    #[inline(always)]
333    fn argmin(arr: &[f16]) -> usize {
334        scalar_argmin_f16_ignore_nan(arr)
335    }
336
337    #[inline(always)]
338    fn argmax(arr: &[f16]) -> usize {
339        scalar_argmax_f16_ignore_nan(arr)
340    }
341}