hpt_iterator/
iterator_traits.rs

1use std::sync::Arc;
2
3use hpt_common::{
4    axis::axis::Axis,
5    layout::layout::Layout,
6    shape::shape::Shape,
7    shape::shape_utils::{mt_intervals, predict_broadcast_shape},
8    strides::strides::Strides,
9};
10use hpt_traits::CommonBounds;
11use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
12
13use crate::{
14    par_strided_zip::{par_strided_zip_simd::ParStridedZipSimd, ParStridedZip},
15    strided_map::StridedMap,
16    strided_zip::{strided_zip_simd::StridedZipSimd, StridedZip},
17    with_simd::WithSimd,
18};
19
20/// A trait for getting and setting values from an iterator.
21pub trait IterGetSet {
22    /// The type of the iterator's elements.
23    type Item;
24    /// set the end index of the iterator, this is used when rayon perform data splitting
25    fn set_end_index(&mut self, end_index: usize);
26    /// set the chunk intervals of the iterator, we chunk the outer loop
27    fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>);
28    /// set the strides for the iterator, we call this method normally when we do broadcasting
29    fn set_strides(&mut self, strides: Strides);
30    /// set the shape for the iterator, we call this method normally when we do broadcasting
31    fn set_shape(&mut self, shape: Shape);
32    /// set the loop progress for the iterator
33    fn set_prg(&mut self, prg: Vec<i64>);
34    /// get the intervals of the iterator
35    fn intervals(&self) -> &Arc<Vec<(usize, usize)>>;
36    /// get the strides of the iterator
37    fn strides(&self) -> &Strides;
38    /// get the shape of the iterator
39    fn shape(&self) -> &Shape;
40    /// get the layout of the iterator
41    fn layout(&self) -> &Layout;
42    /// set the strides for all the iterators
43    fn broadcast_set_strides(&mut self, shape: &Shape);
44    /// get the outer loop size
45    fn outer_loop_size(&self) -> usize;
46    /// get the inner loop size
47    fn inner_loop_size(&self) -> usize;
48    /// update the loop progress
49    fn next(&mut self);
50    /// get the next element of the inner loop
51    fn inner_loop_next(&mut self, index: usize) -> Self::Item;
52}
53
54/// A trait for getting and setting values from an simd iterator
55pub trait IterGetSetSimd {
56    /// The type of the iterator's elements.
57    type Item;
58    /// The type of the iterator's simd elements.
59    type SimdItem;
60    /// set the end index of the iterator, this is used when rayon perform data splitting
61    fn set_end_index(&mut self, end_index: usize);
62    /// set the chunk intervals of the iterator, we chunk the outer loop
63    fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>);
64    /// set the strides for the iterator, we call this method normally when we do broadcasting
65    fn set_strides(&mut self, last_stride: Strides);
66    /// set the shape for the iterator, we call this method normally when we do broadcasting
67    fn set_shape(&mut self, shape: Shape);
68    /// set the loop progress for the iterator
69    fn set_prg(&mut self, prg: Vec<i64>);
70    /// get the intervals of the iterator
71    fn intervals(&self) -> &Arc<Vec<(usize, usize)>>;
72    /// get the strides of the iterator
73    fn strides(&self) -> &Strides;
74    /// get the shape of the iterator
75    fn shape(&self) -> &Shape;
76    /// get the layout of the iterator
77    fn layout(&self) -> &Layout;
78    /// set the strides for all the iterators
79    fn broadcast_set_strides(&mut self, shape: &Shape);
80    /// get the outer loop size
81    fn outer_loop_size(&self) -> usize;
82    /// get the inner loop size
83    fn inner_loop_size(&self) -> usize;
84    /// update the loop progress, this is called when we don't do simd iteration
85    fn next(&mut self);
86    /// update the loop progress, this is called when we do simd iteration
87    fn next_simd(&mut self);
88    /// get the next element of the inner loop
89    fn inner_loop_next(&mut self, index: usize) -> Self::Item;
90    /// get the next vector of the inner loop, this is called when we do simd iteration
91    fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem;
92    /// check if all iterators' last stride is one, only when all iterators' last stride is one, we can do simd iteration
93    fn all_last_stride_one(&self) -> bool;
94    /// get the simd vector size, if any of the iterator returned different vector size, it will return None
95    fn lanes(&self) -> Option<usize>;
96}
97
98/// A trait for performing shape manipulation on an iterator.
99pub trait ShapeManipulator {
100    /// reshape the iterator, we can change the iteration behavior by changing the shape
101    fn reshape<S: Into<Shape>>(self, shape: S) -> Self;
102    /// transpose the iterator, we can change the iteration behavior by changing the axes
103    fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self;
104    /// expand the iterator, we can change the iteration behavior by changing the shape
105    fn expand<S: Into<Shape>>(self, shape: S) -> Self;
106}
107
108/// A trait for performing single thread iteration over an iterator.
109pub trait StridedIterator: IterGetSet
110where
111    Self: Sized,
112{
113    /// perform scalar iteration, this method is for single thread iterator
114    fn for_each<F>(mut self, func: F)
115    where
116        F: Fn(Self::Item),
117    {
118        let outer_loop_size = self.outer_loop_size();
119        let inner_loop_size = self.inner_loop_size(); // we don't need to add 1 as we didn't subtract shape by 1
120        self.set_prg(vec![0; self.shape().len()]);
121        for _ in 0..outer_loop_size {
122            for idx in 0..inner_loop_size {
123                func(self.inner_loop_next(idx));
124            }
125            self.next();
126        }
127    }
128    /// perform scalar iteration with init, this method is for single thread iterator
129    fn for_each_init<F, INIT, T>(mut self, init: INIT, func: F)
130    where
131        F: Fn(&mut T, Self::Item),
132        INIT: Fn() -> T,
133    {
134        let outer_loop_size = self.outer_loop_size();
135        let inner_loop_size = self.inner_loop_size();
136        self.set_prg(vec![0; self.shape().len()]);
137        let mut init = init();
138        for _ in 0..outer_loop_size {
139            for idx in 0..inner_loop_size {
140                func(&mut init, self.inner_loop_next(idx));
141            }
142            self.next();
143        }
144    }
145}
146
147/// A trait to zip two iterators together.
148pub trait StridedIteratorZip: Sized {
149    /// Combines this iterator with another iterator, enabling simultaneous iteration.
150    ///
151    /// This method zips together `self` and `other` into a `StridedZip` iterator, allowing for synchronized
152    ///
153    /// iteration over both iterators. This is particularly useful for operations that require processing
154    ///
155    /// elements from two tensors in parallel, such as element-wise arithmetic operations.
156    ///
157    /// # Arguments
158    ///
159    /// * `other` - The other iterator to zip with. It must implement the `IterGetSet` trait, and
160    ///             its associated `Item` type must be `Send`.
161    ///
162    /// # Returns
163    ///
164    /// A `StridedZip` instance that encapsulates both `self` and `other`, allowing for synchronized
165    ///
166    /// iteration over their elements.
167    ///
168    /// # Panics
169    ///
170    /// This method will panic if the shapes of `self` and `other` cannot be broadcasted together.
171    #[track_caller]
172    fn zip<'a, C>(self, other: C) -> StridedZip<'a, Self, C>
173    where
174        C: IterGetSet + ShapeManipulator,
175        Self: IterGetSet + ShapeManipulator,
176        <C as IterGetSet>::Item: Send,
177        <Self as IterGetSet>::Item: Send,
178    {
179        let new_shape = predict_broadcast_shape(&self.shape(), &other.shape())
180            .expect("Cannot broadcast shapes");
181
182        let mut a = self.reshape(new_shape.clone());
183        let mut b = other.reshape(new_shape.clone());
184
185        a.set_shape(new_shape.clone());
186        b.set_shape(new_shape.clone());
187        StridedZip::new(a, b)
188    }
189}
190
191/// A trait to zip two parallel iterators together.
192pub trait ParStridedIteratorZip: Sized + IterGetSet {
193    /// Combines this iterator with another iterator, enabling simultaneous parallel iteration.
194    ///
195    /// This method performs shape broadcasting between `self` and `other` to ensure that both iterators
196    /// iterate over tensors with compatible shapes. It adjusts the strides and shapes of both iterators
197    /// to match the broadcasted shape and then returns a `ParStridedZip` that allows for synchronized
198    /// parallel iteration over both iterators.
199    ///
200    /// # Arguments
201    ///
202    /// * `other` - The other iterator to zip with. It must implement the `IterGetSet`, `UnindexedProducer`,
203    ///             and `ParallelIterator` traits, and its associated `Item` type must be `Send`.
204    ///
205    /// # Returns
206    ///
207    /// A `ParStridedZip` instance that zips together `self` and `other`, enabling synchronized
208    /// parallel iteration over their elements.
209    ///
210    /// # Panics
211    ///
212    /// This method will panic if the shapes of `self` and `other` cannot be broadcasted together.
213    /// Ensure that the shapes are compatible before calling this method.
214    #[track_caller]
215    fn zip<'a, C>(mut self, mut other: C) -> ParStridedZip<'a, Self, C>
216    where
217        C: UnindexedProducer + 'a + IterGetSet + ParallelIterator + ShapeManipulator,
218        <C as IterGetSet>::Item: Send,
219        Self: UnindexedProducer + ParallelIterator + ShapeManipulator,
220        <Self as IterGetSet>::Item: Send,
221    {
222        let new_shape = predict_broadcast_shape(&self.shape(), &other.shape())
223            .expect("Cannot broadcast shapes");
224
225        let inner_loop_size = new_shape[new_shape.len() - 1] as usize;
226        let outer_loop_size = (new_shape.size() as usize) / inner_loop_size;
227
228        let num_threads;
229        if outer_loop_size < rayon::current_num_threads() {
230            num_threads = outer_loop_size;
231        } else {
232            num_threads = rayon::current_num_threads();
233        }
234        let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
235        let len = intervals.len();
236        self.set_intervals(intervals.clone());
237        self.set_end_index(len);
238        other.set_intervals(intervals.clone());
239        other.set_end_index(len);
240
241        let mut a = self.reshape(new_shape.clone());
242        let mut b = other.reshape(new_shape.clone());
243
244        a.set_shape(new_shape.clone());
245        b.set_shape(new_shape.clone());
246
247        ParStridedZip::new(a, b)
248    }
249}
250
251/// A trait to zip two parallel iterators together.
252pub trait ParStridedIteratorSimdZip: Sized + IterGetSetSimd {
253    /// Combines this `ParStridedZipSimd` iterator with another SIMD-optimized iterator, enabling simultaneous parallel iteration.
254    ///
255    /// This method performs shape broadcasting between `self` and `other` to ensure that both iterators
256    /// iterate over tensors with compatible shapes. It calculates the appropriate iteration intervals based
257    /// on the new broadcasted shape and configures both iterators accordingly. Finally, it returns a new
258    /// `ParStridedZipSimd` instance that allows for synchronized parallel iteration over the combined iterators.
259    ///
260    /// # Arguments
261    ///
262    /// * `other` - The third iterator to zip with. It must implement the `IterGetSetSimd`, `UnindexedProducer`,
263    ///             `ShapeManipulator`, and `ParallelIterator` traits,
264    ///             and its associated `Item` type must be `Send`.
265    ///
266    /// # Returns
267    ///
268    /// A new `ParStridedZipSimd` instance that combines `self` and `other` for synchronized parallel iteration over all three iterators.
269    ///
270    /// # Panics
271    ///
272    /// This method will panic if the shapes of `self` and `other` cannot be broadcasted together.
273    /// Ensure that the shapes are compatible before calling this method.
274    #[track_caller]
275    fn zip<'a, C>(mut self, mut other: C) -> ParStridedZipSimd<'a, Self, C>
276    where
277        C: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator + ShapeManipulator,
278        <C as IterGetSetSimd>::Item: Send,
279        Self: UnindexedProducer + ParallelIterator + ShapeManipulator,
280        <Self as IterGetSetSimd>::Item: Send,
281    {
282        let new_shape = predict_broadcast_shape(&self.shape(), &other.shape())
283            .expect("Cannot broadcast shapes");
284
285        let inner_loop_size = new_shape[new_shape.len() - 1] as usize;
286        let outer_loop_size = (new_shape.size() as usize) / inner_loop_size;
287
288        let num_threads;
289        if outer_loop_size < rayon::current_num_threads() {
290            num_threads = outer_loop_size;
291        } else {
292            num_threads = rayon::current_num_threads();
293        }
294        let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
295        let len = intervals.len();
296        self.set_intervals(intervals.clone());
297        self.set_end_index(len);
298        other.set_intervals(intervals.clone());
299        other.set_end_index(len);
300
301        let mut a = self.reshape(new_shape.clone());
302        let mut b = other.reshape(new_shape.clone());
303
304        a.set_shape(new_shape.clone());
305        b.set_shape(new_shape.clone());
306
307        ParStridedZipSimd::new(a, b)
308    }
309}
310
311/// A trait to zip two simd iterators together.
312pub trait StridedSimdIteratorZip: Sized {
313    /// Combines this iterator with another SIMD-optimized iterator, enabling simultaneous iteration.
314    ///
315    /// This method performs shape broadcasting between `self` and `other` to ensure that both iterators
316    /// iterate over tensors with compatible shapes. It adjusts the strides and shapes of both iterators
317    /// to match the broadcasted shape and then returns a `StridedZipSimd` that allows for synchronized
318    /// iteration over both iterators.
319    ///
320    /// # Arguments
321    ///
322    /// * `other` - The other iterator to zip with. It must implement the `IterGetSetSimd`, `UnindexedProducer`,
323    ///             and `ParallelIterator` traits, and its associated `Item` type must be `Send`.
324    ///
325    /// # Returns
326    ///
327    /// A `StridedZipSimd` instance that zips together `self` and `other`, enabling synchronized
328    /// iteration over their elements.
329    #[track_caller]
330    fn zip<'a, C>(mut self, mut other: C) -> StridedZipSimd<'a, Self, C>
331    where
332        C: 'a + IterGetSetSimd,
333        <C as IterGetSetSimd>::Item: Send,
334        Self: IterGetSetSimd,
335        <Self as IterGetSetSimd>::Item: Send,
336    {
337        let new_shape =
338            predict_broadcast_shape(self.shape(), other.shape()).expect("Cannot broadcast shapes");
339
340        other.broadcast_set_strides(&new_shape);
341        self.broadcast_set_strides(&new_shape);
342
343        other.set_shape(new_shape.clone());
344        self.set_shape(new_shape.clone());
345
346        StridedZipSimd::new(self, other)
347    }
348}
349
350/// A trait for performing single thread simd iteration over an iterator.
351pub trait StridedIteratorSimd
352where
353    Self: Sized + IterGetSetSimd,
354{
355    /// perform simd iteration, this method is for single thread simd iterator
356    fn for_each<F, F2>(mut self, op: F, vec_op: F2)
357    where
358        F: Fn(Self::Item),
359        F2: Fn(Self::SimdItem),
360    {
361        let outer_loop_size = self.outer_loop_size();
362        let inner_loop_size = self.inner_loop_size(); // we don't need to add 1 as we didn't subtract shape by 1
363        self.set_prg(vec![0; self.shape().len()]);
364        match (self.all_last_stride_one(), self.lanes()) {
365            (true, Some(vec_size)) => {
366                let remain = inner_loop_size % vec_size;
367                let inner = inner_loop_size - remain;
368                let n = inner / vec_size;
369                let unroll = n % 4;
370                if remain > 0 {
371                    if unroll == 0 {
372                        for _ in 0..outer_loop_size {
373                            for idx in 0..n / 4 {
374                                vec_op(self.inner_loop_next_simd(idx * 4));
375                                vec_op(self.inner_loop_next_simd(idx * 4 + 1));
376                                vec_op(self.inner_loop_next_simd(idx * 4 + 2));
377                                vec_op(self.inner_loop_next_simd(idx * 4 + 3));
378                            }
379                            for idx in inner..inner_loop_size {
380                                op(self.inner_loop_next(idx));
381                            }
382                            self.next();
383                        }
384                    } else {
385                        for _ in 0..outer_loop_size {
386                            for idx in 0..n {
387                                vec_op(self.inner_loop_next_simd(idx));
388                            }
389                            for idx in inner..inner_loop_size {
390                                op(self.inner_loop_next(idx));
391                            }
392                            self.next();
393                        }
394                    }
395                } else {
396                    if unroll == 0 {
397                        for _ in 0..outer_loop_size {
398                            for idx in 0..n / 4 {
399                                vec_op(self.inner_loop_next_simd(idx * 4));
400                                vec_op(self.inner_loop_next_simd(idx * 4 + 1));
401                                vec_op(self.inner_loop_next_simd(idx * 4 + 2));
402                                vec_op(self.inner_loop_next_simd(idx * 4 + 3));
403                            }
404                            self.next();
405                        }
406                    } else {
407                        for _ in 0..outer_loop_size {
408                            for idx in 0..n {
409                                vec_op(self.inner_loop_next_simd(idx));
410                            }
411                            self.next();
412                        }
413                    }
414                }
415            }
416            _ => {
417                for _ in 0..outer_loop_size {
418                    for idx in 0..inner_loop_size {
419                        op(self.inner_loop_next(idx));
420                    }
421                    self.next();
422                }
423            }
424        }
425    }
426    /// perform simd iteration with init, this method is for single thread simd iterator
427    fn for_each_init<F, INIT, T>(mut self, init: INIT, func: F)
428    where
429        F: Fn(&mut T, Self::Item),
430        INIT: Fn() -> T,
431    {
432        let outer_loop_size = self.outer_loop_size();
433        let inner_loop_size = self.inner_loop_size();
434        let mut init = init();
435        for _ in 0..outer_loop_size {
436            for idx in 0..inner_loop_size {
437                func(&mut init, self.inner_loop_next(idx));
438            }
439            self.next();
440        }
441    }
442}
443
444/// A trait for performing single thread simd iteration over an iterator.
445pub trait ParStridedIteratorSimd
446where
447    Self: Sized + UnindexedProducer + IterGetSetSimd + ParallelIterator,
448{
449    /// perform simd iteration, this method is for single thread simd iterator
450    fn for_each<F, F2>(self, op: F, vec_op: F2)
451    where
452        F: Fn(<Self as IterGetSetSimd>::Item) + Sync,
453        F2: Fn(<Self as IterGetSetSimd>::SimdItem) + Sync + Send + Copy,
454        <Self as IterGetSetSimd>::SimdItem: Send,
455        <Self as IterGetSetSimd>::Item: Send,
456    {
457        let with_simd = WithSimd { base: self, vec_op };
458        with_simd.for_each(|x| {
459            op(x);
460        });
461    }
462}
463
464/// A trait to map a function on the elements of an iterator.
465pub trait StridedIteratorMap: Sized {
466    /// Transforms the strided iterators by applying a provided function to their items.
467    ///
468    /// This method allows for element-wise operations on the zipped iterators by applying `func` to each item.
469    ///
470    /// # Type Parameters
471    ///
472    /// * `'a` - The lifetime associated with the iterators.
473    /// * `F` - The function to apply to each item.
474    /// * `U` - The output type after applying the function.
475    ///
476    /// # Arguments
477    ///
478    /// * `f` - A function that takes an item from the zipped iterator and returns a transformed value.
479    ///
480    /// # Returns
481    ///
482    /// A `StridedMap` instance that applies the provided function during iteration.
483    fn map<'a, T, F, U>(self, f: F) -> StridedMap<'a, Self, T, F>
484    where
485        F: Fn(T) -> U + Sync + Send + 'a,
486        U: CommonBounds,
487        Self: IterGetSet<Item = T>,
488    {
489        StridedMap {
490            iter: self,
491            f,
492            phantom: std::marker::PhantomData,
493        }
494    }
495}
496
497pub(crate) trait StridedHelper {
498    fn _set_last_strides(&mut self, stride: i64);
499    fn _set_strides(&mut self, strides: Strides);
500    fn _set_shape(&mut self, shape: Shape);
501    fn _layout(&self) -> &Layout;
502}
503
504pub(crate) trait ParStridedHelper {
505    fn _set_last_strides(&mut self, stride: i64);
506    fn _set_strides(&mut self, strides: Strides);
507    fn _set_shape(&mut self, shape: Shape);
508    fn _layout(&self) -> &Layout;
509    fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>);
510    fn _set_end_index(&mut self, end_index: usize);
511}