hpt-iterator 0.1.2

An internal library implements iterator for hpt
Documentation
use std::sync::Arc;

use hpt_common::{shape::shape::Shape, strides::strides::Strides};
use rayon::iter::{
    plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
    ParallelIterator,
};

use crate::iterator_traits::IterGetSetSimd;

pub(crate) struct WithSimd<I, F> {
    pub(crate) base: I,
    pub(crate) vec_op: F,
}

impl<I, F> UnindexedProducer for WithSimd<I, F>
where
    I: UnindexedProducer + IterGetSetSimd + ParallelIterator,
    F: Fn(<I as IterGetSetSimd>::SimdItem) + Sync + Send + Copy,
{
    type Item = <I as IterGetSetSimd>::Item;
    fn split(self) -> (Self, Option<Self>) {
        let (a, b) = self.base.split();
        (
            WithSimd {
                base: a,
                vec_op: self.vec_op,
            },
            b.map(|x| WithSimd {
                base: x,
                vec_op: self.vec_op,
            }),
        )
    }

    fn fold_with<FOLD>(mut self, mut folder: FOLD) -> FOLD
    where
        FOLD: Folder<Self::Item>,
    {
        let outer_loop_size = self.outer_loop_size();
        let inner_loop_size = self.inner_loop_size() + 1;
        let vec_op = self.vec_op;
        match (self.all_last_stride_one(), self.lanes()) {
            (true, Some(vec_size)) => {
                let remain = inner_loop_size % vec_size;
                let inner = inner_loop_size - remain;
                let n = inner / vec_size;
                let unroll = n % 4;
                if remain > 0 {
                    if unroll == 0 {
                        for _ in 0..outer_loop_size {
                            for idx in 0..n / 4 {
                                vec_op(self.inner_loop_next_simd(idx * 4));
                                vec_op(self.inner_loop_next_simd(idx * 4 + 1));
                                vec_op(self.inner_loop_next_simd(idx * 4 + 2));
                                vec_op(self.inner_loop_next_simd(idx * 4 + 3));
                            }
                            for idx in inner..inner_loop_size {
                                folder = folder.consume(self.inner_loop_next(idx));
                            }
                            self.next();
                        }
                    } else {
                        for _ in 0..outer_loop_size {
                            for idx in 0..n {
                                let val = self.inner_loop_next_simd(idx);
                                vec_op(val);
                            }
                            for idx in n * vec_size..inner_loop_size {
                                folder = folder.consume(self.inner_loop_next(idx));
                            }
                            self.next();
                        }
                    }
                } else {
                    if unroll == 0 {
                        for _ in 0..outer_loop_size {
                            for idx in 0..n / 4 {
                                vec_op(self.inner_loop_next_simd(idx * 4));
                                vec_op(self.inner_loop_next_simd(idx * 4 + 1));
                                vec_op(self.inner_loop_next_simd(idx * 4 + 2));
                                vec_op(self.inner_loop_next_simd(idx * 4 + 3));
                            }
                            self.next();
                        }
                    } else {
                        for _ in 0..outer_loop_size {
                            for idx in 0..n {
                                vec_op(self.inner_loop_next_simd(idx));
                            }
                            self.next();
                        }
                    }
                }
            }
            _ => {
                for _ in 0..outer_loop_size {
                    for idx in 0..inner_loop_size {
                        folder = folder.consume(self.inner_loop_next(idx));
                    }
                    self.next();
                }
            }
        }
        folder
    }
}

impl<I, F> ParallelIterator for WithSimd<I, F>
where
    I: UnindexedProducer + IterGetSetSimd + ParallelIterator,
    F: Fn(<I as IterGetSetSimd>::SimdItem) + Sync + Send + Copy,
    <I as IterGetSetSimd>::Item: Send,
{
    type Item = <I as IterGetSetSimd>::Item;

    fn drive_unindexed<C>(self, consumer: C) -> C::Result
    where
        C: UnindexedConsumer<Self::Item>,
    {
        bridge_unindexed(self, consumer)
    }
}

impl<I, F> IterGetSetSimd for WithSimd<I, F>
where
    I: UnindexedProducer + IterGetSetSimd + ParallelIterator,
    F: Fn(<I as IterGetSetSimd>::SimdItem) + Sync + Send + Copy,
{
    type Item = <I as IterGetSetSimd>::Item;

    type SimdItem = <I as IterGetSetSimd>::SimdItem;

    fn set_end_index(&mut self, _: usize) {
        panic!("single thread strided zip does not support set_intervals");
    }

    fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
        panic!("single thread strided zip does not support set_intervals");
    }

    fn set_strides(&mut self, last_stride: Strides) {
        self.base.set_strides(last_stride.clone());
    }

    fn set_shape(&mut self, shape: Shape) {
        self.base.set_shape(shape);
    }

    fn set_prg(&mut self, prg: Vec<i64>) {
        self.base.set_prg(prg);
    }

    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
        panic!("single thread strided zip does not support intervals");
    }

    fn strides(&self) -> &Strides {
        self.base.strides()
    }

    fn shape(&self) -> &Shape {
        self.base.shape()
    }

    fn layout(&self) -> &hpt_common::layout::layout::Layout {
        self.base.layout()
    }

    fn broadcast_set_strides(&mut self, shape: &Shape) {
        self.base.broadcast_set_strides(shape);
    }

    fn outer_loop_size(&self) -> usize {
        self.base.outer_loop_size()
    }

    fn inner_loop_size(&self) -> usize {
        self.base.inner_loop_size()
    }

    fn next(&mut self) {
        self.base.next();
    }
    fn next_simd(&mut self) {
        todo!()
    }
    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
        self.base.inner_loop_next(index)
    }
    fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
        self.base.inner_loop_next_simd(index)
    }
    fn all_last_stride_one(&self) -> bool {
        self.base.all_last_stride_one()
    }

    fn lanes(&self) -> Option<usize> {
        self.base.lanes()
    }
}