pub mod chunk_strided_slice;
pub mod strided_slice;
#[cfg(feature = "ndarray")]
use ndarray::{ArrayRef, Dimension};
use std::marker::PhantomData;
use std::ops::ControlFlow;
use std::ptr::NonNull;
use crate::utils::stride_from_shape;
use chunk_strided_slice::{IterLaneChunks, IterLaneChunksMut};
use strided_slice::{IterLanes, IterLanesMut};
pub use chunk_strided_slice::ChunkStridedSliceRef;
pub use strided_slice::StridedSliceRef;
#[inline]
#[track_caller]
pub(crate) fn unravel(flat_index: usize, shape: &[usize]) -> Vec<usize> {
let n_max: usize = shape.iter().product();
assert!(
flat_index <= n_max,
"Flat index is beyond the end of the array."
);
if flat_index == n_max {
let mut inds = shape.iter().map(|n| n - 1).collect::<Vec<_>>();
if let Some(last) = inds.last_mut()
&& let Some(n_last) = shape.last()
{
*last = *n_last;
}
return inds;
}
let mut inds = vec![0; shape.len()];
let mut flat_index = flat_index;
inds.iter_mut()
.zip(shape.iter())
.rev()
.for_each(|(i_dir, n_dir)| {
*i_dir = flat_index % n_dir;
flat_index /= n_dir;
});
inds
}
pub unsafe trait Data: Sized {
type Elem;
}
pub unsafe trait DataMut: Data {}
pub struct SliceLifetime<T> {
_member: PhantomData<T>,
}
unsafe impl<T> Data for SliceLifetime<&T> {
type Elem = T;
}
unsafe impl<T> Data for SliceLifetime<&mut T> {
type Elem = T;
}
unsafe impl<T> DataMut for SliceLifetime<&mut T> {}
#[derive(Clone, Debug)]
pub(crate) struct ArrayInfo {
shape: Vec<usize>,
stride: Vec<isize>,
lane_length: usize,
lane_stride: isize,
}
impl ArrayInfo {
#[track_caller]
fn new(shape: &[usize], stride: &[isize], axis: usize) -> Self {
assert!(
axis < shape.len(),
"Specified axis exceeds shape dimensions"
);
assert_eq!(
stride.len(),
shape.len(),
"Shape and stride should have the same length."
);
let mut stride = stride.to_owned();
let mut shape = shape.to_owned();
let lane_length = shape.remove(axis);
let lane_stride = stride.remove(axis);
Self {
shape,
stride,
lane_length,
lane_stride,
}
}
#[inline(always)]
fn n_lanes(&self) -> usize {
self.shape.iter().product()
}
#[inline(always)]
fn get_position_at(&self, i: usize) -> Vec<usize> {
unravel(i, &self.shape)
}
#[inline(always)]
fn get_offset_at(&self, pos: &[usize]) -> isize {
pos.iter()
.zip(self.stride.iter())
.fold(0, |acc, (i, step)| acc + *i as isize * step)
}
#[inline(always)]
fn advance_position_and_offset(&self, pos: &mut [usize], offset: &mut isize) {
let _ = self
.stride
.iter()
.zip(self.shape.iter())
.zip(pos)
.rev()
.try_for_each(|((str, shp), pos)| {
*offset += *str;
*pos += 1;
if *pos < *shp {
return ControlFlow::Break(());
};
*pos = 0;
*offset -= *shp as isize * str;
ControlFlow::Continue(())
});
}
#[inline(always)]
fn retreat_position_and_offset(&self, pos: &mut [usize], offset: &mut isize) {
let _ = self
.stride
.iter()
.zip(self.shape.iter())
.zip(pos)
.rev()
.try_for_each(|((str, shp), pos)| {
if *pos == 0 {
*pos = *shp - 1;
*offset += *pos as isize * str;
ControlFlow::Continue(())
} else {
*pos -= 1;
*offset -= *str;
ControlFlow::Break(())
}
});
}
}
#[track_caller]
fn lane_parts_from_slice<T>(arr: &[T], shape: &[usize], axis: usize) -> (NonNull<T>, ArrayInfo) {
lane_parts_from_sub_slice(arr, shape, shape, axis)
}
#[track_caller]
fn lane_parts_from_sub_slice<T>(
arr: &[T],
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (NonNull<T>, ArrayInfo) {
let n = arr.len();
assert!(
!arr.is_empty(),
"Attempted to create a lane iterator from an empty slice."
);
let n_items: usize = shape.iter().product();
assert_eq!(
n, n_items,
"array length must be consistent with the shape. Shape suggests {n_items}, but slice had {n} items."
);
assert_eq!(
shape.len(),
sub_shape.len(),
"shape length, {}, and sub_shape length, {}, must be equal",
shape.len(),
sub_shape.len()
);
assert!(
sub_shape.iter().zip(shape.iter()).all(|(n1, n2)| n1 <= n2),
"sub_shape: {:?}, must be equal to our smaller than shape, {:?}",
sub_shape,
shape,
);
assert!(
axis < shape.len(),
"axis: {axis} is out of bounds for dimension size of {}",
shape.len()
);
let stride = stride_from_shape(shape)
.into_iter()
.map(|s| s as isize)
.collect::<Vec<_>>();
let ptr = unsafe { NonNull::new_unchecked(arr.as_ptr() as *mut T) };
(ptr, ArrayInfo::new(sub_shape, &stride, axis))
}
#[cfg(feature = "ndarray")]
#[track_caller]
fn lane_parts_from_ndarray<T, D: Dimension>(
arr: &ArrayRef<T, D>,
sub_shape: &[usize],
axis: usize,
) -> (NonNull<T>, ArrayInfo) {
assert_ne!(
arr.len(),
0,
"Cannot create a lane iterator from an empty ndarray."
);
let ndim = arr.ndim();
assert!(
axis < ndim,
"axis: {axis} is out of bounds for dimension size of {ndim}",
);
assert_eq!(
sub_shape.len(),
arr.ndim(),
"shape.len(), {}, is not equal to arr.ndim(), {ndim}",
sub_shape.len(),
);
assert!(
sub_shape.iter().zip(arr.shape()).all(|(n, m)| n <= m),
"requested shape, {:?} must all be <= arr.shape(), {:?}.",
sub_shape,
arr.shape(),
);
let ptr = unsafe { NonNull::new_unchecked(arr.as_ptr() as *mut T) };
(ptr, ArrayInfo::new(sub_shape, arr.strides(), axis))
}
pub trait LanesIterator {
type Item;
fn iter_lanes<'a>(&'a self, shape: &[usize], axis: usize) -> IterLanes<'a, Self::Item>;
fn iter_lanes_mut<'a>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> IterLanesMut<'a, Self::Item>;
fn iter_lane_chunks<'a, const N: usize>(
&'a self,
shape: &[usize],
axis: usize,
) -> (IterLaneChunks<'a, Self::Item, N>, IterLanes<'a, Self::Item>);
fn iter_lane_chunks_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> (
IterLaneChunksMut<'a, Self::Item, N>,
IterLanesMut<'a, Self::Item>,
);
fn iter_lanes_sub<'a>(
&'a self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> IterLanes<'a, Self::Item>;
fn iter_lanes_sub_mut<'a>(
&'a mut self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> IterLanesMut<'a, Self::Item>;
fn iter_lane_chunks_sub<'a, const N: usize>(
&'a self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (IterLaneChunks<'a, Self::Item, N>, IterLanes<'a, Self::Item>);
fn iter_lane_chunks_sub_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (
IterLaneChunksMut<'a, Self::Item, N>,
IterLanesMut<'a, Self::Item>,
);
fn min_stride_axis(&self, shape: &[usize]) -> usize;
fn is_ax_contiguous(&self, ax: usize, shape: &[usize]) -> bool;
}
impl<T> LanesIterator for [T] {
type Item = T;
#[track_caller]
fn iter_lanes<'a>(&'a self, shape: &[usize], axis: usize) -> IterLanes<'a, Self::Item> {
IterLanes::from_slice(self, shape, axis)
}
#[track_caller]
fn iter_lanes_mut<'a>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> IterLanesMut<'a, Self::Item> {
IterLanesMut::from_slice(self, shape, axis)
}
#[track_caller]
fn iter_lane_chunks<'a, const N: usize>(
&'a self,
shape: &[usize],
axis: usize,
) -> (IterLaneChunks<'a, Self::Item, N>, IterLanes<'a, Self::Item>) {
IterLaneChunks::from_slice(self, shape, axis)
}
#[track_caller]
fn iter_lane_chunks_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> (
IterLaneChunksMut<'a, Self::Item, N>,
IterLanesMut<'a, Self::Item>,
) {
IterLaneChunksMut::from_slice(self, shape, axis)
}
#[track_caller]
fn iter_lanes_sub<'a>(
&'a self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> IterLanes<'a, Self::Item> {
IterLanes::from_sub_slice(self, shape, sub_shape, axis)
}
#[track_caller]
fn iter_lanes_sub_mut<'a>(
&'a mut self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> IterLanesMut<'a, Self::Item> {
IterLanesMut::from_sub_slice(self, shape, sub_shape, axis)
}
#[track_caller]
fn iter_lane_chunks_sub<'a, const N: usize>(
&'a self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (IterLaneChunks<'a, Self::Item, N>, IterLanes<'a, Self::Item>) {
IterLaneChunks::from_sub_slice(self, shape, sub_shape, axis)
}
#[track_caller]
fn iter_lane_chunks_sub_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (
IterLaneChunksMut<'a, Self::Item, N>,
IterLanesMut<'a, Self::Item>,
) {
IterLaneChunksMut::from_sub_slice(self, shape, sub_shape, axis)
}
fn min_stride_axis(&self, shape: &[usize]) -> usize {
if !shape.is_empty() {
shape.len() - 1
} else {
0
}
}
#[inline]
fn is_ax_contiguous(&self, ax: usize, shape: &[usize]) -> bool {
ax + 1 == shape.len()
}
}
#[cfg(feature = "ndarray")]
impl<T, D: ::ndarray::Dimension> LanesIterator for ArrayRef<T, D> {
type Item = T;
#[track_caller]
fn iter_lanes<'a>(&'a self, shape: &[usize], axis: usize) -> IterLanes<'a, Self::Item> {
IterLanes::from_ndarray(self, shape, axis)
}
#[track_caller]
fn iter_lanes_mut<'a>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> IterLanesMut<'a, Self::Item> {
IterLanesMut::from_ndarray(self, shape, axis)
}
#[track_caller]
fn iter_lane_chunks<'a, const N: usize>(
&'a self,
shape: &[usize],
axis: usize,
) -> (IterLaneChunks<'a, Self::Item, N>, IterLanes<'a, Self::Item>) {
IterLaneChunks::from_ndarray(self, shape, axis)
}
#[track_caller]
fn iter_lane_chunks_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> (
IterLaneChunksMut<'a, Self::Item, N>,
IterLanesMut<'a, Self::Item>,
) {
IterLaneChunksMut::from_ndarray(self, shape, axis)
}
#[track_caller]
fn iter_lanes_sub<'a>(
&'a self,
_shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> IterLanes<'a, Self::Item> {
IterLanes::from_ndarray(self, sub_shape, axis)
}
#[track_caller]
fn iter_lanes_sub_mut<'a>(
&'a mut self,
_shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> IterLanesMut<'a, Self::Item> {
IterLanesMut::from_ndarray(self, sub_shape, axis)
}
#[track_caller]
fn iter_lane_chunks_sub<'a, const N: usize>(
&'a self,
_shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (IterLaneChunks<'a, Self::Item, N>, IterLanes<'a, Self::Item>) {
IterLaneChunks::from_ndarray(self, sub_shape, axis)
}
#[track_caller]
fn iter_lane_chunks_sub_mut<'a, const N: usize>(
&'a mut self,
_shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (
IterLaneChunksMut<'a, Self::Item, N>,
IterLanesMut<'a, Self::Item>,
) {
IterLaneChunksMut::from_ndarray(self, sub_shape, axis)
}
fn min_stride_axis(&self, _shape: &[usize]) -> usize {
let (min_axis, _) = self
.strides()
.iter()
.cloned()
.enumerate()
.reduce(|acc, v| if v.1.abs() < acc.1.abs() { v } else { acc })
.unwrap_or((0, 0));
min_axis
}
#[inline]
fn is_ax_contiguous(&self, ax: usize, _shape: &[usize]) -> bool {
self.strides().get(ax).map(|v| *v == 1).unwrap_or(false)
}
}
#[cfg(feature = "rayon")]
pub mod parallel {
use super::chunk_strided_slice::parallel::{ParIterLaneChunks, ParIterLaneChunksMut};
use super::strided_slice::parallel::{ParIterLanes, ParIterLanesMut};
use super::*;
pub trait LanesParallelIterator: LanesIterator {
fn par_iter_lanes<'a>(
&'a self,
shape: &[usize],
axis: usize,
) -> ParIterLanes<'a, Self::Item>;
fn par_iter_lanes_mut<'a>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> ParIterLanesMut<'a, Self::Item>;
fn par_iter_lane_chunks<'a, const N: usize>(
&'a self,
shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunks<'a, Self::Item, N>,
ParIterLanes<'a, Self::Item>,
);
fn par_iter_lane_chunks_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunksMut<'a, Self::Item, N>,
ParIterLanesMut<'a, Self::Item>,
);
fn par_iter_lanes_sub<'a>(
&'a self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> ParIterLanes<'a, Self::Item>;
fn par_iter_lanes_sub_mut<'a>(
&'a mut self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> ParIterLanesMut<'a, Self::Item>;
fn par_iter_lane_chunks_sub<'a, const N: usize>(
&'a self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunks<'a, Self::Item, N>,
ParIterLanes<'a, Self::Item>,
);
fn par_iter_lane_chunks_sub_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunksMut<'a, Self::Item, N>,
ParIterLanesMut<'a, Self::Item>,
);
}
impl<T> LanesParallelIterator for [T] {
#[track_caller]
fn par_iter_lanes<'a>(
&'a self,
shape: &[usize],
axis: usize,
) -> ParIterLanes<'a, Self::Item> {
ParIterLanes::from_slice(self, shape, axis)
}
#[track_caller]
fn par_iter_lanes_mut<'a>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> ParIterLanesMut<'a, Self::Item> {
ParIterLanesMut::from_slice(self, shape, axis)
}
#[track_caller]
fn par_iter_lane_chunks<'a, const N: usize>(
&'a self,
shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunks<'a, Self::Item, N>,
ParIterLanes<'a, Self::Item>,
) {
ParIterLaneChunks::from_slice(self, shape, axis)
}
#[track_caller]
fn par_iter_lane_chunks_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunksMut<'a, Self::Item, N>,
ParIterLanesMut<'a, Self::Item>,
) {
ParIterLaneChunksMut::from_slice(self, shape, axis)
}
#[track_caller]
fn par_iter_lanes_sub<'a>(
&'a self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> ParIterLanes<'a, Self::Item> {
ParIterLanes::from_sub_slice(self, shape, sub_shape, axis)
}
#[track_caller]
fn par_iter_lanes_sub_mut<'a>(
&'a mut self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> ParIterLanesMut<'a, Self::Item> {
ParIterLanesMut::from_sub_slice(self, shape, sub_shape, axis)
}
#[track_caller]
fn par_iter_lane_chunks_sub<'a, const N: usize>(
&'a self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunks<'a, Self::Item, N>,
ParIterLanes<'a, Self::Item>,
) {
ParIterLaneChunks::from_sub_slice(self, shape, sub_shape, axis)
}
#[track_caller]
fn par_iter_lane_chunks_sub_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunksMut<'a, Self::Item, N>,
ParIterLanesMut<'a, Self::Item>,
) {
ParIterLaneChunksMut::from_sub_slice(self, shape, sub_shape, axis)
}
}
#[cfg(feature = "ndarray")]
impl<T, D: ::ndarray::Dimension> LanesParallelIterator for ArrayRef<T, D> {
#[track_caller]
fn par_iter_lanes<'a>(
&'a self,
shape: &[usize],
axis: usize,
) -> ParIterLanes<'a, Self::Item> {
ParIterLanes::from_ndarray(self, shape, axis)
}
#[track_caller]
fn par_iter_lanes_mut<'a>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> ParIterLanesMut<'a, Self::Item> {
ParIterLanesMut::from_ndarray(self, shape, axis)
}
#[track_caller]
fn par_iter_lane_chunks<'a, const N: usize>(
&'a self,
shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunks<'a, Self::Item, N>,
ParIterLanes<'a, Self::Item>,
) {
ParIterLaneChunks::from_ndarray(self, shape, axis)
}
#[track_caller]
fn par_iter_lane_chunks_mut<'a, const N: usize>(
&'a mut self,
shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunksMut<'a, Self::Item, N>,
ParIterLanesMut<'a, Self::Item>,
) {
ParIterLaneChunksMut::from_ndarray(self, shape, axis)
}
#[track_caller]
fn par_iter_lanes_sub<'a>(
&'a self,
_shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> ParIterLanes<'a, Self::Item> {
ParIterLanes::from_ndarray(self, sub_shape, axis)
}
#[track_caller]
fn par_iter_lanes_sub_mut<'a>(
&'a mut self,
_shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> ParIterLanesMut<'a, Self::Item> {
ParIterLanesMut::from_ndarray(self, sub_shape, axis)
}
#[track_caller]
fn par_iter_lane_chunks_sub<'a, const N: usize>(
&'a self,
_shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunks<'a, Self::Item, N>,
ParIterLanes<'a, Self::Item>,
) {
ParIterLaneChunks::from_ndarray(self, sub_shape, axis)
}
#[track_caller]
fn par_iter_lane_chunks_sub_mut<'a, const N: usize>(
&'a mut self,
_shape: &[usize],
sub_shape: &[usize],
axis: usize,
) -> (
ParIterLaneChunksMut<'a, Self::Item, N>,
ParIterLanesMut<'a, Self::Item>,
) {
ParIterLaneChunksMut::from_ndarray(self, sub_shape, axis)
}
}
}