rten_simd/
iter.rs

1//! Tools for vectorized iteration over slices.
2
3use crate::ops::NumOps;
4use crate::{Elem, Simd};
5
6/// Methods for creating vectorized iterators.
7pub trait SimdIterable {
8    /// Element type in the slice.
9    type Elem: Elem;
10
11    /// Iterate over SIMD-sized chunks of the input.
12    ///
13    /// If the input length is not divisble by the SIMD vector width, the
14    /// iterator yields only the full chunks. The tail is accessible via the
15    /// iterator's [`tail`](Iter::tail) method.
16    fn simd_iter<O: NumOps<Self::Elem>>(&self, ops: O) -> Iter<'_, Self::Elem, O>;
17
18    /// Iterate over SIMD-sized chunks of the input.
19    ///
20    /// If the input length is not divisble by the SIMD vector width, the final
21    /// chunk will be padded with zeros.
22    fn simd_iter_pad<O: NumOps<Self::Elem>>(
23        &self,
24        ops: O,
25    ) -> impl ExactSizeIterator<Item = O::Simd>;
26}
27
28impl<T: Elem> SimdIterable for [T] {
29    type Elem = T;
30
31    #[inline]
32    fn simd_iter<O: NumOps<T>>(&self, ops: O) -> Iter<'_, T, O> {
33        Iter::new(ops, self)
34    }
35
36    #[inline]
37    fn simd_iter_pad<O: NumOps<T>>(&self, ops: O) -> impl ExactSizeIterator<Item = O::Simd> {
38        IterPad::new(ops, self)
39    }
40}
41
42/// Iterator which yields chunks of a slice as a SIMD vector.
43///
44/// This type is created by [`SimdIterable::simd_iter`].
45pub struct Iter<'a, T: Elem, O: NumOps<T>> {
46    ops: O,
47    xs: &'a [T],
48    n_full_chunks: usize,
49}
50
51impl<'a, T: Elem, O: NumOps<T>> Iter<'a, T, O> {
52    #[inline]
53    fn new(ops: O, xs: &'a [T]) -> Self {
54        let n_full_chunks = xs.len() / ops.len();
55        Iter {
56            ops,
57            xs,
58            n_full_chunks,
59        }
60    }
61
62    /// Reduce an iterator to a single SIMD vector.
63    ///
64    /// This is like [`Iterator::fold`] but the `fold` function receives SIMD
65    /// vectors instead of single elements. If the iterator length is not a
66    /// multiple of the SIMD vector length, the final vector will be padded with
67    /// zeros. If a padded vector is used for the final update, the accumulator
68    /// vector will only be updated with the results from the non-padding lanes.
69    #[inline]
70    pub fn fold<F: FnMut(O::Simd, O::Simd) -> O::Simd>(
71        mut self,
72        mut accum: O::Simd,
73        mut fold: F,
74    ) -> O::Simd {
75        for chunk in &mut self {
76            accum = fold(accum, chunk);
77        }
78
79        if let Some((tail, mask)) = self.tail() {
80            let new_accum = fold(accum, tail);
81            accum = self.ops.select(new_accum, accum, mask);
82        }
83
84        accum
85    }
86
87    /// Variant of [`fold`](Self::fold) which is unrolled `UNROLL` times.
88    ///
89    /// In each iteration, `UNROLL` SIMD vectors are loaded and used to update
90    /// `UNROLL` separate accumulators via `fold(acc[i], x[i])`. The
91    /// accumulators are reduced to a single SIMD vector at the end using
92    /// `fold_acc`.
93    ///
94    /// When the `fold` operation is simple, this can improve performance over
95    /// `self.fold` by achieving better instruction level parallelism.
96    #[inline]
97    pub fn fold_unroll<const UNROLL: usize>(
98        mut self,
99        accum: O::Simd,
100        mut fold: impl FnMut(O::Simd, O::Simd) -> O::Simd,
101        mut fold_acc: impl FnMut(O::Simd, O::Simd) -> O::Simd,
102    ) -> O::Simd {
103        let mut acc = [accum; UNROLL];
104        let v_len = self.ops.len();
105
106        while let Some((chunk, tail)) = self.xs.split_at_checked(v_len * UNROLL) {
107            let xs: [_; UNROLL] = std::array::from_fn(|i| unsafe {
108                // Safety: `i < UNROLL` and `chunk` length is `v_len * UNROLL`
109                self.ops.load_ptr(chunk.as_ptr().add(v_len * i))
110            });
111            for i in 0..UNROLL {
112                acc[i] = fold(acc[i], xs[i]);
113            }
114            self.xs = tail;
115        }
116        for i in 1..UNROLL {
117            acc[0] = fold_acc(acc[0], acc[i]);
118        }
119        self.fold(acc[0], fold)
120    }
121
122    /// Variant of [`fold`](Self::fold) that computes multiple accumulator
123    /// values in a single pass.
124    #[inline]
125    pub fn fold_n<const N: usize>(
126        mut self,
127        mut accum: [O::Simd; N],
128        mut fold: impl FnMut([O::Simd; N], O::Simd) -> [O::Simd; N],
129    ) -> [O::Simd; N] {
130        for chunk in &mut self {
131            accum = fold(accum, chunk);
132        }
133
134        if let Some((tail, mask)) = self.tail() {
135            let new_accum = fold(accum, tail);
136            for i in 0..N {
137                accum[i] = self.ops.select(new_accum[i], accum[i], mask);
138            }
139        }
140
141        accum
142    }
143
144    /// Variant of [`fold_n`](Self::fold_n) which unrolls computation `UNROLL`
145    /// times, like [`fold_unroll`](Self::fold_unroll).
146    #[inline]
147    pub fn fold_n_unroll<const N: usize, const UNROLL: usize>(
148        mut self,
149        accum: [O::Simd; N],
150        mut fold: impl FnMut([O::Simd; N], O::Simd) -> [O::Simd; N],
151        mut fold_acc: impl FnMut([O::Simd; N], [O::Simd; N]) -> [O::Simd; N],
152    ) -> [O::Simd; N] {
153        let mut acc = [accum; UNROLL];
154        let v_len = self.ops.len();
155
156        while let Some((chunk, tail)) = self.xs.split_at_checked(v_len * UNROLL) {
157            let xs: [_; UNROLL] = std::array::from_fn(|i| unsafe {
158                // Safety: `i < UNROLL` and `chunk` length is `v_len * UNROLL`
159                self.ops.load_ptr(chunk.as_ptr().add(v_len * i))
160            });
161            for i in 0..UNROLL {
162                acc[i] = fold(acc[i], xs[i]);
163            }
164            self.xs = tail;
165        }
166        for i in 1..UNROLL {
167            acc[0] = fold_acc(acc[0], acc[i]);
168        }
169        self.fold_n(acc[0], fold)
170    }
171
172    /// Return a SIMD vector and mask for the left-over elements in the
173    /// slice after iterating over all full SIMD chunks.
174    ///
175    /// Elements of the SIMD vector that correspond to positions where the mask
176    /// is false will be set to zero.
177    #[inline]
178    pub fn tail(&self) -> Option<(O::Simd, <O::Simd as Simd>::Mask)> {
179        let n = self.xs.len();
180        if n > 0 {
181            Some(self.ops.load_pad(self.xs))
182        } else {
183            None
184        }
185    }
186}
187
188impl<T: Elem, O: NumOps<T>> Iterator for Iter<'_, T, O> {
189    type Item = O::Simd;
190
191    #[inline]
192    fn next(&mut self) -> Option<Self::Item> {
193        let v_len = self.ops.len();
194        if let Some((chunk, tail)) = self.xs.split_at_checked(v_len) {
195            self.xs = tail;
196
197            // Safety: `chunk.as_ptr()` points to `v_len` elements.
198            let x = unsafe { self.ops.load_ptr(chunk.as_ptr()) };
199
200            Some(x)
201        } else {
202            None
203        }
204    }
205
206    #[inline]
207    fn size_hint(&self) -> (usize, Option<usize>) {
208        (self.n_full_chunks, Some(self.n_full_chunks))
209    }
210}
211
212impl<T: Elem, O: NumOps<T>> ExactSizeIterator for Iter<'_, T, O> {}
213
214impl<T: Elem, O: NumOps<T>> std::iter::FusedIterator for Iter<'_, T, O> {}
215
216/// Iterator which yields chunks of a slice as a SIMD vector.
217///
218/// This type is created by [`SimdIterable::simd_iter_pad`].
219pub struct IterPad<'a, T: Elem, O: NumOps<T>> {
220    iter: Iter<'a, T, O>,
221    has_tail: bool,
222}
223
224impl<'a, T: Elem, O: NumOps<T>> IterPad<'a, T, O> {
225    #[inline]
226    fn new(ops: O, xs: &'a [T]) -> Self {
227        let iter = Iter::new(ops, xs);
228        let has_tail = !xs.len().is_multiple_of(ops.len());
229        Self { iter, has_tail }
230    }
231}
232
233impl<T: Elem, O: NumOps<T>> Iterator for IterPad<'_, T, O> {
234    type Item = O::Simd;
235
236    #[inline]
237    fn next(&mut self) -> Option<Self::Item> {
238        if let Some(chunk) = self.iter.next() {
239            Some(chunk)
240        } else if self.has_tail {
241            let (tail, _mask) = self.iter.tail().unwrap();
242            self.has_tail = false;
243            Some(tail)
244        } else {
245            None
246        }
247    }
248
249    #[inline]
250    fn size_hint(&self) -> (usize, Option<usize>) {
251        let n_tail = if self.has_tail { 1 } else { 0 };
252        let n_chunks = self.iter.len() + n_tail;
253        (n_chunks, Some(n_chunks))
254    }
255}
256
257impl<T: Elem, O: NumOps<T>> ExactSizeIterator for IterPad<'_, T, O> {}
258
259impl<T: Elem, O: NumOps<T>> std::iter::FusedIterator for IterPad<'_, T, O> {}
260
261#[cfg(test)]
262mod tests {
263    use super::SimdIterable;
264    use crate::dispatch::test_simd_op;
265    use crate::ops::NumOps;
266    use crate::{Isa, Simd, SimdOp};
267
268    // f32 vector length, chosen to exercise main and tail loops for all ISAs.
269    const TEST_LEN: usize = 18;
270
271    #[test]
272    fn test_iter() {
273        test_simd_op!(isa, {
274            let buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
275            let chunks = buf.chunks_exact(isa.f32().len());
276
277            let iter = buf.simd_iter(isa.f32());
278            assert_eq!(iter.len(), chunks.len());
279
280            for (scalar_chunk, simd_chunk) in chunks.zip(iter) {
281                assert_eq!(simd_chunk.to_array().as_ref(), scalar_chunk);
282            }
283        });
284    }
285
286    #[test]
287    fn test_iter_pad() {
288        test_simd_op!(isa, {
289            let buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
290            let chunks = buf.chunks(isa.f32().len());
291
292            let iter = buf.simd_iter_pad(isa.f32());
293            assert_eq!(iter.len(), chunks.len());
294
295            for (scalar_chunk, simd_chunk) in chunks.zip(iter) {
296                let simd_elts = simd_chunk.to_array();
297                let simd_elts = simd_elts.as_ref();
298                assert_eq!(&simd_elts[..scalar_chunk.len()], scalar_chunk);
299                if simd_elts.len() > scalar_chunk.len() {
300                    assert!(&simd_elts[scalar_chunk.len()..].iter().all(|x| *x == 0.));
301                }
302            }
303        });
304    }
305
306    #[test]
307    fn test_fold() {
308        struct Sum<'a> {
309            xs: &'a [f32],
310        }
311
312        impl<'a> SimdOp for Sum<'a> {
313            type Output = f32;
314
315            fn eval<I: Isa>(self, isa: I) -> Self::Output {
316                let ops = isa.f32();
317                let vec_sum = self
318                    .xs
319                    .simd_iter(ops)
320                    .fold(ops.zero(), |sum, x| ops.add(sum, x));
321                vec_sum.to_array().into_iter().fold(0., |sum, x| sum + x)
322            }
323        }
324
325        let buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
326        let expected = (buf.len() as f32 * buf[buf.len() - 1]) / 2.;
327
328        let sum = Sum { xs: &buf }.dispatch();
329        assert_eq!(sum, expected);
330    }
331
332    #[test]
333    fn test_fold_unroll() {
334        const UNROLL: usize = 4;
335
336        struct SumSquare<'a> {
337            xs: &'a [i32],
338        }
339
340        impl<'a> SimdOp for SumSquare<'a> {
341            type Output = i32;
342
343            fn eval<I: Isa>(self, isa: I) -> Self::Output {
344                let ops = isa.i32();
345                let vec_sum = self.xs.simd_iter(ops).fold_unroll::<UNROLL>(
346                    ops.zero(),
347                    |sum, x| ops.mul_add(x, x, sum),
348                    |sum, x| ops.add(sum, x),
349                );
350                vec_sum.to_array().into_iter().fold(0, |sum, x| sum + x)
351            }
352        }
353
354        let buf: Vec<_> = (0..TEST_LEN * UNROLL).map(|x| x as i32).collect();
355        let expected = buf.iter().fold(0, |acc, &x| {
356            let x = x as i32;
357            (x * x) + acc
358        });
359
360        let sum = SumSquare { xs: &buf }.dispatch();
361        assert_eq!(sum, expected);
362    }
363
364    const UNROLL: usize = 4;
365
366    struct MinMax<'a> {
367        xs: &'a [f32],
368        unroll: bool,
369    }
370
371    impl<'a> SimdOp for MinMax<'a> {
372        type Output = (f32, f32);
373
374        fn eval<I: Isa>(self, isa: I) -> Self::Output {
375            let ops = isa.f32();
376            let [vec_min, vec_max] = if self.unroll {
377                self.xs.simd_iter(ops).fold_n_unroll::<2, UNROLL>(
378                    [ops.splat(f32::MAX), ops.splat(f32::MIN)],
379                    |[min, max], x| [ops.min(min, x), ops.max(max, x)],
380                    |[min_a, max_a], [min_b, max_b]| [ops.min(min_a, min_b), ops.max(max_a, max_b)],
381                )
382            } else {
383                self.xs.simd_iter(ops).fold_n(
384                    [ops.splat(f32::MAX), ops.splat(f32::MIN)],
385                    |[min, max], x| [ops.min(min, x), ops.max(max, x)],
386                )
387            };
388            let min = vec_min
389                .to_array()
390                .into_iter()
391                .reduce(|min, x| min.min(x))
392                .unwrap();
393            let max = vec_max
394                .to_array()
395                .into_iter()
396                .reduce(|max, x| max.max(x))
397                .unwrap();
398            (min, max)
399        }
400    }
401
402    #[test]
403    fn test_fold_n() {
404        let buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
405        let (min, max) = MinMax {
406            xs: &buf,
407            unroll: false,
408        }
409        .dispatch();
410        assert_eq!(min, 0. as f32);
411        assert_eq!(max, (TEST_LEN - 1) as f32);
412    }
413
414    #[test]
415    fn test_fold_n_unroll() {
416        let buf: Vec<_> = (0..TEST_LEN * UNROLL).map(|x| x as f32).collect();
417        let (min, max) = MinMax {
418            xs: &buf,
419            unroll: false,
420        }
421        .dispatch();
422        assert_eq!(min, 0. as f32);
423        assert_eq!(max, (TEST_LEN * UNROLL - 1) as f32);
424    }
425}