use crate::{
errors::{IntoArrayError, NotEqualError},
InOut,
};
use core::{marker::PhantomData, slice};
use generic_array::{ArrayLength, GenericArray};
pub struct InOutBuf<'inp, 'out, T> {
pub(crate) in_ptr: *const T,
pub(crate) out_ptr: *mut T,
pub(crate) len: usize,
pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
}
impl<'a, T> From<&'a mut [T]> for InOutBuf<'a, 'a, T> {
#[inline(always)]
fn from(buf: &'a mut [T]) -> Self {
let p = buf.as_mut_ptr();
Self {
in_ptr: p,
out_ptr: p,
len: buf.len(),
_pd: PhantomData,
}
}
}
impl<'a, T> InOutBuf<'a, 'a, T> {
#[inline(always)]
pub fn from_mut(val: &'a mut T) -> InOutBuf<'a, 'a, T> {
let p = val as *mut T;
Self {
in_ptr: p,
out_ptr: p,
len: 1,
_pd: PhantomData,
}
}
}
impl<'inp, 'out, T> IntoIterator for InOutBuf<'inp, 'out, T> {
type Item = InOut<'inp, 'out, T>;
type IntoIter = InOutBufIter<'inp, 'out, T>;
#[inline(always)]
fn into_iter(self) -> Self::IntoIter {
InOutBufIter { buf: self, pos: 0 }
}
}
impl<'inp, 'out, T> InOutBuf<'inp, 'out, T> {
#[inline(always)]
pub fn from_ref_mut(in_val: &'inp T, out_val: &'out mut T) -> Self {
Self {
in_ptr: in_val as *const T,
out_ptr: out_val as *mut T,
len: 1,
_pd: PhantomData,
}
}
#[inline(always)]
pub fn new(in_buf: &'inp [T], out_buf: &'out mut [T]) -> Result<Self, NotEqualError> {
if in_buf.len() != out_buf.len() {
Err(NotEqualError)
} else {
Ok(Self {
in_ptr: in_buf.as_ptr(),
out_ptr: out_buf.as_mut_ptr(),
len: in_buf.len(),
_pd: Default::default(),
})
}
}
#[inline(always)]
pub fn len(&self) -> usize {
self.len
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.len == 0
}
#[inline(always)]
pub fn get<'a>(&'a mut self, pos: usize) -> InOut<'a, 'a, T> {
assert!(pos < self.len);
unsafe {
InOut {
in_ptr: self.in_ptr.add(pos),
out_ptr: self.out_ptr.add(pos),
_pd: PhantomData,
}
}
}
#[inline(always)]
pub fn get_in<'a>(&'a self) -> &'a [T] {
unsafe { slice::from_raw_parts(self.in_ptr, self.len) }
}
#[inline(always)]
pub fn get_out<'a>(&'a mut self) -> &'a mut [T] {
unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
}
#[inline(always)]
pub fn into_out(self) -> &'out mut [T] {
unsafe { slice::from_raw_parts_mut(self.out_ptr, self.len) }
}
#[inline(always)]
pub fn into_raw(self) -> (*const T, *mut T) {
(self.in_ptr, self.out_ptr)
}
#[inline(always)]
pub fn reborrow<'a>(&'a mut self) -> InOutBuf<'a, 'a, T> {
Self {
in_ptr: self.in_ptr,
out_ptr: self.out_ptr,
len: self.len,
_pd: PhantomData,
}
}
#[inline(always)]
pub unsafe fn from_raw(
in_ptr: *const T,
out_ptr: *mut T,
len: usize,
) -> InOutBuf<'inp, 'out, T> {
Self {
in_ptr,
out_ptr,
len,
_pd: PhantomData,
}
}
#[inline(always)]
pub fn split_at(self, mid: usize) -> (InOutBuf<'inp, 'out, T>, InOutBuf<'inp, 'out, T>) {
assert!(mid <= self.len);
let (tail_in_ptr, tail_out_ptr) = unsafe { (self.in_ptr.add(mid), self.out_ptr.add(mid)) };
(
InOutBuf {
in_ptr: self.in_ptr,
out_ptr: self.out_ptr,
len: mid,
_pd: PhantomData,
},
InOutBuf {
in_ptr: tail_in_ptr,
out_ptr: tail_out_ptr,
len: self.len() - mid,
_pd: PhantomData,
},
)
}
#[inline(always)]
pub fn into_chunks<N: ArrayLength<T>>(
self,
) -> (
InOutBuf<'inp, 'out, GenericArray<T, N>>,
InOutBuf<'inp, 'out, T>,
) {
let chunks = self.len() / N::USIZE;
let tail_pos = N::USIZE * chunks;
let tail_len = self.len() - tail_pos;
unsafe {
let chunks = InOutBuf {
in_ptr: self.in_ptr as *const GenericArray<T, N>,
out_ptr: self.out_ptr as *mut GenericArray<T, N>,
len: chunks,
_pd: PhantomData,
};
let tail = InOutBuf {
in_ptr: self.in_ptr.add(tail_pos),
out_ptr: self.out_ptr.add(tail_pos),
len: tail_len,
_pd: PhantomData,
};
(chunks, tail)
}
}
}
impl<'inp, 'out> InOutBuf<'inp, 'out, u8> {
#[inline(always)]
#[allow(clippy::needless_range_loop)]
pub fn xor_in2out(&mut self, data: &[u8]) {
assert_eq!(self.len(), data.len());
unsafe {
for i in 0..data.len() {
let in_ptr = self.in_ptr.add(i);
let out_ptr = self.out_ptr.add(i);
*out_ptr = *in_ptr ^ data[i];
}
}
}
}
impl<'inp, 'out, T, N> TryInto<InOut<'inp, 'out, GenericArray<T, N>>> for InOutBuf<'inp, 'out, T>
where
N: ArrayLength<T>,
{
type Error = IntoArrayError;
#[inline(always)]
fn try_into(self) -> Result<InOut<'inp, 'out, GenericArray<T, N>>, Self::Error> {
if self.len() == N::USIZE {
Ok(InOut {
in_ptr: self.in_ptr as *const _,
out_ptr: self.out_ptr as *mut _,
_pd: PhantomData,
})
} else {
Err(IntoArrayError)
}
}
}
pub struct InOutBufIter<'inp, 'out, T> {
buf: InOutBuf<'inp, 'out, T>,
pos: usize,
}
impl<'inp, 'out, T> Iterator for InOutBufIter<'inp, 'out, T> {
type Item = InOut<'inp, 'out, T>;
#[inline(always)]
fn next(&mut self) -> Option<Self::Item> {
if self.buf.len() == self.pos {
return None;
}
let res = unsafe {
InOut {
in_ptr: self.buf.in_ptr.add(self.pos),
out_ptr: self.buf.out_ptr.add(self.pos),
_pd: PhantomData,
}
};
self.pos += 1;
Some(res)
}
}