#![allow(unused_parens)]
use crate::order::{ColumnMajor, Order, RowMajor};
use crate::view::{DenseView, StridedView};
use std::ops::{Bound, Range, RangeFrom, RangeFull, RangeInclusive, RangeTo, RangeToInclusive};
use std::ops::{Index, IndexMut};
use std::slice::{self, SliceIndex};
pub enum DimInfo {
Range(Range<usize>),
Scalar(usize),
}
pub trait DimIndex: Clone {
const FULL: bool;
const RANGE: bool;
fn dim_info(self, size: usize) -> DimInfo;
}
pub trait IndexMap<const N: usize, O: Order> {
const FULL: bool;
const CONT: usize;
const RANK: usize;
fn view_info(
&self,
dims: &mut [usize],
shape: &mut [usize],
start: &mut [usize],
limit: &[usize],
dim: usize,
);
}
pub trait ViewIndex<T, const N: usize, const M: usize, O: Order> {
type Output: ?Sized;
fn index(self, view: &StridedView<T, N, M, O>) -> &Self::Output;
fn index_mut(self, view: &mut StridedView<T, N, M, O>) -> &mut Self::Output;
}
impl DimIndex for usize {
const FULL: bool = false;
const RANGE: bool = false;
fn dim_info(self, limit: usize) -> DimInfo {
if self >= limit {
panic_bounds_check(self, limit)
}
DimInfo::Scalar(self)
}
}
macro_rules! impl_dim_index {
($type:ty, $full:tt) => {
impl DimIndex for $type {
const FULL: bool = $full;
const RANGE: bool = true;
fn dim_info(self, limit: usize) -> DimInfo {
DimInfo::Range(slice::range(self, ..limit))
}
}
};
}
impl_dim_index!((Bound<usize>, Bound<usize>), false);
impl_dim_index!(Range<usize>, false);
impl_dim_index!(RangeFrom<usize>, false);
impl_dim_index!(RangeFull, true);
impl_dim_index!(RangeInclusive<usize>, false);
impl_dim_index!(RangeTo<usize>, false);
impl_dim_index!(RangeToInclusive<usize>, false);
impl<O: Order, X: DimIndex> IndexMap<1, O> for X {
const FULL: bool = X::FULL;
const CONT: usize = X::RANGE as usize;
const RANK: usize = X::RANGE as usize;
fn view_info(
&self,
dims: &mut [usize],
shape: &mut [usize],
start: &mut [usize],
limits: &[usize],
dim: usize,
) {
start[0] = match self.clone().dim_info(limits[0]) {
DimInfo::Range(r) => {
dims[0] = dim;
shape[0] = r.end - r.start;
r.start
}
DimInfo::Scalar(s) => s,
};
}
}
macro_rules! impl_index_map {
($n:tt, ($($x:tt),+), ($($y:tt),+), $last:tt, ($($vars:tt),+)) => {
impl<$($x: DimIndex),+, $last: DimIndex> IndexMap<$n, ColumnMajor> for ($($x),+, $last) {
const FULL: bool = <($($x),+) as IndexMap<{$n - 1}, ColumnMajor>>::FULL && $last::FULL;
const CONT: usize = <($($x),+) as IndexMap<{$n - 1}, ColumnMajor>>::CONT
+ (<($($x),+) as IndexMap<{$n - 1}, ColumnMajor>>::FULL && $last::RANGE) as usize;
const RANK: usize =
<($($x),+) as IndexMap::<{$n - 1}, ColumnMajor>>::RANK + $last::RANGE as usize;
fn view_info(&self,
dims: &mut [usize],
shape: &mut [usize],
start: &mut [usize],
limits: &[usize],
dim: usize,
) {
start[0] = match self.0.clone().dim_info(limits[0]) {
DimInfo::Range(r) => {
<($($y),+) as IndexMap<{$n - 1}, ColumnMajor>>::view_info(
&($(self.$vars.clone()),+),
&mut dims[1..],
&mut shape[1..],
&mut start[1..],
&limits[1..],
dim + 1,
);
dims[0] = dim;
shape[0] = r.end - r.start;
r.start
}
DimInfo::Scalar(s) => {
<($($y),+) as IndexMap<{$n - 1}, ColumnMajor>>::view_info(
&($(self.$vars.clone()),+),
dims,
shape,
&mut start[1..],
&limits[1..],
dim + 1,
);
s
}
};
}
}
impl<X: DimIndex, $($y: DimIndex),+> IndexMap<$n, RowMajor> for (X, $($y),+) {
const FULL: bool = <($($y),+) as IndexMap<{$n - 1}, RowMajor>>::FULL && X::FULL;
const CONT: usize = <($($y),+) as IndexMap<{$n - 1}, RowMajor>>::CONT
+ (<($($y),+) as IndexMap<{$n - 1}, RowMajor>>::FULL && X::RANGE) as usize;
const RANK: usize =
<($($y),+) as IndexMap::<{$n - 1}, RowMajor>>::RANK + X::RANGE as usize;
fn view_info(&self,
dims: &mut [usize],
shape: &mut [usize],
start: &mut [usize],
limits: &[usize],
dim: usize,
) {
start[0] = match self.0.clone().dim_info(limits[0]) {
DimInfo::Range(r) => {
<($($y),+) as IndexMap<{$n - 1}, RowMajor>>::view_info(
&($(self.$vars.clone()),+),
&mut dims[1..],
&mut shape[1..],
&mut start[1..],
&limits[1..],
dim + 1,
);
dims[0] = dim;
shape[0] = r.end - r.start;
r.start
}
DimInfo::Scalar(s) => {
<($($y),+) as IndexMap<{$n - 1}, RowMajor>>::view_info(
&($(self.$vars.clone()),+),
dims,
shape,
&mut start[1..],
&limits[1..],
dim + 1,
);
s
}
};
}
}
};
}
impl_index_map!(2, (X), (Y), Y, (1));
impl_index_map!(3, (X, Y), (Y, Z), Z, (1, 2));
impl_index_map!(4, (X, Y, Z), (Y, Z, W), W, (1, 2, 3));
impl_index_map!(5, (X, Y, Z, W), (Y, Z, W, U), U, (1, 2, 3, 4));
impl_index_map!(6, (X, Y, Z, W, U), (Y, Z, W, U, V), V, (1, 2, 3, 4, 5));
macro_rules! impl_view_index {
($type:ty) => {
impl<T, const N: usize, O: Order> ViewIndex<T, N, 0, O> for $type {
type Output = <$type as SliceIndex<[T]>>::Output;
fn index(self, view: &DenseView<T, N, O>) -> &Self::Output {
(**view).index(self)
}
fn index_mut(self, view: &mut DenseView<T, N, O>) -> &mut Self::Output {
(**view).index_mut(self)
}
}
};
}
impl_view_index!((Bound<usize>, Bound<usize>));
impl_view_index!(usize);
impl_view_index!(Range<usize>);
impl_view_index!(RangeFrom<usize>);
impl_view_index!(RangeInclusive<usize>);
impl_view_index!(RangeFull);
impl_view_index!(RangeTo<usize>);
impl_view_index!(RangeToInclusive<usize>);
impl<T, const N: usize, const M: usize, O: Order> ViewIndex<T, N, M, O> for [usize; N] {
type Output = T;
fn index(self, view: &StridedView<T, N, M, O>) -> &Self::Output {
let mut index = 0;
for i in 0..self.len() {
if self[i] >= view.size(i) {
panic_bounds_check(self[i], view.size(i))
}
index += self[i] as isize * view.stride(i);
}
unsafe { &*view.as_ptr().offset(index) }
}
fn index_mut(self, view: &mut StridedView<T, N, M, O>) -> &mut Self::Output {
let mut index = 0;
for i in 0..self.len() {
if self[i] >= view.size(i) {
panic_bounds_check(self[i], view.size(i))
}
index += self[i] as isize * view.stride(i);
}
unsafe { &mut *view.as_mut_ptr().offset(index) }
}
}
#[inline(never)]
#[track_caller]
fn panic_bounds_check(index: usize, len: usize) -> ! {
panic!(
"index out of bounds: the len is {} but the index is {}",
len, index
)
}