inout/
inout_buf.rs

1use crate::{
2    InOut,
3    errors::{IntoArrayError, NotEqualError},
4};
5use core::{marker::PhantomData, slice};
6use hybrid_array::{Array, ArraySize};
7
8/// Custom slice type which references one immutable (input) slice and one
9/// mutable (output) slice of equal length. Input and output slices are
10/// either the same or do not overlap.
11pub struct InOutBuf<'inp, 'out, T> {
12    pub(crate) in_ptr: *const T,
13    pub(crate) out_ptr: *mut T,
14    pub(crate) len: usize,
15    pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
16}
17
18impl<'a, T> From<&'a mut [T]> for InOutBuf<'a, 'a, T> {
19    #[inline(always)]
20    fn from(buf: &'a mut [T]) -> Self {
21        let p = buf.as_mut_ptr();
22        Self {
23            in_ptr: p,
24            out_ptr: p,
25            len: buf.len(),
26            _pd: PhantomData,
27        }
28    }
29}
30
31impl<'a, T> InOutBuf<'a, 'a, T> {
32    /// Create `InOutBuf` from a single mutable reference.
33    #[inline(always)]
34    pub fn from_mut(val: &'a mut T) -> InOutBuf<'a, 'a, T> {
35        let p = val as *mut T;
36        Self {
37            in_ptr: p,
38            out_ptr: p,
39            len: 1,
40            _pd: PhantomData,
41        }
42    }
43}
44
45impl<'inp, 'out, T> IntoIterator for InOutBuf<'inp, 'out, T> {
46    type Item = InOut<'inp, 'out, T>;
47    type IntoIter = InOutBufIter<'inp, 'out, T>;
48
49    #[inline(always)]
50    fn into_iter(self) -> Self::IntoIter {
51        InOutBufIter { buf: self, pos: 0 }
52    }
53}
54
55impl<'inp, 'out, T> InOutBuf<'inp, 'out, T> {
56    /// Create `InOutBuf` from a pair of immutable and mutable references.
57    #[inline(always)]
58    pub fn from_ref_mut(in_val: &'inp T, out_val: &'out mut T) -> Self {
59        Self {
60            in_ptr: in_val as *const T,
61            out_ptr: out_val as *mut T,
62            len: 1,
63            _pd: PhantomData,
64        }
65    }
66
67    /// Create `InOutBuf` from immutable and mutable slices.
68    ///
69    /// Returns an error if length of slices is not equal to each other.
70    #[inline(always)]
71    pub fn new(in_buf: &'inp [T], out_buf: &'out mut [T]) -> Result<Self, NotEqualError> {
72        if in_buf.len() != out_buf.len() {
73            Err(NotEqualError)
74        } else {
75            Ok(Self {
76                in_ptr: in_buf.as_ptr(),
77                out_ptr: out_buf.as_mut_ptr(),
78                len: in_buf.len(),
79                _pd: Default::default(),
80            })
81        }
82    }
83
84    /// Get length of the inner buffers.
85    #[inline(always)]
86    pub fn len(&self) -> usize {
87        self.len
88    }
89
90    /// Returns `true` if the buffer has a length of 0.
91    #[inline(always)]
92    pub fn is_empty(&self) -> bool {
93        self.len == 0
94    }
95
96    /// Returns `InOut` for given position.
97    ///
98    /// # Panics
99    /// If `pos` greater or equal to buffer length.
100    #[inline(always)]
101    pub fn get(&mut self, pos: usize) -> InOut<'_, '_, T> {
102        assert!(pos < self.len);
103        unsafe {
104            InOut {
105                in_ptr: self.in_ptr.add(pos),
106                out_ptr: self.out_ptr.add(pos),
107                _pd: PhantomData,
108            }
109        }
110    }
111
112    /// Get input slice.
113    #[inline(always)]
114    pub fn get_in(&self) -> &[T] {
115        unsafe { slice::from_raw_parts(self.in_ptr, self.len) }
116    }
117
118    /// Get output slice.
119    #[inline(always)]
120    pub fn get_out(&mut self) -> &mut [T] {
121        unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
122    }
123
124    /// Consume `self` and get the output slice with lifetime `'out` filled with data from
125    /// the input slice.
126    ///
127    /// In the case if the input and output slices point to the same memory, simply returns
128    /// the output slice. Otherwise, copies data from the former to the latter
129    /// before returning the output slice.
130    pub fn into_out_with_copied_in(self) -> &'out mut [T]
131    where
132        T: Copy,
133    {
134        if !core::ptr::eq(self.in_ptr, self.out_ptr) {
135            unsafe {
136                core::ptr::copy(self.in_ptr, self.out_ptr, self.len);
137            }
138        }
139        unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
140    }
141
142    /// Consume `self` and get output slice with lifetime `'out`.
143    #[inline(always)]
144    pub fn into_out(self) -> &'out mut [T] {
145        unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
146    }
147
148    /// Get raw input and output pointers.
149    #[inline(always)]
150    pub fn into_raw(self) -> (*const T, *mut T) {
151        (self.in_ptr, self.out_ptr)
152    }
153
154    /// Reborrow `self`.
155    #[inline(always)]
156    pub fn reborrow(&mut self) -> InOutBuf<'_, '_, T> {
157        Self {
158            in_ptr: self.in_ptr,
159            out_ptr: self.out_ptr,
160            len: self.len,
161            _pd: PhantomData,
162        }
163    }
164
165    /// Create [`InOutBuf`] from raw input and output pointers.
166    ///
167    /// # Safety
168    /// Behavior is undefined if any of the following conditions are violated:
169    /// - `in_ptr` must point to a properly initialized value of type `T` and
170    ///   must be valid for reads for `len * mem::size_of::<T>()` many bytes.
171    /// - `out_ptr` must point to a properly initialized value of type `T` and
172    ///   must be valid for both reads and writes for `len * mem::size_of::<T>()`
173    ///   many bytes.
174    /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
175    /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
176    ///   them must not be accessed through any other pointer (not derived from
177    ///   the return value) for the duration of lifetime 'a. Both read and write
178    ///   accesses are forbidden.
179    /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
180    ///   `out_ptr` must not be accessed through any other pointer (not derived from
181    ///   the return value) for the duration of lifetime 'a. Both read and write
182    ///   accesses are forbidden. The memory referenced by `in_ptr` must not be
183    ///   mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
184    /// - The total size `len * mem::size_of::<T>()`  must be no larger than `isize::MAX`.
185    #[inline(always)]
186    pub unsafe fn from_raw(
187        in_ptr: *const T,
188        out_ptr: *mut T,
189        len: usize,
190    ) -> InOutBuf<'inp, 'out, T> {
191        Self {
192            in_ptr,
193            out_ptr,
194            len,
195            _pd: PhantomData,
196        }
197    }
198
199    /// Divides one buffer into two at `mid` index.
200    ///
201    /// The first will contain all indices from `[0, mid)` (excluding
202    /// the index `mid` itself) and the second will contain all
203    /// indices from `[mid, len)` (excluding the index `len` itself).
204    ///
205    /// # Panics
206    ///
207    /// Panics if `mid > len`.
208    #[inline(always)]
209    pub fn split_at(self, mid: usize) -> (InOutBuf<'inp, 'out, T>, InOutBuf<'inp, 'out, T>) {
210        assert!(mid <= self.len);
211        let (tail_in_ptr, tail_out_ptr) = unsafe { (self.in_ptr.add(mid), self.out_ptr.add(mid)) };
212        (
213            InOutBuf {
214                in_ptr: self.in_ptr,
215                out_ptr: self.out_ptr,
216                len: mid,
217                _pd: PhantomData,
218            },
219            InOutBuf {
220                in_ptr: tail_in_ptr,
221                out_ptr: tail_out_ptr,
222                len: self.len() - mid,
223                _pd: PhantomData,
224            },
225        )
226    }
227
228    /// Partition buffer into 2 parts: buffer of arrays and tail.
229    #[inline(always)]
230    pub fn into_chunks<N: ArraySize>(
231        self,
232    ) -> (InOutBuf<'inp, 'out, Array<T, N>>, InOutBuf<'inp, 'out, T>) {
233        let chunks = self.len() / N::USIZE;
234        let tail_pos = N::USIZE * chunks;
235        let tail_len = self.len() - tail_pos;
236        unsafe {
237            let chunks = InOutBuf {
238                in_ptr: self.in_ptr as *const Array<T, N>,
239                out_ptr: self.out_ptr as *mut Array<T, N>,
240                len: chunks,
241                _pd: PhantomData,
242            };
243            let tail = InOutBuf {
244                in_ptr: self.in_ptr.add(tail_pos),
245                out_ptr: self.out_ptr.add(tail_pos),
246                len: tail_len,
247                _pd: PhantomData,
248            };
249            (chunks, tail)
250        }
251    }
252}
253
254impl InOutBuf<'_, '_, u8> {
255    /// XORs `data` with values behind the input slice and write
256    /// result to the output slice.
257    ///
258    /// # Panics
259    /// If `data` length is not equal to the buffer length.
260    #[inline(always)]
261    #[allow(clippy::needless_range_loop)]
262    pub fn xor_in2out(&mut self, data: &[u8]) {
263        assert_eq!(self.len(), data.len());
264        unsafe {
265            for i in 0..data.len() {
266                let in_ptr = self.in_ptr.add(i);
267                let out_ptr = self.out_ptr.add(i);
268                *out_ptr = *in_ptr ^ data[i];
269            }
270        }
271    }
272}
273
274impl<'inp, 'out, T, N> TryInto<InOut<'inp, 'out, Array<T, N>>> for InOutBuf<'inp, 'out, T>
275where
276    N: ArraySize,
277{
278    type Error = IntoArrayError;
279
280    #[inline(always)]
281    fn try_into(self) -> Result<InOut<'inp, 'out, Array<T, N>>, Self::Error> {
282        if self.len() == N::USIZE {
283            Ok(InOut {
284                in_ptr: self.in_ptr as *const _,
285                out_ptr: self.out_ptr as *mut _,
286                _pd: PhantomData,
287            })
288        } else {
289            Err(IntoArrayError)
290        }
291    }
292}
293
294/// Iterator over [`InOutBuf`].
295pub struct InOutBufIter<'inp, 'out, T> {
296    buf: InOutBuf<'inp, 'out, T>,
297    pos: usize,
298}
299
300impl<'inp, 'out, T> Iterator for InOutBufIter<'inp, 'out, T> {
301    type Item = InOut<'inp, 'out, T>;
302
303    #[inline(always)]
304    fn next(&mut self) -> Option<Self::Item> {
305        if self.buf.len() == self.pos {
306            return None;
307        }
308        let res = unsafe {
309            InOut {
310                in_ptr: self.buf.in_ptr.add(self.pos),
311                out_ptr: self.buf.out_ptr.add(self.pos),
312                _pd: PhantomData,
313            }
314        };
315        self.pos += 1;
316        Some(res)
317    }
318}