mod zipmacro;
use imp_prelude::*;
use IntoDimension;
use NdIndex;
use Layout;
use layout::{CORDER, FORDER};
use layout::LayoutPriv;
macro_rules! fold_while {
($e:expr) => {
match $e {
FoldWhile::Continue(x) => x,
x => return x,
}
}
}
trait LayoutImpl {
fn layout_impl(&self) -> Layout;
}
trait Broadcast<E>
where E: IntoDimension,
{
type Output: NdProducer<Dim=E::Dim>;
fn broadcast_unwrap(self, shape: E) -> Self::Output;
private_decl!{}
}
impl<S, D> LayoutImpl for ArrayBase<S, D>
where S: Data,
D: Dimension,
{
fn layout_impl(&self) -> Layout {
Layout::new(if self.is_standard_layout() {
if self.ndim() <= 1 {
FORDER | CORDER
} else {
CORDER
}
} else if self.ndim() > 1 && self.t().is_standard_layout() {
FORDER
} else {
0
})
}
}
impl<'a, L> LayoutImpl for &'a L where L: LayoutImpl
{
fn layout_impl(&self) -> Layout {
(**self).layout_impl()
}
}
impl<'a, L> LayoutImpl for &'a mut L where L: LayoutImpl
{
fn layout_impl(&self) -> Layout {
(**self).layout_impl()
}
}
impl<'a, A, D, E> Broadcast<E> for ArrayView<'a, A, D>
where E: IntoDimension,
D: Dimension,
{
type Output = ArrayView<'a, A, E::Dim>;
fn broadcast_unwrap(self, shape: E) -> Self::Output {
let res: ArrayView<A, E::Dim> = (&self).broadcast_unwrap(shape.into_dimension());
unsafe {
ArrayView::new_(res.ptr, res.dim, res.strides)
}
}
private_impl!{}
}
trait Splittable : Sized {
fn split_at(self, Axis, Ix) -> (Self, Self);
}
impl<D> Splittable for D
where D: Dimension,
{
fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
let mut d1 = self;
let mut d2 = d1.clone();
let i = axis.index();
let len = d1[i];
d1[i] = index;
d2[i] = len - index;
(d1, d2)
}
}
pub trait IntoNdProducer {
type Dim: Dimension;
type Output: NdProducer<Dim=Self::Dim>;
fn into_producer(self) -> Self::Output;
}
impl<P> IntoNdProducer for P where P: NdProducer {
type Dim = P::Dim;
type Output = Self;
fn into_producer(self) -> Self::Output { self }
}
pub trait NdProducer {
type Item;
#[doc(hidden)]
type Elem;
type Dim: Dimension;
#[doc(hidden)]
fn layout(&self) -> Layout;
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim;
#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.raw_dim() == *dim
}
#[doc(hidden)]
fn as_ptr(&self) -> *mut Self::Elem;
#[doc(hidden)]
unsafe fn as_ref(&self, *mut Self::Elem) -> Self::Item;
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut Self::Elem;
#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> isize;
#[doc(hidden)]
#[inline(always)]
fn contiguous_stride(&self) -> isize {
1
}
#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) where Self: Sized;
private_decl!{}
}
trait ZippableTuple : Sized {
type Item;
type Ptr: OffsetTuple<Args=Self::Stride> + Copy;
type Dim: Dimension;
type Stride: Copy;
fn as_ptr(&self) -> Self::Ptr;
unsafe fn as_ref(&self, Self::Ptr) -> Self::Item;
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr;
fn stride_of(&self, index: usize) -> Self::Stride;
fn contiguous_stride(&self) -> Self::Stride;
fn split_at(self, axis: Axis, index: usize) -> (Self, Self);
}
impl<'a, A: 'a, S, D> IntoNdProducer for &'a ArrayBase<S, D>
where D: Dimension,
S: Data<Elem=A>,
{
type Dim = D;
type Output = ArrayView<'a, A, D>;
fn into_producer(self) -> Self::Output {
self.view()
}
}
impl<'a, A: 'a, S, D> IntoNdProducer for &'a mut ArrayBase<S, D>
where D: Dimension,
S: DataMut<Elem=A>,
{
type Dim = D;
type Output = ArrayViewMut<'a, A, D>;
fn into_producer(self) -> Self::Output {
self.view_mut()
}
}
impl<'a, A: 'a> IntoNdProducer for &'a [A] {
type Dim = Ix1;
type Output = ArrayView1<'a, A>;
fn into_producer(self) -> Self::Output {
<_>::from(self)
}
}
impl<'a, A: 'a> IntoNdProducer for &'a mut [A] {
type Dim = Ix1;
type Output = ArrayViewMut1<'a, A>;
fn into_producer(self) -> Self::Output {
<_>::from(self)
}
}
impl<'a, A: 'a> IntoNdProducer for &'a Vec<A> {
type Dim = Ix1;
type Output = ArrayView1<'a, A>;
fn into_producer(self) -> Self::Output {
<_>::from(self)
}
}
impl<'a, A: 'a> IntoNdProducer for &'a mut Vec<A> {
type Dim = Ix1;
type Output = ArrayViewMut1<'a, A>;
fn into_producer(self) -> Self::Output {
<_>::from(self)
}
}
impl<'a, A, D> NdProducer for ArrayView<'a, A, D>
where D: Dimension,
{
type Item = &'a A;
type Dim = D;
type Elem = A;
private_impl!{}
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim {
self.raw_dim()
}
#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.dim.equal(dim)
}
#[doc(hidden)]
fn as_ptr(&self) -> *mut A {
self.as_ptr() as _
}
#[doc(hidden)]
fn layout(&self) -> Layout {
LayoutImpl::layout_impl(self)
}
#[doc(hidden)]
unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item {
&*ptr
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
self.ptr.offset(i.index_unchecked(&self.strides))
}
#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> isize {
self.strides()[axis.index()]
}
#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
self.split_at(axis, index)
}
}
impl<'a, A, D> NdProducer for ArrayViewMut<'a, A, D>
where D: Dimension,
{
type Item = &'a mut A;
type Dim = D;
type Elem = A;
private_impl!{}
#[doc(hidden)]
fn raw_dim(&self) -> Self::Dim {
self.raw_dim()
}
#[doc(hidden)]
fn equal_dim(&self, dim: &Self::Dim) -> bool {
self.dim.equal(dim)
}
#[doc(hidden)]
fn as_ptr(&self) -> *mut A {
self.as_ptr() as _
}
#[doc(hidden)]
fn layout(&self) -> Layout {
LayoutImpl::layout_impl(self)
}
#[doc(hidden)]
unsafe fn as_ref(&self, ptr: *mut A) -> Self::Item {
&mut *ptr
}
#[doc(hidden)]
unsafe fn uget_ptr(&self, i: &Self::Dim) -> *mut A {
self.ptr.offset(i.index_unchecked(&self.strides))
}
#[doc(hidden)]
fn stride_of(&self, axis: Axis) -> isize {
self.strides()[axis.index()]
}
#[doc(hidden)]
fn split_at(self, axis: Axis, index: usize) -> (Self, Self) {
self.split_at(axis, index)
}
}
#[derive(Debug, Clone)]
pub struct Zip<Parts, D> {
parts: Parts,
dimension: D,
layout: Layout,
}
impl<P, D> Zip<(P, ), D>
where D: Dimension,
P: NdProducer<Dim=D>
{
pub fn from<IP>(p: IP) -> Self
where IP: IntoNdProducer<Dim=D, Output=P>
{
let array = p.into_producer();
let dim = array.raw_dim();
Zip {
dimension: dim,
layout: array.layout(),
parts: (array, ),
}
}
}
impl<Parts, D> Zip<Parts, D>
where D: Dimension,
{
fn check<P>(&self, part: &P)
where P: NdProducer<Dim=D>
{
ndassert!(part.equal_dim(&self.dimension),
"Zip: Producer dimension mismatch, expected: {:?}, got: {:?}",
self.dimension, part.raw_dim());
}
pub fn size(&self) -> usize {
self.dimension.size()
}
fn len_of(&self, axis: Axis) -> usize {
self.dimension[axis.index()]
}
fn max_stride_axis(&self) -> Axis {
let i = match self.layout.flag() {
CORDER => self.dimension.slice().iter()
.position(|&len| len > 1).unwrap_or(0),
FORDER => self.dimension.slice().iter()
.rposition(|&len| len > 1).unwrap_or(self.dimension.ndim() - 1),
_ => 0,
};
Axis(i)
}
}
impl<P, D> Zip<P, D>
where D: Dimension,
{
fn apply_core<F, Acc>(&mut self, acc: Acc, function: F) -> FoldWhile<Acc>
where F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim=D>,
{
if self.layout.is(CORDER | FORDER) {
self.apply_core_contiguous(acc, function)
} else {
self.apply_core_strided(acc, function)
}
}
fn apply_core_contiguous<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
where F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim=D>,
{
debug_assert!(self.layout.is(CORDER | FORDER));
let size = self.dimension.size();
let ptrs = self.parts.as_ptr();
let inner_strides = self.parts.contiguous_stride();
for i in 0..size {
unsafe {
let ptr_i = ptrs.stride_offset(i, inner_strides);
acc = fold_while![function(acc, self.parts.as_ref(ptr_i))];
}
}
FoldWhile::Continue(acc)
}
fn apply_core_strided<F, Acc>(&mut self, mut acc: Acc, mut function: F) -> FoldWhile<Acc>
where F: FnMut(Acc, P::Item) -> FoldWhile<Acc>,
P: ZippableTuple<Dim=D>,
{
let n = self.dimension.ndim();
if n == 0 {
panic!("Unreachable: ndim == 0 is contiguous")
}
let unroll_axis = n - 1;
let inner_len = self.dimension[unroll_axis];
self.dimension[unroll_axis] = 1;
let mut index_ = self.dimension.first_index();
let inner_strides = self.parts.stride_of(unroll_axis);
while let Some(index) = index_ {
unsafe {
let ptr = self.parts.uget_ptr(&index);
for i in 0..inner_len {
let p = ptr.stride_offset(i, inner_strides);
acc = fold_while!(function(acc, self.parts.as_ref(p)));
}
}
index_ = self.dimension.next_for(index);
}
self.dimension[unroll_axis] = inner_len;
FoldWhile::Continue(acc)
}
}
trait Offset : Copy {
unsafe fn offset(self, off: isize) -> Self;
unsafe fn stride_offset(self, index: usize, stride: isize) -> Self {
self.offset(index as isize * stride)
}
}
impl<T> Offset for *mut T {
unsafe fn offset(self, off: isize) -> Self {
self.offset(off)
}
}
trait OffsetTuple {
type Args;
unsafe fn offset(self, off: isize) -> Self;
unsafe fn stride_offset(self, index: usize, stride: Self::Args) -> Self;
}
impl<T> OffsetTuple for *mut T {
type Args = isize;
unsafe fn offset(self, off: isize) -> Self {
self.offset(off)
}
unsafe fn stride_offset(self, index: usize, stride: isize) -> Self {
self.offset(index as isize * stride)
}
}
macro_rules! sub {
($i:ident [$($x:tt)*]) => { $($x)* };
}
macro_rules! offset_impl {
($([$($param:ident)*][ $($q:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<$($param: Offset),*> OffsetTuple for ($($param, )*) {
type Args = ($(sub!($param [isize]),)*);
unsafe fn offset(self, off: isize) -> Self {
let ($($param, )*) = self;
($($param . offset(off),)*)
}
unsafe fn stride_offset(self, index: usize, stride: Self::Args) -> Self {
let ($($param, )*) = self;
let ($($q, )*) = stride;
($(Offset::stride_offset($param, index, $q),)*)
}
}
)+
}
}
offset_impl!{
[A ][ a],
[A B][ a b],
[A B C][ a b c],
[A B C D][ a b c d],
[A B C D E][ a b c d e],
[A B C D E F][ a b c d e f],
}
macro_rules! zipt_impl {
($([$($p:ident)*][ $($q:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<Dim: Dimension, $($p: NdProducer<Dim=Dim>),*> ZippableTuple for ($($p, )*) {
type Item = ($($p::Item, )*);
type Ptr = ($(*mut $p::Elem, )*);
type Dim = Dim;
type Stride = ($(sub!($p [isize]),)* );
fn stride_of(&self, index: usize) -> Self::Stride {
let ($(ref $p,)*) = *self;
($($p.stride_of(Axis(index)), )*)
}
fn contiguous_stride(&self) -> Self::Stride {
let ($(ref $p,)*) = *self;
($($p.contiguous_stride(), )*)
}
fn as_ptr(&self) -> Self::Ptr {
let ($(ref $p,)*) = *self;
($($p.as_ptr(), )*)
}
unsafe fn as_ref(&self, ptr: Self::Ptr) -> Self::Item {
let ($(ref $q ,)*) = *self;
let ($($p,)*) = ptr;
($($q.as_ref($p),)*)
}
unsafe fn uget_ptr(&self, i: &Self::Dim) -> Self::Ptr {
let ($(ref $p,)*) = *self;
($($p.uget_ptr(i), )*)
}
fn split_at(self, axis: Axis, index: Ix) -> (Self, Self) {
let ($($p,)*) = self;
let ($($p,)*) = (
$($p.split_at(axis, index), )*
);
(
($($p.0,)*),
($($p.1,)*)
)
}
}
)+
}
}
zipt_impl!{
[A ][ a],
[A B][ a b],
[A B C][ a b c],
[A B C D][ a b c d],
[A B C D E][ a b c d e],
[A B C D E F][ a b c d e f],
}
macro_rules! map_impl {
($([$notlast:ident $($p:ident)*],)+) => {
$(
#[allow(non_snake_case)]
impl<D: Dimension, $($p: NdProducer<Dim=D>),*> Zip<($($p,)*), D> {
pub fn apply<F>(&mut self, mut function: F)
where F: FnMut($($p::Item),*)
{
self.apply_core((), move |(), args| {
let ($($p,)*) = args;
FoldWhile::Continue(function($($p),*))
});
}
pub fn fold_while<F, Acc>(&mut self, acc: Acc, mut function: F)
-> FoldWhile<Acc>
where F: FnMut(Acc, $($p::Item),*) -> FoldWhile<Acc>
{
self.apply_core(acc, move |acc, args| {
let ($($p,)*) = args;
function(acc, $($p),*)
})
}
expand_if!(@bool [$notlast]
pub fn and<P>(self, p: P) -> Zip<($($p,)* P::Output, ), D>
where P: IntoNdProducer<Dim=D>,
{
let array = p.into_producer();
self.check(&array);
let part_layout = array.layout();
let ($($p,)*) = self.parts;
Zip {
parts: ($($p,)* array, ),
layout: self.layout.and(part_layout),
dimension: self.dimension,
}
}
pub fn and_broadcast<'a, P, D2, Elem>(self, p: P)
-> Zip<($($p,)* ArrayView<'a, Elem, D>, ), D>
where P: IntoNdProducer<Dim=D2, Output=ArrayView<'a, Elem, D2>>,
D2: Dimension,
{
let array = p.into_producer().broadcast_unwrap(self.dimension.clone());
let part_layout = array.layout();
let ($($p,)*) = self.parts;
Zip {
parts: ($($p,)* array, ),
layout: self.layout.and(part_layout),
dimension: self.dimension,
}
}
);
pub fn split(self) -> (Self, Self) {
let axis = self.max_stride_axis();
let index = self.len_of(axis) / 2;
let (p1, p2) = self.parts.split_at(axis, index);
let (d1, d2) = self.dimension.split_at(axis, index);
(Zip {
dimension: d1,
layout: self.layout,
parts: p1,
},
Zip {
dimension: d2,
layout: self.layout,
parts: p2,
})
}
}
)+
}
}
map_impl!{
[true P1],
[true P1 P2],
[true P1 P2 P3],
[true P1 P2 P3 P4],
[true P1 P2 P3 P4 P5],
[false P1 P2 P3 P4 P5 P6],
}
pub enum FoldWhile<T> {
Continue(T),
Done(T),
}
impl<T> FoldWhile<T> {
pub fn into_inner(self) -> T {
match self {
FoldWhile::Continue(x) | FoldWhile::Done(x) => x
}
}
}