inout/
inout.rs

1use crate::InOutBuf;
2use core::{marker::PhantomData, ptr};
3use hybrid_array::{Array, ArraySize};
4
5/// Custom pointer type which contains one immutable (input) and one mutable
6/// (output) pointer, which are either equal or non-overlapping.
7pub struct InOut<'inp, 'out, T> {
8    pub(crate) in_ptr: *const T,
9    pub(crate) out_ptr: *mut T,
10    pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
11}
12
13impl<'inp, 'out, T> InOut<'inp, 'out, T> {
14    /// Reborrow `self`.
15    #[inline(always)]
16    pub fn reborrow<'a>(&'a mut self) -> InOut<'a, 'a, T> {
17        Self {
18            in_ptr: self.in_ptr,
19            out_ptr: self.out_ptr,
20            _pd: PhantomData,
21        }
22    }
23
24    /// Get immutable reference to the input value.
25    #[inline(always)]
26    pub fn get_in<'a>(&'a self) -> &'a T {
27        unsafe { &*self.in_ptr }
28    }
29
30    /// Get mutable reference to the output value.
31    #[inline(always)]
32    pub fn get_out<'a>(&'a mut self) -> &'a mut T {
33        unsafe { &mut *self.out_ptr }
34    }
35
36    /// Consume `self` and get mutable reference to the output value with lifetime `'out`.
37    #[inline(always)]
38    pub fn into_out(self) -> &'out mut T {
39        unsafe { &mut *self.out_ptr }
40    }
41
42    /// Convert `self` to a pair of raw input and output pointers.
43    #[inline(always)]
44    pub fn into_raw(self) -> (*const T, *mut T) {
45        (self.in_ptr, self.out_ptr)
46    }
47
48    /// Create `InOut` from raw input and output pointers.
49    ///
50    /// # Safety
51    /// Behavior is undefined if any of the following conditions are violated:
52    /// - `in_ptr` must point to a properly initialized value of type `T` and
53    ///   must be valid for reads.
54    /// - `out_ptr` must point to a properly initialized value of type `T` and
55    ///   must be valid for both reads and writes.
56    /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
57    /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
58    ///   them must not be accessed through any other pointer (not derived from
59    ///   the return value) for the duration of lifetime 'a. Both read and write
60    ///   accesses are forbidden.
61    /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
62    ///   `out_ptr` must not be accessed through any other pointer (not derived from
63    ///   the return value) for the duration of lifetime `'a`. Both read and write
64    ///   accesses are forbidden. The memory referenced by `in_ptr` must not be
65    ///   mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
66    #[inline(always)]
67    pub unsafe fn from_raw(in_ptr: *const T, out_ptr: *mut T) -> InOut<'inp, 'out, T> {
68        Self {
69            in_ptr,
70            out_ptr,
71            _pd: PhantomData,
72        }
73    }
74}
75
76impl<'inp, 'out, T: Clone> InOut<'inp, 'out, T> {
77    /// Clone input value and return it.
78    #[inline(always)]
79    pub fn clone_in(&self) -> T {
80        unsafe { (*self.in_ptr).clone() }
81    }
82}
83
84impl<'a, T> From<&'a mut T> for InOut<'a, 'a, T> {
85    #[inline(always)]
86    fn from(val: &'a mut T) -> Self {
87        let p = val as *mut T;
88        Self {
89            in_ptr: p,
90            out_ptr: p,
91            _pd: PhantomData,
92        }
93    }
94}
95
96impl<'inp, 'out, T> From<(&'inp T, &'out mut T)> for InOut<'inp, 'out, T> {
97    #[inline(always)]
98    fn from((in_val, out_val): (&'inp T, &'out mut T)) -> Self {
99        Self {
100            in_ptr: in_val as *const T,
101            out_ptr: out_val as *mut T,
102            _pd: Default::default(),
103        }
104    }
105}
106
107impl<'inp, 'out, T, N: ArraySize> InOut<'inp, 'out, Array<T, N>> {
108    /// Returns `InOut` for the given position.
109    ///
110    /// # Panics
111    /// If `pos` greater or equal to array length.
112    #[inline(always)]
113    pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> {
114        assert!(pos < N::USIZE);
115        unsafe {
116            InOut {
117                in_ptr: (self.in_ptr as *const T).add(pos),
118                out_ptr: (self.out_ptr as *mut T).add(pos),
119                _pd: PhantomData,
120            }
121        }
122    }
123
124    /// Convert `InOut` array to `InOutBuf`.
125    #[inline(always)]
126    pub fn into_buf(self) -> InOutBuf<'inp, 'out, T> {
127        InOutBuf {
128            in_ptr: self.in_ptr as *const T,
129            out_ptr: self.out_ptr as *mut T,
130            len: N::USIZE,
131            _pd: PhantomData,
132        }
133    }
134}
135
136impl<'inp, 'out, N: ArraySize> InOut<'inp, 'out, Array<u8, N>> {
137    /// XOR `data` with values behind the input slice and write
138    /// result to the output slice.
139    ///
140    /// # Panics
141    /// If `data` length is not equal to the buffer length.
142    #[inline(always)]
143    #[allow(clippy::needless_range_loop)]
144    pub fn xor_in2out(&mut self, data: &Array<u8, N>) {
145        unsafe {
146            let input = ptr::read(self.in_ptr);
147            let mut temp = Array::<u8, N>::default();
148            for i in 0..N::USIZE {
149                temp[i] = input[i] ^ data[i];
150            }
151            ptr::write(self.out_ptr, temp);
152        }
153    }
154}
155
156impl<'inp, 'out, N, M> InOut<'inp, 'out, Array<Array<u8, N>, M>>
157where
158    N: ArraySize,
159    M: ArraySize,
160{
161    /// XOR `data` with values behind the input slice and write
162    /// result to the output slice.
163    ///
164    /// # Panics
165    /// If `data` length is not equal to the buffer length.
166    #[inline(always)]
167    #[allow(clippy::needless_range_loop)]
168    pub fn xor_in2out(&mut self, data: &Array<Array<u8, N>, M>) {
169        unsafe {
170            let input = ptr::read(self.in_ptr);
171            let mut temp = Array::<Array<u8, N>, M>::default();
172            for i in 0..M::USIZE {
173                for j in 0..N::USIZE {
174                    temp[i][j] = input[i][j] ^ data[i][j];
175                }
176            }
177            ptr::write(self.out_ptr, temp);
178        }
179    }
180}