hpt_iterator/
par_strided_zip.rs

1use hpt_common::{shape::shape::Shape, strides::strides::Strides};
2use hpt_traits::tensor::CommonBounds;
3use par_strided_zip_simd::ParStridedZipSimd;
4use rayon::iter::{
5    plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
6    ParallelIterator,
7};
8use std::sync::Arc;
9
10use crate::{
11    iterator_traits::{IterGetSet, IterGetSetSimd, ParStridedIteratorZip, ShapeManipulator},
12    par_strided_map::ParStridedMap,
13};
14
15/// A module for parallel strided simd zip iterator.
16pub mod par_strided_zip_simd {
17    use std::sync::Arc;
18
19    use hpt_common::{shape::shape::Shape, strides::strides::Strides, utils::simd_ref::MutVec};
20    use hpt_traits::CommonBounds;
21    use rayon::iter::{
22        plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
23        ParallelIterator,
24    };
25
26    use crate::{
27        iterator_traits::{
28            IterGetSetSimd, ParStridedIteratorSimd, ParStridedIteratorSimdZip, ShapeManipulator,
29        },
30        par_strided_map::par_strided_map_simd::ParStridedMapSimd,
31    };
32
33    // A parallel SIMD-optimized zipped iterator combining two iterators over tensor elements.
34    ///
35    /// This struct allows for synchronized parallel iteration over two SIMD-optimized iterators,
36    #[derive(Clone)]
37    pub struct ParStridedZipSimd<'a, A: 'a, B: 'a> {
38        /// The first iterator to be zipped.
39        pub(crate) a: A,
40        /// The second iterator to be zipped.
41        pub(crate) b: B,
42        /// Phantom data to associate the lifetime `'a` with the struct.
43        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
44    }
45
46    impl<'a, A, B> IterGetSetSimd for ParStridedZipSimd<'a, A, B>
47    where
48        A: IterGetSetSimd,
49        B: IterGetSetSimd,
50    {
51        type Item = (<A as IterGetSetSimd>::Item, <B as IterGetSetSimd>::Item);
52
53        type SimdItem = (
54            <A as IterGetSetSimd>::SimdItem,
55            <B as IterGetSetSimd>::SimdItem,
56        );
57
58        fn set_end_index(&mut self, end_index: usize) {
59            self.a.set_end_index(end_index);
60            self.b.set_end_index(end_index);
61        }
62
63        fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
64            self.a.set_intervals(intervals.clone());
65            self.b.set_intervals(intervals);
66        }
67
68        fn set_strides(&mut self, last_stride: Strides) {
69            self.a.set_strides(last_stride.clone());
70            self.b.set_strides(last_stride);
71        }
72
73        fn set_shape(&mut self, shape: Shape) {
74            self.a.set_shape(shape.clone());
75            self.b.set_shape(shape);
76        }
77
78        fn set_prg(&mut self, prg: Vec<i64>) {
79            self.a.set_prg(prg.clone());
80            self.b.set_prg(prg);
81        }
82
83        fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
84            self.a.intervals()
85        }
86
87        fn strides(&self) -> &Strides {
88            self.a.strides()
89        }
90
91        fn shape(&self) -> &Shape {
92            self.a.shape()
93        }
94
95        fn layout(&self) -> &hpt_common::layout::layout::Layout {
96            self.a.layout()
97        }
98
99        fn broadcast_set_strides(&mut self, shape: &Shape) {
100            self.a.broadcast_set_strides(shape);
101            self.b.broadcast_set_strides(shape);
102        }
103
104        fn outer_loop_size(&self) -> usize {
105            self.a.outer_loop_size()
106        }
107
108        fn inner_loop_size(&self) -> usize {
109            self.a.inner_loop_size()
110        }
111
112        fn next(&mut self) {
113            self.a.next();
114            self.b.next();
115        }
116
117        fn next_simd(&mut self) {
118            self.a.next_simd();
119            self.b.next_simd();
120        }
121
122        fn inner_loop_next(&mut self, index: usize) -> Self::Item {
123            (self.a.inner_loop_next(index), self.b.inner_loop_next(index))
124        }
125
126        fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
127            (
128                self.a.inner_loop_next_simd(index),
129                self.b.inner_loop_next_simd(index),
130            )
131        }
132
133        fn all_last_stride_one(&self) -> bool {
134            self.a.all_last_stride_one() && self.b.all_last_stride_one()
135        }
136
137        fn lanes(&self) -> Option<usize> {
138            match (self.a.lanes(), self.b.lanes()) {
139                (Some(a), Some(b)) => {
140                    if a == b {
141                        Some(a)
142                    } else {
143                        None
144                    }
145                }
146                _ => None,
147            }
148        }
149    }
150    impl<'a, A, B> ParStridedZipSimd<'a, A, B>
151    where
152        A: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator,
153        B: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator,
154        <A as IterGetSetSimd>::Item: Send,
155        <B as IterGetSetSimd>::Item: Send,
156    {
157        /// Creates a new `ParStridedZipSimd` instance by zipping two SIMD-optimized iterators.
158        ///
159        /// # Arguments
160        ///
161        /// * `a` - The first iterator to zip.
162        /// * `b` - The second iterator to zip.
163        ///
164        /// # Returns
165        ///
166        /// A new `ParStridedZipSimd` instance that combines both iterators for synchronized parallel iteration.
167        pub fn new(a: A, b: B) -> Self {
168            ParStridedZipSimd {
169                a,
170                b,
171                phantom: std::marker::PhantomData,
172            }
173        }
174        /// Transforms the zipped iterators by applying provided functions to their items.
175        ///
176        /// This method allows for element-wise operations on the zipped iterators by applying `func` to each pair of items.
177        ///
178        /// # Arguments
179        ///
180        /// * `func` - A function that takes a mutable reference to a target and an item from the first iterator.
181        /// * `func2` - A function that takes a mutable reference to a SIMD vector and an item from the second iterator.
182        ///
183        /// # Returns
184        ///
185        /// A `ParStridedMapSimd` instance that applies the provided functions during iteration.
186        ///
187        pub fn strided_map_simd<F, F2, T>(
188            self,
189            func: F,
190            func2: F2,
191        ) -> ParStridedMapSimd<'a, Self, <Self as IterGetSetSimd>::Item, F, F2>
192        where
193            F: Fn((&mut T, <Self as IterGetSetSimd>::Item)) + Sync + Send + 'a,
194            F2: Fn((MutVec<'_, T::Vec>, <Self as IterGetSetSimd>::SimdItem)) + Sync + Send + 'a,
195            T: CommonBounds,
196            <A as IterGetSetSimd>::Item: Send,
197            <B as IterGetSetSimd>::Item: Send,
198            T::Vec: Send,
199            A: ShapeManipulator,
200            B: ShapeManipulator,
201        {
202            ParStridedMapSimd {
203                iter: self,
204                f: func,
205                f2: func2,
206                phantom: std::marker::PhantomData,
207            }
208        }
209    }
210
211    impl<'a, A, B> ParStridedIteratorSimdZip for ParStridedZipSimd<'a, A, B>
212    where
213        A: UnindexedProducer + ParallelIterator + IterGetSetSimd,
214        B: UnindexedProducer + ParallelIterator + IterGetSetSimd,
215    {
216    }
217    impl<'a, A, B> ParStridedIteratorSimd for ParStridedZipSimd<'a, A, B>
218    where
219        A: UnindexedProducer + ParallelIterator + IterGetSetSimd,
220        B: UnindexedProducer + ParallelIterator + IterGetSetSimd,
221        <A as IterGetSetSimd>::Item: Send,
222        <B as IterGetSetSimd>::Item: Send,
223    {
224    }
225
226    impl<'a, A, B> UnindexedProducer for ParStridedZipSimd<'a, A, B>
227    where
228        A: UnindexedProducer + ParallelIterator + IterGetSetSimd,
229        B: UnindexedProducer + ParallelIterator + IterGetSetSimd,
230    {
231        type Item = <Self as IterGetSetSimd>::Item;
232
233        fn split(self) -> (Self, Option<Self>) {
234            let (left_a, right_a) = self.a.split();
235            let (left_b, right_b) = self.b.split();
236            if right_a.is_none() {
237                (
238                    ParStridedZipSimd {
239                        a: left_a,
240                        b: left_b,
241                        phantom: std::marker::PhantomData,
242                    },
243                    None,
244                )
245            } else {
246                (
247                    ParStridedZipSimd {
248                        a: left_a,
249                        b: left_b,
250                        phantom: std::marker::PhantomData,
251                    },
252                    Some(ParStridedZipSimd {
253                        a: right_a.unwrap(),
254                        b: right_b.unwrap(),
255                        phantom: std::marker::PhantomData,
256                    }),
257                )
258            }
259        }
260
261        fn fold_with<F>(mut self, mut folder: F) -> F
262        where
263            F: Folder<Self::Item>,
264        {
265            let outer_loop_size = self.outer_loop_size();
266            let inner_loop_size = self.inner_loop_size() + 1;
267            for _ in 0..outer_loop_size {
268                for idx in 0..inner_loop_size {
269                    folder = folder.consume(self.inner_loop_next(idx));
270                }
271                self.next();
272            }
273            folder
274        }
275    }
276
277    impl<'a, A, B> ParallelIterator for ParStridedZipSimd<'a, A, B>
278    where
279        A: UnindexedProducer + ParallelIterator + IterGetSetSimd,
280        B: UnindexedProducer + ParallelIterator + IterGetSetSimd,
281        <A as IterGetSetSimd>::Item: Send,
282        <B as IterGetSetSimd>::Item: Send,
283    {
284        type Item = (<A as IterGetSetSimd>::Item, <B as IterGetSetSimd>::Item);
285
286        fn drive_unindexed<C>(self, consumer: C) -> C::Result
287        where
288            C: UnindexedConsumer<Self::Item>,
289        {
290            bridge_unindexed(self, consumer)
291        }
292    }
293}
294
295/// A parallel zipped iterator combining two iterators over tensor elements.
296///
297/// This struct allows for synchronized parallel iteration over two iterators
298///
299/// # Type Parameters
300///
301/// * `'a` - The lifetime associated with the iterators.
302/// * `A` - The type of the first iterator.
303/// * `B` - The type of the second iterator.
304#[derive(Clone)]
305pub struct ParStridedZip<'a, A: 'a, B: 'a> {
306    /// The first iterator to be zipped.
307    pub(crate) a: A,
308    /// The second iterator to be zipped.
309    pub(crate) b: B,
310    /// Phantom data to associate the lifetime `'a` with the struct.
311    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
312}
313
314impl<'a, A, B> IterGetSet for ParStridedZip<'a, A, B>
315where
316    A: IterGetSet,
317    B: IterGetSet,
318{
319    type Item = (<A as IterGetSet>::Item, <B as IterGetSet>::Item);
320
321    fn set_end_index(&mut self, end_index: usize) {
322        self.a.set_end_index(end_index);
323        self.b.set_end_index(end_index);
324    }
325
326    fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
327        self.a.set_intervals(intervals.clone());
328        self.b.set_intervals(intervals);
329    }
330
331    fn set_strides(&mut self, last_stride: Strides) {
332        self.a.set_strides(last_stride.clone());
333        self.b.set_strides(last_stride);
334    }
335
336    fn set_shape(&mut self, shape: Shape) {
337        self.a.set_shape(shape.clone());
338        self.b.set_shape(shape);
339    }
340
341    fn set_prg(&mut self, prg: Vec<i64>) {
342        self.a.set_prg(prg.clone());
343        self.b.set_prg(prg);
344    }
345
346    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
347        self.a.intervals()
348    }
349
350    fn strides(&self) -> &Strides {
351        self.a.strides()
352    }
353
354    fn shape(&self) -> &Shape {
355        self.a.shape()
356    }
357
358    fn layout(&self) -> &hpt_common::layout::layout::Layout {
359        self.a.layout()
360    }
361
362    fn broadcast_set_strides(&mut self, shape: &Shape) {
363        self.a.broadcast_set_strides(shape);
364        self.b.broadcast_set_strides(shape);
365    }
366
367    fn outer_loop_size(&self) -> usize {
368        self.a.outer_loop_size()
369    }
370
371    fn inner_loop_size(&self) -> usize {
372        self.a.inner_loop_size()
373    }
374
375    fn next(&mut self) {
376        self.a.next();
377        self.b.next();
378    }
379
380    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
381        (self.a.inner_loop_next(index), self.b.inner_loop_next(index))
382    }
383}
384
385impl<'a, A, B> ParStridedZip<'a, A, B>
386where
387    A: UnindexedProducer + 'a + IterGetSet + ParallelIterator,
388    B: UnindexedProducer + 'a + IterGetSet + ParallelIterator,
389    <A as IterGetSet>::Item: Send,
390    <B as IterGetSet>::Item: Send,
391{
392    /// Creates a new `ParStridedZip` instance by zipping two iterators.
393    ///
394    /// # Arguments
395    ///
396    /// * `a` - The first iterator to zip.
397    /// * `b` - The second iterator to zip.
398    ///
399    /// # Returns
400    ///
401    /// A new `ParStridedZip` instance that combines both iterators for synchronized parallel iteration.
402    pub fn new(a: A, b: B) -> Self {
403        ParStridedZip {
404            a,
405            b,
406            phantom: std::marker::PhantomData,
407        }
408    }
409    /// Transforms the zipped iterators by applying a provided function to their items.
410    ///
411    /// This method allows for element-wise operations on the zipped iterators by applying `func` to each item.
412    ///
413    /// # Arguments
414    ///
415    /// * `func` - A function that takes an item from the zipped iterator and returns a transformed value.
416    ///
417    /// # Returns
418    ///
419    /// A `ParStridedMap` instance that applies the provided function during iteration.
420    pub fn strided_map<F, T>(
421        self,
422        func: F,
423    ) -> ParStridedMap<'a, Self, <Self as IterGetSet>::Item, F>
424    where
425        F: Fn((&mut T, <Self as IterGetSet>::Item)) + Sync + Send,
426        T: CommonBounds,
427    {
428        ParStridedMap {
429            iter: self,
430            f: func,
431            phantom: std::marker::PhantomData,
432        }
433    }
434}
435
436impl<'a, A, B> UnindexedProducer for ParStridedZip<'a, A, B>
437where
438    A: UnindexedProducer + ParallelIterator + IterGetSet,
439    B: UnindexedProducer + ParallelIterator + IterGetSet,
440{
441    type Item = <Self as IterGetSet>::Item;
442
443    fn split(self) -> (Self, Option<Self>) {
444        let (left_a, right_a) = self.a.split();
445        let (left_b, right_b) = self.b.split();
446        if right_a.is_none() {
447            (
448                ParStridedZip {
449                    a: left_a,
450                    b: left_b,
451                    phantom: std::marker::PhantomData,
452                },
453                None,
454            )
455        } else {
456            (
457                ParStridedZip {
458                    a: left_a,
459                    b: left_b,
460                    phantom: std::marker::PhantomData,
461                },
462                Some(ParStridedZip {
463                    a: right_a.unwrap(),
464                    b: right_b.unwrap(),
465                    phantom: std::marker::PhantomData,
466                }),
467            )
468        }
469    }
470
471    fn fold_with<F>(mut self, mut folder: F) -> F
472    where
473        F: Folder<Self::Item>,
474    {
475        let outer_loop_size = self.outer_loop_size();
476        let inner_loop_size = self.inner_loop_size() + 1;
477        for _ in 0..outer_loop_size {
478            for idx in 0..inner_loop_size {
479                folder = folder.consume(self.inner_loop_next(idx));
480            }
481            self.next();
482        }
483        folder
484    }
485}
486
487impl<'a, A, B> ParallelIterator for ParStridedZip<'a, A, B>
488where
489    A: UnindexedProducer + ParallelIterator + IterGetSet,
490    B: UnindexedProducer + ParallelIterator + IterGetSet,
491    <A as IterGetSet>::Item: Send,
492    <B as IterGetSet>::Item: Send,
493{
494    type Item = (<A as IterGetSet>::Item, <B as IterGetSet>::Item);
495
496    fn drive_unindexed<C>(self, consumer: C) -> C::Result
497    where
498        C: UnindexedConsumer<Self::Item>,
499    {
500        bridge_unindexed(self, consumer)
501    }
502}
503
504impl<'a, A, B> ParStridedIteratorZip for ParStridedZip<'a, A, B>
505where
506    A: UnindexedProducer + ParallelIterator + IterGetSet,
507    B: UnindexedProducer + ParallelIterator + IterGetSet,
508    <A as IterGetSet>::Item: Send,
509    <B as IterGetSet>::Item: Send,
510{
511}
512
513impl<'a, A, B> ShapeManipulator for ParStridedZip<'a, A, B>
514where
515    A: UnindexedProducer + 'a + IterGetSet + ParallelIterator + ShapeManipulator,
516    B: UnindexedProducer + 'a + IterGetSet + ParallelIterator + ShapeManipulator,
517    <A as IterGetSet>::Item: Send,
518    <B as IterGetSet>::Item: Send,
519{
520    fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
521        let tmp: Shape = shape.into();
522        let a = self.a.reshape(tmp.clone());
523        let b = self.b.reshape(tmp);
524        ParStridedZip::new(a, b)
525    }
526
527    fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
528        let axes: hpt_common::axis::axis::Axis = axes.into();
529        let a = self.a.transpose(axes.clone());
530        let b = self.b.transpose(axes);
531        ParStridedZip::new(a, b)
532    }
533
534    fn expand<S: Into<Shape>>(self, shape: S) -> Self {
535        let tmp: Shape = shape.into();
536        let a = self.a.expand(tmp.clone());
537        let b = self.b.expand(tmp);
538        ParStridedZip::new(a, b)
539    }
540}
541
542impl<'a, A, B> ShapeManipulator for ParStridedZipSimd<'a, A, B>
543where
544    A: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator + ShapeManipulator,
545    B: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator + ShapeManipulator,
546    <A as IterGetSetSimd>::Item: Send,
547    <B as IterGetSetSimd>::Item: Send,
548{
549    fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
550        let tmp: Shape = shape.into();
551        let a = self.a.reshape(tmp.clone());
552        let b = self.b.reshape(tmp);
553        ParStridedZipSimd::new(a, b)
554    }
555
556    fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
557        let axes: hpt_common::axis::axis::Axis = axes.into();
558        let a = self.a.transpose(axes.clone());
559        let b = self.b.transpose(axes);
560        ParStridedZipSimd::new(a, b)
561    }
562
563    fn expand<S: Into<Shape>>(self, shape: S) -> Self {
564        let tmp: Shape = shape.into();
565        let a = self.a.expand(tmp.clone());
566        let b = self.b.expand(tmp);
567        ParStridedZipSimd::new(a, b)
568    }
569}