use crate::tensors::dimensions;
use crate::tensors::views::{DataLayout, TensorMut, TensorRef, TensorView};
use crate::tensors::{Dimension, InvalidDimensionsError, InvalidShapeError};
use std::error::Error;
use std::fmt;
use std::marker::PhantomData;
use std::num::NonZeroUsize;
pub use crate::matrices::views::IndexRange;
#[derive(Clone, Debug)]
pub struct TensorRange<T, S, const D: usize> {
source: S,
range: [IndexRange; D],
_type: PhantomData<T>,
}
#[derive(Clone, Debug)]
pub struct TensorMask<T, S, const D: usize> {
source: S,
mask: [IndexRange; D],
_type: PhantomData<T>,
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum IndexRangeValidationError<const D: usize, const P: usize> {
InvalidShape(InvalidShapeError<D>),
InvalidDimensions(InvalidDimensionsError<D, P>),
}
impl<const D: usize, const P: usize> fmt::Display for IndexRangeValidationError<D, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
IndexRangeValidationError::InvalidShape(error) => write!(f, "{:?}", error),
IndexRangeValidationError::InvalidDimensions(error) => write!(f, "{:?}", error),
}
}
}
impl<const D: usize, const P: usize> Error for IndexRangeValidationError<D, P> {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
IndexRangeValidationError::InvalidShape(error) => Some(error),
IndexRangeValidationError::InvalidDimensions(error) => Some(error),
}
}
}
#[derive(Clone, Debug, Eq, PartialEq)]
pub enum StrictIndexRangeValidationError<const D: usize, const P: usize> {
OutsideShape {
shape: [(Dimension, usize); D],
index_range: [Option<IndexRange>; D],
},
Error(IndexRangeValidationError<D, P>),
}
impl<const D: usize, const P: usize> fmt::Display for StrictIndexRangeValidationError<D, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
use StrictIndexRangeValidationError as S;
match self {
S::OutsideShape { shape, index_range } => write!(
f,
"IndexRange array {:?} is out of bounds of shape {:?}",
index_range, shape
),
S::Error(error) => write!(f, "{:?}", error),
}
}
}
impl<const D: usize, const P: usize> Error for StrictIndexRangeValidationError<D, P> {
fn source(&self) -> Option<&(dyn Error + 'static)> {
use StrictIndexRangeValidationError as S;
match self {
S::OutsideShape {
shape: _,
index_range: _,
} => None,
S::Error(error) => Some(error),
}
}
}
fn from_named_to_all_specific_error<T, S, R, const D: usize, const P: usize>(
source: &S,
ranges: [(Dimension, R); P],
) -> Result<[Option<IndexRange>; D], InvalidDimensionsError<D, P>>
where
S: TensorRef<T, D>,
R: Into<IndexRange>,
{
let shape = source.view_shape();
let ranges = ranges.map(|(d, r)| (d, r.into()));
let dimensions = InvalidDimensionsError {
provided: ranges.clone().map(|(d, _)| d),
valid: shape.map(|(d, _)| d),
};
if dimensions.has_duplicates() {
return Err(dimensions);
}
let mut all_ranges: [Option<IndexRange>; D] = std::array::from_fn(|_| None);
for (name, range) in ranges.into_iter() {
match crate::tensors::dimensions::position_of(&shape, name) {
Some(d) => all_ranges[d] = Some(range),
None => return Err(dimensions),
};
}
Ok(all_ranges)
}
fn from_named_to_all<T, S, R, const D: usize, const P: usize>(
source: &S,
ranges: [(Dimension, R); P],
) -> Result<[Option<IndexRange>; D], IndexRangeValidationError<D, P>>
where
S: TensorRef<T, D>,
R: Into<IndexRange>,
{
from_named_to_all_specific_error(source, ranges)
.map_err(|error| IndexRangeValidationError::InvalidDimensions(error))
}
impl<T, S, const D: usize> TensorRange<T, S, D>
where
S: TensorRef<T, D>,
{
pub fn from<R, const P: usize>(
source: S,
ranges: [(Dimension, R); P],
) -> Result<TensorRange<T, S, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
let all_ranges = from_named_to_all(&source, ranges)?;
match TensorRange::from_all(source, all_ranges) {
Ok(tensor_range) => Ok(tensor_range),
Err(invalid_shape) => Err(IndexRangeValidationError::InvalidShape(invalid_shape)),
}
}
pub fn from_strict<R, const P: usize>(
source: S,
ranges: [(Dimension, R); P],
) -> Result<TensorRange<T, S, D>, StrictIndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
use StrictIndexRangeValidationError as S;
let all_ranges = match from_named_to_all(&source, ranges) {
Ok(all_ranges) => all_ranges,
Err(error) => return Err(S::Error(error)),
};
match TensorRange::from_all_strict(source, all_ranges) {
Ok(tensor_range) => Ok(tensor_range),
Err(S::OutsideShape { shape, index_range }) => {
Err(S::OutsideShape { shape, index_range })
}
Err(S::Error(IndexRangeValidationError::InvalidShape(error))) => {
Err(S::Error(IndexRangeValidationError::InvalidShape(error)))
}
Err(S::Error(IndexRangeValidationError::InvalidDimensions(_))) => panic!(
"Unexpected InvalidDimensions error case after validating for InvalidDimensions already"
),
}
}
pub fn from_all<R>(
source: S,
ranges: [Option<R>; D],
) -> Result<TensorRange<T, S, D>, InvalidShapeError<D>>
where
R: Into<IndexRange>,
{
TensorRange::clip_from(
source,
ranges.map(|option| option.map(|range| range.into())),
)
}
fn clip_from(
source: S,
ranges: [Option<IndexRange>; D],
) -> Result<TensorRange<T, S, D>, InvalidShapeError<D>> {
let shape = source.view_shape();
let mut ranges = std::array::from_fn(|d| {
ranges[d]
.clone()
.unwrap_or_else(|| IndexRange::new(0, shape[d].1))
});
let shape = InvalidShapeError {
shape: clip_range_shape(&shape, &mut ranges),
};
if !shape.is_valid() {
return Err(shape);
}
Ok(TensorRange {
source,
range: ranges,
_type: PhantomData,
})
}
pub fn from_all_strict<R>(
source: S,
range: [Option<R>; D],
) -> Result<TensorRange<T, S, D>, StrictIndexRangeValidationError<D, D>>
where
R: Into<IndexRange>,
{
let shape = source.view_shape();
let range = range.map(|option| option.map(|range| range.into()));
if range_exceeds_bounds(&shape, &range) {
return Err(StrictIndexRangeValidationError::OutsideShape {
shape,
index_range: range,
});
}
match TensorRange::clip_from(source, range) {
Ok(tensor_range) => Ok(tensor_range),
Err(invalid_shape) => Err(StrictIndexRangeValidationError::Error(
IndexRangeValidationError::InvalidShape(invalid_shape),
)),
}
}
#[allow(dead_code)]
pub fn source(self) -> S {
self.source
}
#[allow(dead_code)]
pub fn source_ref(&self) -> &S {
&self.source
}
}
fn range_exceeds_bounds<const D: usize>(
source: &[(Dimension, usize); D],
range: &[Option<IndexRange>; D],
) -> bool {
for (d, (_, end)) in source.iter().enumerate() {
let end = *end;
match &range[d] {
None => continue,
Some(range) => {
let range_end = range.start + range.length;
if range_end > end {
return true;
};
}
}
}
false
}
fn clip_range_shape<const D: usize>(
source: &[(Dimension, usize); D],
range: &mut [IndexRange; D],
) -> [(Dimension, usize); D] {
let mut shape = *source;
for (d, (_, length)) in shape.iter_mut().enumerate() {
let range = &mut range[d];
range.clip(*length);
*length = range.length;
}
shape
}
impl<T, S, const D: usize> TensorMask<T, S, D>
where
S: TensorRef<T, D>,
{
pub fn from<R, const P: usize>(
source: S,
masks: [(Dimension, R); P],
) -> Result<TensorMask<T, S, D>, IndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
let all_masks = from_named_to_all(&source, masks)?;
match TensorMask::from_all(source, all_masks) {
Ok(tensor_mask) => Ok(tensor_mask),
Err(invalid_shape) => Err(IndexRangeValidationError::InvalidShape(invalid_shape)),
}
}
pub fn from_strict<R, const P: usize>(
source: S,
masks: [(Dimension, R); P],
) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, P>>
where
R: Into<IndexRange>,
{
use StrictIndexRangeValidationError as S;
let all_masks = match from_named_to_all(&source, masks) {
Ok(all_masks) => all_masks,
Err(error) => return Err(S::Error(error)),
};
match TensorMask::from_all_strict(source, all_masks) {
Ok(tensor_mask) => Ok(tensor_mask),
Err(S::OutsideShape { shape, index_range }) => {
Err(S::OutsideShape { shape, index_range })
}
Err(S::Error(IndexRangeValidationError::InvalidShape(error))) => {
Err(S::Error(IndexRangeValidationError::InvalidShape(error)))
}
Err(S::Error(IndexRangeValidationError::InvalidDimensions(_))) => panic!(
"Unexpected InvalidDimensions error case after validating for InvalidDimensions already"
),
}
}
pub fn from_all<R>(
source: S,
mask: [Option<R>; D],
) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>>
where
R: Into<IndexRange>,
{
TensorMask::clip_from(source, mask.map(|option| option.map(|mask| mask.into())))
}
fn clip_from(
source: S,
masks: [Option<IndexRange>; D],
) -> Result<TensorMask<T, S, D>, InvalidShapeError<D>> {
let shape = source.view_shape();
let mut masks = masks.map(|option| option.unwrap_or_else(|| IndexRange::new(0, 0)));
let shape = InvalidShapeError {
shape: clip_masked_shape(&shape, &mut masks),
};
if !shape.is_valid() {
return Err(shape);
}
Ok(TensorMask {
source,
mask: masks,
_type: PhantomData,
})
}
pub fn from_all_strict<R>(
source: S,
masks: [Option<R>; D],
) -> Result<TensorMask<T, S, D>, StrictIndexRangeValidationError<D, D>>
where
R: Into<IndexRange>,
{
let shape = source.view_shape();
let masks = masks.map(|option| option.map(|mask| mask.into()));
if mask_exceeds_bounds(&shape, &masks) {
return Err(StrictIndexRangeValidationError::OutsideShape {
shape,
index_range: masks,
});
}
match TensorMask::clip_from(source, masks) {
Ok(tensor_mask) => Ok(tensor_mask),
Err(invalid_shape) => Err(StrictIndexRangeValidationError::Error(
IndexRangeValidationError::InvalidShape(invalid_shape),
)),
}
}
pub fn start_and_end_of(
source: S,
dimension: Dimension,
start_and_end: NonZeroUsize,
) -> Result<TensorMask<T, S, D>, InvalidDimensionsError<D, 1>> {
let shape = source.view_shape();
let range = match dimensions::length_of(&shape, dimension) {
None => {
return Err(InvalidDimensionsError::new(
[dimension],
dimensions::names_of(&shape),
));
}
Some(length) => {
let x = start_and_end.get();
let retain_start = std::cmp::min(x, length - 1);
let retain_end = length.saturating_sub(x);
let mut range: IndexRange = (retain_start..retain_end).into();
range.clip(length - 1);
range
}
};
Ok(TensorMask {
source,
mask: std::array::from_fn(|d| {
if shape[d].0 == dimension {
range.clone()
} else {
IndexRange::new(0, 0)
}
}),
_type: PhantomData,
})
}
#[track_caller]
pub(crate) fn panicking_start_and_end_of(
source: S,
dimension: Dimension,
start_and_end: usize,
) -> TensorView<T, TensorMask<T, S, D>, D> {
match NonZeroUsize::new(start_and_end) {
Some(non_zero) => match TensorMask::start_and_end_of(source, dimension, non_zero) {
Ok(tensor) => TensorView::from(tensor),
Err(error) => panic!(
"Dimension name provided {:?} must be in the set of dimension names in the tensor: {:?}",
dimension, error.valid,
),
},
None => panic!("start_and_end must be greater than 0"),
}
}
#[allow(dead_code)]
pub fn source(self) -> S {
self.source
}
#[allow(dead_code)]
pub fn source_ref(&self) -> &S {
&self.source
}
}
fn clip_masked_shape<const D: usize>(
source: &[(Dimension, usize); D],
mask: &mut [IndexRange; D],
) -> [(Dimension, usize); D] {
let mut shape = *source;
for (d, (_, length)) in shape.iter_mut().enumerate() {
let mask = &mut mask[d];
mask.clip(*length);
*length -= mask.length;
}
shape
}
fn mask_exceeds_bounds<const D: usize>(
source: &[(Dimension, usize); D],
mask: &[Option<IndexRange>; D],
) -> bool {
range_exceeds_bounds(source, mask)
}
fn map_indexes_by_range<const D: usize>(
indexes: [usize; D],
ranges: &[IndexRange; D],
) -> Option<[usize; D]> {
let mut mapped = [0; D];
for (d, (r, i)) in ranges.iter().zip(indexes.into_iter()).enumerate() {
mapped[d] = r.map(i)?;
}
Some(mapped)
}
unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorRange<T, S, D>
where
S: TensorRef<T, D>,
{
fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
self.source
.get_reference(map_indexes_by_range(indexes, &self.range)?)
}
fn view_shape(&self) -> [(Dimension, usize); D] {
let mut shape = self.source.view_shape();
for (pair, range) in shape.iter_mut().zip(self.range.iter()) {
pair.1 = range.length;
}
shape
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
unsafe {
self.source.get_reference_unchecked(
map_indexes_by_range(indexes, &self.range).unwrap_unchecked(),
)
}
}
fn data_layout(&self) -> DataLayout<D> {
DataLayout::NonLinear
}
}
unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorRange<T, S, D>
where
S: TensorMut<T, D>,
{
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
self.source
.get_reference_mut(map_indexes_by_range(indexes, &self.range)?)
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
unsafe {
self.source.get_reference_unchecked_mut(
map_indexes_by_range(indexes, &self.range).unwrap_unchecked(),
)
}
}
}
fn map_indexes_by_mask<const D: usize>(indexes: [usize; D], masks: &[IndexRange; D]) -> [usize; D] {
let mut mapped = [0; D];
for (d, (r, i)) in masks.iter().zip(indexes.into_iter()).enumerate() {
mapped[d] = r.mask(i);
}
mapped
}
unsafe impl<T, S, const D: usize> TensorRef<T, D> for TensorMask<T, S, D>
where
S: TensorRef<T, D>,
{
fn get_reference(&self, indexes: [usize; D]) -> Option<&T> {
self.source
.get_reference(map_indexes_by_mask(indexes, &self.mask))
}
fn view_shape(&self) -> [(Dimension, usize); D] {
let mut shape = self.source.view_shape();
for (pair, mask) in shape.iter_mut().zip(self.mask.iter()) {
pair.1 -= mask.length;
}
shape
}
unsafe fn get_reference_unchecked(&self, indexes: [usize; D]) -> &T {
unsafe {
self.source
.get_reference_unchecked(map_indexes_by_mask(indexes, &self.mask))
}
}
fn data_layout(&self) -> DataLayout<D> {
DataLayout::NonLinear
}
}
unsafe impl<T, S, const D: usize> TensorMut<T, D> for TensorMask<T, S, D>
where
S: TensorMut<T, D>,
{
fn get_reference_mut(&mut self, indexes: [usize; D]) -> Option<&mut T> {
self.source
.get_reference_mut(map_indexes_by_mask(indexes, &self.mask))
}
unsafe fn get_reference_unchecked_mut(&mut self, indexes: [usize; D]) -> &mut T {
unsafe {
self.source
.get_reference_unchecked_mut(map_indexes_by_mask(indexes, &self.mask))
}
}
}
#[test]
#[rustfmt::skip]
fn test_constructors() {
use crate::tensors::Tensor;
use crate::tensors::views::TensorView;
let tensor = Tensor::from([("rows", 3), ("columns", 3)], (0..9).collect());
assert_eq!(
TensorView::from(TensorRange::from(&tensor, [("rows", IndexRange::new(1, 2))]).unwrap()),
Tensor::from([("rows", 2), ("columns", 3)], vec![
3, 4, 5,
6, 7, 8
])
);
assert_eq!(
TensorView::from(TensorRange::from(&tensor, [("columns", 2..3)]).unwrap()),
Tensor::from([("rows", 3), ("columns", 1)], vec![
2,
5,
8
])
);
assert_eq!(
TensorView::from(TensorRange::from(&tensor, [("rows", (1, 1)), ("columns", (2, 1))]).unwrap()),
Tensor::from([("rows", 1), ("columns", 1)], vec![5])
);
assert_eq!(
TensorView::from(TensorRange::from(&tensor, [("columns", 1..3)]).unwrap()),
Tensor::from([("rows", 3), ("columns", 2)], vec![
1, 2,
4, 5,
7, 8
])
);
assert_eq!(
TensorView::from(TensorMask::from(&tensor, [("rows", IndexRange::new(1, 1))]).unwrap()),
Tensor::from([("rows", 2), ("columns", 3)], vec![
0, 1, 2,
6, 7, 8
])
);
assert_eq!(
TensorView::from(TensorMask::from(&tensor, [("rows", 2..3), ("columns", 0..1)]).unwrap()),
Tensor::from([("rows", 2), ("columns", 2)], vec![
1, 2,
4, 5
])
);
use IndexRangeValidationError as IRVError;
use InvalidShapeError as ShapeError;
use StrictIndexRangeValidationError::Error as SError;
use StrictIndexRangeValidationError::OutsideShape as OutsideShape;
use InvalidDimensionsError as DError;
assert_eq!(
TensorRange::from(&tensor, [("invalid", 1..2)]).unwrap_err(),
IRVError::InvalidDimensions(DError::new(["invalid"], ["rows", "columns"]))
);
assert_eq!(
TensorMask::from(&tensor, [("wrong", 0..1)]).unwrap_err(),
IRVError::InvalidDimensions(DError::new(["wrong"], ["rows", "columns"]))
);
assert_eq!(
TensorRange::from_strict(&tensor, [("invalid", 1..2)]).unwrap_err(),
SError(IRVError::InvalidDimensions(DError::new(["invalid"], ["rows", "columns"])))
);
assert_eq!(
TensorMask::from_strict(&tensor, [("wrong", 0..1)]).unwrap_err(),
SError(IRVError::InvalidDimensions(DError::new(["wrong"], ["rows", "columns"])))
);
assert_eq!(
TensorRange::from(&tensor, [("rows", 0..0)]).unwrap_err(),
IRVError::InvalidShape(ShapeError::new([("rows", 0), ("columns", 3)]))
);
assert_eq!(
TensorMask::from(&tensor, [("columns", 0..3)]).unwrap_err(),
IRVError::InvalidShape(ShapeError::new([("rows", 3), ("columns", 0)]))
);
assert_eq!(
TensorRange::from_strict(&tensor, [("rows", 0..0)]).unwrap_err(),
SError(IRVError::InvalidShape(ShapeError::new([("rows", 0), ("columns", 3)])))
);
assert_eq!(
TensorMask::from_strict(&tensor, [("columns", 0..3)]).unwrap_err(),
SError(IRVError::InvalidShape(ShapeError::new([("rows", 3), ("columns", 0)])))
);
assert_eq!(
TensorRange::from(&tensor, [("rows", 1..2), ("rows", 2..3)]).unwrap_err(),
IRVError::InvalidDimensions(DError::new(["rows", "rows"], ["rows", "columns"]))
);
assert_eq!(
TensorMask::from(&tensor, [("columns", 1..2), ("columns", 2..3)]).unwrap_err(),
IRVError::InvalidDimensions(DError::new(["columns", "columns"], ["rows", "columns"]))
);
assert_eq!(
TensorRange::from_strict(&tensor, [("rows", 1..2), ("rows", 2..3)]).unwrap_err(),
SError(IRVError::InvalidDimensions(DError::new(["rows", "rows"], ["rows", "columns"])))
);
assert_eq!(
TensorMask::from_strict(&tensor, [("columns", 1..2), ("columns", 2..3)]).unwrap_err(),
SError(IRVError::InvalidDimensions(DError::new(["columns", "columns"], ["rows", "columns"])))
);
assert!(
TensorView::from(TensorRange::from(&tensor, [("rows", 0..4)]).unwrap()).eq(&tensor),
);
assert_eq!(
TensorRange::from_strict(&tensor, [("rows", 0..4)]).unwrap_err(),
OutsideShape {
shape: [("rows", 3), ("columns", 3)],
index_range: [Some(IndexRange::new(0, 4)), None],
}
);
assert_eq!(
TensorView::from(TensorMask::from(&tensor, [("columns", 1..4)]).unwrap()),
Tensor::from([("rows", 3), ("columns", 1)], vec![
0,
3,
6,
])
);
assert_eq!(
TensorMask::from_strict(&tensor, [("columns", 1..4)]).unwrap_err(),
OutsideShape {
shape: [("rows", 3), ("columns", 3)],
index_range: [None, Some(IndexRange::new(1, 3))],
}
);
}