use std::ops::{Index, IndexMut};
use std::slice;
use crate::addressing::{Order, address};
use crate::errors::ShapeError;
#[derive(Debug)]
pub struct NdSliceMut<'s, T, const N: usize> {
slice: &'s mut [T],
shape: [usize; N],
order: Order,
}
type ConstructionResult<'s, T, const N: usize> = Result<NdSliceMut<'s, T, N>, ShapeError<'s, T, N>>;
impl<'s, T, const N: usize> NdSliceMut<'s, T, N> {
pub fn new(slice: &'s mut [T], shape: [usize; N], order: Order) -> ConstructionResult<'s, T, N> {
if slice.len() == shape.iter().fold(1, |acc, &x| acc * x) {
Ok(Self { slice, shape, order })
} else {
Err(ShapeError::new(slice, shape))
}
}
pub unsafe fn from_ptr(ptr: *mut T, len: usize, shape: [usize; N], order: Order) -> ConstructionResult<'s, T, N> {
NdSliceMut::new(slice::from_raw_parts_mut(ptr, len), shape, order)
}
pub fn new_row_ordered(slice: &'s mut [T], shape: [usize; N]) -> ConstructionResult<'s, T, N> {
NdSliceMut::new(slice, shape, Order::RowMajor)
}
pub unsafe fn row_ordered_from_ptr(ptr: *mut T, len: usize, shape: [usize; N]) -> ConstructionResult<'s, T, N> {
NdSliceMut::from_ptr(ptr, len, shape, Order::RowMajor)
}
pub fn col_ordered(slice: &'s mut [T], shape: [usize; N]) -> ConstructionResult<'s, T, N> {
NdSliceMut::new(slice, shape, Order::ColumnMajor)
}
pub unsafe fn col_ordered_from_ptr(ptr: *mut T, len: usize, shape: [usize; N]) -> ConstructionResult<'s, T, N> {
NdSliceMut::from_ptr(ptr, len, shape, Order::ColumnMajor)
}
}
impl<T, const N: usize> Index<[usize; N]> for NdSliceMut<'_, T, N> {
type Output = T;
fn index(&self, index: [usize; N]) -> &Self::Output {
&self.slice[address(&self.order, &self.shape, &index)]
}
}
impl<T, const N: usize> IndexMut<[usize; N]> for NdSliceMut<'_, T, N> {
fn index_mut(&mut self, index: [usize; N]) -> &mut Self::Output {
&mut self.slice[address(&self.order, &self.shape, &index)]
}
}