rten_vecmath/
min_max.rs

1use rten_simd::ops::NumOps;
2use rten_simd::{Isa, Simd, SimdIterable, SimdOp};
3
4/// Compute the minimum and maximum values in a slice of floats.
5pub struct MinMax<'a> {
6    input: &'a [f32],
7}
8
9impl<'a> MinMax<'a> {
10    pub fn new(input: &'a [f32]) -> Self {
11        MinMax { input }
12    }
13}
14
15impl SimdOp for MinMax<'_> {
16    type Output = (f32, f32);
17
18    #[inline(always)]
19    fn eval<I: Isa>(self, isa: I) -> Self::Output {
20        let ops = isa.f32();
21        let [vec_min, vec_max] = self.input.simd_iter(ops).fold_n_unroll::<2, 4>(
22            [ops.splat(f32::MAX), ops.splat(f32::MIN)],
23            #[inline(always)]
24            |[min, max], x| [ops.min(x, min), ops.max(x, max)],
25            #[inline(always)]
26            |[min_a, max_a], [min_b, max_b]| [ops.min(min_a, min_b), ops.max(max_a, max_b)],
27        );
28        let min = vec_min
29            .to_array()
30            .as_ref()
31            .iter()
32            .fold(f32::MAX, |min, x| x.min(min));
33        let max = vec_max
34            .to_array()
35            .as_ref()
36            .iter()
37            .fold(f32::MIN, |max, x| x.max(max));
38        (min, max)
39    }
40}
41
42/// Compute the maximum value in a slice, propagating NaNs.
43pub struct MaxNum<'a, T> {
44    input: &'a [T],
45}
46
47impl<'a, T> MaxNum<'a, T> {
48    pub fn new(input: &'a [T]) -> Self {
49        MaxNum { input }
50    }
51}
52
53impl<'a> SimdOp for MaxNum<'a, f32> {
54    type Output = f32;
55
56    #[inline(always)]
57    fn eval<I: Isa>(self, isa: I) -> Self::Output {
58        let ops = isa.f32();
59
60        let max_num = |max, x| {
61            let not_nan = ops.eq(x, x);
62            let new_max = ops.max(max, x);
63            ops.select(new_max, x, not_nan)
64        };
65
66        let vec_max =
67            self.input
68                .simd_iter(ops)
69                .fold_unroll::<2>(ops.splat(f32::MIN), max_num, max_num);
70
71        vec_max
72            .to_array()
73            .as_ref()
74            .iter()
75            .copied()
76            .fold(f32::MIN, |max, x| {
77                if x.is_nan() {
78                    x
79                } else if max.is_nan() {
80                    max
81                } else {
82                    x.max(max)
83                }
84            })
85    }
86}
87
88/// Compute the minimum value in a slice, propagating NaNs.
89pub struct MinNum<'a, T> {
90    input: &'a [T],
91}
92
93impl<'a, T> MinNum<'a, T> {
94    pub fn new(input: &'a [T]) -> Self {
95        MinNum { input }
96    }
97}
98
99impl<'a> SimdOp for MinNum<'a, f32> {
100    type Output = f32;
101
102    #[inline(always)]
103    fn eval<I: Isa>(self, isa: I) -> Self::Output {
104        let ops = isa.f32();
105
106        let min_num = |min, x| {
107            let not_nan = ops.eq(x, x);
108            let new_min = ops.min(min, x);
109            ops.select(new_min, x, not_nan)
110        };
111
112        let vec_min = self.input.simd_iter(ops).fold(ops.splat(f32::MAX), min_num);
113
114        vec_min
115            .to_array()
116            .as_ref()
117            .iter()
118            .copied()
119            .fold(f32::MAX, |min, x| {
120                if x.is_nan() {
121                    x
122                } else if min.is_nan() {
123                    min
124                } else {
125                    x.min(min)
126                }
127            })
128    }
129}
130
131#[cfg(test)]
132mod tests {
133    use super::{MaxNum, MinMax, MinNum};
134    use rten_simd::SimdOp;
135
136    // Chosen to not be a multiple of vector size, so that tail handling is
137    // exercised.
138    const LEN: usize = 100;
139
140    fn reference_min_max(xs: &[f32]) -> (f32, f32) {
141        let min = xs.iter().fold(f32::MAX, |min, x| x.min(min));
142        let max = xs.iter().fold(f32::MIN, |max, x| x.max(max));
143        (min, max)
144    }
145
146    #[test]
147    fn test_min_max() {
148        let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
149        let expected = reference_min_max(&xs);
150        let min_max = MinMax::new(&xs).dispatch();
151        assert_eq!(min_max, expected);
152    }
153
154    #[test]
155    fn test_max_num() {
156        let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
157        let (_, expected_max) = reference_min_max(&xs);
158        let max = MaxNum::new(&xs).dispatch();
159        assert_eq!(max, expected_max);
160
161        let xs = [0.1, 1.0, 0.2, f32::NAN, 0.4, 0.5, 0.6];
162        let max = MaxNum::new(&xs).dispatch();
163        assert!(max.is_nan());
164    }
165
166    #[test]
167    fn test_min_num() {
168        let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
169        let (expected_min, _) = reference_min_max(&xs);
170        let min = MinNum::new(&xs).dispatch();
171        assert_eq!(min, expected_min);
172
173        let xs = [0.1, 1.0, 0.2, f32::NAN, 0.4, 0.5, 0.6];
174        let min = MinNum::new(&xs).dispatch();
175        assert!(min.is_nan());
176    }
177}