use crate::error::{Error, Result};
use crate::numeric::Float;
use crate::view2::validate_view;
#[derive(Clone, Debug, PartialEq)]
pub struct ArrayN<T> {
data: Vec<T>,
shape: Vec<usize>,
strides: Vec<isize>,
}
#[derive(Clone, Debug)]
pub struct ArrayViewN<'a, T> {
data: &'a [T],
shape: Vec<usize>,
strides: Vec<isize>,
offset: isize,
}
#[derive(Debug)]
pub struct ArrayViewMutN<'a, T> {
data: &'a mut [T],
shape: Vec<usize>,
strides: Vec<isize>,
offset: isize,
}
impl<T> ArrayN<T> {
pub fn from_vec(shape: Vec<usize>, data: Vec<T>) -> Result<Self> {
let expected = checked_len(&shape)?;
if data.len() != expected {
return Err(Error::shape(vec![expected], vec![data.len()]));
}
let strides = row_major_strides(&shape);
Ok(Self {
data,
shape,
strides,
})
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn strides(&self) -> &[isize] {
&self.strides
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn len(&self) -> usize {
self.data.len()
}
pub fn is_empty(&self) -> bool {
self.data.is_empty()
}
pub fn as_slice(&self) -> &[T] {
&self.data
}
pub fn as_mut_slice(&mut self) -> &mut [T] {
&mut self.data
}
pub fn view(&self) -> ArrayViewN<'_, T> {
ArrayViewN {
data: &self.data,
shape: self.shape.clone(),
strides: self.strides.clone(),
offset: 0,
}
}
pub fn view_mut(&mut self) -> ArrayViewMutN<'_, T> {
ArrayViewMutN {
data: &mut self.data,
shape: self.shape.clone(),
strides: self.strides.clone(),
offset: 0,
}
}
pub fn get(&self, index: &[usize]) -> Option<&T> {
self.linear_index(index).map(|idx| &self.data[idx])
}
pub fn slice_axis(&self, axis: usize, index: usize) -> Result<ArrayViewN<'_, T>> {
self.view().slice_axis(axis, index)
}
pub fn slice_axis_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMutN<'_, T>> {
if axis >= self.ndim() {
return Err(Error::AxisOutOfBounds {
axis,
ndim: self.ndim(),
});
}
if index >= self.shape[axis] {
return Err(Error::IndexOutOfBounds);
}
let mut shape = self.shape.clone();
let mut strides = self.strides.clone();
let offset = index as isize * strides[axis];
shape.remove(axis);
strides.remove(axis);
Ok(ArrayViewMutN {
data: &mut self.data,
shape,
strides,
offset,
})
}
pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
self.linear_index(index).map(|idx| &mut self.data[idx])
}
fn linear_index(&self, index: &[usize]) -> Option<usize> {
if index.len() != self.ndim() {
return None;
}
let mut linear = 0usize;
for ((&idx, &dim), &stride) in index.iter().zip(&self.shape).zip(&self.strides) {
if idx >= dim {
return None;
}
linear += idx * stride as usize;
}
Some(linear)
}
}
impl<T: Clone> ArrayN<T> {
pub fn filled(shape: Vec<usize>, value: T) -> Self {
let len = shape.iter().product();
let strides = row_major_strides(&shape);
Self {
data: vec![value; len],
shape,
strides,
}
}
pub fn try_filled(shape: Vec<usize>, value: T) -> Result<Self> {
let len = checked_len(&shape)?;
let strides = row_major_strides(&shape);
let mut data = Vec::new();
data.try_reserve_exact(len)
.map_err(|_| Error::AllocationFailed)?;
data.resize(len, value);
Ok(Self {
data,
shape,
strides,
})
}
}
impl<T: Float> ArrayN<T> {
pub fn zeros(shape: Vec<usize>) -> Self {
Self::filled(shape, T::zero())
}
pub fn try_zeros(shape: Vec<usize>) -> Result<Self> {
Self::try_filled(shape, T::zero())
}
pub fn ones(shape: Vec<usize>) -> Self {
Self::filled(shape, T::one())
}
pub fn try_ones(shape: Vec<usize>) -> Result<Self> {
Self::try_filled(shape, T::one())
}
}
impl<'a, T> ArrayViewN<'a, T> {
pub fn new(
data: &'a [T],
shape: &'a [usize],
strides: &'a [isize],
offset: isize,
) -> Result<Self> {
validate_view(data.len(), shape, strides, offset)?;
Ok(Self {
data,
shape: shape.to_vec(),
strides: strides.to_vec(),
offset,
})
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn strides(&self) -> &[isize] {
&self.strides
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn len(&self) -> usize {
self.shape.iter().product()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_contiguous(&self) -> bool {
is_compact_row_major(&self.shape, &self.strides)
}
pub fn as_slice(&self) -> Option<&'a [T]> {
if !self.is_contiguous() {
return None;
}
let start = self.offset as usize;
let end = start + self.len();
Some(&self.data[start..end])
}
pub fn get(&self, index: &[usize]) -> Option<&'a T> {
self.linear_index(index).map(|idx| &self.data[idx])
}
pub fn slice_axis(&self, axis: usize, index: usize) -> Result<Self> {
if axis >= self.ndim() {
return Err(Error::AxisOutOfBounds {
axis,
ndim: self.ndim(),
});
}
if index >= self.shape[axis] {
return Err(Error::IndexOutOfBounds);
}
let mut shape = self.shape.clone();
let mut strides = self.strides.clone();
let offset = self.offset + index as isize * strides[axis];
shape.remove(axis);
strides.remove(axis);
Ok(Self {
data: self.data,
shape,
strides,
offset,
})
}
fn linear_index(&self, index: &[usize]) -> Option<usize> {
if index.len() != self.ndim() {
return None;
}
let mut linear = self.offset;
for ((&idx, &dim), &stride) in index.iter().zip(&self.shape).zip(&self.strides) {
if idx >= dim {
return None;
}
linear += idx as isize * stride;
}
(linear >= 0).then_some(linear as usize)
}
}
impl<'a, T> ArrayViewMutN<'a, T> {
pub fn new(
data: &'a mut [T],
shape: Vec<usize>,
strides: Vec<isize>,
offset: isize,
) -> Result<Self> {
validate_view(data.len(), &shape, &strides, offset)?;
Ok(Self {
data,
shape,
strides,
offset,
})
}
pub fn shape(&self) -> &[usize] {
&self.shape
}
pub fn strides(&self) -> &[isize] {
&self.strides
}
pub fn ndim(&self) -> usize {
self.shape.len()
}
pub fn len(&self) -> usize {
self.shape.iter().product()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn is_contiguous(&self) -> bool {
is_compact_row_major(&self.shape, &self.strides)
}
pub fn as_mut_slice(&mut self) -> Option<&mut [T]> {
if !self.is_contiguous() {
return None;
}
let start = self.offset as usize;
let end = start + self.len();
Some(&mut self.data[start..end])
}
pub fn as_view(&self) -> ArrayViewN<'_, T> {
ArrayViewN {
data: self.data,
shape: self.shape.clone(),
strides: self.strides.clone(),
offset: self.offset,
}
}
pub fn get(&self, index: &[usize]) -> Option<&T> {
self.linear_index(index).map(|idx| &self.data[idx])
}
pub fn get_mut(&mut self, index: &[usize]) -> Option<&mut T> {
let linear = self.linear_index(index)?;
Some(&mut self.data[linear])
}
pub fn slice_axis_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMutN<'_, T>> {
if axis >= self.ndim() {
return Err(Error::AxisOutOfBounds {
axis,
ndim: self.ndim(),
});
}
if index >= self.shape[axis] {
return Err(Error::IndexOutOfBounds);
}
let mut shape = self.shape.clone();
let mut strides = self.strides.clone();
let offset = self.offset + index as isize * strides[axis];
shape.remove(axis);
strides.remove(axis);
Ok(ArrayViewMutN {
data: &mut *self.data,
shape,
strides,
offset,
})
}
fn linear_index(&self, index: &[usize]) -> Option<usize> {
if index.len() != self.ndim() {
return None;
}
let mut linear = self.offset;
for ((&idx, &dim), &stride) in index.iter().zip(&self.shape).zip(&self.strides) {
if idx >= dim {
return None;
}
linear += idx as isize * stride;
}
(linear >= 0).then_some(linear as usize)
}
}
fn checked_len(shape: &[usize]) -> Result<usize> {
shape
.iter()
.try_fold(1usize, |acc, &dim| acc.checked_mul(dim))
.ok_or(Error::DimensionTooLarge)
}
fn row_major_strides(shape: &[usize]) -> Vec<isize> {
let mut strides = vec![1isize; shape.len()];
let mut acc = 1isize;
for axis in (0..shape.len()).rev() {
strides[axis] = acc;
acc *= shape[axis] as isize;
}
strides
}
fn is_compact_row_major(shape: &[usize], strides: &[isize]) -> bool {
if shape.contains(&0) {
return true;
}
let mut expected = 1isize;
for (&dim, &stride) in shape.iter().zip(strides).rev() {
if dim > 1 && stride != expected {
return false;
}
expected *= dim as isize;
}
true
}