1use std::mem::MaybeUninit;
2
3use rten_simd::functional::simd_map;
4use rten_simd::ops::NumOps;
5use rten_simd::span::SrcDest;
6use rten_simd::{Isa, SimdIterable, SimdOp};
7
8pub struct Normalize<'src, 'dst> {
20 src_dest: SrcDest<'src, 'dst, f32>,
21 opts: NormalizeOptions<'src>,
22}
23
24impl<'src, 'dst> Normalize<'src, 'dst> {
25 pub fn new(
28 input: &'src [f32],
29 output: &'dst mut [MaybeUninit<f32>],
30 opts: NormalizeOptions<'src>,
31 ) -> Self {
32 Normalize {
33 src_dest: (input, output).into(),
34 opts,
35 }
36 }
37
38 pub fn new_mut(input: &'dst mut [f32], opts: NormalizeOptions<'src>) -> Self
40 where
41 'dst: 'src,
42 {
43 Normalize {
44 src_dest: input.into(),
45 opts,
46 }
47 }
48}
49
50pub struct NormalizeOptions<'a> {
52 pub pre_scale_bias: f32,
54
55 pub scale: f32,
58
59 pub element_scale: Option<&'a [f32]>,
61
62 pub bias: f32,
64
65 pub element_bias: Option<&'a [f32]>,
67}
68
69impl Default for NormalizeOptions<'_> {
70 fn default() -> Self {
71 NormalizeOptions {
72 pre_scale_bias: 0.,
73 scale: 1.,
74 element_scale: None,
75 bias: 0.,
76 element_bias: None,
77 }
78 }
79}
80
81impl<'dst> SimdOp for Normalize<'_, 'dst> {
82 type Output = &'dst mut [f32];
84
85 #[inline(always)]
86 fn eval<I: Isa>(self, isa: I) -> Self::Output {
87 let ops = isa.f32();
88
89 let Self {
90 src_dest,
91 opts:
92 NormalizeOptions {
93 pre_scale_bias,
94 scale,
95 element_scale,
96 bias,
97 element_bias,
98 },
99 } = self;
100
101 if let Some(scale) = element_scale {
102 assert_eq!(scale.len(), src_dest.len());
103 }
104 if let Some(bias) = element_bias {
105 assert_eq!(bias.len(), src_dest.len());
106 }
107
108 let one = ops.one();
109 let zero = ops.zero();
110 let pre_scale_bias_vec = ops.splat(pre_scale_bias);
111
112 match (element_scale, element_bias, scale, bias) {
113 (None, None, scale, bias) => {
114 let const_scale_vec = ops.splat(scale);
116 let const_bias_vec = ops.splat(bias);
117
118 simd_map(
119 ops,
120 src_dest,
121 #[inline(always)]
122 |x| {
123 let y = ops.sub(x, pre_scale_bias_vec);
124 ops.mul_add(y, const_scale_vec, const_bias_vec)
125 },
126 )
127 }
128 (Some(scale), None, const_scale, 0.) => {
129 let const_scale_vec = ops.splat(const_scale);
132 let mut scale_iter = scale.simd_iter_pad(ops);
133
134 simd_map(
135 ops,
136 src_dest,
137 #[inline(always)]
138 |x| {
139 let scale_vec = scale_iter.next().unwrap();
140 let scale_vec = ops.mul(scale_vec, const_scale_vec);
141
142 let y = ops.sub(x, pre_scale_bias_vec);
143 ops.mul(y, scale_vec)
144 },
145 )
146 }
147 (element_scale, element_bias, const_scale, const_bias) => {
148 let const_scale_vec = ops.splat(const_scale);
149 let const_bias_vec = ops.splat(const_bias);
150 let mut scale_iter = element_scale.map(|s| s.simd_iter_pad(ops));
151 let mut bias_iter = element_bias.map(|b| b.simd_iter_pad(ops));
152
153 simd_map(
154 ops,
155 src_dest,
156 #[inline(always)]
157 |x| {
158 let scale_vec = scale_iter.as_mut().and_then(|s| s.next()).unwrap_or(one);
159 let scale_vec = ops.mul(scale_vec, const_scale_vec);
160
161 let bias_vec = bias_iter.as_mut().and_then(|b| b.next()).unwrap_or(zero);
162 let bias_vec = ops.add(bias_vec, const_bias_vec);
163
164 let y = ops.sub(x, pre_scale_bias_vec);
165 ops.mul_add(y, scale_vec, bias_vec)
166 },
167 )
168 }
169 }
170 }
171}
172
173#[cfg(test)]
174mod tests {
175 use super::{Normalize, NormalizeOptions};
176 use rten_simd::SimdOp;
177
178 fn reference_normalize_mut(
179 data: &mut [f32],
180 pre_scale_bias: f32,
181 scale: f32,
182 element_scale: Option<&[f32]>,
183 bias: f32,
184 element_bias: Option<&[f32]>,
185 ) {
186 for i in 0..data.len() {
187 let x_scale = scale * element_scale.map(|es| es[i]).unwrap_or(1.);
188 let x_bias = bias + element_bias.map(|eb| eb[i]).unwrap_or(0.);
189 data[i] = (data[i] - pre_scale_bias).mul_add(x_scale, x_bias)
190 }
191 }
192
193 #[test]
194 fn test_normalize_mut() {
195 let data: Vec<_> = (0..10).map(|i| i as f32 * 0.1).collect();
196 let pre_scale_bias = 0.5;
197 let scale = 0.123;
198 let element_scale: Vec<_> = (0..data.len()).map(|i| 1.0 + i as f32 * 0.1).collect();
199 let bias = 0.3;
200 let element_bias: Vec<_> = (0..data.len()).map(|i| -0.5 + i as f32 * 0.2).collect();
201
202 let mut expected = data.clone();
204 reference_normalize_mut(
205 &mut expected[..],
206 pre_scale_bias,
207 scale,
208 Some(&element_scale),
209 bias,
210 Some(&element_bias),
211 );
212
213 let mut actual = data.clone();
214 Normalize::new_mut(
215 &mut actual[..],
216 NormalizeOptions {
217 pre_scale_bias,
218 scale,
219 element_scale: Some(&element_scale),
220 bias,
221 element_bias: Some(&element_bias),
222 },
223 )
224 .dispatch();
225 assert_eq!(actual, expected);
226
227 let mut expected = data.clone();
229 reference_normalize_mut(
230 &mut expected[..],
231 pre_scale_bias,
232 scale,
233 Some(&element_scale),
234 0.,
235 None,
236 );
237
238 let mut actual = data.clone();
239 Normalize::new_mut(
240 &mut actual[..],
241 NormalizeOptions {
242 pre_scale_bias,
243 scale,
244 element_scale: Some(&element_scale),
245 bias: 0.,
246 element_bias: None,
247 },
248 )
249 .dispatch();
250 assert_eq!(actual, expected);
251
252 let mut expected = data.clone();
254 reference_normalize_mut(&mut expected[..], pre_scale_bias, scale, None, bias, None);
255
256 let mut actual = data.clone();
257 Normalize::new_mut(
258 &mut actual[..],
259 NormalizeOptions {
260 pre_scale_bias,
261 scale,
262 element_scale: None,
263 bias,
264 element_bias: None,
265 },
266 )
267 .dispatch();
268
269 assert_eq!(actual, expected);
270 }
271}