hpt_iterator/
par_strided_fold.rs

1use std::fmt::Display;
2
3use hpt_traits::CommonBounds;
4use rayon::iter::{
5    plumbing::{bridge_unindexed, UnindexedConsumer, UnindexedProducer},
6    ParallelIterator,
7};
8
9use crate::iterator_traits::IterGetSet;
10
11/// A parallel strided fold iterator over tensor elements.
12///
13/// This struct facilitates performing fold (reduction) operations on tensor elements in a parallel and strided manner.
14/// It leverages Rayon for concurrent execution, ensuring efficient traversal and aggregation of tensor data based on
15/// their strides.
16pub struct ParStridedFold<I, ID, F> {
17    pub(crate) iter: I,
18    pub(crate) identity: ID,
19    pub(crate) fold_op: F,
20}
21
22impl<I, ID, F> ParallelIterator for ParStridedFold<I, ID, F>
23where
24    I: ParallelIterator + UnindexedProducer + IterGetSet,
25    F: Fn(ID, <I as IterGetSet>::Item) -> ID + Sync + Send + Copy,
26    ID: CommonBounds,
27    <I as IterGetSet>::Item: Display,
28{
29    type Item = ID;
30
31    fn drive_unindexed<C>(self, consumer: C) -> C::Result
32    where
33        C: UnindexedConsumer<Self::Item>,
34    {
35        bridge_unindexed(self, consumer)
36    }
37}
38
39impl<I, ID, F> UnindexedProducer for ParStridedFold<I, ID, F>
40where
41    I: ParallelIterator + UnindexedProducer + IterGetSet,
42    F: Fn(ID, <I as IterGetSet>::Item) -> ID + Sync + Send + Copy,
43    ID: CommonBounds,
44    <I as IterGetSet>::Item: Display,
45{
46    type Item = ID;
47
48    fn split(self) -> (Self, Option<Self>) {
49        let (a, b) = self.iter.split();
50        (
51            ParStridedFold {
52                iter: a,
53                identity: self.identity,
54                fold_op: self.fold_op,
55            },
56            b.map(|b| ParStridedFold {
57                iter: b,
58                identity: self.identity,
59                fold_op: self.fold_op,
60            }),
61        )
62    }
63
64    fn fold_with<FD>(mut self, mut folder: FD) -> FD
65    where
66        FD: rayon::iter::plumbing::Folder<Self::Item>,
67    {
68        let init = self.identity;
69        let outer_loop_size = self.iter.outer_loop_size();
70        let inner_loop_size = self.iter.inner_loop_size() + 1; // parallel iterator will auto subtract 1
71        for _ in 0..outer_loop_size {
72            for i in 0..inner_loop_size {
73                let item = self.iter.inner_loop_next(i);
74                let val = (self.fold_op)(init, item);
75                folder = folder.consume(val);
76            }
77            self.iter.next();
78        }
79        folder
80    }
81}