use crate::prelude_dev::*;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum Indexer {
Slice(SliceI),
Select(isize),
Insert,
Ellipsis,
}
pub use Indexer::Ellipsis;
pub use Indexer::Insert as NewAxis;
impl<R> From<R> for Indexer
where
R: Into<SliceI>,
{
fn from(slice: R) -> Self {
Self::Slice(slice.into())
}
}
impl From<Option<usize>> for Indexer {
fn from(opt: Option<usize>) -> Self {
match opt {
Some(_) => panic!("Option<T> should not be used in Indexer."),
None => Self::Insert,
}
}
}
macro_rules! impl_from_int_into_indexer {
($($t:ty),*) => {
$(
impl From<$t> for Indexer {
fn from(index: $t) -> Self {
Self::Select(index as isize)
}
}
)*
};
}
impl_from_int_into_indexer!(usize, isize, u32, i32, u64, i64);
macro_rules! impl_into_axes_index {
($($t:ty),*) => {
$(
impl TryFrom<$t> for AxesIndex<Indexer> {
type Error = Error;
fn try_from(index: $t) -> Result<Self> {
Ok(AxesIndex::Val(index.try_into()?))
}
}
impl<const N: usize> TryFrom<[$t; N]> for AxesIndex<Indexer> {
type Error = Error;
fn try_from(index: [$t; N]) -> Result<Self> {
let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
Ok(AxesIndex::Vec(index))
}
}
impl TryFrom<Vec<$t>> for AxesIndex<Indexer> {
type Error = Error;
fn try_from(index: Vec<$t>) -> Result<Self> {
let index = index.iter().map(|v| v.clone().into()).collect::<Vec<_>>();
Ok(AxesIndex::Vec(index))
}
}
)*
};
}
impl_into_axes_index!(usize, isize, u32, i32, u64, i64);
impl_into_axes_index!(Option<usize>);
impl_into_axes_index!(
Slice<isize>,
core::ops::Range<isize>,
core::ops::RangeFrom<isize>,
core::ops::RangeTo<isize>,
core::ops::Range<usize>,
core::ops::RangeFrom<usize>,
core::ops::RangeTo<usize>,
core::ops::Range<i32>,
core::ops::RangeFrom<i32>,
core::ops::RangeTo<i32>,
core::ops::RangeFull
);
impl_from_tuple_to_axes_index!(Indexer);
pub trait IndexerPreserveAPI: Sized {
fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self>;
}
impl<D> IndexerPreserveAPI for Layout<D>
where
D: DimDevAPI,
{
fn dim_narrow(&self, axis: isize, slice: SliceI) -> Result<Self> {
let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
let axis = axis as usize;
let mut shape = self.shape().clone();
let mut stride = self.stride().clone();
if slice == Slice::new(None, None, None) {
return Ok(self.clone());
}
let len_prev = shape[axis] as isize;
let step = slice.step().unwrap_or(1);
rstsr_assert!(step != 0, InvalidValue)?;
if len_prev == 0 {
return Ok(self.clone());
}
if step > 0 {
let mut start = slice.start().unwrap_or(0);
let mut stop = slice.stop().unwrap_or(len_prev);
if start < 0 {
start = (len_prev + start).max(0);
}
if stop < 0 {
stop = (len_prev + stop).max(0);
}
if start > len_prev || start > stop {
start = 0;
stop = 0;
} else if stop > len_prev {
stop = len_prev;
}
let offset = (self.offset() as isize + stride[axis] * start) as usize;
shape[axis] = ((stop - start + step - 1) / step).max(0) as usize;
stride[axis] *= step;
return Self::new(shape, stride, offset);
} else {
let mut start = slice.start().unwrap_or(len_prev - 1);
let mut stop = slice.stop().unwrap_or(-1);
if start < 0 {
start = (len_prev + start).max(0);
}
if stop < -1 {
stop = (len_prev + stop).max(-1);
}
if stop > len_prev - 1 || stop > start {
start = 0;
stop = 0;
} else if start > len_prev - 1 {
start = len_prev - 1;
}
let offset = (self.offset() as isize + stride[axis] * start) as usize;
shape[axis] = ((stop - start + step + 1) / step).max(0) as usize;
stride[axis] *= step;
return Self::new(shape, stride, offset);
}
}
}
pub trait IndexerSmallerOneAPI {
type DOut: DimDevAPI;
fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>>;
fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>>;
fn dim_chop(&self, axis: isize) -> Result<Layout<Self::DOut>>;
}
impl<D> IndexerSmallerOneAPI for Layout<D>
where
D: DimDevAPI + DimSmallerOneAPI,
D::SmallerOne: DimDevAPI,
{
type DOut = <D as DimSmallerOneAPI>::SmallerOne;
fn dim_select(&self, axis: isize, index: isize) -> Result<Layout<Self::DOut>> {
let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
let axis = axis as usize;
let shape = self.shape();
let stride = self.stride();
let mut offset = self.offset() as isize;
let mut shape_new = vec![];
let mut stride_new = vec![];
for (i, (&d, &s)) in shape.as_ref().iter().zip(stride.as_ref().iter()).enumerate() {
if i == axis {
let idx = if index < 0 { d as isize + index } else { index };
rstsr_pattern!(idx, 0..d as isize, ValueOutOfRange)?;
offset += s * idx;
} else {
shape_new.push(d);
stride_new.push(s);
}
}
let offset = offset as usize;
let layout = Layout::<IxD>::new(shape_new, stride_new, offset)?;
return layout.into_dim();
}
fn dim_eliminate(&self, axis: isize) -> Result<Layout<Self::DOut>> {
let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
let axis = axis as usize;
let mut shape = self.shape().as_ref().to_vec();
let mut stride = self.stride().as_ref().to_vec();
let offset = self.offset();
if shape[axis] != 1 {
rstsr_raise!(InvalidValue, "Dimension to be eliminated is not 1.")?;
}
shape.remove(axis);
stride.remove(axis);
let layout = Layout::<IxD>::new(shape, stride, offset)?;
return layout.into_dim();
}
fn dim_chop(&self, axis: isize) -> Result<Layout<Self::DOut>> {
let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
rstsr_pattern!(axis, 0..self.ndim() as isize, ValueOutOfRange)?;
let axis = axis as usize;
let mut shape = self.shape().as_ref().to_vec();
let mut stride = self.stride().as_ref().to_vec();
let offset = self.offset();
shape.remove(axis);
stride.remove(axis);
let layout = Layout::<IxD>::new(shape, stride, offset)?;
return layout.into_dim();
}
}
pub trait IndexerLargerOneAPI {
type DOut: DimDevAPI;
fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>>;
}
impl<D> IndexerLargerOneAPI for Layout<D>
where
D: DimDevAPI + DimLargerOneAPI,
D::LargerOne: DimDevAPI,
{
type DOut = <D as DimLargerOneAPI>::LargerOne;
fn dim_insert(&self, axis: isize) -> Result<Layout<Self::DOut>> {
let axis = if axis < 0 { self.ndim() as isize + axis + 1 } else { axis };
rstsr_pattern!(axis, 0..(self.ndim() + 1) as isize, ValueOutOfRange)?;
let axis = axis as usize;
let is_f_prefer = self.f_prefer();
let mut shape = self.shape().as_ref().to_vec();
let mut stride = self.stride().as_ref().to_vec();
let offset = self.offset();
if is_f_prefer {
if axis == 0 {
shape.insert(0, 1);
stride.insert(0, 1);
} else {
shape.insert(axis, 1);
stride.insert(axis, stride[axis - 1]);
}
} else if axis == self.ndim() {
shape.push(1);
stride.push(1);
} else {
shape.insert(axis, 1);
stride.insert(axis, stride[axis]);
}
let layout = Layout::new(shape, stride, offset)?;
return layout.into_dim();
}
}
pub trait IndexerDynamicAPI: IndexerPreserveAPI {
fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>>;
fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)>;
fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)>;
}
impl<D> IndexerDynamicAPI for Layout<D>
where
D: DimDevAPI,
{
fn dim_slice(&self, indexers: &[Indexer]) -> Result<Layout<IxD>> {
let shape = self.shape().as_ref().to_vec();
let stride = self.stride().as_ref().to_vec();
let mut layout = Layout::new(shape, stride, self.offset)?;
let mut indexers = indexers.to_vec();
let mut counter_slice = 0;
let mut counter_select = 0;
let mut idx_ellipsis = None;
for (n, indexer) in indexers.iter().enumerate() {
match indexer {
Indexer::Slice(_) => counter_slice += 1,
Indexer::Select(_) => counter_select += 1,
Indexer::Ellipsis => match idx_ellipsis {
Some(_) => rstsr_raise!(InvalidValue, "Only one ellipsis indexer allowed.")?,
None => idx_ellipsis = Some(n),
},
_ => {},
}
}
rstsr_pattern!(counter_slice + counter_select, 0..=self.ndim(), ValueOutOfRange)?;
let n_ellipsis = self.ndim() - counter_slice - counter_select;
if n_ellipsis == 0 {
if let Some(idx) = idx_ellipsis {
indexers.remove(idx);
}
} else if let Some(idx_ellipsis) = idx_ellipsis {
indexers[idx_ellipsis] = SliceI::new(None, None, None).into();
if n_ellipsis > 1 {
for _ in 1..n_ellipsis {
indexers.insert(idx_ellipsis, SliceI::new(None, None, None).into());
}
}
} else {
for _ in 0..n_ellipsis {
indexers.push(SliceI::new(None, None, None).into());
}
}
let mut cur_dim = self.ndim() as isize;
for indexer in indexers.iter().rev() {
match indexer {
Indexer::Slice(slice) => {
cur_dim -= 1;
layout = layout.dim_narrow(cur_dim, *slice)?;
},
Indexer::Select(index) => {
cur_dim -= 1;
layout = layout.dim_select(cur_dim, *index)?;
},
Indexer::Insert => {
layout = layout.dim_insert(cur_dim)?;
},
_ => rstsr_raise!(InvalidValue, "Invalid indexer found : {:?}", indexer)?,
}
}
rstsr_assert!(cur_dim == 0, Miscellaneous, "Internal program error in indexer.")?;
return Ok(layout);
}
fn dim_split_at(&self, axis: isize) -> Result<(Layout<IxD>, Layout<IxD>)> {
let axis = if axis < 0 { self.ndim() as isize + axis } else { axis };
rstsr_pattern!(axis, 0..=self.ndim() as isize, ValueOutOfRange)?;
let axis = axis as usize;
let shape = self.shape().as_ref().to_vec();
let stride = self.stride().as_ref().to_vec();
let offset = self.offset();
let (shape1, shape2) = shape.split_at(axis);
let (stride1, stride2) = stride.split_at(axis);
let layout1 = unsafe { Layout::new_unchecked(shape1.to_vec(), stride1.to_vec(), offset) };
let layout2 = unsafe { Layout::new_unchecked(shape2.to_vec(), stride2.to_vec(), offset) };
return Ok((layout1, layout2));
}
fn dim_split_axes(&self, axes: &[isize]) -> Result<(Layout<IxD>, Layout<IxD>)> {
let axes_update = normalize_axes_index(axes.into(), self.ndim(), false, false)?
.into_iter()
.map(|axis| axis as usize)
.collect::<Vec<usize>>();
let axes_rest = (0..self.ndim()).filter(|&axis| !axes_update.contains(&axis)).collect::<Vec<_>>();
let offset = self.offset();
let shape_axes = axes_update.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
let strides_axes = axes_update.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
let layout_axes = Layout::new(shape_axes, strides_axes, offset)?;
let shape_rest = axes_rest.iter().map(|&axis| self.shape()[axis]).collect::<Vec<_>>();
let strides_rest = axes_rest.iter().map(|&axis| self.stride()[axis]).collect::<Vec<_>>();
let layout_rest = Layout::new(shape_rest, strides_rest, offset)?;
return Ok((layout_axes, layout_rest));
}
}
#[macro_export]
macro_rules! slice {
($stop:expr) => {{
use $crate::layout::slice::Slice;
Slice::<isize>::from(Slice::new(None, $stop, None))
}};
($start:expr, $stop:expr) => {{
use $crate::layout::slice::Slice;
Slice::<isize>::from(Slice::new($start, $stop, None))
}};
($start:expr, $stop:expr, $step:expr) => {{
use $crate::layout::slice::Slice;
Slice::<isize>::from(Slice::new($start, $stop, $step))
}};
}
#[macro_export]
macro_rules! s {
[$($slc:expr),*] => {
[$(($slc).into()),*].as_ref()
};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_slice() {
let t = 3_usize;
let s = slice!(1, 2, t);
assert_eq!(s.start(), Some(1));
assert_eq!(s.stop(), Some(2));
assert_eq!(s.step(), Some(3));
}
#[test]
fn test_slice_at_dim() {
let l = Layout::new([2, 3, 4], [1, 10, 100], 0).unwrap();
let s = slice!(10, 1, -1);
let l1 = l.dim_narrow(1, s).unwrap();
println!("{l1:?}");
let l2 = l.dim_select(1, -2).unwrap();
println!("{l2:?}");
let l3 = l.dim_insert(1).unwrap();
println!("{l3:?}");
let l = Layout::new([2, 3, 4], [100, 10, 1], 0).unwrap();
let l3 = l.dim_insert(1).unwrap();
println!("{l3:?}");
let l4 = l.dim_slice(s![Indexer::Ellipsis, 1..3, None, 2]).unwrap();
let l4 = l4.into_dim::<Ix3>().unwrap();
println!("{l4:?}");
assert_eq!(l4.shape(), &[2, 2, 1]);
assert_eq!(l4.offset(), 12);
let l5 = l.dim_slice(s![None, 1, None, 1..3]).unwrap();
let l5 = l5.into_dim::<Ix4>().unwrap();
println!("{l5:?}");
assert_eq!(l5.shape(), &[1, 1, 2, 4]);
assert_eq!(l5.offset(), 110);
}
#[test]
fn test_slice_with_stride() {
let l = Layout::new([24], [1], 0).unwrap();
let b = l.dim_narrow(0, slice!(5, 15, 2)).unwrap();
assert_eq!(b, Layout::new([5], [2], 5).unwrap());
let b = l.dim_narrow(0, slice!(5, 16, 2)).unwrap();
assert_eq!(b, Layout::new([6], [2], 5).unwrap());
let b = l.dim_narrow(0, slice!(15, 5, -2)).unwrap();
assert_eq!(b, Layout::new([5], [-2], 15).unwrap());
let b = l.dim_narrow(0, slice!(15, 4, -2)).unwrap();
assert_eq!(b, Layout::new([6], [-2], 15).unwrap());
}
#[test]
fn test_expand_dims() {
let l = Layout::<Ix3>::new([2, 3, 4], [1, 10, 100], 0).unwrap();
let l1 = l.dim_insert(0).unwrap();
println!("{l1:?}");
let l2 = l.dim_insert(1).unwrap();
println!("{l2:?}");
let l3 = l.dim_insert(3).unwrap();
println!("{l3:?}");
let l4 = l.dim_insert(-1).unwrap();
println!("{l4:?}");
let l5 = l.dim_insert(-4).unwrap();
println!("{l5:?}");
}
}