1use crate::ops::NumOps;
4use crate::{Elem, Simd};
5
6pub trait SimdIterable {
8 type Elem: Elem;
10
11 fn simd_iter<O: NumOps<Self::Elem>>(&self, ops: O) -> Iter<'_, Self::Elem, O>;
17
18 fn simd_iter_pad<O: NumOps<Self::Elem>>(
23 &self,
24 ops: O,
25 ) -> impl ExactSizeIterator<Item = O::Simd>;
26}
27
28impl<T: Elem> SimdIterable for [T] {
29 type Elem = T;
30
31 #[inline]
32 fn simd_iter<O: NumOps<T>>(&self, ops: O) -> Iter<'_, T, O> {
33 Iter::new(ops, self)
34 }
35
36 #[inline]
37 fn simd_iter_pad<O: NumOps<T>>(&self, ops: O) -> impl ExactSizeIterator<Item = O::Simd> {
38 IterPad::new(ops, self)
39 }
40}
41
42pub struct Iter<'a, T: Elem, O: NumOps<T>> {
46 ops: O,
47 xs: &'a [T],
48 n_full_chunks: usize,
49}
50
51impl<'a, T: Elem, O: NumOps<T>> Iter<'a, T, O> {
52 #[inline]
53 fn new(ops: O, xs: &'a [T]) -> Self {
54 let n_full_chunks = xs.len() / ops.len();
55 Iter {
56 ops,
57 xs,
58 n_full_chunks,
59 }
60 }
61
62 #[inline]
70 pub fn fold<F: FnMut(O::Simd, O::Simd) -> O::Simd>(
71 mut self,
72 mut accum: O::Simd,
73 mut fold: F,
74 ) -> O::Simd {
75 for chunk in &mut self {
76 accum = fold(accum, chunk);
77 }
78
79 if let Some((tail, mask)) = self.tail() {
80 let new_accum = fold(accum, tail);
81 accum = self.ops.select(new_accum, accum, mask);
82 }
83
84 accum
85 }
86
87 #[inline]
97 pub fn fold_unroll<const UNROLL: usize>(
98 mut self,
99 accum: O::Simd,
100 mut fold: impl FnMut(O::Simd, O::Simd) -> O::Simd,
101 mut fold_acc: impl FnMut(O::Simd, O::Simd) -> O::Simd,
102 ) -> O::Simd {
103 let mut acc = [accum; UNROLL];
104 let v_len = self.ops.len();
105
106 while let Some((chunk, tail)) = self.xs.split_at_checked(v_len * UNROLL) {
107 let xs: [_; UNROLL] = std::array::from_fn(|i| unsafe {
108 self.ops.load_ptr(chunk.as_ptr().add(v_len * i))
110 });
111 for i in 0..UNROLL {
112 acc[i] = fold(acc[i], xs[i]);
113 }
114 self.xs = tail;
115 }
116 for i in 1..UNROLL {
117 acc[0] = fold_acc(acc[0], acc[i]);
118 }
119 self.fold(acc[0], fold)
120 }
121
122 #[inline]
125 pub fn fold_n<const N: usize>(
126 mut self,
127 mut accum: [O::Simd; N],
128 mut fold: impl FnMut([O::Simd; N], O::Simd) -> [O::Simd; N],
129 ) -> [O::Simd; N] {
130 for chunk in &mut self {
131 accum = fold(accum, chunk);
132 }
133
134 if let Some((tail, mask)) = self.tail() {
135 let new_accum = fold(accum, tail);
136 for i in 0..N {
137 accum[i] = self.ops.select(new_accum[i], accum[i], mask);
138 }
139 }
140
141 accum
142 }
143
144 #[inline]
147 pub fn fold_n_unroll<const N: usize, const UNROLL: usize>(
148 mut self,
149 accum: [O::Simd; N],
150 mut fold: impl FnMut([O::Simd; N], O::Simd) -> [O::Simd; N],
151 mut fold_acc: impl FnMut([O::Simd; N], [O::Simd; N]) -> [O::Simd; N],
152 ) -> [O::Simd; N] {
153 let mut acc = [accum; UNROLL];
154 let v_len = self.ops.len();
155
156 while let Some((chunk, tail)) = self.xs.split_at_checked(v_len * UNROLL) {
157 let xs: [_; UNROLL] = std::array::from_fn(|i| unsafe {
158 self.ops.load_ptr(chunk.as_ptr().add(v_len * i))
160 });
161 for i in 0..UNROLL {
162 acc[i] = fold(acc[i], xs[i]);
163 }
164 self.xs = tail;
165 }
166 for i in 1..UNROLL {
167 acc[0] = fold_acc(acc[0], acc[i]);
168 }
169 self.fold_n(acc[0], fold)
170 }
171
172 #[inline]
178 pub fn tail(&self) -> Option<(O::Simd, <O::Simd as Simd>::Mask)> {
179 let n = self.xs.len();
180 if n > 0 {
181 Some(self.ops.load_pad(self.xs))
182 } else {
183 None
184 }
185 }
186}
187
188impl<T: Elem, O: NumOps<T>> Iterator for Iter<'_, T, O> {
189 type Item = O::Simd;
190
191 #[inline]
192 fn next(&mut self) -> Option<Self::Item> {
193 let v_len = self.ops.len();
194 if let Some((chunk, tail)) = self.xs.split_at_checked(v_len) {
195 self.xs = tail;
196
197 let x = unsafe { self.ops.load_ptr(chunk.as_ptr()) };
199
200 Some(x)
201 } else {
202 None
203 }
204 }
205
206 #[inline]
207 fn size_hint(&self) -> (usize, Option<usize>) {
208 (self.n_full_chunks, Some(self.n_full_chunks))
209 }
210}
211
212impl<T: Elem, O: NumOps<T>> ExactSizeIterator for Iter<'_, T, O> {}
213
214impl<T: Elem, O: NumOps<T>> std::iter::FusedIterator for Iter<'_, T, O> {}
215
216pub struct IterPad<'a, T: Elem, O: NumOps<T>> {
220 iter: Iter<'a, T, O>,
221 has_tail: bool,
222}
223
224impl<'a, T: Elem, O: NumOps<T>> IterPad<'a, T, O> {
225 #[inline]
226 fn new(ops: O, xs: &'a [T]) -> Self {
227 let iter = Iter::new(ops, xs);
228 let has_tail = !xs.len().is_multiple_of(ops.len());
229 Self { iter, has_tail }
230 }
231}
232
233impl<T: Elem, O: NumOps<T>> Iterator for IterPad<'_, T, O> {
234 type Item = O::Simd;
235
236 #[inline]
237 fn next(&mut self) -> Option<Self::Item> {
238 if let Some(chunk) = self.iter.next() {
239 Some(chunk)
240 } else if self.has_tail {
241 let (tail, _mask) = self.iter.tail().unwrap();
242 self.has_tail = false;
243 Some(tail)
244 } else {
245 None
246 }
247 }
248
249 #[inline]
250 fn size_hint(&self) -> (usize, Option<usize>) {
251 let n_tail = if self.has_tail { 1 } else { 0 };
252 let n_chunks = self.iter.len() + n_tail;
253 (n_chunks, Some(n_chunks))
254 }
255}
256
257impl<T: Elem, O: NumOps<T>> ExactSizeIterator for IterPad<'_, T, O> {}
258
259impl<T: Elem, O: NumOps<T>> std::iter::FusedIterator for IterPad<'_, T, O> {}
260
261#[cfg(test)]
262mod tests {
263 use super::SimdIterable;
264 use crate::dispatch::test_simd_op;
265 use crate::ops::NumOps;
266 use crate::{Isa, Simd, SimdOp};
267
268 const TEST_LEN: usize = 18;
270
271 #[test]
272 fn test_iter() {
273 test_simd_op!(isa, {
274 let buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
275 let chunks = buf.chunks_exact(isa.f32().len());
276
277 let iter = buf.simd_iter(isa.f32());
278 assert_eq!(iter.len(), chunks.len());
279
280 for (scalar_chunk, simd_chunk) in chunks.zip(iter) {
281 assert_eq!(simd_chunk.to_array().as_ref(), scalar_chunk);
282 }
283 });
284 }
285
286 #[test]
287 fn test_iter_pad() {
288 test_simd_op!(isa, {
289 let buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
290 let chunks = buf.chunks(isa.f32().len());
291
292 let iter = buf.simd_iter_pad(isa.f32());
293 assert_eq!(iter.len(), chunks.len());
294
295 for (scalar_chunk, simd_chunk) in chunks.zip(iter) {
296 let simd_elts = simd_chunk.to_array();
297 let simd_elts = simd_elts.as_ref();
298 assert_eq!(&simd_elts[..scalar_chunk.len()], scalar_chunk);
299 if simd_elts.len() > scalar_chunk.len() {
300 assert!(&simd_elts[scalar_chunk.len()..].iter().all(|x| *x == 0.));
301 }
302 }
303 });
304 }
305
306 #[test]
307 fn test_fold() {
308 struct Sum<'a> {
309 xs: &'a [f32],
310 }
311
312 impl<'a> SimdOp for Sum<'a> {
313 type Output = f32;
314
315 fn eval<I: Isa>(self, isa: I) -> Self::Output {
316 let ops = isa.f32();
317 let vec_sum = self
318 .xs
319 .simd_iter(ops)
320 .fold(ops.zero(), |sum, x| ops.add(sum, x));
321 vec_sum.to_array().into_iter().fold(0., |sum, x| sum + x)
322 }
323 }
324
325 let buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
326 let expected = (buf.len() as f32 * buf[buf.len() - 1]) / 2.;
327
328 let sum = Sum { xs: &buf }.dispatch();
329 assert_eq!(sum, expected);
330 }
331
332 #[test]
333 fn test_fold_unroll() {
334 const UNROLL: usize = 4;
335
336 struct SumSquare<'a> {
337 xs: &'a [i32],
338 }
339
340 impl<'a> SimdOp for SumSquare<'a> {
341 type Output = i32;
342
343 fn eval<I: Isa>(self, isa: I) -> Self::Output {
344 let ops = isa.i32();
345 let vec_sum = self.xs.simd_iter(ops).fold_unroll::<UNROLL>(
346 ops.zero(),
347 |sum, x| ops.mul_add(x, x, sum),
348 |sum, x| ops.add(sum, x),
349 );
350 vec_sum.to_array().into_iter().fold(0, |sum, x| sum + x)
351 }
352 }
353
354 let buf: Vec<_> = (0..TEST_LEN * UNROLL).map(|x| x as i32).collect();
355 let expected = buf.iter().fold(0, |acc, &x| {
356 let x = x as i32;
357 (x * x) + acc
358 });
359
360 let sum = SumSquare { xs: &buf }.dispatch();
361 assert_eq!(sum, expected);
362 }
363
364 const UNROLL: usize = 4;
365
366 struct MinMax<'a> {
367 xs: &'a [f32],
368 unroll: bool,
369 }
370
371 impl<'a> SimdOp for MinMax<'a> {
372 type Output = (f32, f32);
373
374 fn eval<I: Isa>(self, isa: I) -> Self::Output {
375 let ops = isa.f32();
376 let [vec_min, vec_max] = if self.unroll {
377 self.xs.simd_iter(ops).fold_n_unroll::<2, UNROLL>(
378 [ops.splat(f32::MAX), ops.splat(f32::MIN)],
379 |[min, max], x| [ops.min(min, x), ops.max(max, x)],
380 |[min_a, max_a], [min_b, max_b]| [ops.min(min_a, min_b), ops.max(max_a, max_b)],
381 )
382 } else {
383 self.xs.simd_iter(ops).fold_n(
384 [ops.splat(f32::MAX), ops.splat(f32::MIN)],
385 |[min, max], x| [ops.min(min, x), ops.max(max, x)],
386 )
387 };
388 let min = vec_min
389 .to_array()
390 .into_iter()
391 .reduce(|min, x| min.min(x))
392 .unwrap();
393 let max = vec_max
394 .to_array()
395 .into_iter()
396 .reduce(|max, x| max.max(x))
397 .unwrap();
398 (min, max)
399 }
400 }
401
402 #[test]
403 fn test_fold_n() {
404 let buf: Vec<_> = (0..TEST_LEN).map(|x| x as f32).collect();
405 let (min, max) = MinMax {
406 xs: &buf,
407 unroll: false,
408 }
409 .dispatch();
410 assert_eq!(min, 0. as f32);
411 assert_eq!(max, (TEST_LEN - 1) as f32);
412 }
413
414 #[test]
415 fn test_fold_n_unroll() {
416 let buf: Vec<_> = (0..TEST_LEN * UNROLL).map(|x| x as f32).collect();
417 let (min, max) = MinMax {
418 xs: &buf,
419 unroll: false,
420 }
421 .dispatch();
422 assert_eq!(min, 0. as f32);
423 assert_eq!(max, (TEST_LEN * UNROLL - 1) as f32);
424 }
425}