1use crate::InOutBuf;
2use core::{marker::PhantomData, ptr};
3use hybrid_array::{Array, ArraySize};
4
5pub struct InOut<'inp, 'out, T> {
8 pub(crate) in_ptr: *const T,
9 pub(crate) out_ptr: *mut T,
10 pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
11}
12
13impl<'inp, 'out, T> InOut<'inp, 'out, T> {
14 #[inline(always)]
16 pub fn reborrow<'a>(&'a mut self) -> InOut<'a, 'a, T> {
17 Self {
18 in_ptr: self.in_ptr,
19 out_ptr: self.out_ptr,
20 _pd: PhantomData,
21 }
22 }
23
24 #[inline(always)]
26 pub fn get_in<'a>(&'a self) -> &'a T {
27 unsafe { &*self.in_ptr }
28 }
29
30 #[inline(always)]
32 pub fn get_out<'a>(&'a mut self) -> &'a mut T {
33 unsafe { &mut *self.out_ptr }
34 }
35
36 #[inline(always)]
38 pub fn into_out(self) -> &'out mut T {
39 unsafe { &mut *self.out_ptr }
40 }
41
42 #[inline(always)]
44 pub fn into_raw(self) -> (*const T, *mut T) {
45 (self.in_ptr, self.out_ptr)
46 }
47
48 #[inline(always)]
67 pub unsafe fn from_raw(in_ptr: *const T, out_ptr: *mut T) -> InOut<'inp, 'out, T> {
68 Self {
69 in_ptr,
70 out_ptr,
71 _pd: PhantomData,
72 }
73 }
74}
75
76impl<'inp, 'out, T: Clone> InOut<'inp, 'out, T> {
77 #[inline(always)]
79 pub fn clone_in(&self) -> T {
80 unsafe { (*self.in_ptr).clone() }
81 }
82}
83
84impl<'a, T> From<&'a mut T> for InOut<'a, 'a, T> {
85 #[inline(always)]
86 fn from(val: &'a mut T) -> Self {
87 let p = val as *mut T;
88 Self {
89 in_ptr: p,
90 out_ptr: p,
91 _pd: PhantomData,
92 }
93 }
94}
95
96impl<'inp, 'out, T> From<(&'inp T, &'out mut T)> for InOut<'inp, 'out, T> {
97 #[inline(always)]
98 fn from((in_val, out_val): (&'inp T, &'out mut T)) -> Self {
99 Self {
100 in_ptr: in_val as *const T,
101 out_ptr: out_val as *mut T,
102 _pd: Default::default(),
103 }
104 }
105}
106
107impl<'inp, 'out, T, N: ArraySize> InOut<'inp, 'out, Array<T, N>> {
108 #[inline(always)]
113 pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> {
114 assert!(pos < N::USIZE);
115 unsafe {
116 InOut {
117 in_ptr: (self.in_ptr as *const T).add(pos),
118 out_ptr: (self.out_ptr as *mut T).add(pos),
119 _pd: PhantomData,
120 }
121 }
122 }
123
124 #[inline(always)]
126 pub fn into_buf(self) -> InOutBuf<'inp, 'out, T> {
127 InOutBuf {
128 in_ptr: self.in_ptr as *const T,
129 out_ptr: self.out_ptr as *mut T,
130 len: N::USIZE,
131 _pd: PhantomData,
132 }
133 }
134}
135
136impl<'inp, 'out, N: ArraySize> InOut<'inp, 'out, Array<u8, N>> {
137 #[inline(always)]
143 #[allow(clippy::needless_range_loop)]
144 pub fn xor_in2out(&mut self, data: &Array<u8, N>) {
145 unsafe {
146 let input = ptr::read(self.in_ptr);
147 let mut temp = Array::<u8, N>::default();
148 for i in 0..N::USIZE {
149 temp[i] = input[i] ^ data[i];
150 }
151 ptr::write(self.out_ptr, temp);
152 }
153 }
154}
155
156impl<'inp, 'out, N, M> InOut<'inp, 'out, Array<Array<u8, N>, M>>
157where
158 N: ArraySize,
159 M: ArraySize,
160{
161 #[inline(always)]
167 #[allow(clippy::needless_range_loop)]
168 pub fn xor_in2out(&mut self, data: &Array<Array<u8, N>, M>) {
169 unsafe {
170 let input = ptr::read(self.in_ptr);
171 let mut temp = Array::<Array<u8, N>, M>::default();
172 for i in 0..M::USIZE {
173 for j in 0..N::USIZE {
174 temp[i][j] = input[i][j] ^ data[i][j];
175 }
176 }
177 ptr::write(self.out_ptr, temp);
178 }
179 }
180}