hpt_iterator/
par_strided.rs

1use crate::{
2    iterator_traits::{IterGetSet, ParStridedHelper, ParStridedIteratorZip, ShapeManipulator},
3    par_strided_fold::ParStridedFold,
4    par_strided_map::ParStridedMap,
5    shape_manipulate::{par_expand, par_reshape, par_transpose},
6};
7use hpt_common::{
8    axis::axis::Axis,
9    layout::layout::Layout,
10    shape::shape::Shape,
11    shape::shape_utils::{mt_intervals, try_pad_shape},
12    strides::strides::Strides,
13    strides::strides_utils::preprocess_strides,
14    utils::pointer::Pointer,
15};
16use hpt_traits::tensor::{CommonBounds, TensorInfo};
17use rayon::iter::{
18    plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
19    ParallelIterator,
20};
21use std::sync::Arc;
22
23/// A module for parallel strided iterators.
24pub mod par_strided_simd {
25    use hpt_types::vectors::traits::VecTrait;
26    use std::sync::Arc;
27
28    use crate::{CommonBounds, TensorInfo};
29    use hpt_common::{
30        axis::axis::Axis,
31        layout::layout::Layout,
32        shape::shape::Shape,
33        shape::shape_utils::{mt_intervals, try_pad_shape},
34        strides::strides::Strides,
35        strides::strides_utils::preprocess_strides,
36        utils::pointer::Pointer,
37        utils::simd_ref::MutVec,
38    };
39    use rayon::iter::{
40        plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
41        ParallelIterator,
42    };
43
44    use crate::{
45        iterator_traits::{
46            IterGetSetSimd, ParStridedHelper, ParStridedIteratorSimd, ParStridedIteratorSimdZip,
47            ShapeManipulator,
48        },
49        par_strided_map::par_strided_map_simd::ParStridedMapSimd,
50        shape_manipulate::{par_expand, par_reshape, par_transpose},
51    };
52
53    /// Parallel strided iterator for SIMD operations.
54    ///
55    /// This struct is used to iterate over a tensor in parallel, with SIMD support.
56    #[derive(Clone)]
57    pub struct ParStridedSimd<T: Send + Copy + Sync> {
58        /// Pointer to the data.
59        pub(crate) ptr: Pointer<T>,
60        /// Layout of the tensor.
61        pub(crate) layout: Layout,
62        /// Progress of the loop.
63        pub(crate) prg: Vec<i64>,
64        /// Chunk intervals for the outer loop.
65        pub(crate) intervals: Arc<Vec<(usize, usize)>>,
66        /// Start index of the chunk intervals.
67        pub(crate) start_index: usize,
68        /// End index of the chunk intervals.
69        pub(crate) end_index: usize,
70        /// Stride of the last dimension.
71        pub(crate) last_stride: i64,
72    }
73    impl<T: CommonBounds> ParStridedSimd<T> {
74        /// Returns the shape of the iterator.
75        pub fn shape(&self) -> &Shape {
76            self.layout.shape()
77        }
78
79        /// Returns the strides of the iterator.
80        pub fn strides(&self) -> &Strides {
81            self.layout.strides()
82        }
83
84        /// Create a new parallel strided iterator for SIMD operations.
85        pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
86            let inner_loop_size = *tensor.shape().last().unwrap() as usize;
87            let outer_loop_size = tensor.size() / inner_loop_size;
88            let num_threads;
89            if outer_loop_size < rayon::current_num_threads() {
90                num_threads = outer_loop_size;
91            } else {
92                num_threads = rayon::current_num_threads();
93            }
94            let intervals = mt_intervals(outer_loop_size, num_threads);
95            let len = intervals.len();
96            ParStridedSimd {
97                ptr: tensor.ptr(),
98                layout: tensor.layout().clone(),
99                prg: vec![],
100                intervals: Arc::new(intervals),
101                start_index: 0,
102                end_index: len,
103                last_stride: *tensor.strides().last().unwrap(),
104            }
105        }
106
107        /// Map the iterator with a function.
108        pub fn strided_map_simd<'a, F, F2>(
109            self,
110            f: F,
111            vec_op: F2,
112        ) -> ParStridedMapSimd<'a, ParStridedSimd<T>, T, F, F2>
113        where
114            F: Fn((&mut T, <Self as IterGetSetSimd>::Item)) + Sync + Send + 'a,
115            <Self as IterGetSetSimd>::Item: Send,
116            F2: Send + Sync + Copy + Fn((MutVec<'_, T::Vec>, <Self as IterGetSetSimd>::SimdItem)),
117        {
118            {
119                ParStridedMapSimd {
120                    iter: self,
121                    f,
122                    f2: vec_op,
123                    phantom: std::marker::PhantomData,
124                }
125            }
126        }
127    }
128
129    impl<T: CommonBounds> ParStridedIteratorSimdZip for ParStridedSimd<T> {}
130    impl<T: CommonBounds> ParStridedIteratorSimd for ParStridedSimd<T> {}
131
132    impl<T: CommonBounds> IterGetSetSimd for ParStridedSimd<T>
133    where
134        T::Vec: Send,
135    {
136        type Item = T;
137
138        type SimdItem = T::Vec;
139
140        fn set_end_index(&mut self, end_index: usize) {
141            self.end_index = end_index;
142        }
143
144        fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
145            self.intervals = intervals;
146        }
147
148        fn set_strides(&mut self, strides: Strides) {
149            self.layout.set_strides(strides);
150        }
151
152        fn set_shape(&mut self, shape: Shape) {
153            self.layout.set_shape(shape);
154        }
155
156        fn set_prg(&mut self, prg: Vec<i64>) {
157            self.prg = prg;
158        }
159
160        fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
161            &self.intervals
162        }
163
164        fn strides(&self) -> &Strides {
165            self.layout.strides()
166        }
167
168        fn shape(&self) -> &Shape {
169            self.layout.shape()
170        }
171
172        fn layout(&self) -> &Layout {
173            &self.layout
174        }
175
176        fn broadcast_set_strides(&mut self, shape: &Shape) {
177            let self_shape = try_pad_shape(self.shape(), shape.len());
178            self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
179            self.last_stride = self.strides()[self.strides().len() - 1];
180        }
181
182        fn outer_loop_size(&self) -> usize {
183            self.intervals[self.start_index].1 - self.intervals[self.start_index].0
184        }
185
186        fn inner_loop_size(&self) -> usize {
187            self.shape().last().unwrap().clone() as usize
188        }
189
190        fn next(&mut self) {
191            for j in (0..(self.shape().len() as i64) - 1).rev() {
192                let j = j as usize;
193                if self.prg[j] < self.shape()[j] {
194                    self.prg[j] += 1;
195                    self.ptr.offset(self.strides()[j]);
196                    break;
197                } else {
198                    self.prg[j] = 0;
199                    self.ptr.offset(-self.strides()[j] * self.shape()[j]);
200                }
201            }
202        }
203
204        fn next_simd(&mut self) {}
205
206        fn inner_loop_next(&mut self, index: usize) -> Self::Item {
207            unsafe { *self.ptr.get_ptr().add(index * (self.last_stride as usize)) }
208        }
209
210        #[inline(always)]
211        fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
212            unsafe { T::Vec::from_ptr(self.ptr.get_ptr().add(index * T::Vec::SIZE)) }
213        }
214
215        fn all_last_stride_one(&self) -> bool {
216            self.last_stride == 1
217        }
218
219        fn lanes(&self) -> Option<usize> {
220            Some(T::Vec::SIZE)
221        }
222    }
223
224    impl<T> ParallelIterator for ParStridedSimd<T>
225    where
226        T: CommonBounds,
227        T::Vec: Send,
228    {
229        type Item = T;
230
231        fn drive_unindexed<C>(self, consumer: C) -> C::Result
232        where
233            C: UnindexedConsumer<Self::Item>,
234        {
235            bridge_unindexed(self, consumer)
236        }
237    }
238
239    impl<T> UnindexedProducer for ParStridedSimd<T>
240    where
241        T: CommonBounds,
242        T::Vec: Send,
243    {
244        type Item = T;
245
246        fn split(mut self) -> (Self, Option<Self>) {
247            if self.end_index - self.start_index <= 1 {
248                let mut curent_shape_prg: Vec<i64> = vec![0; self.shape().len()];
249                let mut amount =
250                    self.intervals[self.start_index].0 * (*self.shape().last().unwrap() as usize);
251                for j in (0..self.shape().len()).rev() {
252                    curent_shape_prg[j] = (amount as i64) % self.shape()[j];
253                    amount /= self.shape()[j] as usize;
254                    self.ptr += curent_shape_prg[j] * self.strides()[j];
255                }
256                self.prg = curent_shape_prg;
257                let mut new_shape = self.shape().to_vec();
258                new_shape.iter_mut().for_each(|x| {
259                    *x -= 1;
260                });
261                self.last_stride = self.strides()[self.strides().len() - 1];
262                self.set_shape(Shape::from(new_shape));
263                return (self, None);
264            }
265            let _left_interval = &self.intervals[self.start_index..self.end_index];
266            let left = _left_interval.len() / 2;
267            let right = _left_interval.len() / 2 + (_left_interval.len() % 2);
268            (
269                ParStridedSimd {
270                    ptr: self.ptr.clone(),
271                    layout: self.layout.clone(),
272                    prg: vec![],
273                    intervals: self.intervals.clone(),
274                    start_index: self.start_index,
275                    end_index: self.start_index + left,
276                    last_stride: self.last_stride,
277                },
278                Some(ParStridedSimd {
279                    ptr: self.ptr.clone(),
280                    layout: self.layout.clone(),
281                    prg: vec![],
282                    intervals: self.intervals.clone(),
283                    start_index: self.start_index + left,
284                    end_index: self.start_index + left + right,
285                    last_stride: self.last_stride,
286                }),
287            )
288        }
289
290        fn fold_with<F>(self, folder: F) -> F
291        where
292            F: Folder<Self::Item>,
293        {
294            folder
295        }
296    }
297
298    impl<T: CommonBounds> ParStridedHelper for ParStridedSimd<T> {
299        fn _set_last_strides(&mut self, stride: i64) {
300            self.last_stride = stride;
301        }
302
303        fn _set_strides(&mut self, strides: Strides) {
304            self.layout.set_strides(strides);
305        }
306
307        fn _set_shape(&mut self, shape: Shape) {
308            self.layout.set_shape(shape);
309        }
310
311        fn _layout(&self) -> &Layout {
312            &self.layout
313        }
314
315        fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
316            self.intervals = intervals;
317        }
318
319        fn _set_end_index(&mut self, end_index: usize) {
320            self.end_index = end_index;
321        }
322    }
323
324    impl<T: CommonBounds> ShapeManipulator for ParStridedSimd<T>
325    where
326        T::Vec: Send,
327    {
328        fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
329            par_reshape(self, shape)
330        }
331
332        fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
333            par_transpose(self, axes)
334        }
335
336        fn expand<S: Into<Shape>>(self, shape: S) -> Self {
337            par_expand(self, shape)
338        }
339    }
340}
341
342/// A parallel strided iterator over tensor elements.
343///
344/// This struct provides mutable access to tensor elements with strided access patterns optimized for parallel processing.
345#[derive(Clone)]
346pub struct ParStrided<T> {
347    /// A pointer to the tensor's data.
348    pub(crate) ptr: Pointer<T>,
349    /// The layout of the tensor, including shape and strides.
350    pub(crate) layout: Layout,
351    /// Progress of the loop.
352    pub(crate) prg: Vec<i64>,
353    /// Chunk intervals for the outer loop.
354    pub(crate) intervals: Arc<Vec<(usize, usize)>>,
355    /// Start index of the chunk intervals.
356    pub(crate) start_index: usize,
357    /// End index of the chunk intervals.
358    pub(crate) end_index: usize,
359    /// Stride of the last dimension.
360    pub(crate) last_stride: i64,
361}
362
363impl<T: CommonBounds> ParStrided<T> {
364    /// Retrieves the shape of the tensor.
365    ///
366    /// # Returns
367    ///
368    /// A reference to the `Shape` struct representing the tensor's dimensions.
369    pub fn shape(&self) -> &Shape {
370        self.layout.shape()
371    }
372    /// Retrieves the strides of the tensor.
373    ///
374    /// # Returns
375    ///
376    /// A reference to the `Strides` struct representing the tensor's stride information.
377    pub fn strides(&self) -> &Strides {
378        self.layout.strides()
379    }
380    /// Creates a new `ParStrided` instance from a given tensor.
381    ///
382    /// This constructor initializes the `ParStrided` iterator by determining the appropriate number of threads
383    /// based on the tensor's outer loop size and Rayon’s current number of threads. It then divides the
384    /// iteration workload into intervals for parallel execution.
385    ///
386    /// # Arguments
387    ///
388    /// * `tensor` - The tensor implementing the `TensorInfo<T>` trait to iterate over mutably.
389    ///
390    /// # Returns
391    ///
392    /// A new instance of `ParStrided` initialized with the provided tensor.
393    pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
394        let inner_loop_size = tensor.shape()[tensor.shape().len() - 1] as usize;
395        let outer_loop_size = tensor.size() / inner_loop_size;
396        let num_threads;
397        if outer_loop_size < rayon::current_num_threads() {
398            num_threads = outer_loop_size;
399        } else {
400            num_threads = rayon::current_num_threads();
401        }
402        let intervals = mt_intervals(outer_loop_size, num_threads);
403        let len = intervals.len();
404        ParStrided {
405            ptr: tensor.ptr(),
406            layout: tensor.layout().clone(),
407            prg: vec![],
408            intervals: Arc::new(intervals),
409            start_index: 0,
410            end_index: len,
411            last_stride: tensor.strides()[tensor.strides().len() - 1],
412        }
413    }
414    /// Performs a parallel fold (reduce) operation over the tensor elements.
415    ///
416    /// This method applies a folding function `fold_op` to accumulate tensor elements into an initial
417    /// identity value `identity`. It leverages parallel iteration to perform the fold operation efficiently.
418    ///
419    /// # Type Parameters
420    ///
421    /// * `ID` - The type of the accumulator.
422    /// * `F` - The folding function.
423    ///
424    /// # Arguments
425    ///
426    /// * `identity` - The initial value for the accumulator.
427    /// * `fold_op` - A function that takes the current accumulator and an element, returning the updated accumulator.
428    ///
429    /// # Returns
430    ///
431    /// A `ParStridedFold` instance that represents the fold operation.
432    pub fn par_strided_fold<ID, F>(self, identity: ID, fold_op: F) -> ParStridedFold<Self, ID, F>
433    where
434        F: Fn(ID, T) -> ID + Sync + Send + Copy,
435        ID: Sync + Send + Copy,
436    {
437        ParStridedFold {
438            iter: self,
439            identity,
440            fold_op,
441        }
442    }
443    /// Transforms the zipped iterators by applying a provided function to their items.
444    ///
445    /// This method allows for element-wise operations on the zipped iterators by applying `func` to each item.
446    ///
447    /// # Type Parameters
448    ///
449    /// * `'a` - The lifetime associated with the iterators.
450    /// * `F` - The function to apply to each item.
451    /// * `U` - The output type after applying the function.
452    ///
453    /// # Arguments
454    ///
455    /// * `f` - A function that takes an item from the zipped iterator and returns a transformed value.
456    ///
457    /// # Returns
458    ///
459    /// A `ParStridedMap` instance that applies the provided function during iteration.
460    pub fn strided_map<'a, F, U>(self, f: F) -> ParStridedMap<'a, ParStrided<T>, T, F>
461    where
462        F: Fn((&mut U, T)) + Sync + Send + 'a,
463        U: CommonBounds,
464    {
465        ParStridedMap {
466            iter: self,
467            f,
468            phantom: std::marker::PhantomData,
469        }
470    }
471}
472
473impl<T: CommonBounds> ParStridedIteratorZip for ParStrided<T> {}
474
475impl<T: CommonBounds> IterGetSet for ParStrided<T> {
476    type Item = T;
477
478    fn set_end_index(&mut self, end_index: usize) {
479        self.end_index = end_index;
480    }
481
482    fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
483        self.intervals = intervals;
484    }
485
486    fn set_strides(&mut self, strides: Strides) {
487        self.layout.set_strides(strides);
488    }
489
490    fn set_shape(&mut self, shape: Shape) {
491        self.layout.set_shape(shape);
492    }
493
494    fn set_prg(&mut self, prg: Vec<i64>) {
495        self.prg = prg;
496    }
497
498    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
499        &self.intervals
500    }
501
502    fn strides(&self) -> &Strides {
503        self.layout.strides()
504    }
505
506    fn shape(&self) -> &Shape {
507        self.layout.shape()
508    }
509
510    fn layout(&self) -> &Layout {
511        &self.layout
512    }
513
514    fn broadcast_set_strides(&mut self, shape: &Shape) {
515        let self_shape = try_pad_shape(self.shape(), shape.len());
516        self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
517        self.last_stride = self.strides()[self.strides().len() - 1];
518    }
519
520    fn outer_loop_size(&self) -> usize {
521        self.intervals[self.start_index].1 - self.intervals[self.start_index].0
522    }
523
524    fn inner_loop_size(&self) -> usize {
525        self.shape().last().unwrap().clone() as usize
526    }
527
528    fn next(&mut self) {
529        for j in (0..(self.shape().len() as i64) - 1).rev() {
530            let j = j as usize;
531            if self.prg[j] < self.shape()[j] {
532                self.prg[j] += 1;
533                self.ptr.offset(self.strides()[j]);
534                break;
535            } else {
536                self.prg[j] = 0;
537                self.ptr.offset(-self.strides()[j] * self.shape()[j]);
538            }
539        }
540    }
541
542    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
543        unsafe { *self.ptr.get_ptr().add(index * (self.last_stride as usize)) }
544    }
545}
546
547impl<T> ParallelIterator for ParStrided<T>
548where
549    T: CommonBounds,
550    T::Vec: Send,
551{
552    type Item = T;
553
554    fn drive_unindexed<C>(self, consumer: C) -> C::Result
555    where
556        C: UnindexedConsumer<Self::Item>,
557    {
558        bridge_unindexed(self, consumer)
559    }
560}
561
562impl<T> UnindexedProducer for ParStrided<T>
563where
564    T: CommonBounds,
565    T::Vec: Send,
566{
567    type Item = T;
568
569    fn split(mut self) -> (Self, Option<Self>) {
570        if self.end_index - self.start_index <= 1 {
571            let mut curent_shape_prg: Vec<i64> = vec![0; self.shape().len()];
572            let mut amount =
573                self.intervals[self.start_index].0 * (*self.shape().last().unwrap() as usize);
574            let mut index = 0;
575            for j in (0..self.shape().len()).rev() {
576                curent_shape_prg[j] = (amount as i64) % self.shape()[j];
577                amount /= self.shape()[j] as usize;
578                index += curent_shape_prg[j] * self.strides()[j];
579            }
580            self.ptr.offset(index);
581            self.prg = curent_shape_prg;
582            let mut new_shape = self.shape().to_vec();
583            new_shape.iter_mut().for_each(|x| {
584                *x -= 1;
585            });
586            self.last_stride = self.strides()[self.strides().len() - 1];
587            self.set_shape(Shape::from(new_shape));
588            return (self, None);
589        }
590        let _left_interval = &self.intervals[self.start_index..self.end_index];
591        let left = _left_interval.len() / 2;
592        let right = _left_interval.len() / 2 + (_left_interval.len() % 2);
593        (
594            ParStrided {
595                ptr: self.ptr.clone(),
596                layout: self.layout.clone(),
597                prg: vec![],
598                intervals: self.intervals.clone(),
599                start_index: self.start_index,
600                end_index: self.start_index + left,
601                last_stride: self.last_stride,
602            },
603            Some(ParStrided {
604                ptr: self.ptr.clone(),
605                layout: self.layout.clone(),
606                prg: vec![],
607                intervals: self.intervals.clone(),
608                start_index: self.start_index + left,
609                end_index: self.start_index + left + right,
610                last_stride: self.last_stride,
611            }),
612        )
613    }
614
615    fn fold_with<F>(self, folder: F) -> F
616    where
617        F: Folder<Self::Item>,
618    {
619        folder
620    }
621}
622
623impl<T> ParStridedHelper for ParStrided<T> {
624    fn _set_last_strides(&mut self, last_stride: i64) {
625        self.last_stride = last_stride;
626    }
627
628    fn _set_strides(&mut self, strides: Strides) {
629        self.layout.set_strides(strides);
630    }
631
632    fn _set_shape(&mut self, shape: Shape) {
633        self.layout.set_shape(shape);
634    }
635
636    fn _layout(&self) -> &Layout {
637        &self.layout
638    }
639
640    fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
641        self.intervals = intervals;
642    }
643
644    fn _set_end_index(&mut self, end_index: usize) {
645        self.end_index = end_index;
646    }
647}
648
649impl<T: CommonBounds> ShapeManipulator for ParStrided<T> {
650    #[track_caller]
651    fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
652        par_reshape(self, shape)
653    }
654
655    fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
656        par_transpose(self, axes)
657    }
658
659    fn expand<S: Into<Shape>>(self, shape: S) -> Self {
660        par_expand(self, shape)
661    }
662}