use std::iter::repeat;
use std::ops::Range;
use smallvec::{smallvec, SmallVec};
use crate::errors::{DimensionError, FromDataError, ReshapeError, SliceError};
use crate::index_iterator::{DynIndices, NdIndices};
use crate::overlap::{is_contiguous, may_have_internal_overlap};
use crate::slice_range::{IntoSliceItems, SliceItem};
use crate::type_num::{OptionalUInt, Unknown, U0, U1, U2, U3, U4, U5};
pub fn is_valid_permutation(ndim: usize, permutation: &[usize]) -> bool {
permutation.len() == ndim
&& (0..ndim).all(|dim| permutation.iter().filter(|d| **d == dim).count() == 1)
}
pub trait Layout {
type Index<'a>: AsRef<[usize]> + Clone + std::fmt::Debug + PartialEq<Self::Index<'a>>;
type Indices;
#[inline]
fn offset(&self, index: Self::Index<'_>) -> usize {
self.try_offset(index.clone()).unwrap_or_else(|| {
panic!(
"index {:?} out of bounds for shape {:?}",
index.as_ref(),
self.shape().as_ref()
);
})
}
fn offset_unchecked(&self, index: Self::Index<'_>) -> usize {
index
.as_ref()
.iter()
.zip(self.strides().as_ref())
.map(|(idx, stride)| *idx * *stride)
.sum()
}
fn try_offset(&self, index: Self::Index<'_>) -> Option<usize>;
fn ndim(&self) -> usize;
fn len(&self) -> usize;
fn is_contiguous(&self) -> bool {
is_contiguous(self.shape(), self.strides())
}
fn is_broadcast(&self) -> bool {
!self.is_empty() && self.strides().as_ref().iter().any(|&stride| stride == 0)
}
fn is_empty(&self) -> bool {
self.len() == 0
}
fn shape(&self) -> Self::Index<'_>;
fn size(&self, dim: usize) -> usize {
self.shape().as_ref()[dim]
}
fn strides(&self) -> Self::Index<'_>;
fn stride(&self, dim: usize) -> usize {
self.strides().as_ref()[dim]
}
fn indices(&self) -> Self::Indices;
fn can_broadcast_to(&self, target_shape: &[usize]) -> bool {
if self.shape().as_ref() == target_shape {
return true;
} else if self.ndim() > target_shape.len() {
return false;
}
let target_dims = target_shape[target_shape.len() - self.shape().as_ref().len()..]
.iter()
.copied();
self.shape()
.as_ref()
.iter()
.copied()
.zip(target_dims)
.all(|(a, b)| a == b || a == 1)
}
fn can_broadcast_with(&self, shape: &[usize]) -> bool {
if self.shape().as_ref() == shape {
return true;
}
let current_shape = self.shape();
let a = current_shape.as_ref();
let b = shape;
let a_pad = b.len().saturating_sub(a.len());
let b_pad = a.len().saturating_sub(b.len());
let a_iter = a.iter().copied().rev().chain(repeat(1).take(a_pad));
let b_iter = b.iter().copied().rev().chain(repeat(1).take(b_pad));
a_iter.zip(b_iter).all(|(a, b)| a == b || a == 1 || b == 1)
}
fn min_data_len(&self) -> usize {
if self.shape().as_ref().iter().any(|&size| size == 0) {
return 0;
}
let max_offset: usize = self
.shape()
.as_ref()
.iter()
.zip(self.strides().as_ref())
.map(|(size, stride)| (size - 1) * stride)
.sum();
max_offset + 1
}
}
pub trait MatrixLayout {
fn rows(&self) -> usize;
fn cols(&self) -> usize;
fn row_stride(&self) -> usize;
fn col_stride(&self) -> usize;
}
pub enum OverlapPolicy {
AllowOverlap,
DisallowOverlap,
}
#[derive(Clone, Copy, Debug, PartialEq)]
pub struct NdLayout<const N: usize> {
shape: [usize; N],
strides: [usize; N],
}
impl<const N: usize> Layout for NdLayout<N> {
type Index<'a> = [usize; N];
type Indices = NdIndices<N>;
fn ndim(&self) -> usize {
N
}
fn len(&self) -> usize {
self.shape.iter().product()
}
#[inline]
fn try_offset(&self, index: [usize; N]) -> Option<usize> {
if !self.index_valid(index) {
return None;
}
Some(self.offset_unchecked(index))
}
#[inline]
fn offset_unchecked(&self, index: [usize; N]) -> usize {
let mut offset = 0;
for i in 0..N {
offset += index[i] * self.strides[i];
}
offset
}
#[inline]
fn shape(&self) -> Self::Index<'_> {
self.shape
}
#[inline]
fn strides(&self) -> Self::Index<'_> {
self.strides
}
fn indices(&self) -> Self::Indices {
NdIndices::from_shape(self.shape)
}
}
impl MatrixLayout for NdLayout<2> {
#[inline]
fn rows(&self) -> usize {
self.size(0)
}
#[inline]
fn cols(&self) -> usize {
self.size(1)
}
#[inline]
fn row_stride(&self) -> usize {
self.stride(0)
}
#[inline]
fn col_stride(&self) -> usize {
self.stride(1)
}
}
fn slice_layout<I: AsRef<[usize]>, O: AsMut<[usize]>>(
in_shape: I,
in_strides: I,
mut out_shape: O,
mut out_strides: O,
range: &[SliceItem],
) -> Result<(usize, usize), SliceError> {
let in_shape = in_shape.as_ref();
let in_strides = in_strides.as_ref();
let out_shape = out_shape.as_mut();
let out_strides = out_strides.as_mut();
let mut ndim = 0;
let mut offset = 0;
for (in_dim, (&size, &stride)) in in_shape.iter().zip(in_strides.iter()).enumerate() {
let (offset_adjust, new_size_stride) = match range.get(in_dim) {
Some(&SliceItem::Index(idx)) => {
let pos_idx = if idx >= 0 { idx } else { idx + size as isize };
if pos_idx < 0 || pos_idx >= size as isize {
return Err(SliceError::InvalidIndex {
axis: in_dim,
index: idx,
size,
});
}
(stride * pos_idx as usize, None)
}
Some(SliceItem::Range(range)) => {
let resolved = range.resolve(size).ok_or(SliceError::InvalidRange {
axis: in_dim,
range: *range,
size,
})?;
let step: usize = range
.step()
.try_into()
.map_err(|_| SliceError::InvalidStep {
axis: in_dim,
step: range.step(),
})?;
let new_size = if step == 1 {
resolved.end - resolved.start
} else {
range.index_range(size).steps()
};
let new_stride = stride * step;
(stride * resolved.start, Some((new_size, new_stride)))
}
None => (0, Some((size, stride))),
};
offset += offset_adjust;
if let Some((new_size, new_stride)) = new_size_stride {
out_shape[ndim] = new_size;
out_strides[ndim] = new_stride;
ndim += 1;
}
}
if out_shape.iter().any(|size| *size == 0) {
offset = 0;
}
Ok((ndim, offset))
}
fn broadcast_strides<'a>(
from_shape: &'a [usize],
from_strides: &'a [usize],
to_shape: &'a [usize],
) -> impl Iterator<Item = usize> + 'a {
let pad = to_shape.len() - from_shape.len();
repeat(0)
.take(pad)
.chain(from_shape.iter().zip(from_strides.iter()).enumerate().map(
move |(i, (size, stride))| {
if *size == 1 && to_shape[i + pad] > 1 {
0
} else {
*stride
}
},
))
}
impl<const N: usize> NdLayout<N> {
pub fn from_dyn(l: DynLayout) -> Self {
assert!(l.ndim() == N, "Dynamic layout dims != {}", N);
NdLayout {
shape: l.shape().try_into().unwrap(),
strides: l.strides().try_into().unwrap(),
}
}
pub fn as_dyn(&self) -> DynLayout {
self.into()
}
pub fn index_valid(&self, index: [usize; N]) -> bool {
let mut valid = true;
for i in 0..N {
valid = valid && index[i] < self.shape[i]
}
valid
}
pub fn contiguous_strides(shape: [usize; N]) -> [usize; N] {
let mut strides = [0; N];
for i in 0..N {
strides[i] = shape[i + 1..].iter().product();
}
strides
}
pub fn from_shape(shape: [usize; N]) -> Self {
Self {
shape,
strides: Self::contiguous_strides(shape),
}
}
pub fn try_from_shape_and_strides(
shape: [usize; N],
strides: [usize; N],
overlap: OverlapPolicy,
) -> Result<NdLayout<N>, FromDataError> {
let layout = NdLayout { shape, strides };
match overlap {
OverlapPolicy::DisallowOverlap => {
if may_have_internal_overlap(&layout.shape, &layout.strides) {
return Err(FromDataError::MayOverlap);
}
}
OverlapPolicy::AllowOverlap => {}
}
Ok(layout)
}
pub fn broadcast<const M: usize>(&self, to_shape: [usize; M]) -> NdLayout<M> {
assert!(
self.can_broadcast_to(&to_shape),
"Cannot broadcast to specified shape"
);
let mut strides = [0usize; M];
for (i, stride) in broadcast_strides(&self.shape(), &self.strides(), &to_shape).enumerate()
{
strides[i] = stride;
}
NdLayout {
shape: to_shape,
strides,
}
}
pub fn permuted(&self, dims: [usize; N]) -> Self {
assert!(is_valid_permutation(N, &dims), "permutation is invalid");
let mut shape = [0; N];
let mut strides = [0; N];
for i in 0..N {
shape[i] = self.shape[dims[i]];
strides[i] = self.strides[dims[i]];
}
NdLayout { shape, strides }
}
pub fn transposed(&self) -> Self {
let dims = std::array::from_fn(|i| N - i - 1);
self.permuted(dims)
}
pub fn slice<const M: usize>(
&self,
range: &[SliceItem],
) -> Result<(Range<usize>, NdLayout<M>), SliceError> {
if self.ndim() < range.len() {
return Err(SliceError::TooManyDims {
ndim: self.ndim(),
range_ndim: range.len(),
});
}
let mut shape: [usize; M] = [0; M];
let mut strides: [usize; M] = [0; M];
let (ndim, offset) =
slice_layout(&self.shape, &self.strides, &mut shape, &mut strides, range)?;
if ndim != M {
return Err(SliceError::OutputDimsMismatch {
actual: ndim,
expected: M,
});
}
let layout = NdLayout { shape, strides };
Ok((offset..offset + layout.min_data_len(), layout))
}
pub fn resize_dim(&mut self, dim: usize, new_size: usize) {
self.shape[dim] = new_size;
}
}
impl<'a, const N: usize> TryFrom<&'a DynLayout> for NdLayout<N> {
type Error = DimensionError;
fn try_from(value: &'a DynLayout) -> Result<NdLayout<N>, DimensionError> {
let shape: [usize; N] = value.shape().try_into().map_err(|_| DimensionError {})?;
let strides: [usize; N] = value.strides().try_into().map_err(|_| DimensionError {})?;
Ok(NdLayout { shape, strides })
}
}
#[derive(Debug, PartialEq)]
pub struct DynLayout {
shape_and_strides: SmallVec<[usize; 8]>,
}
impl Clone for DynLayout {
fn clone(&self) -> DynLayout {
DynLayout {
shape_and_strides: SmallVec::from_slice(self.shape_and_strides.as_slice()),
}
}
}
impl Layout for DynLayout {
type Index<'a> = &'a [usize];
type Indices = DynIndices;
fn len(&self) -> usize {
self.shape().iter().product()
}
#[inline]
fn try_offset(&self, index: Self::Index<'_>) -> Option<usize> {
let shape = self.shape();
let strides = self.strides();
let mut valid = index.as_ref().len() == shape.len();
let mut offset = 0;
for (idx, (size, stride)) in index.as_ref().iter().zip(shape.iter().zip(strides.iter())) {
valid = valid && idx < size;
offset += idx * stride;
}
valid.then_some(offset)
}
fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
fn ndim(&self) -> usize {
self.shape_and_strides.len() / 2
}
#[inline]
fn shape(&self) -> &[usize] {
&self.shape_and_strides[0..self.ndim()]
}
#[inline]
fn size(&self, dim: usize) -> usize {
self.shape_and_strides[dim]
}
#[inline]
fn strides(&self) -> &[usize] {
&self.shape_and_strides[self.ndim()..]
}
#[inline]
fn stride(&self, dim: usize) -> usize {
self.shape_and_strides[self.ndim() + dim]
}
fn indices(&self) -> DynIndices {
DynIndices::from_shape(self.shape())
}
}
impl DynLayout {
pub fn from_shape(shape: &[usize]) -> DynLayout {
DynLayout {
shape_and_strides: Self::contiguous_shape_and_strides(shape),
}
}
pub fn try_from_shape_and_strides(
shape: &[usize],
strides: &[usize],
overlap: OverlapPolicy,
) -> Result<DynLayout, FromDataError> {
let mut shape_and_strides = SmallVec::with_capacity(shape.len() + strides.len());
shape_and_strides.extend_from_slice(shape);
shape_and_strides.extend_from_slice(strides);
let layout = DynLayout { shape_and_strides };
match overlap {
OverlapPolicy::DisallowOverlap => {
if may_have_internal_overlap(layout.shape(), layout.strides()) {
return Err(FromDataError::MayOverlap);
}
}
OverlapPolicy::AllowOverlap => {}
}
Ok(layout)
}
pub fn from_layout<L: Layout>(layout: &L) -> DynLayout {
DynLayout::try_from_shape_and_strides(
layout.shape().as_ref(),
layout.strides().as_ref(),
OverlapPolicy::AllowOverlap,
)
.expect("invalid layout")
}
pub fn broadcast(&self, to_shape: &[usize]) -> DynLayout {
assert!(
self.can_broadcast_to(to_shape),
"Cannot broadcast to specified shape"
);
let mut shape_and_strides = SmallVec::with_capacity(to_shape.len() * 2);
shape_and_strides.extend(to_shape.iter().copied());
shape_and_strides.extend(broadcast_strides(self.shape(), self.strides(), to_shape));
DynLayout { shape_and_strides }
}
pub fn move_axis(&mut self, from: usize, to: usize) {
let ndim = self.ndim();
assert!(from < ndim && to < ndim);
let size = self.shape_and_strides.remove(from);
let stride = self.shape_and_strides.remove(ndim - 1 + from);
self.shape_and_strides.insert(to, size);
self.shape_and_strides.insert(ndim + to, stride);
}
pub fn slice(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError> {
if self.ndim() < range.len() {
return Err(SliceError::TooManyDims {
ndim: self.ndim(),
range_ndim: range.len(),
});
}
let out_dims = self.ndim()
- range
.iter()
.filter(|item| matches!(item, SliceItem::Index(_)))
.count();
let mut shape_and_strides = smallvec![0; out_dims * 2];
let (out_shape, out_strides) = shape_and_strides.as_mut_slice().split_at_mut(out_dims);
let (_ndim, offset) =
slice_layout(self.shape(), self.strides(), out_shape, out_strides, range)?;
let layout = Self { shape_and_strides };
Ok((offset..offset + layout.min_data_len(), layout))
}
pub fn resize_dim(&mut self, dim: usize, new_size: usize) {
self.shape_and_strides[dim] = new_size;
}
pub fn make_contiguous(&mut self) {
self.shape_and_strides = Self::contiguous_shape_and_strides(self.shape());
}
fn permute_iter<I: Clone + Iterator<Item = usize>>(&mut self, dims: I) {
let strides = self.strides();
let shape = self.shape();
let shape_iter = dims.clone().map(|dim| shape[dim]);
let stride_iter = dims.map(|dim| strides[dim]);
self.shape_and_strides = shape_iter.chain(stride_iter).collect();
}
pub fn permute(&mut self, dims: &[usize]) {
assert!(
is_valid_permutation(self.ndim(), dims),
"permutation is invalid"
);
self.permute_iter(dims.iter().copied());
}
pub fn permuted(&self, dims: &[usize]) -> DynLayout {
let mut permuted = self.clone();
permuted.permute(dims);
permuted
}
pub fn transpose(&mut self) {
self.permute_iter((0..self.ndim()).rev());
}
pub fn transposed(&self) -> DynLayout {
let mut transposed = self.clone();
transposed.transpose();
transposed
}
pub fn insert_dim(&mut self, dim: usize) {
let ndim = self.ndim();
let new_size = 1;
let (max_stride, size_for_max_stride) = self
.strides()
.iter()
.copied()
.zip(self.shape().iter().copied())
.max_by_key(|(stride, _size)| *stride)
.unwrap_or((1, 1));
let new_stride = max_stride * size_for_max_stride;
self.shape_and_strides.insert(dim, new_size);
self.shape_and_strides.insert(ndim + 1 + dim, new_stride);
}
pub fn slice_offset<Idx: AsRef<[usize]>>(&self, index: Idx) -> usize {
let index = index.as_ref();
assert!(index.len() <= self.ndim());
let shape = self.shape();
let mut offset = 0;
for i in 0..index.len() {
assert!(
index[i] < shape[i],
"Invalid index {} for dim {}",
index[i],
i
);
offset += index[i] * self.stride(i)
}
offset
}
pub fn squeezed(&self) -> DynLayout {
let shape = self.shape().iter().copied().filter(|&size| size != 1);
let strides = self
.shape()
.iter()
.zip(self.strides())
.filter_map(|(&size, &stride)| if size != 1 { Some(stride) } else { None });
DynLayout {
shape_and_strides: shape.chain(strides).collect(),
}
}
pub fn dims<const N: usize>(&self) -> [usize; N] {
assert!(
self.ndim() == N,
"Cannot extract {} dim tensor as {} dim array",
self.ndim(),
N
);
self.shape().try_into().unwrap()
}
fn contiguous_shape_and_strides(shape: &[usize]) -> SmallVec<[usize; 8]> {
let mut strides_and_shape: SmallVec<[usize; 8]> = SmallVec::from_slice(shape);
strides_and_shape.resize(shape.len() * 2, 0);
let mut stride = 1;
for i in (0..shape.len()).rev() {
strides_and_shape[shape.len() + i] = stride;
stride *= shape[i];
}
strides_and_shape
}
}
impl<const N: usize> From<&NdLayout<N>> for DynLayout {
fn from(value: &NdLayout<N>) -> DynLayout {
DynLayout::from_layout(value)
}
}
impl<const N: usize> From<NdLayout<N>> for DynLayout {
fn from(value: NdLayout<N>) -> DynLayout {
DynLayout::from_layout(&value)
}
}
pub trait MutLayout: Layout + Clone {
fn from_shape(shape: Self::Index<'_>) -> Self;
fn from_shape_and_strides(
shape: Self::Index<'_>,
strides: Self::Index<'_>,
overlap: OverlapPolicy,
) -> Result<Self, FromDataError>;
fn index_axis(&self, axis: usize, index: usize) -> (Range<usize>, <Self as RemoveDim>::Output)
where
Self: RemoveDim,
{
assert!(axis < self.ndim());
assert!(index < self.size(axis));
let layout = self.remove_dim(axis);
let start_offset = self.stride(axis) * index;
(start_offset..start_offset + layout.min_data_len(), layout)
}
fn move_axis(&mut self, from: usize, to: usize);
fn permuted(&self, order: Self::Index<'_>) -> Self;
fn reshaped_for_view<S: IntoLayout>(&self, shape: S) -> Result<S::Layout, ReshapeError> {
if !self.is_contiguous() {
return Err(ReshapeError::NotContiguous);
}
self.reshaped_for_copy(shape)
}
fn reshaped_for_copy<S: IntoLayout>(&self, shape: S) -> Result<S::Layout, ReshapeError> {
let layout = shape.into_layout();
if layout.len() != self.len() {
return Err(ReshapeError::LengthMismatch);
}
Ok(layout)
}
fn resize_dim(&mut self, dim: usize, size: usize);
fn transposed(&self) -> Self;
fn slice<const M: usize>(
&self,
range: &[SliceItem],
) -> Result<(Range<usize>, NdLayout<M>), SliceError>;
fn slice_dyn(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError>;
fn slice_axis(&self, axis: usize, range: Range<usize>) -> (Range<usize>, Self) {
assert!(range.end >= range.start);
let mut sliced_layout = self.clone();
sliced_layout.resize_dim(axis, range.len());
let range = if sliced_layout.is_empty() {
0..0
} else {
let start_offset = range.start * sliced_layout.stride(axis);
let end_offset = start_offset + sliced_layout.min_data_len();
start_offset..end_offset
};
(range, sliced_layout)
}
fn squeezed(&self) -> DynLayout;
fn split(&self, axis: usize, mid: usize) -> ((Range<usize>, Self), (Range<usize>, Self));
}
pub trait BroadcastLayout<L: MutLayout> {
fn broadcast<S: IntoLayout<Layout = L>>(&self, shape: S) -> L;
}
impl<const N: usize, const M: usize> BroadcastLayout<NdLayout<M>> for NdLayout<N> {
fn broadcast<S: IntoLayout<Layout = NdLayout<M>>>(&self, shape: S) -> NdLayout<M> {
let shape: [usize; M] = shape.as_ref().try_into().unwrap();
self.broadcast(shape)
}
}
impl<const N: usize> BroadcastLayout<DynLayout> for NdLayout<N> {
fn broadcast<S: IntoLayout<Layout = DynLayout>>(&self, shape: S) -> DynLayout {
let dyn_layout: DynLayout = self.into();
dyn_layout.broadcast(shape.as_ref())
}
}
impl BroadcastLayout<DynLayout> for DynLayout {
fn broadcast<S: IntoLayout<Layout = DynLayout>>(&self, shape: S) -> DynLayout {
self.broadcast(shape.as_ref())
}
}
impl<const N: usize> BroadcastLayout<NdLayout<N>> for DynLayout {
fn broadcast<S: IntoLayout<Layout = NdLayout<N>>>(&self, shape: S) -> NdLayout<N> {
let dyn_broadcast = self.broadcast(shape.as_ref());
(&dyn_broadcast).try_into().unwrap()
}
}
impl<const N: usize> MutLayout for NdLayout<N> {
fn from_shape(shape: [usize; N]) -> Self {
Self::from_shape(shape)
}
fn from_shape_and_strides(
shape: Self::Index<'_>,
strides: Self::Index<'_>,
overlap: OverlapPolicy,
) -> Result<Self, FromDataError> {
Self::try_from_shape_and_strides(shape, strides, overlap)
}
fn move_axis(&mut self, from: usize, to: usize) {
assert!(from < N && to < N);
let mut dyn_layout = self.as_dyn();
dyn_layout.move_axis(from, to);
*self = NdLayout::try_from(&dyn_layout).unwrap();
}
fn permuted(&self, order: [usize; N]) -> NdLayout<N> {
self.permuted(order)
}
fn resize_dim(&mut self, dim: usize, size: usize) {
self.resize_dim(dim, size)
}
fn transposed(&self) -> NdLayout<N> {
self.transposed()
}
fn slice<const M: usize>(
&self,
range: &[SliceItem],
) -> Result<(Range<usize>, NdLayout<M>), SliceError> {
self.slice(range)
}
fn slice_dyn(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError> {
self.as_dyn().slice(range)
}
fn squeezed(&self) -> DynLayout {
self.as_dyn().squeezed()
}
fn split(&self, axis: usize, mid: usize) -> ((Range<usize>, Self), (Range<usize>, Self)) {
assert!(axis < self.ndim());
assert!(mid <= self.size(axis));
let left_shape = std::array::from_fn(|i| if i == axis { mid } else { self.shape[i] });
let right_shape = std::array::from_fn(|i| {
if i == axis {
self.size(axis) - mid
} else {
self.shape[i]
}
});
let left = NdLayout {
shape: left_shape,
strides: self.strides,
};
let right = NdLayout {
shape: right_shape,
strides: self.strides,
};
let mid_offset = mid * self.strides[axis];
let left_offsets = 0..left.min_data_len();
let end_offset = self.min_data_len();
let right_offsets = if right.is_empty() {
end_offset..end_offset
} else {
mid_offset..end_offset
};
((left_offsets, left), (right_offsets, right))
}
}
impl MutLayout for DynLayout {
fn from_shape(shape: &[usize]) -> Self {
Self::from_shape(shape)
}
fn from_shape_and_strides(
shape: &[usize],
strides: &[usize],
overlap: OverlapPolicy,
) -> Result<Self, FromDataError> {
Self::try_from_shape_and_strides(shape, strides, overlap)
}
fn move_axis(&mut self, from: usize, to: usize) {
self.move_axis(from, to)
}
fn permuted(&self, order: &[usize]) -> DynLayout {
self.permuted(order)
}
fn resize_dim(&mut self, dim: usize, size: usize) {
self.resize_dim(dim, size)
}
fn transposed(&self) -> DynLayout {
self.transposed()
}
fn slice<const M: usize>(
&self,
range: &[SliceItem],
) -> Result<(Range<usize>, NdLayout<M>), SliceError> {
let (offset_range, dyn_layout) = self.slice(range)?;
let nd_layout =
NdLayout::try_from(&dyn_layout).map_err(|_| SliceError::OutputDimsMismatch {
actual: dyn_layout.ndim(),
expected: M,
})?;
Ok((offset_range, nd_layout))
}
fn slice_dyn(&self, range: &[SliceItem]) -> Result<(Range<usize>, DynLayout), SliceError> {
self.slice(range)
}
fn squeezed(&self) -> DynLayout {
self.squeezed()
}
fn split(&self, axis: usize, mid: usize) -> ((Range<usize>, Self), (Range<usize>, Self)) {
assert!(axis < self.ndim());
assert!(mid <= self.size(axis));
let mut left_shape_strides: SmallVec<[usize; 8]> = (0..self.ndim())
.map(|i| if i == axis { mid } else { self.size(i) })
.collect();
left_shape_strides.extend(self.strides().iter().copied());
let mut right_shape_strides: SmallVec<[usize; 8]> = (0..self.ndim())
.map(|i| {
if i == axis {
self.size(axis) - mid
} else {
self.size(i)
}
})
.collect();
right_shape_strides.extend(self.strides().iter().copied());
let left = DynLayout {
shape_and_strides: left_shape_strides,
};
let right = DynLayout {
shape_and_strides: right_shape_strides,
};
let mid_offset = mid * self.stride(axis);
let left_offsets = 0..left.min_data_len();
let end_offset = self.min_data_len();
let right_offsets = if right.is_empty() {
end_offset..end_offset
} else {
mid_offset..end_offset
};
((left_offsets, left), (right_offsets, right))
}
}
pub trait IntoLayout: AsRef<[usize]> {
type Layout: MutLayout;
fn into_layout(self) -> Self::Layout;
}
impl<const N: usize> IntoLayout for [usize; N] {
type Layout = NdLayout<N>;
#[inline]
fn into_layout(self) -> NdLayout<N> {
NdLayout::from_shape(self)
}
}
impl<'a> IntoLayout for &'a [usize] {
type Layout = DynLayout;
#[inline]
fn into_layout(self) -> DynLayout {
DynLayout::from_shape(self)
}
}
pub trait ResizeLayout: MutLayout {
fn insert_axis(&mut self, index: usize);
fn remove_axis(&mut self, index: usize);
fn merge_axes(&mut self);
}
impl ResizeLayout for DynLayout {
fn insert_axis(&mut self, index: usize) {
self.insert_dim(index)
}
fn remove_axis(&mut self, index: usize) {
assert!(self.size(index) == 1);
self.shape_and_strides.remove(index);
self.shape_and_strides.remove(self.ndim() + index);
}
fn merge_axes(&mut self) {
if self.ndim() == 0 {
return;
}
let mut shape = SmallVec::<[usize; 4]>::new();
let mut strides = SmallVec::<[usize; 4]>::new();
shape.push(self.size(self.ndim() - 1));
strides.push(self.stride(self.ndim() - 1));
for (&outer_size, &outer_stride) in
self.shape().iter().zip(self.strides().iter()).rev().skip(1)
{
let inner_stride = strides.last().unwrap();
let inner_size = shape.last().unwrap();
let can_merge = outer_size == 1 || (outer_stride == inner_stride * inner_size);
if can_merge {
let prev_size = shape.last_mut().unwrap();
*prev_size *= outer_size;
} else {
shape.push(outer_size);
strides.push(outer_stride);
}
}
shape.reverse();
strides.reverse();
self.shape_and_strides = shape.iter().chain(strides.iter()).copied().collect();
}
}
pub trait AsIndex<L: Layout> {
fn as_index(&self) -> L::Index<'_>;
}
impl<T: AsRef<[usize]>> AsIndex<DynLayout> for T {
fn as_index(&self) -> &[usize] {
self.as_ref()
}
}
impl<const N: usize> AsIndex<NdLayout<N>> for [usize; N] {
fn as_index(&self) -> [usize; N] {
*self
}
}
pub trait RemoveDim {
type Output: MutLayout;
fn remove_dim(&self, dim: usize) -> Self::Output;
}
impl RemoveDim for DynLayout {
type Output = DynLayout;
fn remove_dim(&self, dim: usize) -> DynLayout {
let ndim = self.ndim();
assert!(ndim > 0, "cannot remove axis from tensor with 0 dims");
let shape = (0..ndim - 1).map(|i| {
if i < dim {
self.size(i)
} else {
self.size(i + 1)
}
});
let strides = (0..ndim - 1).map(|i| {
if i < dim {
self.stride(i)
} else {
self.stride(i + 1)
}
});
DynLayout {
shape_and_strides: shape.chain(strides).collect(),
}
}
}
macro_rules! impl_remove_dim {
($in_dims:expr, $out_dims:expr) => {
impl RemoveDim for NdLayout<$in_dims> {
type Output = NdLayout<$out_dims>;
fn remove_dim(&self, dim: usize) -> Self::Output {
let shape = std::array::from_fn(|i| {
if i < dim {
self.shape[i]
} else {
self.shape[i + 1]
}
});
let strides = std::array::from_fn(|i| {
if i < dim {
self.strides[i]
} else {
self.strides[i + 1]
}
});
NdLayout { shape, strides }
}
}
};
}
impl_remove_dim!(1, 0);
impl_remove_dim!(2, 1);
impl_remove_dim!(3, 2);
impl_remove_dim!(4, 3);
impl_remove_dim!(5, 4);
pub trait SliceWith<R: IntoSliceItems, IdxCount: OptionalUInt> {
type Layout: MutLayout;
fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError>;
}
impl<R: IntoSliceItems, L: MutLayout> SliceWith<R, Unknown> for L {
type Layout = DynLayout;
fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
self.slice_dyn(range.into_slice_items().as_ref())
}
}
impl<R: IntoSliceItems, const N: usize> SliceWith<R, U0> for NdLayout<N> {
type Layout = NdLayout<N>;
fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
self.slice(range.into_slice_items().as_ref())
}
}
macro_rules! impl_slice_with_dynlayout {
($range_ndim:ty) => {
impl<R: IntoSliceItems> SliceWith<R, $range_ndim> for DynLayout {
type Layout = DynLayout;
fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
self.slice_dyn(range.into_slice_items().as_ref())
}
}
};
}
impl_slice_with_dynlayout!(U0);
impl_slice_with_dynlayout!(U1);
impl_slice_with_dynlayout!(U2);
impl_slice_with_dynlayout!(U3);
impl_slice_with_dynlayout!(U4);
impl_slice_with_dynlayout!(U5);
macro_rules! impl_slice_with {
($ndim:literal, $range_ndim:ty, $out_ndim:literal) => {
impl<R: IntoSliceItems> SliceWith<R, $range_ndim> for NdLayout<$ndim> {
type Layout = NdLayout<$out_ndim>;
fn slice_with(&self, range: R) -> Result<(Range<usize>, Self::Layout), SliceError> {
self.slice(range.into_slice_items().as_ref())
}
}
};
}
impl_slice_with!(1, U1, 0);
impl_slice_with!(2, U1, 1);
impl_slice_with!(2, U2, 0);
impl_slice_with!(3, U1, 2);
impl_slice_with!(3, U2, 1);
impl_slice_with!(3, U3, 0);
impl_slice_with!(4, U1, 3);
impl_slice_with!(4, U2, 2);
impl_slice_with!(4, U3, 1);
impl_slice_with!(4, U4, 0);
impl_slice_with!(5, U1, 4);
impl_slice_with!(5, U2, 3);
impl_slice_with!(5, U3, 2);
impl_slice_with!(5, U4, 1);
impl_slice_with!(5, U5, 0);
#[cfg(test)]
mod tests {
use std::ops::Range;
use super::OverlapPolicy;
use crate::errors::{ReshapeError, SliceError};
use crate::layout::{DynLayout, Layout, MutLayout, NdLayout, ResizeLayout};
use crate::SliceItem;
fn layout_with_strides<const N: usize>(shape: [usize; N], strides: [usize; N]) -> NdLayout<N> {
NdLayout::try_from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap).unwrap()
}
#[test]
fn test_is_broadcast() {
let layout = DynLayout::from_shape(&[5, 5]);
assert!(!layout.is_broadcast());
let layout = DynLayout::from_shape(&[5, 0]);
assert!(!layout.is_broadcast());
let layout =
DynLayout::try_from_shape_and_strides(&[5, 5], &[0, 0], OverlapPolicy::AllowOverlap)
.unwrap();
assert!(layout.is_broadcast());
}
#[test]
fn test_try_from_shape_and_strides() {
struct Case<'a> {
shape: &'a [usize],
strides: &'a [usize],
}
let cases = [
Case {
shape: &[10, 10],
strides: &[10, 1],
},
Case {
shape: &[10, 10],
strides: &[10, 0],
},
];
for case in cases {
let layout = DynLayout::try_from_shape_and_strides(
case.shape,
case.strides,
OverlapPolicy::AllowOverlap,
)
.unwrap();
assert_eq!(layout.shape(), case.shape);
assert_eq!(layout.strides(), case.strides);
}
}
#[test]
fn test_index_axis() {
struct Case {
layout: NdLayout<2>,
axis: usize,
index: usize,
expected: (usize, NdLayout<1>), }
let cases = [
Case {
layout: NdLayout::from_shape([3, 4]),
axis: 0,
index: 1,
expected: (4, layout_with_strides([4], [1])),
},
Case {
layout: NdLayout::from_shape([3, 4]),
axis: 1,
index: 2,
expected: (2, layout_with_strides([3], [4])),
},
];
for Case {
layout,
axis,
index,
expected,
} in cases
{
let (expected_start, expected_layout) = expected;
let (offsets, sliced_layout) = layout.index_axis(axis, index);
assert_eq!(sliced_layout, expected_layout);
assert_eq!(offsets.start, expected_start);
assert_eq!(offsets.len(), expected_layout.min_data_len());
let (_, sliced_layout_dyn) = layout.as_dyn().index_axis(axis, index);
assert_eq!(sliced_layout_dyn, expected_layout.as_dyn());
}
}
#[test]
#[should_panic(expected = "axis < self.ndim()")]
fn test_index_axis_invalid_axis() {
NdLayout::from_shape([2, 3]).index_axis(2, 0);
}
#[test]
#[should_panic(expected = "index < self.size(axis)")]
fn test_index_axis_invalid_index() {
NdLayout::from_shape([2, 3]).index_axis(0, 3);
}
#[test]
fn test_move_axis() {
let mut layout = DynLayout::from_shape(&[2, 4, 8]);
assert_eq!(layout.strides(), [32, 8, 1]);
layout.move_axis(1, 0);
assert_eq!(layout.shape(), [4, 2, 8]);
assert_eq!(layout.strides(), [8, 32, 1]);
layout.move_axis(0, 1);
assert_eq!(layout.shape(), [2, 4, 8]);
assert_eq!(layout.strides(), [32, 8, 1]);
layout.move_axis(2, 1);
assert_eq!(layout.shape(), [2, 8, 4]);
assert_eq!(layout.strides(), [32, 1, 8]);
}
#[test]
#[should_panic]
fn test_move_axis_invalid_from() {
let mut layout = DynLayout::from_shape(&[2, 4, 8]);
layout.move_axis(3, 0);
}
#[test]
#[should_panic]
fn test_move_axis_invalid_to() {
let mut layout = DynLayout::from_shape(&[2, 4, 8]);
layout.move_axis(0, 3);
}
#[test]
#[should_panic(expected = "permutation is invalid")]
fn test_permute_invalid_len() {
let mut layout = DynLayout::from_shape(&[5, 5]);
layout.permute(&[1, 0, 3]);
}
#[test]
#[should_panic(expected = "permutation is invalid")]
fn test_permute_too_few_dims() {
let mut layout = DynLayout::from_shape(&[5, 5]);
layout.permute(&[1]);
}
#[test]
#[should_panic(expected = "permutation is invalid")]
fn test_permute_repeated_dims() {
let mut layout = DynLayout::from_shape(&[5, 5]);
layout.permute(&[1, 1]);
}
#[test]
fn test_reshaped() {
struct Case<'a> {
layout: DynLayout,
new_shape: &'a [usize],
for_copy: bool,
error: Option<ReshapeError>,
}
let cases = [
Case {
layout: DynLayout::from_shape(&[2, 2]),
new_shape: &[4],
for_copy: false,
error: None,
},
Case {
layout: DynLayout::from_shape(&[2, 2]).transposed(),
new_shape: &[4],
for_copy: false,
error: Some(ReshapeError::NotContiguous),
},
Case {
layout: DynLayout::from_shape(&[2, 2]),
new_shape: &[3],
for_copy: false,
error: Some(ReshapeError::LengthMismatch),
},
Case {
layout: DynLayout::from_shape(&[2, 2]).transposed(),
new_shape: &[4],
for_copy: true,
error: None,
},
Case {
layout: DynLayout::from_shape(&[2, 2]),
new_shape: &[3],
for_copy: false,
error: Some(ReshapeError::LengthMismatch),
},
];
for Case {
layout,
new_shape,
for_copy,
error,
} in cases
{
let reshaped = if for_copy {
layout.reshaped_for_copy(new_shape)
} else {
layout.reshaped_for_view(new_shape)
};
assert_eq!(reshaped.as_ref().err(), error.as_ref());
if let Ok(new_layout) = reshaped {
assert_eq!(new_layout.shape(), new_shape);
}
}
}
#[test]
fn test_squeezed() {
let layout = DynLayout::from_shape(&[1, 1, 10, 20]);
let squeezed = layout.squeezed();
assert_eq!(squeezed.shape(), &[10, 20]);
assert_eq!(squeezed.strides(), &[20, 1]);
}
#[test]
fn test_slice_axis() {
struct Case<'a> {
shape: &'a [usize],
axis: usize,
range: Range<usize>,
sliced_shape: &'a [usize],
offsets: Range<usize>,
}
let cases = [Case {
shape: &[3, 5],
axis: 1,
range: 2..4,
sliced_shape: &[3, 2],
offsets: 2..14,
}];
for Case {
shape,
axis,
range,
sliced_shape,
offsets,
} in cases
{
let layout = DynLayout::from_shape(shape);
let (offset_range, sliced_layout) = layout.slice_axis(axis, range);
assert_eq!(sliced_layout.shape(), sliced_shape);
assert_eq!(sliced_layout.strides(), layout.strides());
assert_eq!(offset_range, offsets);
}
}
#[test]
fn test_slice_invalid() {
struct Case<'a> {
layout: DynLayout,
ranges: &'a [SliceItem],
expected: SliceError,
}
let cases = [
Case {
layout: DynLayout::from_shape(&[3, 5]),
ranges: &[SliceItem::Index(4), SliceItem::Index(0)],
expected: SliceError::InvalidIndex {
axis: 0,
index: 4,
size: 3,
},
},
Case {
layout: DynLayout::from_shape(&[3, 5]),
ranges: &[SliceItem::Range((1..4).into()), SliceItem::Index(0)],
expected: SliceError::InvalidRange {
axis: 0,
range: (1..4).into(),
size: 3,
},
},
Case {
layout: DynLayout::from_shape(&[3, 5]),
ranges: &[SliceItem::Index(-4)],
expected: SliceError::InvalidIndex {
axis: 0,
index: -4,
size: 3,
},
},
Case {
layout: DynLayout::from_shape(&[3, 5]),
ranges: &[SliceItem::Range((4..).into()), SliceItem::Index(0)],
expected: SliceError::InvalidRange {
axis: 0,
range: (4..).into(),
size: 3,
},
},
Case {
layout: DynLayout::from_shape(&[3, 5]),
ranges: &[SliceItem::full_range(), SliceItem::range(0, None, -1)],
expected: SliceError::InvalidStep { axis: 1, step: -1 },
},
];
for Case {
layout,
ranges,
expected,
} in cases
{
let result = layout.slice(ranges);
assert_eq!(result, Err(expected));
}
}
#[test]
fn test_size_stride() {
let layout = DynLayout::from_shape(&[10, 20, 30]);
for (dim, (&size, &stride)) in layout.shape().iter().zip(layout.strides()).enumerate() {
assert_eq!(layout.size(dim), size);
assert_eq!(layout.stride(dim), stride);
}
}
#[test]
fn test_split() {
struct Case {
shape: [usize; 2],
strides: Option<[usize; 2]>,
axis: usize,
mid: usize,
}
let mut cases = Vec::new();
let shape = [4, 2];
for axis in 0..shape.len() {
for mid in 0..shape[axis] {
cases.push(Case {
shape,
axis,
mid,
strides: None,
});
}
}
cases.push(Case {
shape: [0, 0],
strides: None,
axis: 0,
mid: 0,
});
cases.push(Case {
shape: [1, 4],
strides: Some([10, 0]),
axis: 0,
mid: 1,
});
fn check_split<L: MutLayout>(layout: L, axis: usize, mid: usize) {
let (left, right) = layout.split(axis, mid);
let (left_offsets, left_layout) = left;
let (right_offsets, right_layout) = right;
assert_eq!(left_layout.strides(), layout.strides());
assert_eq!(right_layout.strides(), layout.strides());
assert_eq!(left_offsets.len(), left_layout.min_data_len());
assert_eq!(right_offsets.len(), right_layout.min_data_len());
let orig_len = layout.min_data_len();
assert!(left_offsets.start <= orig_len && left_offsets.end <= orig_len);
assert!(right_offsets.start <= orig_len && right_offsets.end <= orig_len);
for i in 0..layout.ndim() {
assert_eq!(
left_layout.size(i),
if i == axis { mid } else { layout.size(i) }
);
assert_eq!(
right_layout.size(i),
if i == axis {
layout.size(i) - mid
} else {
layout.size(i)
}
);
}
}
for Case {
shape,
strides,
axis,
mid,
} in cases
{
let layout = if let Some(strides) = strides {
NdLayout::try_from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)
.unwrap()
} else {
NdLayout::from_shape(shape)
};
let dyn_layout = if let Some(strides) = strides {
DynLayout::try_from_shape_and_strides(
shape.as_slice(),
strides.as_slice(),
OverlapPolicy::AllowOverlap,
)
.unwrap()
} else {
DynLayout::from_shape(shape.as_slice())
};
check_split(layout, axis, mid);
check_split(dyn_layout, axis, mid);
}
}
#[test]
fn test_merge_axes() {
struct Case<'a> {
shape: &'a [usize],
strides: &'a [usize],
merged_shape: &'a [usize],
merged_strides: &'a [usize],
}
let cases = [
Case {
shape: &[],
strides: &[],
merged_shape: &[],
merged_strides: &[],
},
Case {
shape: &[10],
strides: &[2],
merged_shape: &[10],
merged_strides: &[2],
},
Case {
shape: &[10, 10],
strides: &[10, 1],
merged_shape: &[100],
merged_strides: &[1],
},
Case {
shape: &[10, 10],
strides: &[1, 10],
merged_shape: &[10, 10],
merged_strides: &[1, 10],
},
Case {
shape: &[1, 10, 10],
strides: &[10, 1, 10],
merged_shape: &[10, 10],
merged_strides: &[1, 10],
},
Case {
shape: &[2, 1, 1, 2],
strides: &[2, 2, 2, 1],
merged_shape: &[4],
merged_strides: &[1],
},
Case {
shape: &[2, 1, 1, 2],
strides: &[2, 4, 4, 1],
merged_shape: &[4],
merged_strides: &[1],
},
];
for Case {
shape,
strides,
merged_shape,
merged_strides,
} in cases
{
let mut layout =
DynLayout::try_from_shape_and_strides(shape, strides, OverlapPolicy::AllowOverlap)
.unwrap();
layout.merge_axes();
assert_eq!(layout.shape(), merged_shape);
assert_eq!(layout.strides(), merged_strides);
}
}
}