rten_vecmath/
normalize.rs

1use std::mem::MaybeUninit;
2
3use rten_simd::functional::simd_map;
4use rten_simd::ops::NumOps;
5use rten_simd::span::SrcDest;
6use rten_simd::{Isa, SimdIterable, SimdOp};
7
8/// Normalize the mean and variance of elements in a slice.
9///
10/// This normalizes elements according to the formula:
11///
12/// ```text
13/// output[i] = (input[i] - pre_scale_bias) * scale * element_scale[i] + bias + element_bias[i]
14/// ```
15///
16/// # Panics
17///
18/// Dispatching the operation panics if any of the slices have different lengths.
19pub struct Normalize<'src, 'dst> {
20    src_dest: SrcDest<'src, 'dst, f32>,
21    opts: NormalizeOptions<'src>,
22}
23
24impl<'src, 'dst> Normalize<'src, 'dst> {
25    /// Create a normalize operation which reads `input` and writes the normalized
26    /// output to `output`.
27    pub fn new(
28        input: &'src [f32],
29        output: &'dst mut [MaybeUninit<f32>],
30        opts: NormalizeOptions<'src>,
31    ) -> Self {
32        Normalize {
33            src_dest: (input, output).into(),
34            opts,
35        }
36    }
37
38    /// Create a normalize operation which normalizes `input` in-place.
39    pub fn new_mut(input: &'dst mut [f32], opts: NormalizeOptions<'src>) -> Self
40    where
41        'dst: 'src,
42    {
43        Normalize {
44            src_dest: input.into(),
45            opts,
46        }
47    }
48}
49
50/// Configuration for the [`Normalize`] operation.
51pub struct NormalizeOptions<'a> {
52    /// Bias to subtract before scaling. This is usually the mean of the data.
53    pub pre_scale_bias: f32,
54
55    /// Constant scale to multiply each element by. This is usually the inverse
56    /// standard deviation of the data.
57    pub scale: f32,
58
59    /// Per-element scale to multiply each element by.
60    pub element_scale: Option<&'a [f32]>,
61
62    /// Constant bias to add after scaling
63    pub bias: f32,
64
65    /// Per-element bias to add after scaling
66    pub element_bias: Option<&'a [f32]>,
67}
68
69impl Default for NormalizeOptions<'_> {
70    fn default() -> Self {
71        NormalizeOptions {
72            pre_scale_bias: 0.,
73            scale: 1.,
74            element_scale: None,
75            bias: 0.,
76            element_bias: None,
77        }
78    }
79}
80
81impl<'dst> SimdOp for Normalize<'_, 'dst> {
82    /// The normalized elements.
83    type Output = &'dst mut [f32];
84
85    #[inline(always)]
86    fn eval<I: Isa>(self, isa: I) -> Self::Output {
87        let ops = isa.f32();
88
89        let Self {
90            src_dest,
91            opts:
92                NormalizeOptions {
93                    pre_scale_bias,
94                    scale,
95                    element_scale,
96                    bias,
97                    element_bias,
98                },
99        } = self;
100
101        if let Some(scale) = element_scale {
102            assert_eq!(scale.len(), src_dest.len());
103        }
104        if let Some(bias) = element_bias {
105            assert_eq!(bias.len(), src_dest.len());
106        }
107
108        let one = ops.one();
109        let zero = ops.zero();
110        let pre_scale_bias_vec = ops.splat(pre_scale_bias);
111
112        match (element_scale, element_bias, scale, bias) {
113            (None, None, scale, bias) => {
114                // Per channel scale and bias only. Used for BatchNormalization.
115                let const_scale_vec = ops.splat(scale);
116                let const_bias_vec = ops.splat(bias);
117
118                simd_map(
119                    ops,
120                    src_dest,
121                    #[inline(always)]
122                    |x| {
123                        let y = ops.sub(x, pre_scale_bias_vec);
124                        ops.mul_add(y, const_scale_vec, const_bias_vec)
125                    },
126                )
127            }
128            (Some(scale), None, const_scale, 0.) => {
129                // Scale only. Used by eg. LayerNormalization when there is no
130                // bias and RMS normalization.
131                let const_scale_vec = ops.splat(const_scale);
132                let mut scale_iter = scale.simd_iter_pad(ops);
133
134                simd_map(
135                    ops,
136                    src_dest,
137                    #[inline(always)]
138                    |x| {
139                        let scale_vec = scale_iter.next().unwrap();
140                        let scale_vec = ops.mul(scale_vec, const_scale_vec);
141
142                        let y = ops.sub(x, pre_scale_bias_vec);
143                        ops.mul(y, scale_vec)
144                    },
145                )
146            }
147            (element_scale, element_bias, const_scale, const_bias) => {
148                let const_scale_vec = ops.splat(const_scale);
149                let const_bias_vec = ops.splat(const_bias);
150                let mut scale_iter = element_scale.map(|s| s.simd_iter_pad(ops));
151                let mut bias_iter = element_bias.map(|b| b.simd_iter_pad(ops));
152
153                simd_map(
154                    ops,
155                    src_dest,
156                    #[inline(always)]
157                    |x| {
158                        let scale_vec = scale_iter.as_mut().and_then(|s| s.next()).unwrap_or(one);
159                        let scale_vec = ops.mul(scale_vec, const_scale_vec);
160
161                        let bias_vec = bias_iter.as_mut().and_then(|b| b.next()).unwrap_or(zero);
162                        let bias_vec = ops.add(bias_vec, const_bias_vec);
163
164                        let y = ops.sub(x, pre_scale_bias_vec);
165                        ops.mul_add(y, scale_vec, bias_vec)
166                    },
167                )
168            }
169        }
170    }
171}
172
173#[cfg(test)]
174mod tests {
175    use super::{Normalize, NormalizeOptions};
176    use rten_simd::SimdOp;
177
178    fn reference_normalize_mut(
179        data: &mut [f32],
180        pre_scale_bias: f32,
181        scale: f32,
182        element_scale: Option<&[f32]>,
183        bias: f32,
184        element_bias: Option<&[f32]>,
185    ) {
186        for i in 0..data.len() {
187            let x_scale = scale * element_scale.map(|es| es[i]).unwrap_or(1.);
188            let x_bias = bias + element_bias.map(|eb| eb[i]).unwrap_or(0.);
189            data[i] = (data[i] - pre_scale_bias).mul_add(x_scale, x_bias)
190        }
191    }
192
193    #[test]
194    fn test_normalize_mut() {
195        let data: Vec<_> = (0..10).map(|i| i as f32 * 0.1).collect();
196        let pre_scale_bias = 0.5;
197        let scale = 0.123;
198        let element_scale: Vec<_> = (0..data.len()).map(|i| 1.0 + i as f32 * 0.1).collect();
199        let bias = 0.3;
200        let element_bias: Vec<_> = (0..data.len()).map(|i| -0.5 + i as f32 * 0.2).collect();
201
202        // Per-element scale and bias
203        let mut expected = data.clone();
204        reference_normalize_mut(
205            &mut expected[..],
206            pre_scale_bias,
207            scale,
208            Some(&element_scale),
209            bias,
210            Some(&element_bias),
211        );
212
213        let mut actual = data.clone();
214        Normalize::new_mut(
215            &mut actual[..],
216            NormalizeOptions {
217                pre_scale_bias,
218                scale,
219                element_scale: Some(&element_scale),
220                bias,
221                element_bias: Some(&element_bias),
222            },
223        )
224        .dispatch();
225        assert_eq!(actual, expected);
226
227        // Per-element scale, but no bias
228        let mut expected = data.clone();
229        reference_normalize_mut(
230            &mut expected[..],
231            pre_scale_bias,
232            scale,
233            Some(&element_scale),
234            0.,
235            None,
236        );
237
238        let mut actual = data.clone();
239        Normalize::new_mut(
240            &mut actual[..],
241            NormalizeOptions {
242                pre_scale_bias,
243                scale,
244                element_scale: Some(&element_scale),
245                bias: 0.,
246                element_bias: None,
247            },
248        )
249        .dispatch();
250        assert_eq!(actual, expected);
251
252        // Per-channel (ie. constant) scale and bias only
253        let mut expected = data.clone();
254        reference_normalize_mut(&mut expected[..], pre_scale_bias, scale, None, bias, None);
255
256        let mut actual = data.clone();
257        Normalize::new_mut(
258            &mut actual[..],
259            NormalizeOptions {
260                pre_scale_bias,
261                scale,
262                element_scale: None,
263                bias,
264                element_bias: None,
265            },
266        )
267        .dispatch();
268
269        assert_eq!(actual, expected);
270    }
271}