rten_simd/
writer.rs

1use std::mem::{MaybeUninit, transmute};
2
3use crate::Elem;
4use crate::ops::NumOps;
5
6/// Utility for incrementally filling an uninitialized slice, one SIMD vector
7/// at a time.
8pub struct SliceWriter<'a, T> {
9    buf: &'a mut [MaybeUninit<T>],
10    n_init: usize,
11}
12
13impl<'a, T: Elem> SliceWriter<'a, T> {
14    /// Create a writer which initializes elements of `buf`.
15    pub fn new(buf: &'a mut [MaybeUninit<T>]) -> Self {
16        SliceWriter { buf, n_init: 0 }
17    }
18
19    /// Initialize the next `ops.len()` elements of the slice from the contents
20    /// of SIMD vector `xs`.
21    ///
22    /// Panics if the slice does not have space for `ops.len()` elements.
23    pub fn write_vec<O: NumOps<T>>(&mut self, ops: O, xs: O::Simd) {
24        let written = ops.store_uninit(xs, &mut self.buf[self.n_init..]);
25        self.n_init += written.len();
26    }
27
28    /// Initialize the next element of the slice from `x`.
29    ///
30    /// Panics if the slice does not have space for writing any more elements.
31    pub fn write_scalar(&mut self, x: T) {
32        self.buf[self.n_init].write(x);
33        self.n_init += 1;
34    }
35
36    /// Finish writing the slice and return the initialized portion.
37    pub fn into_mut_slice(self) -> &'a mut [T] {
38        let init = &mut self.buf[0..self.n_init];
39
40        // Safety: All elements in `init` have been initialized.
41        unsafe { transmute::<&mut [MaybeUninit<T>], &mut [T]>(init) }
42    }
43}
44
45#[cfg(test)]
46mod tests {
47    use std::mem::MaybeUninit;
48
49    use crate::ops::NumOps;
50    use crate::{Isa, SimdOp, SliceWriter};
51
52    #[test]
53    fn test_slice_writer() {
54        struct MemCopy<'src, 'dest> {
55            src: &'src [f32],
56            dest: &'dest mut [MaybeUninit<f32>],
57        }
58
59        impl<'src, 'dest> SimdOp for MemCopy<'src, 'dest> {
60            type Output = &'dest mut [f32];
61
62            fn eval<I: Isa>(self, isa: I) -> &'dest mut [f32] {
63                let ops = isa.f32();
64
65                let mut src_chunks = self.src.chunks_exact(ops.len());
66                let mut dest_writer = SliceWriter::new(self.dest);
67
68                for chunk in src_chunks.by_ref() {
69                    let xs = ops.load(chunk);
70                    dest_writer.write_vec(ops, xs);
71                }
72
73                for x in src_chunks.remainder() {
74                    dest_writer.write_scalar(*x);
75                }
76
77                dest_writer.into_mut_slice()
78            }
79        }
80
81        // Length which should cover the vectorized body and tail cases for
82        // every ISA.
83        let len = 17;
84        let src: Vec<_> = (0..len).map(|x| x as f32).collect();
85        let mut dest = Vec::with_capacity(src.len());
86
87        let copied = MemCopy {
88            src: &src,
89            dest: dest.spare_capacity_mut(),
90        }
91        .dispatch();
92        assert_eq!(copied, src);
93    }
94}