1use crate::InOutBuf;
2use core::{marker::PhantomData, ops::Mul, ptr};
3use hybrid_array::{Array, ArraySize, typenum::Prod};
4
5pub struct InOut<'inp, 'out, T> {
8 pub(crate) in_ptr: *const T,
9 pub(crate) out_ptr: *mut T,
10 pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
11}
12
13impl<'inp, 'out, T> InOut<'inp, 'out, T> {
14 #[inline(always)]
16 pub fn reborrow(&mut self) -> InOut<'_, '_, T> {
17 Self {
18 in_ptr: self.in_ptr,
19 out_ptr: self.out_ptr,
20 _pd: PhantomData,
21 }
22 }
23
24 #[inline(always)]
26 pub fn get_in(&self) -> &T {
27 unsafe { &*self.in_ptr }
28 }
29
30 #[inline(always)]
32 pub fn get_out(&mut self) -> &mut T {
33 unsafe { &mut *self.out_ptr }
34 }
35
36 pub fn into_out_with_copied_in(self) -> &'out mut T
43 where
44 T: Copy,
45 {
46 if !core::ptr::eq(self.in_ptr, self.out_ptr) {
47 unsafe {
48 ptr::copy(self.in_ptr, self.out_ptr, 1);
49 }
50 }
51 unsafe { &mut *self.out_ptr }
52 }
53
54 #[inline(always)]
56 pub fn into_out(self) -> &'out mut T {
57 unsafe { &mut *self.out_ptr }
58 }
59
60 #[inline(always)]
62 pub fn into_raw(self) -> (*const T, *mut T) {
63 (self.in_ptr, self.out_ptr)
64 }
65
66 #[inline(always)]
85 pub unsafe fn from_raw(in_ptr: *const T, out_ptr: *mut T) -> InOut<'inp, 'out, T> {
86 Self {
87 in_ptr,
88 out_ptr,
89 _pd: PhantomData,
90 }
91 }
92}
93
94impl<T: Clone> InOut<'_, '_, T> {
95 #[inline(always)]
97 pub fn clone_in(&self) -> T {
98 unsafe { (*self.in_ptr).clone() }
99 }
100}
101
102impl<'a, T> From<&'a mut T> for InOut<'a, 'a, T> {
103 #[inline(always)]
104 fn from(val: &'a mut T) -> Self {
105 let p = val as *mut T;
106 Self {
107 in_ptr: p,
108 out_ptr: p,
109 _pd: PhantomData,
110 }
111 }
112}
113
114impl<'inp, 'out, T> From<(&'inp T, &'out mut T)> for InOut<'inp, 'out, T> {
115 #[inline(always)]
116 fn from((in_val, out_val): (&'inp T, &'out mut T)) -> Self {
117 Self {
118 in_ptr: in_val as *const T,
119 out_ptr: out_val as *mut T,
120 _pd: Default::default(),
121 }
122 }
123}
124
125impl<'inp, 'out, T, N: ArraySize> InOut<'inp, 'out, Array<T, N>> {
126 #[inline(always)]
131 pub fn get(&mut self, pos: usize) -> InOut<'_, '_, T> {
132 assert!(pos < N::USIZE);
133 unsafe {
134 InOut {
135 in_ptr: (self.in_ptr as *const T).add(pos),
136 out_ptr: (self.out_ptr as *mut T).add(pos),
137 _pd: PhantomData,
138 }
139 }
140 }
141
142 #[inline(always)]
144 pub fn into_buf(self) -> InOutBuf<'inp, 'out, T> {
145 InOutBuf {
146 in_ptr: self.in_ptr as *const T,
147 out_ptr: self.out_ptr as *mut T,
148 len: N::USIZE,
149 _pd: PhantomData,
150 }
151 }
152}
153
154impl<'inp, 'out, T, N, M> From<InOut<'inp, 'out, Array<T, Prod<N, M>>>>
155 for Array<InOut<'inp, 'out, Array<T, N>>, M>
156where
157 N: ArraySize,
158 M: ArraySize,
159 N: Mul<M>,
160 Prod<N, M>: ArraySize,
161{
162 fn from(buf: InOut<'inp, 'out, Array<T, Prod<N, M>>>) -> Self {
163 let in_ptr: *const Array<T, N> = buf.in_ptr.cast();
164 let out_ptr: *mut Array<T, N> = buf.out_ptr.cast();
165
166 Array::from_fn(|i| unsafe {
167 InOut {
168 in_ptr: in_ptr.add(i),
169 out_ptr: out_ptr.add(i),
170 _pd: PhantomData,
171 }
172 })
173 }
174}
175
176impl<N: ArraySize> InOut<'_, '_, Array<u8, N>> {
177 #[inline(always)]
183 #[allow(clippy::needless_range_loop)]
184 pub fn xor_in2out(&mut self, data: &Array<u8, N>) {
185 unsafe {
186 let input = ptr::read(self.in_ptr);
187 let mut temp = Array::<u8, N>::default();
188 for i in 0..N::USIZE {
189 temp[i] = input[i] ^ data[i];
190 }
191 ptr::write(self.out_ptr, temp);
192 }
193 }
194}
195
196impl<N, M> InOut<'_, '_, Array<Array<u8, N>, M>>
197where
198 N: ArraySize,
199 M: ArraySize,
200{
201 #[inline(always)]
207 #[allow(clippy::needless_range_loop)]
208 pub fn xor_in2out(&mut self, data: &Array<Array<u8, N>, M>) {
209 unsafe {
210 let input = ptr::read(self.in_ptr);
211 let mut temp = Array::<Array<u8, N>, M>::default();
212 for i in 0..M::USIZE {
213 for j in 0..N::USIZE {
214 temp[i][j] = input[i][j] ^ data[i][j];
215 }
216 }
217 ptr::write(self.out_ptr, temp);
218 }
219 }
220}