use core::ops::{Index, IndexMut};
use crate::array2::Array2;
use crate::error::{Error, Result};
#[derive(Clone, Copy, Debug)]
pub struct ArrayView2<'a, T> {
pub(crate) data: &'a [T],
pub(crate) shape: [usize; 2],
pub(crate) strides: [isize; 2],
pub(crate) offset: isize,
}
#[derive(Debug)]
pub struct ArrayViewMut2<'a, T> {
pub(crate) data: &'a mut [T],
pub(crate) shape: [usize; 2],
pub(crate) strides: [isize; 2],
pub(crate) offset: isize,
}
impl<'a, T> ArrayView2<'a, T> {
pub fn new(
data: &'a [T],
shape: [usize; 2],
strides: [isize; 2],
offset: isize,
) -> Result<Self> {
validate_view(data.len(), &shape, &strides, offset)?;
Ok(Self {
data,
shape,
strides,
offset,
})
}
pub(crate) fn from_raw_parts(
data: &'a [T],
shape: [usize; 2],
strides: [isize; 2],
offset: isize,
) -> Self {
Self {
data,
shape,
strides,
offset,
}
}
#[inline]
pub fn shape(&self) -> [usize; 2] {
self.shape
}
#[inline]
pub fn rows(&self) -> usize {
self.shape[0]
}
#[inline]
pub fn cols(&self) -> usize {
self.shape[1]
}
#[inline]
pub fn strides(&self) -> [isize; 2] {
self.strides
}
#[inline]
pub fn row_stride(&self) -> isize {
self.strides[0]
}
#[inline]
pub fn col_stride(&self) -> isize {
self.strides[1]
}
#[inline]
pub fn leading_dimension(&self) -> isize {
self.strides[0]
}
#[inline]
pub fn len(&self) -> usize {
self.shape[0] * self.shape[1]
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn is_contiguous(&self) -> bool {
is_compact_row_major(self.shape, self.strides)
}
pub fn as_slice(&self) -> Option<&'a [T]> {
if !self.is_contiguous() {
return None;
}
let start = self.offset as usize;
let end = start + self.len();
Some(&self.data[start..end])
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> Option<&'a T> {
if row >= self.rows() || col >= self.cols() {
return None;
}
Some(&self.data[self.linear_index(row, col)])
}
#[inline]
pub fn transpose(self) -> Self {
Self {
data: self.data,
shape: [self.shape[1], self.shape[0]],
strides: [self.strides[1], self.strides[0]],
offset: self.offset,
}
}
pub fn row(&self, row: usize) -> Result<Self> {
if row >= self.rows() {
return Err(Error::IndexOutOfBounds);
}
Ok(Self {
data: self.data,
shape: [1, self.cols()],
strides: self.strides,
offset: self.offset + row as isize * self.strides[0],
})
}
pub fn row_slice(&self, row: usize) -> Result<Option<&'a [T]>> {
if row >= self.rows() {
return Err(Error::IndexOutOfBounds);
}
if self.cols() == 0 {
return Ok(Some(&self.data[0..0]));
}
if self.strides[1] != 1 {
return Ok(None);
}
let start = self.linear_index(row, 0);
let end = start + self.cols();
Ok(Some(&self.data[start..end]))
}
pub fn col(&self, col: usize) -> Result<Self> {
if col >= self.cols() {
return Err(Error::IndexOutOfBounds);
}
Ok(Self {
data: self.data,
shape: [self.rows(), 1],
strides: self.strides,
offset: self.offset + col as isize * self.strides[1],
})
}
pub fn rows_range(&self, start: usize, end: usize) -> Result<Self> {
if start > end || end > self.rows() {
return Err(Error::IndexOutOfBounds);
}
Ok(Self {
data: self.data,
shape: [end - start, self.cols()],
strides: self.strides,
offset: self.offset + start as isize * self.strides[0],
})
}
pub fn cols_range(&self, start: usize, end: usize) -> Result<Self> {
if start > end || end > self.cols() {
return Err(Error::IndexOutOfBounds);
}
Ok(Self {
data: self.data,
shape: [self.rows(), end - start],
strides: self.strides,
offset: self.offset + start as isize * self.strides[1],
})
}
#[inline]
pub(crate) fn linear_index(&self, row: usize, col: usize) -> usize {
(self.offset + row as isize * self.strides[0] + col as isize * self.strides[1]) as usize
}
}
impl<T: Clone> ArrayView2<'_, T> {
pub fn to_row_major(&self) -> Array2<T> {
Array2::from_fn(self.shape, |i, j| self[(i, j)].clone())
}
pub fn to_col_major_vec(&self) -> Vec<T> {
let mut data = Vec::with_capacity(self.len());
for j in 0..self.cols() {
for i in 0..self.rows() {
data.push(self[(i, j)].clone());
}
}
data
}
}
impl<'a, T> ArrayViewMut2<'a, T> {
pub fn new(
data: &'a mut [T],
shape: [usize; 2],
strides: [isize; 2],
offset: isize,
) -> Result<Self> {
validate_view(data.len(), &shape, &strides, offset)?;
Ok(Self {
data,
shape,
strides,
offset,
})
}
pub(crate) fn from_raw_parts(
data: &'a mut [T],
shape: [usize; 2],
strides: [isize; 2],
offset: isize,
) -> Self {
Self {
data,
shape,
strides,
offset,
}
}
#[inline]
pub fn shape(&self) -> [usize; 2] {
self.shape
}
#[inline]
pub fn rows(&self) -> usize {
self.shape[0]
}
#[inline]
pub fn cols(&self) -> usize {
self.shape[1]
}
#[inline]
pub fn strides(&self) -> [isize; 2] {
self.strides
}
#[inline]
pub fn row_stride(&self) -> isize {
self.strides[0]
}
#[inline]
pub fn col_stride(&self) -> isize {
self.strides[1]
}
#[inline]
pub fn leading_dimension(&self) -> isize {
self.strides[0]
}
#[inline]
pub fn len(&self) -> usize {
self.shape[0] * self.shape[1]
}
#[inline]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub fn is_contiguous(&self) -> bool {
is_compact_row_major(self.shape, self.strides)
}
pub fn as_view(&self) -> ArrayView2<'_, T> {
ArrayView2 {
data: self.data,
shape: self.shape,
strides: self.strides,
offset: self.offset,
}
}
pub fn as_mut_slice(&mut self) -> Option<&mut [T]> {
if !self.is_contiguous() {
return None;
}
let start = self.offset as usize;
let end = start + self.len();
Some(&mut self.data[start..end])
}
#[inline]
pub fn get(&self, row: usize, col: usize) -> Option<&T> {
if row >= self.rows() || col >= self.cols() {
return None;
}
Some(&self.data[self.linear_index(row, col)])
}
#[inline]
pub fn get_mut(&mut self, row: usize, col: usize) -> Option<&mut T> {
if row >= self.rows() || col >= self.cols() {
return None;
}
let index = self.linear_index(row, col);
Some(&mut self.data[index])
}
pub fn transpose(self) -> Self {
Self {
data: self.data,
shape: [self.shape[1], self.shape[0]],
strides: [self.strides[1], self.strides[0]],
offset: self.offset,
}
}
pub fn row_mut(&mut self, row: usize) -> Result<ArrayViewMut2<'_, T>> {
if row >= self.rows() {
return Err(Error::IndexOutOfBounds);
}
let cols = self.cols();
let strides = self.strides;
let offset = self.offset + row as isize * strides[0];
Ok(ArrayViewMut2 {
data: &mut *self.data,
shape: [1, cols],
strides,
offset,
})
}
pub fn row_slice_mut(&mut self, row: usize) -> Result<Option<&mut [T]>> {
if row >= self.rows() {
return Err(Error::IndexOutOfBounds);
}
if self.cols() == 0 {
return Ok(Some(&mut self.data[0..0]));
}
if self.strides[1] != 1 {
return Ok(None);
}
let start = self.linear_index(row, 0);
let end = start + self.cols();
Ok(Some(&mut self.data[start..end]))
}
pub fn col_mut(&mut self, col: usize) -> Result<ArrayViewMut2<'_, T>> {
if col >= self.cols() {
return Err(Error::IndexOutOfBounds);
}
let rows = self.rows();
let strides = self.strides;
let offset = self.offset + col as isize * strides[1];
Ok(ArrayViewMut2 {
data: &mut *self.data,
shape: [rows, 1],
strides,
offset,
})
}
pub fn rows_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
if start > end || end > self.rows() {
return Err(Error::IndexOutOfBounds);
}
let cols = self.cols();
let strides = self.strides;
let offset = self.offset + start as isize * strides[0];
Ok(ArrayViewMut2 {
data: &mut *self.data,
shape: [end - start, cols],
strides,
offset,
})
}
pub fn cols_range_mut(&mut self, start: usize, end: usize) -> Result<ArrayViewMut2<'_, T>> {
if start > end || end > self.cols() {
return Err(Error::IndexOutOfBounds);
}
let rows = self.rows();
let strides = self.strides;
let offset = self.offset + start as isize * strides[1];
Ok(ArrayViewMut2 {
data: &mut *self.data,
shape: [rows, end - start],
strides,
offset,
})
}
#[inline]
pub(crate) fn linear_index(&self, row: usize, col: usize) -> usize {
(self.offset + row as isize * self.strides[0] + col as isize * self.strides[1]) as usize
}
}
impl<T: Clone> ArrayViewMut2<'_, T> {
pub fn to_row_major(&self) -> Array2<T> {
self.as_view().to_row_major()
}
pub fn to_col_major_vec(&self) -> Vec<T> {
self.as_view().to_col_major_vec()
}
pub fn copy_from_view(&mut self, other: ArrayView2<'_, T>) -> Result<()> {
if self.shape() != other.shape() {
return Err(Error::shape(self.shape(), other.shape()));
}
for i in 0..self.rows() {
for j in 0..self.cols() {
self[(i, j)] = other[(i, j)].clone();
}
}
Ok(())
}
}
#[inline]
pub(crate) fn is_compact_row_major(shape: [usize; 2], strides: [isize; 2]) -> bool {
shape[0] == 0
|| shape[1] == 0
|| (strides[1] == 1 && (shape[0] <= 1 || strides[0] == shape[1] as isize))
}
impl<T> Index<(usize, usize)> for ArrayView2<'_, T> {
type Output = T;
fn index(&self, index: (usize, usize)) -> &Self::Output {
self.get(index.0, index.1)
.expect("view index out of bounds")
}
}
impl<T> Index<(usize, usize)> for ArrayViewMut2<'_, T> {
type Output = T;
fn index(&self, index: (usize, usize)) -> &Self::Output {
self.get(index.0, index.1)
.expect("view index out of bounds")
}
}
impl<T> IndexMut<(usize, usize)> for ArrayViewMut2<'_, T> {
fn index_mut(&mut self, index: (usize, usize)) -> &mut Self::Output {
self.get_mut(index.0, index.1)
.expect("view index out of bounds")
}
}
pub(crate) fn validate_view(
len: usize,
shape: &[usize],
strides: &[isize],
offset: isize,
) -> Result<()> {
if shape.len() != strides.len() || offset < 0 {
return Err(Error::InvalidStride);
}
if shape.contains(&0) {
return Ok(());
}
let mut min = offset;
let mut max = offset;
for (&dim, &stride) in shape.iter().zip(strides) {
let span = (dim - 1) as isize * stride;
if span >= 0 {
max += span;
} else {
min += span;
}
}
if min < 0 || max < 0 || max as usize >= len {
return Err(Error::InvalidStride);
}
Ok(())
}