tensors-rs 0.1.2

Compact NumPy-like dense tensor primitives for safe numerical Rust.
Documentation
//! Three-dimensional strided array views.

use core::ops::{Index, IndexMut};

use crate::error::{Error, Result};
use crate::view2::{ArrayView2, ArrayViewMut2, validate_view};

/// Immutable 3D array view.
#[derive(Clone, Copy, Debug)]
pub struct ArrayView3<'a, T> {
    pub(crate) data: &'a [T],
    pub(crate) shape: [usize; 3],
    pub(crate) strides: [isize; 3],
    pub(crate) offset: isize,
}

/// Mutable 3D array view.
#[derive(Debug)]
pub struct ArrayViewMut3<'a, T> {
    pub(crate) data: &'a mut [T],
    pub(crate) shape: [usize; 3],
    pub(crate) strides: [isize; 3],
    pub(crate) offset: isize,
}

impl<'a, T> ArrayView3<'a, T> {
    /// Create a checked immutable view.
    pub fn new(
        data: &'a [T],
        shape: [usize; 3],
        strides: [isize; 3],
        offset: isize,
    ) -> Result<Self> {
        validate_view(data.len(), &shape, &strides, offset)?;
        Ok(Self {
            data,
            shape,
            strides,
            offset,
        })
    }

    pub(crate) fn from_raw_parts(
        data: &'a [T],
        shape: [usize; 3],
        strides: [isize; 3],
        offset: isize,
    ) -> Self {
        Self {
            data,
            shape,
            strides,
            offset,
        }
    }

    /// Shape as `[dim0, dim1, dim2]`.
    pub fn shape(&self) -> [usize; 3] {
        self.shape
    }

    /// Strides in elements.
    pub fn strides(&self) -> [isize; 3] {
        self.strides
    }

    /// Number of logical elements.
    pub fn len(&self) -> usize {
        self.shape.iter().product()
    }

    /// Whether the view is empty.
    pub fn is_empty(&self) -> bool {
        self.len() == 0
    }

    /// Whether the view is compact row-major contiguous.
    pub fn is_contiguous(&self) -> bool {
        self.shape.contains(&0)
            || (self.offset == 0
                && self.strides
                    == [
                        (self.shape[1] * self.shape[2]) as isize,
                        self.shape[2] as isize,
                        1,
                    ]
                && self.len() == self.data.len())
    }

    /// Borrow the backing slice if this view covers it contiguously.
    pub fn as_slice(&self) -> Option<&'a [T]> {
        self.is_contiguous().then_some(self.data)
    }

    /// Get an element reference.
    pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&'a T> {
        (i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
            .then(|| &self.data[self.linear_index(i, j, k)])
    }

    /// Extract a 2D matrix view by fixing one axis.
    pub fn matrix_at(&self, axis: usize, index: usize) -> Result<ArrayView2<'a, T>> {
        if axis >= 3 {
            return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
        }
        if index >= self.shape[axis] {
            return Err(Error::IndexOutOfBounds);
        }
        let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
        Ok(ArrayView2::from_raw_parts(
            self.data,
            [self.shape[axes[0]], self.shape[axes[1]]],
            [self.strides[axes[0]], self.strides[axes[1]]],
            self.offset + index as isize * self.strides[axis],
        ))
    }

    /// Visit each 2D matrix slice along `axis` in order.
    pub fn for_each_matrix(
        &self,
        axis: usize,
        mut f: impl FnMut(usize, ArrayView2<'a, T>) -> Result<()>,
    ) -> Result<()> {
        if axis >= 3 {
            return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
        }
        for index in 0..self.shape[axis] {
            f(index, self.matrix_at(axis, index)?)?;
        }
        Ok(())
    }

    #[inline]
    pub(crate) fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
        (self.offset
            + i as isize * self.strides[0]
            + j as isize * self.strides[1]
            + k as isize * self.strides[2]) as usize
    }
}

