1use crate::Elem;
4use crate::ops::NumOps;
5use crate::span::SrcDest;
6
7#[inline(always)]
17pub fn simd_map<'src, 'dst, T: Elem + 'static, O: NumOps<T>, Op: FnMut(O::Simd) -> O::Simd>(
18 ops: O,
19 src_dest: impl Into<SrcDest<'src, 'dst, T>>,
20 mut op: Op,
21) -> &'dst mut [T] {
22 let mut src_dest = src_dest.into();
23 let (mut in_ptr, mut out_ptr, mut n) = src_dest.src_dest_ptr();
24
25 let v_len = ops.len();
26 while n >= v_len {
27 let x = unsafe { ops.load_ptr(in_ptr) };
30 let y = op(x);
31 unsafe { ops.store_ptr(y, out_ptr as *mut T) };
32
33 n -= v_len;
36 unsafe {
37 in_ptr = in_ptr.add(v_len);
38 out_ptr = out_ptr.add(v_len);
39 }
40 }
41
42 if n > 0 {
43 let mask = ops.first_n_mask(n);
44
45 let x = unsafe { ops.load_ptr_mask(in_ptr, mask) };
48 let y = op(x);
49 unsafe {
50 ops.store_ptr_mask(y, out_ptr as *mut T, mask);
51 }
52 }
53
54 unsafe { src_dest.dest_assume_init() }
56}
57
58#[inline(always)]
67pub fn simd_apply<
68 T: Elem + 'static,
69 O: NumOps<T>,
70 Op: FnMut(O::Simd) -> O::Simd,
71 const UNROLL: usize,
72>(
73 ops: O,
74 dest: &mut [T],
75 mut op: Op,
76) -> &mut [T] {
77 let v_len = ops.len();
78 let mut chunks = dest.chunks_exact_mut(v_len * UNROLL);
79 for chunk in &mut chunks {
80 for i in 0..UNROLL {
81 let x = unsafe { ops.load_ptr(chunk.as_ptr().add(i * v_len)) };
83 let y = op(x);
84 unsafe {
85 ops.store_ptr(y, chunk.as_mut_ptr().add(i * v_len));
86 }
87 }
88 }
89
90 let mut tail_chunks = chunks.into_remainder().chunks_exact_mut(v_len);
91 for chunk in &mut tail_chunks {
92 let x = ops.load(chunk);
93 let y = op(x);
94 ops.store(y, chunk);
95 }
96
97 let tail = tail_chunks.into_remainder();
98 if !tail.is_empty() {
99 let mask = ops.first_n_mask(tail.len());
100
101 let x = unsafe { ops.load_ptr_mask(tail.as_ptr(), mask) };
103 let y = op(x);
104 unsafe {
105 ops.store_ptr_mask(y, tail.as_mut_ptr(), mask);
106 }
107 }
108
109 dest
110}
111
112#[cfg(test)]
113mod tests {
114 use crate::ops::NumOps;
115 use crate::{Isa, SimdOp};
116
117 use super::{simd_apply, simd_map};
118
119 const TEST_LEN: usize = 18;
121
122 #[test]
123 fn test_simd_map() {
124 struct Square<'a> {
125 xs: &'a mut [f32],
126 }
127
128 impl<'a> SimdOp for Square<'a> {
129 type Output = &'a mut [f32];
130
131 fn eval<I: Isa>(self, isa: I) -> Self::Output {
132 let ops = isa.f32();
133 simd_map(ops, self.xs, |x| ops.mul(x, x))
134 }
135 }
136
137 let mut buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
138 let expected: Vec<_> = buf.iter().map(|x| *x * *x).collect();
139
140 let squared = Square { xs: &mut buf }.dispatch();
141
142 assert_eq!(squared, &expected);
143 }
144
145 #[test]
146 fn test_simd_apply() {
147 struct Square<'a> {
148 xs: &'a mut [f32],
149 }
150
151 impl<'a> SimdOp for Square<'a> {
152 type Output = &'a mut [f32];
153
154 fn eval<I: Isa>(self, isa: I) -> Self::Output {
155 let ops = isa.f32();
156 const UNROLL: usize = 2;
157 simd_apply::<_, _, _, UNROLL>(ops, self.xs, |x| ops.mul(x, x))
158 }
159 }
160
161 let test_len = TEST_LEN * 4;
163 let mut buf: Vec<_> = (0..test_len).map(|x| x as f32).collect();
164 let expected: Vec<_> = buf.iter().map(|x| *x * *x).collect();
165
166 let squared = Square { xs: &mut buf }.dispatch();
167
168 assert_eq!(squared, &expected);
169 }
170}