1use crate::{
2 InOut,
3 errors::{IntoArrayError, NotEqualError},
4};
5use core::{marker::PhantomData, slice};
6use hybrid_array::{Array, ArraySize};
7
8pub struct InOutBuf<'inp, 'out, T> {
12 pub(crate) in_ptr: *const T,
13 pub(crate) out_ptr: *mut T,
14 pub(crate) len: usize,
15 pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
16}
17
18impl<'a, T> From<&'a mut [T]> for InOutBuf<'a, 'a, T> {
19 #[inline(always)]
20 fn from(buf: &'a mut [T]) -> Self {
21 let p = buf.as_mut_ptr();
22 Self {
23 in_ptr: p,
24 out_ptr: p,
25 len: buf.len(),
26 _pd: PhantomData,
27 }
28 }
29}
30
31impl<'a, T> InOutBuf<'a, 'a, T> {
32 #[inline(always)]
34 pub fn from_mut(val: &'a mut T) -> InOutBuf<'a, 'a, T> {
35 let p = val as *mut T;
36 Self {
37 in_ptr: p,
38 out_ptr: p,
39 len: 1,
40 _pd: PhantomData,
41 }
42 }
43}
44
45impl<'inp, 'out, T> IntoIterator for InOutBuf<'inp, 'out, T> {
46 type Item = InOut<'inp, 'out, T>;
47 type IntoIter = InOutBufIter<'inp, 'out, T>;
48
49 #[inline(always)]
50 fn into_iter(self) -> Self::IntoIter {
51 InOutBufIter { buf: self, pos: 0 }
52 }
53}
54
55impl<'inp, 'out, T> InOutBuf<'inp, 'out, T> {
56 #[inline(always)]
58 pub fn from_ref_mut(in_val: &'inp T, out_val: &'out mut T) -> Self {
59 Self {
60 in_ptr: in_val as *const T,
61 out_ptr: out_val as *mut T,
62 len: 1,
63 _pd: PhantomData,
64 }
65 }
66
67 #[inline(always)]
71 pub fn new(in_buf: &'inp [T], out_buf: &'out mut [T]) -> Result<Self, NotEqualError> {
72 if in_buf.len() != out_buf.len() {
73 Err(NotEqualError)
74 } else {
75 Ok(Self {
76 in_ptr: in_buf.as_ptr(),
77 out_ptr: out_buf.as_mut_ptr(),
78 len: in_buf.len(),
79 _pd: Default::default(),
80 })
81 }
82 }
83
84 #[inline(always)]
86 pub fn len(&self) -> usize {
87 self.len
88 }
89
90 #[inline(always)]
92 pub fn is_empty(&self) -> bool {
93 self.len == 0
94 }
95
96 #[inline(always)]
101 pub fn get(&mut self, pos: usize) -> InOut<'_, '_, T> {
102 assert!(pos < self.len);
103 unsafe {
104 InOut {
105 in_ptr: self.in_ptr.add(pos),
106 out_ptr: self.out_ptr.add(pos),
107 _pd: PhantomData,
108 }
109 }
110 }
111
112 #[inline(always)]
114 pub fn get_in(&self) -> &[T] {
115 unsafe { slice::from_raw_parts(self.in_ptr, self.len) }
116 }
117
118 #[inline(always)]
120 pub fn get_out(&mut self) -> &mut [T] {
121 unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
122 }
123
124 pub fn into_out_with_copied_in(self) -> &'out mut [T]
131 where
132 T: Copy,
133 {
134 if !core::ptr::eq(self.in_ptr, self.out_ptr) {
135 unsafe {
136 core::ptr::copy(self.in_ptr, self.out_ptr, self.len);
137 }
138 }
139 unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
140 }
141
142 #[inline(always)]
144 pub fn into_out(self) -> &'out mut [T] {
145 unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
146 }
147
148 #[inline(always)]
150 pub fn into_raw(self) -> (*const T, *mut T) {
151 (self.in_ptr, self.out_ptr)
152 }
153
154 #[inline(always)]
156 pub fn reborrow(&mut self) -> InOutBuf<'_, '_, T> {
157 Self {
158 in_ptr: self.in_ptr,
159 out_ptr: self.out_ptr,
160 len: self.len,
161 _pd: PhantomData,
162 }
163 }
164
165 #[inline(always)]
186 pub unsafe fn from_raw(
187 in_ptr: *const T,
188 out_ptr: *mut T,
189 len: usize,
190 ) -> InOutBuf<'inp, 'out, T> {
191 Self {
192 in_ptr,
193 out_ptr,
194 len,
195 _pd: PhantomData,
196 }
197 }
198
199 #[inline(always)]
209 pub fn split_at(self, mid: usize) -> (InOutBuf<'inp, 'out, T>, InOutBuf<'inp, 'out, T>) {
210 assert!(mid <= self.len);
211 let (tail_in_ptr, tail_out_ptr) = unsafe { (self.in_ptr.add(mid), self.out_ptr.add(mid)) };
212 (
213 InOutBuf {
214 in_ptr: self.in_ptr,
215 out_ptr: self.out_ptr,
216 len: mid,
217 _pd: PhantomData,
218 },
219 InOutBuf {
220 in_ptr: tail_in_ptr,
221 out_ptr: tail_out_ptr,
222 len: self.len() - mid,
223 _pd: PhantomData,
224 },
225 )
226 }
227
228 #[inline(always)]
230 pub fn into_chunks<N: ArraySize>(
231 self,
232 ) -> (InOutBuf<'inp, 'out, Array<T, N>>, InOutBuf<'inp, 'out, T>) {
233 let chunks = self.len() / N::USIZE;
234 let tail_pos = N::USIZE * chunks;
235 let tail_len = self.len() - tail_pos;
236 unsafe {
237 let chunks = InOutBuf {
238 in_ptr: self.in_ptr as *const Array<T, N>,
239 out_ptr: self.out_ptr as *mut Array<T, N>,
240 len: chunks,
241 _pd: PhantomData,
242 };
243 let tail = InOutBuf {
244 in_ptr: self.in_ptr.add(tail_pos),
245 out_ptr: self.out_ptr.add(tail_pos),
246 len: tail_len,
247 _pd: PhantomData,
248 };
249 (chunks, tail)
250 }
251 }
252}
253
254impl InOutBuf<'_, '_, u8> {
255 #[inline(always)]
261 #[allow(clippy::needless_range_loop)]
262 pub fn xor_in2out(&mut self, data: &[u8]) {
263 assert_eq!(self.len(), data.len());
264 unsafe {
265 for i in 0..data.len() {
266 let in_ptr = self.in_ptr.add(i);
267 let out_ptr = self.out_ptr.add(i);
268 *out_ptr = *in_ptr ^ data[i];
269 }
270 }
271 }
272}
273
274impl<'inp, 'out, T, N> TryInto<InOut<'inp, 'out, Array<T, N>>> for InOutBuf<'inp, 'out, T>
275where
276 N: ArraySize,
277{
278 type Error = IntoArrayError;
279
280 #[inline(always)]
281 fn try_into(self) -> Result<InOut<'inp, 'out, Array<T, N>>, Self::Error> {
282 if self.len() == N::USIZE {
283 Ok(InOut {
284 in_ptr: self.in_ptr as *const _,
285 out_ptr: self.out_ptr as *mut _,
286 _pd: PhantomData,
287 })
288 } else {
289 Err(IntoArrayError)
290 }
291 }
292}
293
294pub struct InOutBufIter<'inp, 'out, T> {
296 buf: InOutBuf<'inp, 'out, T>,
297 pos: usize,
298}
299
300impl<'inp, 'out, T> Iterator for InOutBufIter<'inp, 'out, T> {
301 type Item = InOut<'inp, 'out, T>;
302
303 #[inline(always)]
304 fn next(&mut self) -> Option<Self::Item> {
305 if self.buf.len() == self.pos {
306 return None;
307 }
308 let res = unsafe {
309 InOut {
310 in_ptr: self.buf.in_ptr.add(self.pos),
311 out_ptr: self.buf.out_ptr.add(self.pos),
312 _pd: PhantomData,
313 }
314 };
315 self.pos += 1;
316 Some(res)
317 }
318}