impl<'a, T> ArrayViewMut3<'a, T> {
    /// Create a checked mutable view.
    pub fn new(
        data: &'a mut [T],
        shape: [usize; 3],
        strides: [isize; 3],
        offset: isize,
    ) -> Result<Self> {
        validate_view(data.len(), &shape, &strides, offset)?;
        Ok(Self {
            data,
            shape,
            strides,
            offset,
        })
    }

    pub(crate) fn from_raw_parts(
        data: &'a mut [T],
        shape: [usize; 3],
        strides: [isize; 3],
        offset: isize,
    ) -> Self {
        Self {
            data,
            shape,
            strides,
            offset,
        }
    }

    /// Shape as `[dim0, dim1, dim2]`.
    pub fn shape(&self) -> [usize; 3] {
        self.shape
    }

    /// Immutable view over the same region.
    pub fn as_view(&self) -> ArrayView3<'_, T> {
        ArrayView3 {
            data: self.data,
            shape: self.shape,
            strides: self.strides,
            offset: self.offset,
        }
    }

    /// Get an element reference.
    pub fn get(&self, i: usize, j: usize, k: usize) -> Option<&T> {
        (i < self.shape[0] && j < self.shape[1] && k < self.shape[2])
            .then(|| &self.data[self.linear_index(i, j, k)])
    }

    /// Get a mutable element reference.
    pub fn get_mut(&mut self, i: usize, j: usize, k: usize) -> Option<&mut T> {
        if i >= self.shape[0] || j >= self.shape[1] || k >= self.shape[2] {
            return None;
        }
        let index = self.linear_index(i, j, k);
        Some(&mut self.data[index])
    }

    /// Extract a mutable 2D matrix view by fixing one axis.
    pub fn matrix_at_mut(&mut self, axis: usize, index: usize) -> Result<ArrayViewMut2<'_, T>> {
        if axis >= 3 {
            return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
        }
        if index >= self.shape[axis] {
            return Err(Error::IndexOutOfBounds);
        }
        let axes: Vec<usize> = (0..3).filter(|&candidate| candidate != axis).collect();
        let offset = self.offset + index as isize * self.strides[axis];
        ArrayViewMut2::new(
            &mut *self.data,
            [self.shape[axes[0]], self.shape[axes[1]]],
            [self.strides[axes[0]], self.strides[axes[1]]],
            offset,
        )
    }

    /// Visit each mutable 2D matrix slice along `axis` in order.
    pub fn for_each_matrix_mut(
        &mut self,
        axis: usize,
        mut f: impl FnMut(usize, ArrayViewMut2<'_, T>) -> Result<()>,
    ) -> Result<()> {
        if axis >= 3 {
            return Err(Error::AxisOutOfBounds { axis, ndim: 3 });
        }
        for index in 0..self.shape[axis] {
            f(index, self.matrix_at_mut(axis, index)?)?;
        }
        Ok(())
    }

    #[inline]
    pub(crate) fn linear_index(&self, i: usize, j: usize, k: usize) -> usize {
        (self.offset
            + i as isize * self.strides[0]
            + j as isize * self.strides[1]
            + k as isize * self.strides[2]) as usize
    }
}

impl<T> Index<(usize, usize, usize)> for ArrayView3<'_, T> {
    type Output = T;

    fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
        self.get(index.0, index.1, index.2)
            .expect("view index out of bounds")
    }
}

impl<T> Index<(usize, usize, usize)> for ArrayViewMut3<'_, T> {
    type Output = T;

    fn index(&self, index: (usize, usize, usize)) -> &Self::Output {
        self.get(index.0, index.1, index.2)
            .expect("view index out of bounds")
    }
}

impl<T> IndexMut<(usize, usize, usize)> for ArrayViewMut3<'_, T> {
    fn index_mut(&mut self, index: (usize, usize, usize)) -> &mut Self::Output {
        self.get_mut(index.0, index.1, index.2)
            .expect("view index out of bounds")
    }
}