use crate::{Entity, MatMut, MatRef};
use assert2::{assert, debug_assert};
use core::mem::MaybeUninit;
use reborrow::*;
use seal::Seal;
mod seal {
    pub trait Seal {}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum Diag {
    Skip,
    Include,
}
pub struct Read<'a, E: Entity> {
    ptr: E::Group<&'a MaybeUninit<E::Unit>>,
}
pub struct ReadWrite<'a, E: Entity> {
    ptr: E::Group<&'a mut MaybeUninit<E::Unit>>,
}
impl<E: Entity> Read<'_, E> {
    #[inline(always)]
    pub fn read(&self) -> E {
        E::from_units(E::map(
            E::as_ref(&self.ptr),
            #[inline(always)]
            |ptr| unsafe { ptr.assume_init_read() },
        ))
    }
}
impl<E: Entity> ReadWrite<'_, E> {
    #[inline(always)]
    pub fn read(&self) -> E {
        E::from_units(E::map(
            E::as_ref(&self.ptr),
            #[inline(always)]
            |ptr| unsafe { ptr.assume_init_ref().clone() },
        ))
    }
    #[inline(always)]
    pub fn write(&mut self, value: E) {
        let value = E::into_units(value);
        E::map(
            E::zip(E::as_mut(&mut self.ptr), value),
            #[inline(always)]
            |(ptr, value)| unsafe { *ptr.assume_init_mut() = value },
        );
    }
}
pub trait Mat<'short, Outlives = &'short Self>: Seal {
    type Item;
    type RawSlice;
    fn transpose(self) -> Self;
    fn reverse_rows(self) -> Self;
    fn reverse_cols(self) -> Self;
    fn nrows(&self) -> usize;
    fn ncols(&self) -> usize;
    fn row_stride(&self) -> isize;
    fn col_stride(&self) -> isize;
    unsafe fn get(&'short mut self, i: usize, j: usize) -> Self::Item;
    unsafe fn get_column_slice(
        &'short mut self,
        i: usize,
        j: usize,
        n_elems: usize,
    ) -> Self::RawSlice;
    unsafe fn get_slice_elem(slice: &mut Self::RawSlice, idx: usize) -> Self::Item;
}
impl<'a, E: Entity> Seal for MatRef<'a, E> {}
impl<'a, 'short, E: Entity> Mat<'short> for MatRef<'a, E> {
    type Item = Read<'short, E>;
    type RawSlice = E::Group<&'a [MaybeUninit<E::Unit>]>;
    #[inline(always)]
    fn transpose(self) -> Self {
        self.transpose()
    }
    #[inline(always)]
    fn reverse_rows(self) -> Self {
        self.reverse_rows()
    }
    #[inline(always)]
    fn reverse_cols(self) -> Self {
        self.reverse_cols()
    }
    #[inline(always)]
    fn nrows(&self) -> usize {
        (*self).nrows()
    }
    #[inline(always)]
    fn ncols(&self) -> usize {
        (*self).ncols()
    }
    #[inline(always)]
    fn row_stride(&self) -> isize {
        (*self).row_stride()
    }
    #[inline(always)]
    fn col_stride(&self) -> isize {
        (*self).col_stride()
    }
    #[inline(always)]
    unsafe fn get(&mut self, i: usize, j: usize) -> Self::Item {
        Read {
            ptr: E::map(
                self.ptr_inbounds_at(i, j),
                #[inline(always)]
                |ptr| &*(ptr as *const MaybeUninit<E::Unit>),
            ),
        }
    }
    #[inline(always)]
    unsafe fn get_column_slice(&mut self, i: usize, j: usize, n_elems: usize) -> Self::RawSlice {
        E::map(
            self.ptr_at(i, j),
            #[inline(always)]
            |ptr| core::slice::from_raw_parts(ptr as *const MaybeUninit<E::Unit>, n_elems),
        )
    }
    #[inline(always)]
    unsafe fn get_slice_elem(slice: &mut Self::RawSlice, idx: usize) -> Self::Item {
        Read {
            ptr: E::map(
                E::as_mut(slice),
                #[inline(always)]
                |slice| slice.get_unchecked(idx),
            ),
        }
    }
}
impl<'a, E: Entity> Seal for MatMut<'a, E> {}
impl<'a, 'short, E: Entity> Mat<'short> for MatMut<'a, E> {
    type Item = ReadWrite<'short, E>;
    type RawSlice = E::Group<&'a mut [MaybeUninit<E::Unit>]>;
    #[inline(always)]
    fn transpose(self) -> Self {
        self.transpose()
    }
    #[inline(always)]
    fn reverse_rows(self) -> Self {
        self.reverse_rows()
    }
    #[inline(always)]
    fn reverse_cols(self) -> Self {
        self.reverse_cols()
    }
    #[inline(always)]
    fn nrows(&self) -> usize {
        (*self).nrows()
    }
    #[inline(always)]
    fn ncols(&self) -> usize {
        (*self).ncols()
    }
    #[inline(always)]
    fn row_stride(&self) -> isize {
        (*self).row_stride()
    }
    #[inline(always)]
    fn col_stride(&self) -> isize {
        (*self).col_stride()
    }
    #[inline(always)]
    unsafe fn get(&mut self, i: usize, j: usize) -> Self::Item {
        ReadWrite {
            ptr: E::map(
                self.rb_mut().ptr_inbounds_at(i, j),
                #[inline(always)]
                |ptr| &mut *(ptr as *mut MaybeUninit<E::Unit>),
            ),
        }
    }
    #[inline(always)]
    unsafe fn get_column_slice(&mut self, i: usize, j: usize, n_elems: usize) -> Self::RawSlice {
        E::map(
            self.rb_mut().ptr_at(i, j),
            #[inline(always)]
            |ptr| core::slice::from_raw_parts_mut(ptr as *mut MaybeUninit<E::Unit>, n_elems),
        )
    }
    #[inline(always)]
    unsafe fn get_slice_elem(slice: &mut Self::RawSlice, idx: usize) -> Self::Item {
        ReadWrite {
            ptr: E::map(
                E::as_mut(slice),
                #[inline(always)]
                |slice| &mut *(slice.get_unchecked_mut(idx) as *mut _),
            ),
        }
    }
}
pub struct Zip<Tuple> {
    pub(crate) tuple: Tuple,
}
include!(concat!(env!("OUT_DIR"), "/zip.rs"));
#[cfg(test)]
mod tests {
    use super::*;
    use crate::{zipped, ComplexField, Mat};
    use assert2::assert;
    #[test]
    fn test_zip() {
        for (m, n) in [(2, 2), (4, 2), (2, 4)] {
            for rev_dst in [false, true] {
                for rev_src in [false, true] {
                    for transpose_dst in [false, true] {
                        for transpose_src in [false, true] {
                            for diag in [Diag::Include, Diag::Skip] {
                                let mut dst = Mat::with_dims(
                                    if transpose_dst { n } else { m },
                                    if transpose_dst { m } else { n },
                                    |_, _| f64::zero(),
                                );
                                let src = Mat::with_dims(
                                    if transpose_src { n } else { m },
                                    if transpose_src { m } else { n },
                                    |_, _| f64::one(),
                                );
                                let mut target = Mat::with_dims(m, n, |_, _| f64::zero());
                                let target_src = Mat::with_dims(m, n, |_, _| f64::one());
                                zipped!(target.as_mut(), target_src.as_ref())
                                    .for_each_triangular_lower(diag, |mut dst, src| {
                                        dst.write(src.read())
                                    });
                                let mut dst = dst.as_mut();
                                let mut src = src.as_ref();
                                if transpose_dst {
                                    dst = dst.transpose();
                                }
                                if rev_dst {
                                    dst = dst.reverse_rows();
                                }
                                if transpose_src {
                                    src = src.transpose();
                                }
                                if rev_src {
                                    src = src.reverse_rows();
                                }
                                zipped!(dst.rb_mut(), src)
                                    .for_each_triangular_lower(diag, |mut dst, src| {
                                        dst.write(src.read())
                                    });
                                assert!(dst.rb() == target.as_ref());
                            }
                        }
                    }
                }
            }
        }
    }
}