inout/
inout.rs

1use crate::InOutBuf;
2use core::{marker::PhantomData, ops::Mul, ptr};
3use hybrid_array::{Array, ArraySize, typenum::Prod};
4
5/// Custom pointer type which contains one immutable (input) and one mutable
6/// (output) pointer, which are either equal or non-overlapping.
7pub struct InOut<'inp, 'out, T> {
8    pub(crate) in_ptr: *const T,
9    pub(crate) out_ptr: *mut T,
10    pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
11}
12
13impl<'inp, 'out, T> InOut<'inp, 'out, T> {
14    /// Reborrow `self`.
15    #[inline(always)]
16    pub fn reborrow(&mut self) -> InOut<'_, '_, T> {
17        Self {
18            in_ptr: self.in_ptr,
19            out_ptr: self.out_ptr,
20            _pd: PhantomData,
21        }
22    }
23
24    /// Get immutable reference to the input value.
25    #[inline(always)]
26    pub fn get_in(&self) -> &T {
27        unsafe { &*self.in_ptr }
28    }
29
30    /// Get mutable reference to the output value.
31    #[inline(always)]
32    pub fn get_out(&mut self) -> &mut T {
33        unsafe { &mut *self.out_ptr }
34    }
35
36    /// Consume `self` and get mutable reference to the output value with lifetime `'out`
37    /// and output value equal to the input value.
38    ///
39    /// In the case if the input and output references are the same, simply returns
40    /// the output reference. Otherwise, copies data from the former to the latter
41    /// before returning the output reference.
42    pub fn into_out_with_copied_in(self) -> &'out mut T
43    where
44        T: Copy,
45    {
46        if !core::ptr::eq(self.in_ptr, self.out_ptr) {
47            unsafe {
48                ptr::copy(self.in_ptr, self.out_ptr, 1);
49            }
50        }
51        unsafe { &mut *self.out_ptr }
52    }
53
54    /// Consume `self` and get mutable reference to the output value with lifetime `'out`.
55    #[inline(always)]
56    pub fn into_out(self) -> &'out mut T {
57        unsafe { &mut *self.out_ptr }
58    }
59
60    /// Convert `self` to a pair of raw input and output pointers.
61    #[inline(always)]
62    pub fn into_raw(self) -> (*const T, *mut T) {
63        (self.in_ptr, self.out_ptr)
64    }
65
66    /// Create `InOut` from raw input and output pointers.
67    ///
68    /// # Safety
69    /// Behavior is undefined if any of the following conditions are violated:
70    /// - `in_ptr` must point to a properly initialized value of type `T` and
71    ///   must be valid for reads.
72    /// - `out_ptr` must point to a properly initialized value of type `T` and
73    ///   must be valid for both reads and writes.
74    /// - `in_ptr` and `out_ptr` must be either equal or non-overlapping.
75    /// - If `in_ptr` and `out_ptr` are equal, then the memory referenced by
76    ///   them must not be accessed through any other pointer (not derived from
77    ///   the return value) for the duration of lifetime 'a. Both read and write
78    ///   accesses are forbidden.
79    /// - If `in_ptr` and `out_ptr` are not equal, then the memory referenced by
80    ///   `out_ptr` must not be accessed through any other pointer (not derived from
81    ///   the return value) for the duration of lifetime `'a`. Both read and write
82    ///   accesses are forbidden. The memory referenced by `in_ptr` must not be
83    ///   mutated for the duration of lifetime `'a`, except inside an `UnsafeCell`.
84    #[inline(always)]
85    pub unsafe fn from_raw(in_ptr: *const T, out_ptr: *mut T) -> InOut<'inp, 'out, T> {
86        Self {
87            in_ptr,
88            out_ptr,
89            _pd: PhantomData,
90        }
91    }
92}
93
94impl<T: Clone> InOut<'_, '_, T> {
95    /// Clone input value and return it.
96    #[inline(always)]
97    pub fn clone_in(&self) -> T {
98        unsafe { (*self.in_ptr).clone() }
99    }
100}
101
102impl<'a, T> From<&'a mut T> for InOut<'a, 'a, T> {
103    #[inline(always)]
104    fn from(val: &'a mut T) -> Self {
105        let p = val as *mut T;
106        Self {
107            in_ptr: p,
108            out_ptr: p,
109            _pd: PhantomData,
110        }
111    }
112}
113
114impl<'inp, 'out, T> From<(&'inp T, &'out mut T)> for InOut<'inp, 'out, T> {
115    #[inline(always)]
116    fn from((in_val, out_val): (&'inp T, &'out mut T)) -> Self {
117        Self {
118            in_ptr: in_val as *const T,
119            out_ptr: out_val as *mut T,
120            _pd: Default::default(),
121        }
122    }
123}
124
125impl<'inp, 'out, T, N: ArraySize> InOut<'inp, 'out, Array<T, N>> {
126    /// Returns `InOut` for the given position.
127    ///
128    /// # Panics
129    /// If `pos` greater or equal to array length.
130    #[inline(always)]
131    pub fn get(&mut self, pos: usize) -> InOut<'_, '_, T> {
132        assert!(pos < N::USIZE);
133        unsafe {
134            InOut {
135                in_ptr: (self.in_ptr as *const T).add(pos),
136                out_ptr: (self.out_ptr as *mut T).add(pos),
137                _pd: PhantomData,
138            }
139        }
140    }
141
142    /// Convert `InOut` array to `InOutBuf`.
143    #[inline(always)]
144    pub fn into_buf(self) -> InOutBuf<'inp, 'out, T> {
145        InOutBuf {
146            in_ptr: self.in_ptr as *const T,
147            out_ptr: self.out_ptr as *mut T,
148            len: N::USIZE,
149            _pd: PhantomData,
150        }
151    }
152}
153
154impl<'inp, 'out, T, N, M> From<InOut<'inp, 'out, Array<T, Prod<N, M>>>>
155    for Array<InOut<'inp, 'out, Array<T, N>>, M>
156where
157    N: ArraySize,
158    M: ArraySize,
159    N: Mul<M>,
160    Prod<N, M>: ArraySize,
161{
162    fn from(buf: InOut<'inp, 'out, Array<T, Prod<N, M>>>) -> Self {
163        let in_ptr: *const Array<T, N> = buf.in_ptr.cast();
164        let out_ptr: *mut Array<T, N> = buf.out_ptr.cast();
165
166        Array::from_fn(|i| unsafe {
167            InOut {
168                in_ptr: in_ptr.add(i),
169                out_ptr: out_ptr.add(i),
170                _pd: PhantomData,
171            }
172        })
173    }
174}
175
176impl<N: ArraySize> InOut<'_, '_, Array<u8, N>> {
177    /// XOR `data` with values behind the input slice and write
178    /// result to the output slice.
179    ///
180    /// # Panics
181    /// If `data` length is not equal to the buffer length.
182    #[inline(always)]
183    #[allow(clippy::needless_range_loop)]
184    pub fn xor_in2out(&mut self, data: &Array<u8, N>) {
185        unsafe {
186            let input = ptr::read(self.in_ptr);
187            let mut temp = Array::<u8, N>::default();
188            for i in 0..N::USIZE {
189                temp[i] = input[i] ^ data[i];
190            }
191            ptr::write(self.out_ptr, temp);
192        }
193    }
194}
195
196impl<N, M> InOut<'_, '_, Array<Array<u8, N>, M>>
197where
198    N: ArraySize,
199    M: ArraySize,
200{
201    /// XOR `data` with values behind the input slice and write
202    /// result to the output slice.
203    ///
204    /// # Panics
205    /// If `data` length is not equal to the buffer length.
206    #[inline(always)]
207    #[allow(clippy::needless_range_loop)]
208    pub fn xor_in2out(&mut self, data: &Array<Array<u8, N>, M>) {
209        unsafe {
210            let input = ptr::read(self.in_ptr);
211            let mut temp = Array::<Array<u8, N>, M>::default();
212            for i in 0..M::USIZE {
213                for j in 0..N::USIZE {
214                    temp[i][j] = input[i][j] ^ data[i][j];
215                }
216            }
217            ptr::write(self.out_ptr, temp);
218        }
219    }
220}