use core::fmt::Debug;
use core::hash::Hash;
use crate::dim::Dims;
use crate::layout::{Dense, Layout, Strided};
use crate::shape::{DynRank, Shape};
pub trait Mapping: Clone + Debug + Default + Eq + Hash + Send + Sync {
type Shape: Shape;
type Layout: Layout<Mapping<Self::Shape> = Self>;
fn is_contiguous(&self) -> bool;
fn shape(&self) -> &Self::Shape;
fn stride(&self, index: usize) -> isize;
#[inline]
fn dim(&self, index: usize) -> usize {
self.shape().dim(index)
}
#[inline]
fn dims(&self) -> &[usize]
where
Self: Mapping<Shape = DynRank>,
{
self.shape().dims()
}
#[inline]
fn is_empty(&self) -> bool {
self.shape().is_empty()
}
#[inline]
fn len(&self) -> usize {
self.shape().len()
}
#[inline]
fn rank(&self) -> usize {
self.shape().rank()
}
#[doc(hidden)]
fn for_each_stride<F: FnMut(usize, isize)>(&self, f: F);
#[doc(hidden)]
fn inner_stride(&self) -> isize;
#[doc(hidden)]
fn linear_offset(&self, index: usize) -> isize;
#[doc(hidden)]
fn permute<M: Mapping>(mapping: &M, perm: &[usize]) -> Self;
#[doc(hidden)]
fn prepend_dim<M: Mapping>(mapping: &M, size: usize, stride: isize) -> Self;
#[doc(hidden)]
fn remap<M: Mapping>(mapping: &M) -> Self;
#[doc(hidden)]
fn remove_dim<M: Mapping>(mapping: &M, index: usize) -> Self;
#[doc(hidden)]
fn reshape<S: Shape>(&self, new_shape: S) -> <Self::Layout as Layout>::Mapping<S>;
#[doc(hidden)]
fn resize_dim<M: Mapping>(mapping: &M, index: usize, new_size: usize) -> Self;
#[doc(hidden)]
fn shape_mut(&mut self) -> &mut Self::Shape;
#[doc(hidden)]
fn transpose<M: Mapping<Shape: Shape<Reverse = Self::Shape>>>(mapping: &M) -> Self;
#[doc(hidden)]
#[inline]
fn offset(&self, index: &[usize]) -> isize {
debug_assert!(index.len() == self.rank(), "invalid rank");
let mut offset = 0;
self.for_each_stride(|i, stride| {
debug_assert!(index[i] < self.dim(i), "index out of bounds");
offset += stride * index[i] as isize;
});
offset
}
}
#[derive(Debug, Default, Eq, Hash, PartialEq)]
pub struct DenseMapping<S: Shape> {
shape: S,
}
#[derive(Debug, Eq, Hash, PartialEq)]
pub struct StridedMapping<S: Shape> {
shape: S,
strides: S::Dims<isize>,
}
impl<S: Shape> DenseMapping<S> {
#[inline]
pub fn new(shape: S) -> Self {
Self { shape }
}
}
impl<S: Shape> Clone for DenseMapping<S> {
#[inline]
fn clone(&self) -> Self {
Self::new(self.shape.clone())
}
#[inline]
fn clone_from(&mut self, source: &Self) {
self.shape.clone_from(&source.shape);
}
}
impl<S: Shape + Copy> Copy for DenseMapping<S> {}
impl<S: Shape> Mapping for DenseMapping<S> {
type Shape = S;
type Layout = Dense;
#[inline]
fn is_contiguous(&self) -> bool {
true
}
#[inline]
fn shape(&self) -> &S {
&self.shape
}
#[inline]
fn stride(&self, index: usize) -> isize {
assert!(index < self.rank(), "invalid dimension");
let mut stride = 1;
for i in index + 1..self.rank() {
stride *= self.dim(i);
}
stride as isize
}
#[inline]
fn for_each_stride<F: FnMut(usize, isize)>(&self, mut f: F) {
let mut stride = 1;
for i in (0..self.rank()).rev() {
f(i, stride as isize);
stride *= self.dim(i);
}
}
#[inline]
fn inner_stride(&self) -> isize {
if S::RANK == Some(0) { 0 } else { 1 }
}
#[inline]
fn linear_offset(&self, index: usize) -> isize {
debug_assert!(index < self.len(), "index out of bounds");
index as isize
}
#[inline]
fn permute<M: Mapping>(mapping: &M, perm: &[usize]) -> Self {
assert!(perm.len() == mapping.rank(), "invalid permutation");
for i in 0..mapping.rank() {
assert!(perm[i] == i, "invalid permutation");
}
Self::remap(mapping)
}
#[inline]
fn prepend_dim<M: Mapping>(mapping: &M, size: usize, stride: isize) -> Self {
assert!(M::Layout::IS_DENSE, "invalid layout");
assert!(stride == mapping.len() as isize, "invalid stride");
Self::new(mapping.shape().prepend_dim(size))
}
#[inline]
fn remap<M: Mapping>(mapping: &M) -> Self {
assert!(mapping.is_contiguous(), "mapping not contiguous");
Self::new(mapping.shape().with_dims(Shape::from_dims))
}
#[inline]
fn remove_dim<M: Mapping>(mapping: &M, index: usize) -> Self {
assert!(M::Layout::IS_DENSE, "invalid layout");
assert!(index == 0, "invalid dimension");
Self::new(mapping.shape().remove_dim(index))
}
#[inline]
fn reshape<R: Shape>(&self, new_shape: R) -> DenseMapping<R> {
DenseMapping::new(self.shape.reshape(new_shape))
}
#[inline]
fn resize_dim<M: Mapping>(mapping: &M, index: usize, new_size: usize) -> Self {
assert!(M::Layout::IS_DENSE, "invalid layout");
assert!(index == 0, "invalid dimension");
Self::new(mapping.shape().resize_dim(index, new_size))
}
#[inline]
fn shape_mut(&mut self) -> &mut S {
&mut self.shape
}
#[inline]
fn transpose<M: Mapping<Shape: Shape<Reverse = S>>>(mapping: &M) -> Self {
assert!(mapping.rank() < 2 && M::Layout::IS_DENSE, "invalid layout");
Self::new(mapping.shape().reverse())
}
}
impl<S: Shape> StridedMapping<S> {
#[inline]
pub fn new(shape: S, strides: &[isize]) -> Self {
assert!(shape.rank() == strides.len(), "length mismatch");
Self { shape, strides: TryFrom::try_from(strides).expect("invalid rank") }
}
#[inline]
pub fn strides(&self) -> &[isize] {
self.strides.as_ref()
}
}
impl<S: Shape> Default for StridedMapping<S> {
#[inline]
fn default() -> Self {
Self { shape: S::default(), strides: S::Dims::new(S::default().rank()) }
}
}
impl<S: Shape> Clone for StridedMapping<S> {
#[inline]
fn clone(&self) -> Self {
Self { shape: self.shape.clone(), strides: self.strides.clone() }
}
#[inline]
fn clone_from(&mut self, source: &Self) {
self.shape.clone_from(&source.shape);
self.strides.clone_from(&source.strides);
}
}
impl<S: Shape<Dims<isize>: Copy> + Copy> Copy for StridedMapping<S> {}
impl<S: Shape> Mapping for StridedMapping<S> {
type Shape = S;
type Layout = Strided;
#[inline]
fn is_contiguous(&self) -> bool {
let mut stride = 1;
for i in (0..self.rank()).rev() {
if self.strides.as_ref()[i] != stride {
return false;
}
stride *= self.dim(i) as isize;
}
true
}
#[inline]
fn shape(&self) -> &S {
&self.shape
}
#[inline]
fn stride(&self, index: usize) -> isize {
assert!(index < self.rank(), "invalid dimension");
self.strides.as_ref()[index]
}
#[inline]
fn for_each_stride<F: FnMut(usize, isize)>(&self, mut f: F) {
for i in 0..self.rank() {
f(i, self.strides.as_ref()[i])
}
}
#[inline]
fn inner_stride(&self) -> isize {
if self.rank() > 0 { self.strides.as_ref()[self.rank() - 1] } else { 0 }
}
#[inline]
fn linear_offset(&self, index: usize) -> isize {
debug_assert!(index < self.len(), "index out of bounds");
let mut dividend = index;
let mut offset = 0;
for i in (0..self.rank()).rev() {
offset += self.strides.as_ref()[i] * (dividend % self.dim(i)) as isize;
dividend /= self.dim(i);
}
offset
}
#[inline]
fn permute<M: Mapping>(mapping: &M, perm: &[usize]) -> Self {
assert!(perm.len() == mapping.rank(), "invalid permutation");
let mut index_mask = 0;
for i in 0..mapping.rank() {
assert!(perm[i] < mapping.rank(), "invalid permutation");
index_mask |= 1 << perm[i];
}
assert!(index_mask == !(usize::MAX << mapping.rank()), "invalid permutation");
let mut shape = S::new(mapping.rank());
let mut strides = S::Dims::new(mapping.rank());
shape.with_mut_dims(|dims| {
for i in 0..mapping.rank() {
dims[perm[i]] = i;
}
mapping.for_each_stride(|i, stride| strides.as_mut()[dims[i]] = stride);
for i in 0..mapping.rank() {
dims[i] = mapping.dim(perm[i]);
}
});
Self { shape, strides }
}
#[inline]
fn prepend_dim<M: Mapping>(mapping: &M, size: usize, stride: isize) -> Self {
let mut strides = S::Dims::new(mapping.rank() + 1);
strides.as_mut()[0] = stride;
mapping.for_each_stride(|i, stride| strides.as_mut()[i + 1] = stride);
Self { shape: mapping.shape().prepend_dim(size), strides }
}
#[inline]
fn remap<M: Mapping>(mapping: &M) -> Self {
let mut strides = S::Dims::new(mapping.rank());
mapping.for_each_stride(|i, stride| strides.as_mut()[i] = stride);
Self { shape: mapping.shape().with_dims(Shape::from_dims), strides }
}
#[inline]
fn remove_dim<M: Mapping>(mapping: &M, index: usize) -> Self {
assert!(index < mapping.rank(), "invalid dimension");
let mut strides = S::Dims::new(mapping.rank() - 1);
mapping.for_each_stride(|i, stride| {
if i < index {
strides.as_mut()[i] = stride;
} else if i > index {
strides.as_mut()[i - 1] = stride;
}
});
Self { shape: mapping.shape().remove_dim(index), strides }
}
#[inline]
fn reshape<R: Shape>(&self, new_shape: R) -> StridedMapping<R> {
let new_shape = self.shape.reshape(new_shape);
let mut new_strides = R::Dims::new(new_shape.rank());
let mut old_len = 1usize;
let mut new_len = 1usize;
let mut old_stride = 1;
let mut new_stride = 1;
let mut valid_layout = true;
let mut j = new_shape.rank();
for i in (0..self.rank()).rev() {
if old_len == new_len {
old_stride = self.strides.as_ref()[i];
new_stride = old_stride;
} else {
valid_layout &= old_stride == self.strides.as_ref()[i];
}
old_len *= self.dim(i);
old_stride *= self.dim(i) as isize;
while j > 0 {
if new_len * new_shape.dim(j - 1) > old_len {
break;
}
j -= 1;
new_strides.as_mut()[j] = new_stride;
new_len *= new_shape.dim(j);
new_stride *= new_shape.dim(j) as isize;
}
}
while j > 0 {
j -= 1;
new_strides.as_mut()[j] = new_stride;
new_len *= new_shape.dim(j);
new_stride *= new_shape.dim(j) as isize;
}
assert!(new_len == 0 || valid_layout, "memory layout not compatible");
StridedMapping { shape: new_shape, strides: new_strides }
}
#[inline]
fn resize_dim<M: Mapping>(mapping: &M, index: usize, new_size: usize) -> Self {
let mut strides = S::Dims::new(mapping.rank());
mapping.for_each_stride(|i, stride| strides.as_mut()[i] = stride);
Self { shape: mapping.shape().resize_dim(index, new_size), strides }
}
#[inline]
fn shape_mut(&mut self) -> &mut S {
&mut self.shape
}
#[inline]
fn transpose<M: Mapping<Shape: Shape<Reverse = S>>>(mapping: &M) -> Self {
let mut strides = S::Dims::new(mapping.rank());
mapping.for_each_stride(|i, stride| strides.as_mut()[mapping.rank() - 1 - i] = stride);
Self { shape: mapping.shape().reverse(), strides }
}
}