argminmax/scalar/
generic.rs1#[cfg(feature = "float")]
2use num_traits::float::FloatCore;
3use num_traits::PrimInt;
4
5use super::super::dtype_strategy::Int;
6#[cfg(any(feature = "float", feature = "half"))]
8use super::super::dtype_strategy::{FloatIgnoreNaN, FloatReturnNaN};
9
10trait SCALARInit<ScalarDType: Copy + PartialOrd> {
18 const _RETURN_AT_NAN: bool;
19
20 fn _init_min(start_value: ScalarDType) -> ScalarDType;
23
24 fn _init_max(start_value: ScalarDType) -> ScalarDType;
25
26 fn _allow_first_non_nan_update(start_value: ScalarDType) -> bool;
29
30 fn _nan_check(v: ScalarDType) -> bool;
33}
34
35pub trait ScalarArgMinMax<ScalarDType: Copy + PartialOrd> {
39 fn argminmax(data: &[ScalarDType]) -> (usize, usize);
49
50 fn argmin(data: &[ScalarDType]) -> usize;
59
60 fn argmax(data: &[ScalarDType]) -> usize;
69}
70
71pub struct SCALAR<DTypeStrategy> {
77 pub(crate) _dtype_strategy: std::marker::PhantomData<DTypeStrategy>,
78}
79
80impl<ScalarDType> SCALARInit<ScalarDType> for SCALAR<Int>
83where
84 ScalarDType: PrimInt,
85{
86 const _RETURN_AT_NAN: bool = false;
87
88 #[inline(always)]
89 fn _init_min(start_value: ScalarDType) -> ScalarDType {
90 start_value
91 }
92
93 #[inline(always)]
94 fn _init_max(start_value: ScalarDType) -> ScalarDType {
95 start_value
96 }
97
98 #[inline(always)]
99 fn _allow_first_non_nan_update(_start_value: ScalarDType) -> bool {
100 false
101 }
102
103 #[inline(always)]
104 fn _nan_check(_v: ScalarDType) -> bool {
105 false
106 }
107}
108
109#[cfg(feature = "float")]
110impl<ScalarDType> SCALARInit<ScalarDType> for SCALAR<FloatReturnNaN>
111where
112 ScalarDType: FloatCore,
113{
114 const _RETURN_AT_NAN: bool = true;
115
116 #[inline(always)]
117 fn _init_min(start_value: ScalarDType) -> ScalarDType {
118 start_value
119 }
120
121 #[inline(always)]
122 fn _init_max(start_value: ScalarDType) -> ScalarDType {
123 start_value
124 }
125
126 #[inline(always)]
127 fn _allow_first_non_nan_update(_start_value: ScalarDType) -> bool {
128 false
129 }
130
131 #[inline(always)]
132 fn _nan_check(v: ScalarDType) -> bool {
133 v.is_nan()
134 }
135}
136
137#[cfg(feature = "float")]
138impl<ScalarDType> SCALARInit<ScalarDType> for SCALAR<FloatIgnoreNaN>
139where
140 ScalarDType: FloatCore,
141{
142 const _RETURN_AT_NAN: bool = false;
143
144 #[inline(always)]
145 fn _init_min(start_value: ScalarDType) -> ScalarDType {
146 if start_value.is_nan() {
147 ScalarDType::infinity()
148 } else {
149 start_value
150 }
151 }
152
153 #[inline(always)]
154 fn _init_max(start_value: ScalarDType) -> ScalarDType {
155 if start_value.is_nan() {
156 ScalarDType::neg_infinity()
157 } else {
158 start_value
159 }
160 }
161
162 #[inline(always)]
163 fn _allow_first_non_nan_update(start_value: ScalarDType) -> bool {
164 start_value.is_nan()
165 }
166
167 #[inline(always)]
168 fn _nan_check(v: ScalarDType) -> bool {
169 v.is_nan()
170 }
171}
172
173macro_rules! impl_scalar {
176 ($dtype_strategy:ty, $($dtype:ty),*) => {
177 $(
178 impl ScalarArgMinMax<$dtype> for SCALAR<$dtype_strategy>
179 {
180 #[inline(always)]
181 fn argminmax(arr: &[$dtype]) -> (usize, usize) {
182 assert!(!arr.is_empty());
183 let mut low_index: usize = 0;
184 let mut high_index: usize = 0;
185 let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
188 let mut low: $dtype = Self::_init_min(start_value);
189 let mut high: $dtype = Self::_init_max(start_value);
190 let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
191 for i in 0..arr.len() {
192 let v: $dtype = unsafe { *arr.get_unchecked(i) };
193 if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
194 return (i, i); }
197 if first_non_nan_update {
198 if !Self::_nan_check(v) {
200 low = v;
202 low_index = i;
203 high = v;
204 high_index = i;
205 first_non_nan_update = false;
207 }
208 } else if v < low {
209 low = v;
210 low_index = i;
211 } else if v > high {
212 high = v;
213 high_index = i;
214 }
215 }
216 (low_index, high_index)
217 }
218
219 #[inline(always)]
220 fn argmin(arr: &[$dtype]) -> usize {
221 assert!(!arr.is_empty());
222 let mut low_index: usize = 0;
223 let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
226 let mut low: $dtype = Self::_init_min(start_value);
227 let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
228 for i in 0..arr.len() {
229 let v: $dtype = unsafe { *arr.get_unchecked(i) };
230 if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
231 return i; }
234 if first_non_nan_update {
235 if !Self::_nan_check(v) {
237 low = v;
239 low_index = i;
240 first_non_nan_update = false;
242 }
243 } else if v < low {
244 low = v;
245 low_index = i;
246 }
247 }
248 low_index
249 }
250
251 #[inline(always)]
252 fn argmax(arr: &[$dtype]) -> usize {
253 assert!(!arr.is_empty());
254 let mut high_index: usize = 0;
255 let start_value: $dtype = unsafe { *arr.get_unchecked(0) };
258 let mut high: $dtype = Self::_init_max(start_value);
259 let mut first_non_nan_update: bool = Self::_allow_first_non_nan_update(start_value);
260 for i in 0..arr.len() {
261 let v: $dtype = unsafe { *arr.get_unchecked(i) };
262 if <Self as SCALARInit<$dtype>>::_RETURN_AT_NAN && Self::_nan_check(v) {
263 return i; }
266 if first_non_nan_update {
267 if !Self::_nan_check(v) {
269 high = v;
271 high_index = i;
272 first_non_nan_update = false;
274 }
275 } else if v > high {
276 high = v;
277 high_index = i;
278 }
279 }
280 high_index
281 }
282 }
283 )*
284 };
285}
286
287impl_scalar!(Int, i8, i16, i32, i64, u8, u16, u32, u64);
288#[cfg(feature = "float")]
289impl_scalar!(FloatReturnNaN, f32, f64);
290#[cfg(feature = "float")]
291impl_scalar!(FloatIgnoreNaN, f32, f64);
292
293#[cfg(feature = "half")]
296use super::scalar_f16::{
297 scalar_argmax_f16_ignore_nan, scalar_argmin_f16_ignore_nan, scalar_argminmax_f16_ignore_nan,
298};
299#[cfg(feature = "half")]
300use super::scalar_f16::{
301 scalar_argmax_f16_return_nan, scalar_argmin_f16_return_nan, scalar_argminmax_f16_return_nan,
302};
303
304#[cfg(feature = "half")]
305use half::f16;
306
307#[cfg(feature = "half")]
308impl ScalarArgMinMax<f16> for SCALAR<FloatReturnNaN> {
309 #[inline(always)]
310 fn argminmax(arr: &[f16]) -> (usize, usize) {
311 scalar_argminmax_f16_return_nan(arr)
312 }
313
314 #[inline(always)]
315 fn argmin(arr: &[f16]) -> usize {
316 scalar_argmin_f16_return_nan(arr)
317 }
318
319 #[inline(always)]
320 fn argmax(arr: &[f16]) -> usize {
321 scalar_argmax_f16_return_nan(arr)
322 }
323}
324
325#[cfg(feature = "half")]
326impl ScalarArgMinMax<f16> for SCALAR<FloatIgnoreNaN> {
327 #[inline(always)]
328 fn argminmax(arr: &[f16]) -> (usize, usize) {
329 scalar_argminmax_f16_ignore_nan(arr)
330 }
331
332 #[inline(always)]
333 fn argmin(arr: &[f16]) -> usize {
334 scalar_argmin_f16_ignore_nan(arr)
335 }
336
337 #[inline(always)]
338 fn argmax(arr: &[f16]) -> usize {
339 scalar_argmax_f16_ignore_nan(arr)
340 }
341}