use core::ops::{Index, IndexMut};
use crate::error::{Error, Result};
use crate::view2::{ArrayView2, ArrayViewMut2, validate_view};
#[derive(Clone, Copy, Debug)]
pub struct ArrayView3<'a, T> {
pub(crate) data: &'a [T],
pub(crate) shape: [usize; 3],
pub(crate) strides: [isize; 3],
pub(crate) offset: isize,
}
#[derive(Debug)]
pub struct ArrayViewMut3<'a, T> {
pub(crate) data: &'a mut [T],
pub(crate) shape: [usize; 3],
pub(crate) strides: [isize; 3],
pub(crate) offset: isize,
}
impl<'a, T> ArrayView3<'a, T> {
pub fn new(
data: &'a [T],
shape: [usize; 3],
strides: [isize; 3],
offset: isize,
) -> Result<Self> {
validate_view(data.len(), &shape, &strides, offset)?;
Ok(Self {
data,
shape,
strides,
offset,
})
}
pub(crate) fn from_raw_parts(
data: &'a [T],
shape: [usize; 3],
strides: [isize; 3],
offset: isize,
) -> Self {
Self {
data,
shape,
strides,
offset,
}
}
pub fn shape(&self) -> [usize; 3] {
self.shape
}
pub fn strides(&self) -> [isize; 3] {
self.strides
}
pub fn len(&self) -> usize {
self.shape.iter().product()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_contiguous(&self) -> bool {
self.shape.contains(&0)
|| (self.offset == 0
&& self.strides
== [
(self.shape[1] * self.shape[2]) as isize,
self.shape[2] as isize,
1,
]
&& self.len() == self.data.len())
}
pub fn as_slice(&self) -> Option<&'a [T]> {
self.is_contiguous().then_some(self.data)
}
pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&'a T> {
(i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
.then(|| &self.data[self.linear_index(i, j, k)])
}
pub fn matrix_at(&self, axis: usize, index: usize) -> Result<ArrayView2<'a, T>> {
if axis >= 3 {
return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
}
if index >= self.shape[axis] {
return Err(Error::IndexOutOfBounds);
}
let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
Ok(ArrayView2::from_raw_parts(
self.data,
[self.shape[axes[0]], self.shape[axes[1]]],
[self.strides[axes[0]], self.strides[axes[1]]],
self.offset + index as isize * self.strides[axis],
))
}
pub fn for_each_matrix(
&self,
axis: usize,
mut f: impl FnMut(usize, ArrayView2<'a, T>) -> Result<()>,
) -> Result<()> {
if axis >= 3 {
return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
}
for index in 0..self.shape[axis] {
f(index, self.matrix_at(axis, index)?)?;
}
Ok(())
}
#[inline]
pub(crate) fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
(self.offset
+ i as isize * self.strides[0]
+ j as isize * self.strides[1]
+ k as isize * self.strides[2]) as usize
}
}
impl<'a, T> ArrayViewMut3<'a, T> {
pub fn new(
data: &'a mut [T],
shape: [usize; 3],
strides: [isize; 3],
offset: isize,
) -> Result<Self> {
validate_view(data.len(), &shape, &strides, offset)?;
Ok(Self {
data,
shape,
strides,
offset,
})
}
pub(crate) fn from_raw_parts(
data: &'a mut [T],
shape: [usize; 3],
strides: [isize; 3],
offset: isize,
) -> Self {
Self {
data,
shape,
strides,
offset,
}
}
pub fn shape(&self) -> [usize; 3] {
self.shape
}
pub fn as_view(&self) -> ArrayView3<'_, T> {
ArrayView3 {
data: self.data,
shape: self.shape,
strides: self.strides,
offset: self.offset,
}
}
pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&T> {
(i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
.then(|| &self.data[self.linear_index(i, j, k)])
}
pub fn get_mut(&mut self, i: usize, j: usize, k: usize) -> Option<&mut T> {
if i >= self.shape[0] || j >= self.shape[1] || k >= self.shape[2] {
return None;
}
let index = self.linear_index(i, j, k);
Some(&mut self.data[index])
}
pub fn matrix_at_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMut2<'_, T>> {
if axis >= 3 {
return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
}
if index >= self.shape[axis] {
return Err(Error::IndexOutOfBounds);
}
let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
let offset = self.offset + index as isize * self.strides[axis];
ArrayViewMut2::new(
&mut *self.data,
[self.shape[axes[0]], self.shape[axes[1]]],
[self.strides[axes[0]], self.strides[axes[1]]],
offset,
)
}
pub fn for_each_matrix_mut(
&mut self,
axis: usize,
mut f: impl FnMut(usize, ArrayViewMut2<'_, T>) -> Result<()>,
) -> Result<()> {
if axis >= 3 {
return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
}
for index in 0..self.shape[axis] {
f(index, self.matrix_at_mut(axis, index)?)?;
}
Ok(())
}
#[inline]
pub(crate) fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
(self.offset
+ i as isize * self.strides[0]
+ j as isize * self.strides[1]
+ k as isize * self.strides[2]) as usize
}
}
impl<T> Index<(usize, usize, usize)> for ArrayView3<'_, T> {
type Output = T;
fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
self.get(index.0, index.1, index.2)
.expect("view index out of bounds")
}
}
impl<T> Index<(usize, usize, usize)> for ArrayViewMut3<'_, T> {
type Output = T;
fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
self.get(index.0, index.1, index.2)
.expect("view index out of bounds")
}
}
impl<T> IndexMut<(usize, usize, usize)> for ArrayViewMut3<'_, T> {
fn index_mut(&mut self, index: (usize, usize, usize)) -> &mut Self::Output {
self.get_mut(index.0, index.1, index.2)
.expect("view index out of bounds")
}
}