use crate::InOutBuf;
use core::{marker::PhantomData, ops::Mul, ptr};
use hybrid_array::{Array, ArraySize, typenum::Prod};
pub struct InOut<'inp, 'out, T> {
pub(crate) in_ptr: *const T,
pub(crate) out_ptr: *mut T,
pub(crate) _pd: PhantomData<(&'inp T, &'out mut T)>,
}
impl<'inp, 'out, T> InOut<'inp, 'out, T> {
#[inline(always)]
pub fn reborrow(&mut self) -> InOut<'_, '_, T> {
Self {
in_ptr: self.in_ptr,
out_ptr: self.out_ptr,
_pd: PhantomData,
}
}
#[inline(always)]
pub fn get_in(&self) -> &T {
unsafe { &*self.in_ptr }
}
#[inline(always)]
pub fn get_out(&mut self) -> &mut T {
unsafe { &mut *self.out_ptr }
}
pub fn into_out_with_copied_in(self) -> &'out mut T
where
T: Copy,
{
if !core::ptr::eq(self.in_ptr, self.out_ptr) {
unsafe {
ptr::copy(self.in_ptr, self.out_ptr, 1);
}
}
unsafe { &mut *self.out_ptr }
}
#[inline(always)]
pub fn into_out(self) -> &'out mut T {
unsafe { &mut *self.out_ptr }
}
#[inline(always)]
pub fn into_raw(self) -> (*const T, *mut T) {
(self.in_ptr, self.out_ptr)
}
#[inline(always)]
pub unsafe fn from_raw(in_ptr: *const T, out_ptr: *mut T) -> InOut<'inp, 'out, T> {
Self {
in_ptr,
out_ptr,
_pd: PhantomData,
}
}
}
impl<T: Clone> InOut<'_, '_, T> {
#[inline(always)]
pub fn clone_in(&self) -> T {
unsafe { (*self.in_ptr).clone() }
}
}
impl<'a, T> From<&'a mut T> for InOut<'a, 'a, T> {
#[inline(always)]
fn from(val: &'a mut T) -> Self {
let p = val as *mut T;
Self {
in_ptr: p,
out_ptr: p,
_pd: PhantomData,
}
}
}
impl<'inp, 'out, T> From<(&'inp T, &'out mut T)> for InOut<'inp, 'out, T> {
#[inline(always)]
fn from((in_val, out_val): (&'inp T, &'out mut T)) -> Self {
Self {
in_ptr: in_val as *const T,
out_ptr: out_val as *mut T,
_pd: Default::default(),
}
}
}
impl<'inp, 'out, T, N: ArraySize> InOut<'inp, 'out, Array<T, N>> {
#[inline(always)]
pub fn get(&mut self, pos: usize) -> InOut<'_, '_, T> {
assert!(pos < N::USIZE);
unsafe {
InOut {
in_ptr: (self.in_ptr as *const T).add(pos),
out_ptr: (self.out_ptr as *mut T).add(pos),
_pd: PhantomData,
}
}
}
#[inline(always)]
pub fn into_buf(self) -> InOutBuf<'inp, 'out, T> {
InOutBuf {
in_ptr: self.in_ptr as *const T,
out_ptr: self.out_ptr as *mut T,
len: N::USIZE,
_pd: PhantomData,
}
}
}
impl<'inp, 'out, T, N, M> From<InOut<'inp, 'out, Array<T, Prod<N, M>>>>
for Array<InOut<'inp, 'out, Array<T, N>>, M>
where
N: ArraySize,
M: ArraySize,
N: Mul<M>,
Prod<N, M>: ArraySize,
{
fn from(buf: InOut<'inp, 'out, Array<T, Prod<N, M>>>) -> Self {
let in_ptr: *const Array<T, N> = buf.in_ptr.cast();
let out_ptr: *mut Array<T, N> = buf.out_ptr.cast();
Array::from_fn(|i| unsafe {
InOut {
in_ptr: in_ptr.add(i),
out_ptr: out_ptr.add(i),
_pd: PhantomData,
}
})
}
}
impl<N: ArraySize> InOut<'_, '_, Array<u8, N>> {
#[inline(always)]
#[allow(clippy::needless_range_loop)]
pub fn xor_in2out(&mut self, data: &Array<u8, N>) {
unsafe {
let input = ptr::read(self.in_ptr);
let mut temp = Array::<u8, N>::default();
for i in 0..N::USIZE {
temp[i] = input[i] ^ data[i];
}
ptr::write(self.out_ptr, temp);
}
}
}
impl<N, M> InOut<'_, '_, Array<Array<u8, N>, M>>
where
N: ArraySize,
M: ArraySize,
{
#[inline(always)]
#[allow(clippy::needless_range_loop)]
pub fn xor_in2out(&mut self, data: &Array<Array<u8, N>, M>) {
unsafe {
let input = ptr::read(self.in_ptr);
let mut temp = Array::<Array<u8, N>, M>::default();
for i in 0..M::USIZE {
for j in 0..N::USIZE {
temp[i][j] = input[i][j] ^ data[i][j];
}
}
ptr::write(self.out_ptr, temp);
}
}
}