rten_simd/
functional.rs

1//! Vectorized higher-order operations (map etc.)
2
3use crate::Elem;
4use crate::ops::NumOps;
5use crate::span::SrcDest;
6
7/// Transform a slice by applying a vectorized map function to its elements.
8///
9/// This function can be applied both in-place (mutable destination) and
10/// with separate source/destination buffers.
11///
12/// If the slice is not a multiple of the vector length, the final call to
13/// `op` will use a vector padded with zeros.
14///
15/// The map function must have the same input and output type.
16#[inline(always)]
17pub fn simd_map<'src, 'dst, T: Elem + 'static, O: NumOps<T>, Op: FnMut(O::Simd) -> O::Simd>(
18    ops: O,
19    src_dest: impl Into<SrcDest<'src, 'dst, T>>,
20    mut op: Op,
21) -> &'dst mut [T] {
22    let mut src_dest = src_dest.into();
23    let (mut in_ptr, mut out_ptr, mut n) = src_dest.src_dest_ptr();
24
25    let v_len = ops.len();
26    while n >= v_len {
27        // Safety: `in_ptr` and `out_ptr` point to a buffer with at least `v_len`
28        // elements.
29        let x = unsafe { ops.load_ptr(in_ptr) };
30        let y = op(x);
31        unsafe { ops.store_ptr(y, out_ptr as *mut T) };
32
33        // Safety: `in_ptr` and `out_ptr` are pointers into buffers of the same
34        // length, with at least `v_len` elements.
35        n -= v_len;
36        unsafe {
37            in_ptr = in_ptr.add(v_len);
38            out_ptr = out_ptr.add(v_len);
39        }
40    }
41
42    if n > 0 {
43        let mask = ops.first_n_mask(n);
44
45        // Safety: Mask bit `i` is only set if `in_ptr.add(i)` and
46        // `out_ptr.add(i)` are valid.
47        let x = unsafe { ops.load_ptr_mask(in_ptr, mask) };
48        let y = op(x);
49        unsafe {
50            ops.store_ptr_mask(y, out_ptr as *mut T, mask);
51        }
52    }
53
54    // Safety: All elements in `src_dest` have been initialized.
55    unsafe { src_dest.dest_assume_init() }
56}
57
58/// Transform a slice in-place by applying a vectorized map function to its
59/// elements.
60///
61/// If the slice is not a multiple of the vector length, the final call to
62/// `op` will use a vector padded with zeros.
63///
64/// `UNROLL` specifies a loop unrolling factor. When the operation is very
65/// cheap, explicit unrolling can improve instruction level parallelism.
66#[inline(always)]
67pub fn simd_apply<
68    T: Elem + 'static,
69    O: NumOps<T>,
70    Op: FnMut(O::Simd) -> O::Simd,
71    const UNROLL: usize,
72>(
73    ops: O,
74    dest: &mut [T],
75    mut op: Op,
76) -> &mut [T] {
77    let v_len = ops.len();
78    let mut chunks = dest.chunks_exact_mut(v_len * UNROLL);
79    for chunk in &mut chunks {
80        for i in 0..UNROLL {
81            // Safety: Sliced chunk points to `v_len` elements.
82            let x = unsafe { ops.load_ptr(chunk.as_ptr().add(i * v_len)) };
83            let y = op(x);
84            unsafe {
85                ops.store_ptr(y, chunk.as_mut_ptr().add(i * v_len));
86            }
87        }
88    }
89
90    let mut tail_chunks = chunks.into_remainder().chunks_exact_mut(v_len);
91    for chunk in &mut tail_chunks {
92        let x = ops.load(chunk);
93        let y = op(x);
94        ops.store(y, chunk);
95    }
96
97    let tail = tail_chunks.into_remainder();
98    if !tail.is_empty() {
99        let mask = ops.first_n_mask(tail.len());
100
101        // Safety: `mask[i]` is true where `tail.add(i)` is valid.
102        let x = unsafe { ops.load_ptr_mask(tail.as_ptr(), mask) };
103        let y = op(x);
104        unsafe {
105            ops.store_ptr_mask(y, tail.as_mut_ptr(), mask);
106        }
107    }
108
109    dest
110}
111
112#[cfg(test)]
113mod tests {
114    use crate::ops::NumOps;
115    use crate::{Isa, SimdOp};
116
117    use super::{simd_apply, simd_map};
118
119    // f32 vector length, chosen to exercise main and tail loops for all ISAs.
120    const TEST_LEN: usize = 18;
121
122    #[test]
123    fn test_simd_map() {
124        struct Square<'a> {
125            xs: &'a mut [f32],
126        }
127
128        impl<'a> SimdOp for Square<'a> {
129            type Output = &'a mut [f32];
130
131            fn eval<I: Isa>(self, isa: I) -> Self::Output {
132                let ops = isa.f32();
133                simd_map(ops, self.xs, |x| ops.mul(x, x))
134            }
135        }
136
137        let mut buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
138        let expected: Vec<_> = buf.iter().map(|x| *x * *x).collect();
139
140        let squared = Square { xs: &mut buf }.dispatch();
141
142        assert_eq!(squared, &expected);
143    }
144
145    #[test]
146    fn test_simd_apply() {
147        struct Square<'a> {
148            xs: &'a mut [f32],
149        }
150
151        impl<'a> SimdOp for Square<'a> {
152            type Output = &'a mut [f32];
153
154            fn eval<I: Isa>(self, isa: I) -> Self::Output {
155                let ops = isa.f32();
156                const UNROLL: usize = 2;
157                simd_apply::<_, _, _, UNROLL>(ops, self.xs, |x| ops.mul(x, x))
158            }
159        }
160
161        // Extend `TEST_LEN` to test the unrolled loops in `simd_apply`.
162        let test_len = TEST_LEN * 4;
163        let mut buf: Vec<_> = (0..test_len).map(|x| x as f32).collect();
164        let expected: Vec<_> = buf.iter().map(|x| *x * *x).collect();
165
166        let squared = Square { xs: &mut buf }.dispatch();
167
168        assert_eq!(squared, &expected);
169    }
170}