#[macro_use]
mod zipmacro;
mod ndproducer;
#[cfg(feature = "rayon")]
use std::mem::MaybeUninit;
use crate::imp_prelude::*;
use crate::partial::Partial;
use crate::AssignElem;
use crate::IntoDimension;
use crate::Layout;
use crate::dimension;
use crate::indexes::{indices, Indices};
use crate::split_at::{SplitAt, SplitPreference};
pub use self::ndproducer::{IntoNdProducer, NdProducer, Offset};
macro_rules! fold_while {
($e:expr) => {
match $e {
FoldWhile::Continue(x) => x,
x => return x,
}
};
}
trait Broadcast<E>
where E: IntoDimension
{
type Output: NdProducer<Dim = E::Dim>;
#[track_caller]
fn broadcast_unwrap(self, shape: E) -> Self::Output;
private_decl! {}
}
fn array_layout<D: Dimension>(dim: &D, strides: &D) -> Layout
{
let n = dim.ndim();
if dimension::is_layout_c(dim, strides) {
if n <= 1 || dim.slice().iter().filter(|&&len| len > 1).count() <= 1 {
Layout::one_dimensional()
} else {
Layout::c()
}
} else if n > 1 && dimension::is_layout_f(dim, strides) {
Layout::f()
} else if n > 1 {
if dim[0] > 1 && strides[0] == 1 {
Layout::fpref()
} else if dim[n - 1] > 1 && strides[n - 1] == 1 {
Layout::cpref()
} else {
Layout::none()
}
} else {
Layout::none()
}
}
impl<A, D> LayoutRef<A, D>
where D: Dimension
{
pub(crate) fn layout_impl(&self) -> Layout
{
array_layout(self._dim(), self._strides())
}
}
impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
where
E: IntoDimension,
D: Dimension,
{
type Output = ArrayView<'a, A, E::Dim>;
fn broadcast_unwrap(self, shape: E) -> Self::Output
{
#[allow(clippy::needless_borrow)]
let res: ArrayView<'_, A, E::Dim> = (*self).broadcast_unwrap(shape.into_dimension());
unsafe { ArrayView::new(res.parts.ptr, res.parts.dim, res.parts.strides) }
}
private_impl! {}
}
trait ZippableTuple: Sized
{
type Item;
type Ptr: OffsetTuple<Args = Self::Stride> + Copy;
type Dim: Dimension;
type Stride: Copy;
fn as_ptr(&self) -> Self::Ptr;
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item;
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
fn stride_of(&self, index: usize) -> Self::Stride;
fn contiguous_stride(&self) -> Self::Stride;
fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
}
#[derive(Debug, Clone)]
#[must_use = "zipping producers is lazy and does nothing unless consumed"]
pub struct Zip<Parts, D>
{
parts: Parts,
dimension: D,
layout: Layout,
layout_tendency: i32,
}
impl<P, D> Zip<(P,), D>
where
D: Dimension,
P: NdProducer<Dim = D>,
{
pub fn from<IP>(p: IP) -> Self
where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
{
let array = p.into_producer();
let dim = array.raw_dim();
let layout = array.layout();
Zip {
dimension: dim,
layout,
parts: (array,),
layout_tendency: layout.tendency(),
}
}
}
impl<P, D> Zip<(Indices<D>, P), D>
where
D: Dimension + Copy,
P: NdProducer<Dim = D>,
{
pub fn indexed<IP>(p: IP) -> Self
where IP: IntoNdProducer<Dim = D, Output = P, Item = P::Item>
{
let array = p.into_producer();
let dim = array.raw_dim();
Zip::from(indices(dim)).and(array)
}
}
#[inline]
fn zip_dimension_check<D, P>(dimension: &D, part: &P)
where
D: Dimension,
P: NdProducer<Dim = D>,
{
ndassert!(
part.equal_dim(dimension),
"Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
dimension,
part.raw_dim()
);
}
impl<Parts, D> Zip<Parts, D>
where D: Dimension
{
pub fn size(&self) -> usize
{
self.dimension.size()
}
#[track_caller]
fn len_of(&self, axis: Axis) -> usize
{
self.dimension[axis.index()]
}
fn prefer_f(&self) -> bool
{
!self.layout.is(Layout::CORDER) && (self.layout.is(Layout::FORDER) || self.layout_tendency < 0)
}
fn max_stride_axis(&self) -> Axis
{
let i = if self.prefer_f() {
self.dimension
.slice()
.iter()
.rposition(|&len| len > 1)
.unwrap_or(self.dimension.ndim() - 1)
} else {
self.dimension
.slice()
.iter()
.position(|&len| len > 1)
.unwrap_or(0)
};
Axis(i)
}
}
impl<P, D> Zip<P, D>
where D: Dimension
{
fn for_each_core<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
{
if self.dimension.ndim() == 0 {
function(acc, unsafe { self.parts.as_ref(self.parts.as_ptr()) })
} else if self.layout.is(Layout::CORDER | Layout::FORDER) {
self.for_each_core_contiguous(acc, function)
} else {
self.for_each_core_strided(acc, function)
}
}
fn for_each_core_contiguous<F, Acc>(&mut self, acc: Acc, mut function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
{
debug_assert!(self.layout.is(Layout::CORDER | Layout::FORDER));
let size = self.dimension.size();
let ptrs = self.parts.as_ptr();
let inner_strides = self.parts.contiguous_stride();
unsafe { self.inner(acc, ptrs, inner_strides, size, &mut function) }
}
unsafe fn inner<F, Acc>(
&self, mut acc: Acc, ptr: P::Ptr, strides: P::Stride, len: usize, function: &mut F,
) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple,
{
let mut i = 0;
while i < len {
let p = ptr.stride_offset(strides, i);
acc = fold_while!(function(acc, self.parts.as_ref(p)));
i += 1;
}
FoldWhile::Continue(acc)
}
fn for_each_core_strided<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
{
let n = self.dimension.ndim();
if n == 0 {
panic!("Unreachable: ndim == 0 is contiguous")
}
if n == 1 || self.layout_tendency >= 0 {
self.for_each_core_strided_c(acc, function)
} else {
self.for_each_core_strided_f(acc, function)
}
}
fn for_each_core_strided_c<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
{
let n = self.dimension.ndim();
let unroll_axis = n - 1;
let inner_len = self.dimension[unroll_axis];
self.dimension[unroll_axis] = 1;
let mut index_ = self.dimension.first_index();
let inner_strides = self.parts.stride_of(unroll_axis);
while let Some(index) = index_ {
unsafe {
let ptr = self.parts.uget_ptr(&index);
acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
}
index_ = self.dimension.next_for(index);
}
FoldWhile::Continue(acc)
}
fn for_each_core_strided_f<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
where
F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim = D>,
{
let unroll_axis = 0;
let inner_len = self.dimension[unroll_axis];
self.dimension[unroll_axis] = 1;
let index_ = self.dimension.first_index();
let inner_strides = self.parts.stride_of(unroll_axis);
if let Some(mut index) = index_ {
loop {
unsafe {
let ptr = self.parts.uget_ptr(&index);
acc = fold_while![self.inner(acc, ptr, inner_strides, inner_len, &mut function)];
}
if !self.dimension.next_for_f(&mut index) {
break;
}
}
}
FoldWhile::Continue(acc)
}
#[cfg(feature = "rayon")]
pub(crate) fn uninitialized_for_current_layout<T>(&self) -> Array<MaybeUninit<T>, D>
{
let is_f = self.prefer_f();
Array::uninit(self.dimension.clone().set_f(is_f))
}
}
impl<D, P1, P2> Zip<(P1, P2), D>
where
D: Dimension,
P1: NdProducer<Dim = D>,
P1: NdProducer<Dim = D>,
{
#[inline]
pub(crate) fn debug_assert_c_order(self) -> Self
{
debug_assert!(self.layout.is(Layout::CORDER) || self.layout_tendency >= 0 ||
self.dimension.slice().iter().filter(|&&d| d > 1).count() <= 1,
"Assertion failed: traversal is not c-order or 1D for \
layout {:?}, tendency {}, dimension {:?}",
self.layout, self.layout_tendency, self.dimension);
self
}
}
trait OffsetTuple
{
type Args;
unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self;
}
impl<T> OffsetTuple for *mut T
{
type Args = isize;
unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self
{
self.offset(index as isize * stride)
}
}
macro_rules! offset_impl {
($([$($param:ident)*][ $($q:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
type Args = ($($param::Stride,)*);
unsafe fn stride_offset(self, stride: Self::Args, index: usize) -> Self {
let ($($param, )*) = self;
let ($($q, )*) = stride;
($(Offset::stride_offset($param, $q, index),)*)
}
}
)+
};
}
offset_impl! {
[A ][ a],
[A B][ a b],
[A B C][ a b c],
[A B C D][ a b c d],
[A B C D E][ a b c d e],
[A B C D E F][ a b c d e f],
}
macro_rules! zipt_impl {
($([$($p:ident)*][ $($q:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
type Item = ($($p::Item, )*);
type Ptr = ($($p::Ptr, )*);
type Dim = Dim;
type Stride = ($($p::Stride,)* );
fn stride_of(&self, index: usize) -> Self::Stride {
let ($(ref $p,)*) = *self;
($($p.stride_of(Axis(index)), )*)
}
fn contiguous_stride(&self) -> Self::Stride {
let ($(ref $p,)*) = *self;
($($p.contiguous_stride(), )*)
}
fn as_ptr(&self) -> Self::Ptr {
let ($(ref $p,)*) = *self;
($($p.as_ptr(), )*)
}
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
let ($(ref $q ,)*) = *self;
let ($($p,)*) = ptr;
($($q.as_ref($p),)*)
}
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
let ($(ref $p,)*) = *self;
($($p.uget_ptr(i), )*)
}
fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
let ($($p,)*) = self;
let ($($p,)*) = (
$($p.split_at(axis, index), )*
);
(
($($p.0,)*),
($($p.1,)*)
)
}
}
)+
};
}
zipt_impl! {
[A ][ a],
[A B][ a b],
[A B C][ a b c],
[A B C D][ a b c d],
[A B C D E][ a b c d e],
[A B C D E F][ a b c d e f],
}
macro_rules! map_impl {
($([$notlast:ident $($p:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<D, $($p),*> Zip<($($p,)*), D>
where D: Dimension,
$($p: NdProducer<Dim=D> ,)*
{
pub fn for_each<F>(mut self, mut function: F)
where F: FnMut($($p::Item),*)
{
self.for_each_core((), move |(), args| {
let ($($p,)*) = args;
FoldWhile::Continue(function($($p),*))
});
}
pub fn fold<F, Acc>(mut self, acc: Acc, mut function: F) -> Acc
where
F: FnMut(Acc, $($p::Item),*) -> Acc,
{
self.for_each_core(acc, move |acc, args| {
let ($($p,)*) = args;
FoldWhile::Continue(function(acc, $($p),*))
}).into_inner()
}
pub fn fold_while<F, Acc>(mut self, acc: Acc, mut function: F)
-> FoldWhile<Acc>
where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
{
self.for_each_core(acc, move |acc, args| {
let ($($p,)*) = args;
function(acc, $($p),*)
})
}
pub fn all<F>(mut self, mut predicate: F) -> bool
where F: FnMut($($p::Item),*) -> bool
{
!self.for_each_core((), move |_, args| {
let ($($p,)*) = args;
if predicate($($p),*) {
FoldWhile::Continue(())
} else {
FoldWhile::Done(())
}
}).is_done()
}
pub fn any<F>(mut self, mut predicate: F) -> bool
where F: FnMut($($p::Item),*) -> bool
{
self.for_each_core((), move |_, args| {
let ($($p,)*) = args;
if predicate($($p),*) {
FoldWhile::Done(())
} else {
FoldWhile::Continue(())
}
}).is_done()
}
expand_if!(@bool [$notlast]
#[track_caller]
pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
where P: IntoNdProducer<Dim=D>,
{
let part = p.into_producer();
zip_dimension_check(&self.dimension, &part);
self.build_and(part)
}
#[allow(unused)]
pub(crate) unsafe fn and_unchecked<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
where P: IntoNdProducer<Dim=D>,
{
#[cfg(debug_assertions)]
{
self.and(p)
}
#[cfg(not(debug_assertions))]
{
self.build_and(p.into_producer())
}
}
#[track_caller]
pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
-> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>, Item=&'a Elem>,
D2: Dimension,
{
let part = p.into_producer().broadcast_unwrap(self.dimension.clone());
self.build_and(part)
}
fn build_and<P>(self, part: P) -> Zip<($($p,)* P, ), D>
where P: NdProducer<Dim=D>,
{
let part_layout = part.layout();
let ($($p,)*) = self.parts;
Zip {
parts: ($($p,)* part, ),
layout: self.layout.intersect(part_layout),
dimension: self.dimension,
layout_tendency: self.layout_tendency + part_layout.tendency(),
}
}
pub fn map_collect<R>(self, f: impl FnMut($($p::Item,)* ) -> R) -> Array<R, D> {
self.map_collect_owned(f)
}
pub(crate) fn map_collect_owned<S, R>(self, f: impl FnMut($($p::Item,)* ) -> R)
-> ArrayBase<S, D>
where
S: DataOwned<Elem = R>,
{
let shape = self.dimension.clone().set_f(self.prefer_f());
let output = <ArrayBase<S, D>>::build_uninit(shape, |output| {
unsafe {
let output_view = output.into_raw_view_mut().cast::<R>();
self.and(output_view)
.collect_with_partial(f)
.release_ownership();
}
});
unsafe {
output.assume_init()
}
}
pub fn map_assign_into<R, Q>(self, into: Q, mut f: impl FnMut($($p::Item,)* ) -> R)
where Q: IntoNdProducer<Dim=D>,
Q::Item: AssignElem<R>
{
self.and(into)
.for_each(move |$($p, )* output_| {
output_.assign_elem(f($($p ),*));
});
}
);
pub fn split(self) -> (Self, Self) {
debug_assert_ne!(self.size(), 0, "Attempt to split empty zip");
debug_assert_ne!(self.size(), 1, "Attempt to split zip with 1 elem");
SplitPreference::split(self)
}
}
expand_if!(@bool [$notlast]
#[allow(non_snake_case)]
impl<D, PLast, R, $($p),*> Zip<($($p,)* PLast), D>
where D: Dimension,
$($p: NdProducer<Dim=D> ,)*
PLast: NdProducer<Dim = D, Item = *mut R, Ptr = *mut R, Stride = isize>,
{
pub(crate) unsafe fn collect_with_partial<F>(self, mut f: F) -> Partial<R>
where F: FnMut($($p::Item,)* ) -> R
{
let (.., ref output) = &self.parts;
if cfg!(debug_assertions) {
let out_layout = output.layout();
assert!(out_layout.is(Layout::CORDER | Layout::FORDER));
assert!(
(self.layout_tendency <= 0 && out_layout.tendency() <= 0) ||
(self.layout_tendency >= 0 && out_layout.tendency() >= 0),
"layout tendency violation for self layout {:?}, output layout {:?},\
output shape {:?}",
self.layout, out_layout, output.raw_dim());
}
let mut partial = Partial::new(output.as_ptr());
let partial_len = &mut partial.len;
self.for_each(move |$($p,)* output_elem: *mut R| {
output_elem.write(f($($p),*));
if std::mem::needs_drop::<R>() {
*partial_len += 1;
}
});
partial
}
}
);
impl<D, $($p),*> SplitPreference for Zip<($($p,)*), D>
where D: Dimension,
$($p: NdProducer<Dim=D> ,)*
{
fn can_split(&self) -> bool { self.size() > 1 }
fn split_preference(&self) -> (Axis, usize) {
let axis = self.max_stride_axis();
let index = self.len_of(axis) / 2;
(axis, index)
}
}
impl<D, $($p),*> SplitAt for Zip<($($p,)*), D>
where D: Dimension,
$($p: NdProducer<Dim=D> ,)*
{
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
let (p1, p2) = self.parts.split_at(axis, index);
let (d1, d2) = self.dimension.split_at(axis, index);
(Zip {
dimension: d1,
layout: self.layout,
parts: p1,
layout_tendency: self.layout_tendency,
},
Zip {
dimension: d2,
layout: self.layout,
parts: p2,
layout_tendency: self.layout_tendency,
})
}
}
)+
};
}
map_impl! {
[true P1],
[true P1 P2],
[true P1 P2 P3],
[true P1 P2 P3 P4],
[true P1 P2 P3 P4 P5],
[false P1 P2 P3 P4 P5 P6],
}
#[derive(Debug, Copy, Clone)]
pub enum FoldWhile<T>
{
Continue(T),
Done(T),
}
impl<T> FoldWhile<T>
{
pub fn into_inner(self) -> T
{
match self {
FoldWhile::Continue(x) | FoldWhile::Done(x) => x,
}
}
pub fn is_done(&self) -> bool
{
match *self {
FoldWhile::Continue(_) => false,
FoldWhile::Done(_) => true,
}
}
}