hpt_iterator/
par_strided_map_mut.rs

1use std::sync::Arc;
2
3use hpt_common::{shape::shape::Shape, strides::strides::Strides};
4use hpt_traits::tensor::{CommonBounds, TensorInfo};
5use rayon::iter::{
6    plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
7    ParallelIterator,
8};
9
10use crate::{
11    iterator_traits::IterGetSet, par_strided_mut::ParStridedMut, par_strided_zip::ParStridedZip,
12};
13
14/// A module for parallel strided mutable map iterator.
15pub mod par_strided_map_mut_simd {
16    use std::sync::Arc;
17
18    use crate::{CommonBounds, TensorInfo};
19    use hpt_common::{shape::shape::Shape, strides::strides::Strides, utils::simd_ref::MutVec};
20    use hpt_types::dtype::TypeCommon;
21    use rayon::iter::{
22        plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
23        ParallelIterator,
24    };
25
26    use crate::{
27        iterator_traits::IterGetSetSimd,
28        par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
29        par_strided_zip::par_strided_zip_simd::ParStridedZipSimd,
30    };
31
32    /// A parallel mutable SIMD-optimized map iterator.
33    ///
34    /// This struct provides mutable access to tensor elements with strided access patterns optimized
35    pub struct ParStridedMapMutSimd<'a, T>
36    where
37        T: TypeCommon + Send + Copy + Sync,
38    {
39        /// The underlying parallel SIMD-optimized strided iterator.
40        pub(crate) base: ParStridedMutSimd<'a, T>,
41        /// Phantom data to associate the lifetime `'a` with the struct.
42        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
43    }
44
45    impl<'a, T> ParStridedMapMutSimd<'a, T>
46    where
47        T: CommonBounds,
48        T::Vec: Send,
49    {
50        /// Creates a new `ParStridedMapMutSimd` instance from a given result tensor.
51        ///
52        /// This constructor initializes the `ParStridedMapMutSimd` iterator by wrapping the provided result
53        /// tensor in a `ParStridedMutSimd` instance. It sets up the necessary data structures for parallel and SIMD
54        /// optimized iteration.
55        ///
56        /// # Type Parameters
57        ///
58        /// * `U` - The type of the tensor to accumulate results into. Must implement `TensorInfo<T>`.
59        ///
60        /// # Arguments
61        ///
62        /// * `res_tensor` - The tensor implementing the `TensorInfo<T>` trait to accumulate results into.
63        ///
64        /// # Returns
65        ///
66        /// A new instance of `ParStridedMapMutSimd` initialized with the provided result tensor.
67        pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
68            ParStridedMapMutSimd {
69                base: ParStridedMutSimd::new(res_tensor),
70                phantom: std::marker::PhantomData,
71            }
72        }
73        /// Combines this `ParStridedMapMutSimd` iterator with another SIMD-optimized iterator, enabling simultaneous parallel iteration.
74        ///
75        /// This method performs shape broadcasting between `self` and `other` to ensure that both iterators
76        /// iterate over tensors with compatible shapes. It calculates the appropriate iteration intervals based
77        /// on the new broadcasted shape and configures both iterators accordingly. Finally, it returns a new
78        /// `ParStridedZipSimd` instance that allows for synchronized parallel iteration over the combined iterators.
79        ///
80        /// **Note:** This implementation leverages Rayon for parallel execution and assumes that the iterators
81        /// support SIMD optimizations.
82        ///
83        /// # Type Parameters
84        ///
85        /// * `C` - The type of the other iterator to zip with. Must implement `UnindexedProducer`, `IterGetSetSimd`, and `ParallelIterator`.
86        ///
87        /// # Arguments
88        ///
89        /// * `self` - The `ParStridedMapMutSimd` iterator instance.
90        /// * `other` - The other iterator to zip with.
91        ///
92        /// # Returns
93        ///
94        /// A new `ParStridedZipSimd` instance that combines `self` and `other` for synchronized parallel iteration over both iterators.
95        pub fn zip<C>(self, other: C) -> ParStridedZipSimd<'a, Self, C>
96        where
97            C: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator,
98            <C as IterGetSetSimd>::Item: Send,
99            T::Vec: Send,
100        {
101            ParStridedZipSimd::new(self, other)
102        }
103    }
104
105    impl<'a, T> ParallelIterator for ParStridedMapMutSimd<'a, T>
106    where
107        T: 'a + CommonBounds,
108        T::Vec: Send,
109    {
110        type Item = &'a mut T;
111
112        fn drive_unindexed<C>(self, consumer: C) -> C::Result
113        where
114            C: UnindexedConsumer<Self::Item>,
115        {
116            bridge_unindexed(self, consumer)
117        }
118    }
119
120    impl<'a, T> UnindexedProducer for ParStridedMapMutSimd<'a, T>
121    where
122        T: 'a + CommonBounds,
123        T::Vec: Send,
124    {
125        type Item = &'a mut T;
126
127        fn split(self) -> (Self, Option<Self>) {
128            let (a, b) = self.base.split();
129            (
130                ParStridedMapMutSimd {
131                    base: a,
132                    phantom: std::marker::PhantomData,
133                },
134                b.map(|x| ParStridedMapMutSimd {
135                    base: x,
136                    phantom: std::marker::PhantomData,
137                }),
138            )
139        }
140
141        fn fold_with<F>(self, folder: F) -> F
142        where
143            F: Folder<Self::Item>,
144        {
145            folder
146        }
147    }
148
149    impl<'a, T: 'a + CommonBounds> IterGetSetSimd for ParStridedMapMutSimd<'a, T>
150    where
151        T::Vec: Send,
152    {
153        type Item = &'a mut T;
154
155        type SimdItem = MutVec<'a, T::Vec>;
156
157        fn set_end_index(&mut self, end_index: usize) {
158            self.base.set_end_index(end_index);
159        }
160
161        fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
162            self.base.set_intervals(intervals);
163        }
164
165        fn set_strides(&mut self, strides: Strides) {
166            self.base.set_strides(strides);
167        }
168
169        fn set_shape(&mut self, shape: Shape) {
170            self.base.set_shape(shape);
171        }
172
173        fn set_prg(&mut self, prg: Vec<i64>) {
174            self.base.set_prg(prg);
175        }
176
177        fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
178            self.base.intervals()
179        }
180
181        fn strides(&self) -> &Strides {
182            self.base.strides()
183        }
184
185        fn shape(&self) -> &Shape {
186            self.base.shape()
187        }
188
189        fn layout(&self) -> &hpt_common::layout::layout::Layout {
190            self.base.layout()
191        }
192
193        fn broadcast_set_strides(&mut self, shape: &Shape) {
194            self.base.broadcast_set_strides(shape);
195        }
196
197        fn outer_loop_size(&self) -> usize {
198            self.base.outer_loop_size()
199        }
200
201        fn inner_loop_size(&self) -> usize {
202            self.base.inner_loop_size()
203        }
204
205        fn next(&mut self) {
206            self.base.next();
207        }
208
209        fn next_simd(&mut self) {
210            self.base.next_simd();
211        }
212
213        fn inner_loop_next(&mut self, index: usize) -> Self::Item {
214            self.base.inner_loop_next(index)
215        }
216
217        fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
218            self.base.inner_loop_next_simd(index)
219        }
220
221        fn all_last_stride_one(&self) -> bool {
222            self.base.all_last_stride_one()
223        }
224
225        fn lanes(&self) -> Option<usize> {
226            self.base.lanes()
227        }
228    }
229}
230
231/// A parallel mutable map iterator.
232///
233/// This struct provides mutable access to tensor elements with strided access
234pub struct ParStridedMapMut<'a, T>
235where
236    T: Copy,
237{
238    /// The base parallel strided mutable iterator.
239    pub(crate) base: ParStridedMut<'a, T>,
240    /// Phantom data to associate the lifetime `'a` with the struct.
241    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
242}
243
244impl<'a, T> ParStridedMapMut<'a, T>
245where
246    T: CommonBounds,
247{
248    /// Creates a new `ParStridedMapMut` instance from a given result tensor.
249    ///
250    /// This constructor initializes the `ParStridedMapMut` iterator by wrapping the provided result
251    /// tensor in a `ParStridedMut` instance. It sets up the necessary data structures for parallel iteration.
252    ///
253    /// # Type Parameters
254    ///
255    /// * `U` - The type of the tensor to accumulate results into. Must implement `TensorInfo<T>`.
256    ///
257    /// # Arguments
258    ///
259    /// * `res_tensor` - The tensor implementing the `TensorInfo<T>` trait to accumulate results into.
260    ///
261    /// # Returns
262    ///
263    /// A new instance of `ParStridedMapMut` initialized with the provided result tensor.
264    pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
265        ParStridedMapMut {
266            base: ParStridedMut::new(res_tensor),
267            phantom: std::marker::PhantomData,
268        }
269    }
270    /// Combines this `ParStridedMapMut` iterator with another iterator, enabling simultaneous parallel iteration.
271    ///
272    /// This method performs shape broadcasting between `self` and `other` to ensure that both iterators
273    /// iterate over tensors with compatible shapes. It calculates the appropriate iteration intervals based
274    /// on the new broadcasted shape and configures both iterators accordingly. Finally, it returns a new
275    /// `ParStridedZip` instance that allows for synchronized parallel iteration over the combined iterators.
276    ///
277    /// **Note:** This implementation leverages Rayon for parallel execution.
278    ///
279    /// # Type Parameters
280    ///
281    /// * `C` - The type of the other iterator to zip with. Must implement `UnindexedProducer`, `IterGetSet`, and `ParallelIterator`.
282    ///
283    /// # Arguments
284    ///
285    /// * `self` - The `ParStridedMapMut` iterator instance.
286    /// * `other` - The other iterator to zip with.
287    ///
288    /// # Returns
289    ///
290    /// A new `ParStridedZip` instance that combines `self` and `other` for synchronized parallel iteration over both iterators.
291    pub fn zip<C>(self, other: C) -> ParStridedZip<'a, Self, C>
292    where
293        C: UnindexedProducer + 'a + IterGetSet + ParallelIterator,
294        <C as IterGetSet>::Item: Send,
295    {
296        ParStridedZip::new(self, other)
297    }
298}
299
300impl<'a, T> ParallelIterator for ParStridedMapMut<'a, T>
301where
302    T: 'a + CommonBounds,
303{
304    type Item = &'a mut T;
305
306    fn drive_unindexed<C>(self, consumer: C) -> C::Result
307    where
308        C: UnindexedConsumer<Self::Item>,
309    {
310        bridge_unindexed(self, consumer)
311    }
312}
313
314impl<'a, T> UnindexedProducer for ParStridedMapMut<'a, T>
315where
316    T: 'a + CommonBounds,
317{
318    type Item = &'a mut T;
319
320    fn split(self) -> (Self, Option<Self>) {
321        let (a, b) = self.base.split();
322        (
323            ParStridedMapMut {
324                base: a,
325                phantom: std::marker::PhantomData,
326            },
327            b.map(|x| ParStridedMapMut {
328                base: x,
329                phantom: std::marker::PhantomData,
330            }),
331        )
332    }
333
334    fn fold_with<F>(self, folder: F) -> F
335    where
336        F: Folder<Self::Item>,
337    {
338        folder
339    }
340}
341
342impl<'a, T: 'a + CommonBounds> IterGetSet for ParStridedMapMut<'a, T> {
343    type Item = &'a mut T;
344
345    fn set_end_index(&mut self, end_index: usize) {
346        self.base.set_end_index(end_index);
347    }
348
349    fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
350        self.base.set_intervals(intervals);
351    }
352
353    fn set_strides(&mut self, strides: Strides) {
354        self.base.set_strides(strides);
355    }
356
357    fn set_shape(&mut self, shape: Shape) {
358        self.base.set_shape(shape);
359    }
360
361    fn set_prg(&mut self, prg: Vec<i64>) {
362        self.base.set_prg(prg);
363    }
364
365    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
366        self.base.intervals()
367    }
368
369    fn strides(&self) -> &Strides {
370        self.base.strides()
371    }
372    fn shape(&self) -> &Shape {
373        self.base.shape()
374    }
375    fn layout(&self) -> &hpt_common::layout::layout::Layout {
376        self.base.layout()
377    }
378    fn broadcast_set_strides(&mut self, shape: &Shape) {
379        self.base.broadcast_set_strides(shape);
380    }
381
382    fn outer_loop_size(&self) -> usize {
383        self.base.outer_loop_size()
384    }
385
386    fn inner_loop_size(&self) -> usize {
387        self.base.inner_loop_size()
388    }
389
390    fn next(&mut self) {
391        self.base.next();
392    }
393
394    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
395        self.base.inner_loop_next(index)
396    }
397}