1use rten_simd::ops::NumOps;
2use rten_simd::{Isa, Simd, SimdIterable, SimdOp};
3
4pub struct MinMax<'a> {
6 input: &'a [f32],
7}
8
9impl<'a> MinMax<'a> {
10 pub fn new(input: &'a [f32]) -> Self {
11 MinMax { input }
12 }
13}
14
15impl SimdOp for MinMax<'_> {
16 type Output = (f32, f32);
17
18 #[inline(always)]
19 fn eval<I: Isa>(self, isa: I) -> Self::Output {
20 let ops = isa.f32();
21 let [vec_min, vec_max] = self.input.simd_iter(ops).fold_n_unroll::<2, 4>(
22 [ops.splat(f32::MAX), ops.splat(f32::MIN)],
23 #[inline(always)]
24 |[min, max], x| [ops.min(x, min), ops.max(x, max)],
25 #[inline(always)]
26 |[min_a, max_a], [min_b, max_b]| [ops.min(min_a, min_b), ops.max(max_a, max_b)],
27 );
28 let min = vec_min
29 .to_array()
30 .as_ref()
31 .iter()
32 .fold(f32::MAX, |min, x| x.min(min));
33 let max = vec_max
34 .to_array()
35 .as_ref()
36 .iter()
37 .fold(f32::MIN, |max, x| x.max(max));
38 (min, max)
39 }
40}
41
42pub struct MaxNum<'a, T> {
44 input: &'a [T],
45}
46
47impl<'a, T> MaxNum<'a, T> {
48 pub fn new(input: &'a [T]) -> Self {
49 MaxNum { input }
50 }
51}
52
53impl<'a> SimdOp for MaxNum<'a, f32> {
54 type Output = f32;
55
56 #[inline(always)]
57 fn eval<I: Isa>(self, isa: I) -> Self::Output {
58 let ops = isa.f32();
59
60 let max_num = |max, x| {
61 let not_nan = ops.eq(x, x);
62 let new_max = ops.max(max, x);
63 ops.select(new_max, x, not_nan)
64 };
65
66 let vec_max =
67 self.input
68 .simd_iter(ops)
69 .fold_unroll::<2>(ops.splat(f32::MIN), max_num, max_num);
70
71 vec_max
72 .to_array()
73 .as_ref()
74 .iter()
75 .copied()
76 .fold(f32::MIN, |max, x| {
77 if x.is_nan() {
78 x
79 } else if max.is_nan() {
80 max
81 } else {
82 x.max(max)
83 }
84 })
85 }
86}
87
88pub struct MinNum<'a, T> {
90 input: &'a [T],
91}
92
93impl<'a, T> MinNum<'a, T> {
94 pub fn new(input: &'a [T]) -> Self {
95 MinNum { input }
96 }
97}
98
99impl<'a> SimdOp for MinNum<'a, f32> {
100 type Output = f32;
101
102 #[inline(always)]
103 fn eval<I: Isa>(self, isa: I) -> Self::Output {
104 let ops = isa.f32();
105
106 let min_num = |min, x| {
107 let not_nan = ops.eq(x, x);
108 let new_min = ops.min(min, x);
109 ops.select(new_min, x, not_nan)
110 };
111
112 let vec_min = self.input.simd_iter(ops).fold(ops.splat(f32::MAX), min_num);
113
114 vec_min
115 .to_array()
116 .as_ref()
117 .iter()
118 .copied()
119 .fold(f32::MAX, |min, x| {
120 if x.is_nan() {
121 x
122 } else if min.is_nan() {
123 min
124 } else {
125 x.min(min)
126 }
127 })
128 }
129}
130
131#[cfg(test)]
132mod tests {
133 use super::{MaxNum, MinMax, MinNum};
134 use rten_simd::SimdOp;
135
136 const LEN: usize = 100;
139
140 fn reference_min_max(xs: &[f32]) -> (f32, f32) {
141 let min = xs.iter().fold(f32::MAX, |min, x| x.min(min));
142 let max = xs.iter().fold(f32::MIN, |max, x| x.max(max));
143 (min, max)
144 }
145
146 #[test]
147 fn test_min_max() {
148 let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
149 let expected = reference_min_max(&xs);
150 let min_max = MinMax::new(&xs).dispatch();
151 assert_eq!(min_max, expected);
152 }
153
154 #[test]
155 fn test_max_num() {
156 let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
157 let (_, expected_max) = reference_min_max(&xs);
158 let max = MaxNum::new(&xs).dispatch();
159 assert_eq!(max, expected_max);
160
161 let xs = [0.1, 1.0, 0.2, f32::NAN, 0.4, 0.5, 0.6];
162 let max = MaxNum::new(&xs).dispatch();
163 assert!(max.is_nan());
164 }
165
166 #[test]
167 fn test_min_num() {
168 let xs: Vec<f32> = (0..LEN).map(|i| i as f32 * 0.1).collect();
169 let (expected_min, _) = reference_min_max(&xs);
170 let min = MinNum::new(&xs).dispatch();
171 assert_eq!(min, expected_min);
172
173 let xs = [0.1, 1.0, 0.2, f32::NAN, 0.4, 0.5, 0.6];
174 let min = MinNum::new(&xs).dispatch();
175 assert!(min.is_nan());
176 }
177}