mdarray 0.2.0

Multidimensional array for Rust
Documentation
use crate::layout::{Layout, StridedLayout};
use crate::order::Order;
use crate::view::StridedView;
use std::iter::FusedIterator;
use std::marker::PhantomData;

macro_rules! impl_iter {
    ($name:tt, $as_ptr:tt, $raw_mut:tt, {$($mut:tt)?}) => {
        pub struct $name<'a, T, const N: usize, const M: usize, O: Order> {
            layout: &'a StridedLayout<N, M, O>,
            start: *$raw_mut T,
            end: *$raw_mut T,
            indices: [usize; M],
            inner_size: usize,
            _marker: PhantomData<&'a $($mut)? T>,
        }

        impl<'a, T, const N: usize, const M: usize, O: Order> $name<'a, T, N, M, O> {
            pub fn new(view: &'a $($mut)? StridedView<T, N, M, O>) -> Self {
                let inner_size = O::select(
                    view.shape()[..N - M].iter().product(),
                    view.shape()[M..].iter().product(),
                );

                Self {
                    layout: view.layout(),
                    start: view.$as_ptr(),
                    end: unsafe { view.$as_ptr().add(inner_size) },
                    indices: [0; M],
                    inner_size,
                    _marker: PhantomData,
                }
            }

            fn outer_size(&self, dim: usize) -> usize {
                O::select(
                    self.layout.shape()[N - M + dim],
                    self.layout.shape()[M - 1 - dim],
                )
            }

            fn outer_stride(&self, dim: usize) -> isize {
                O::select(
                    self.layout.strides()[dim],
                    self.layout.strides()[M - 1 - dim],
                )
            }
        }

        impl<'a, T, const N: usize, const M: usize, O: Order> ExactSizeIterator
            for $name<'a, T, N, M, O>
        {
        }

        impl<'a, T, const N: usize, const M: usize, O: Order> FusedIterator
            for $name<'a, T, N, M, O>
        {
        }

        impl<'a, T, const N: usize, const M: usize, O: Order> Iterator for $name<'a, T, N, M, O> {
            type Item = &'a $($mut)? T;

            #[inline(always)]
            fn next(&mut self) -> Option<Self::Item> {
                if self.start == self.end {
                    None
                } else {
                    unsafe {
                        let current = self.start;

                        self.start = self.start.offset(1);

                        if M > 0 && (M == N || self.start == self.end) {
                            self.start = self.start.sub(self.inner_size);
                            self.indices[0] += 1;

                            if self.indices[0] == self.outer_size(0) {
                                self.start = self.start.offset(
                                    (1 - self.outer_size(0) as isize)
                                        * self.outer_stride(0),
                                );
                                self.end = self.start;
                                self.indices[0] = 0;

                                for i in 1..M {
                                    self.indices[i] += 1;

                                    if self.indices[i] == self.outer_size(i) {
                                        self.start = self.start.offset(
                                            (1 - self.outer_size(i) as isize)
                                                * self.outer_stride(i),
                                        );
                                        self.end = self.start;
                                        self.indices[i] = 0;
                                    } else {
                                        self.start = self.start.offset(self.outer_stride(i));
                                        self.end = self.start.add(self.inner_size);
                                        break;
                                    }
                                }
                            } else {
                                self.start = self.start.offset(self.outer_stride(0));
                                self.end = self.start.add(self.inner_size);
                            }
                        }

                        Some(&$($mut)? *current)
                    }
                }
            }

            fn size_hint(&self) -> (usize, Option<usize>) {
                if self.start == self.end {
                    (0, Some(0))
                } else {
                    let mut len = unsafe { self.end.offset_from(self.start) as usize };
                    let mut prod = self.inner_size;

                    for i in 0..M {
                        len += (self.outer_size(i) - self.indices[i] - 1) * prod;
                        prod *= self.outer_size(i);
                    }

                    (len, Some(len))
                }
            }
        }
    };
}

impl_iter!(Iter, as_ptr, const, {});
impl_iter!(IterMut, as_mut_ptr, mut, {mut});