hpt_iterator/
par_strided_mut.rs

1use crate::{
2    iterator_traits::{IterGetSet, ParStridedHelper, ParStridedIteratorZip, ShapeManipulator},
3    par_strided::ParStrided,
4    shape_manipulate::{par_expand, par_reshape, par_transpose},
5};
6use hpt_common::shape::shape::Shape;
7use hpt_traits::tensor::{CommonBounds, TensorInfo};
8use rayon::iter::{
9    plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
10    ParallelIterator,
11};
12use std::sync::Arc;
13
14/// A module for parallel mutable strided iterator.
15pub mod par_strided_map_mut_simd {
16    use crate::{
17        iterator_traits::{IterGetSetSimd, ParStridedIteratorSimd, ParStridedIteratorSimdZip},
18        par_strided::par_strided_simd::ParStridedSimd,
19    };
20    use hpt_common::{shape::shape::Shape, utils::pointer::Pointer, utils::simd_ref::MutVec};
21    use hpt_traits::{CommonBounds, TensorInfo};
22    use hpt_types::dtype::TypeCommon;
23    use hpt_types::traits::VecTrait;
24    use rayon::iter::{
25        plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
26        ParallelIterator,
27    };
28    use std::sync::Arc;
29    /// A parallel mutable SIMD-optimized strided iterator over tensor elements.
30    ///
31    /// This struct provides mutable access to tensor elements with strided access patterns optimized
32    /// for parallel and SIMD (Single Instruction, Multiple Data) processing. It leverages Rayon for
33    /// concurrent execution and ensures efficient traversal and modification of tensor data based on
34    /// their strides.
35    pub struct ParStridedMutSimd<'a, T: TypeCommon + Send + Copy + Sync> {
36        /// The base parallel SIMD-optimized strided iterator.
37        pub(crate) base: ParStridedSimd<T>,
38        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
39    }
40
41    impl<'a, T: CommonBounds> ParStridedMutSimd<'a, T> {
42        /// Creates a new `ParStridedMutSimd` instance from a given tensor.
43        ///
44        /// This constructor initializes the `ParStridedMutSimd` iterator by wrapping the provided tensor
45        /// in a `ParStridedSimd` instance. It sets up the necessary data structures for parallel and SIMD
46        /// optimized iteration.
47        ///
48        /// # Arguments
49        ///
50        /// * `tensor` - The tensor implementing the `TensorInfo<T>` trait to iterate over mutably.
51        ///
52        /// # Returns
53        ///
54        /// A new instance of `ParStridedMutSimd` initialized with the provided tensor.
55        pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
56            ParStridedMutSimd {
57                base: ParStridedSimd::new(tensor),
58                phantom: std::marker::PhantomData,
59            }
60        }
61    }
62
63    impl<'a, T: CommonBounds> ParStridedIteratorSimdZip for ParStridedMutSimd<'a, T> {}
64    impl<'a, T: CommonBounds> ParStridedIteratorSimd for ParStridedMutSimd<'a, T> {}
65
66    impl<'a, T> ParallelIterator for ParStridedMutSimd<'a, T>
67    where
68        T: CommonBounds,
69        T::Vec: Send,
70    {
71        type Item = &'a mut T;
72
73        fn drive_unindexed<C>(self, consumer: C) -> C::Result
74        where
75            C: UnindexedConsumer<Self::Item>,
76        {
77            bridge_unindexed(self, consumer)
78        }
79    }
80
81    impl<'a, T> UnindexedProducer for ParStridedMutSimd<'a, T>
82    where
83        T: CommonBounds,
84        T::Vec: Send,
85    {
86        type Item = &'a mut T;
87
88        fn split(self) -> (Self, Option<Self>) {
89            let (a, b) = self.base.split();
90            (
91                ParStridedMutSimd {
92                    base: a,
93                    phantom: std::marker::PhantomData,
94                },
95                b.map(|x| ParStridedMutSimd {
96                    base: x,
97                    phantom: std::marker::PhantomData,
98                }),
99            )
100        }
101
102        fn fold_with<F>(self, folder: F) -> F
103        where
104            F: Folder<Self::Item>,
105        {
106            folder
107        }
108    }
109
110    impl<'a, T: 'a> IterGetSetSimd for ParStridedMutSimd<'a, T>
111    where
112        T: CommonBounds,
113        T::Vec: Send,
114    {
115        type Item = &'a mut T;
116
117        type SimdItem
118            = MutVec<'a, T::Vec>
119        where
120            Self: 'a;
121
122        fn set_end_index(&mut self, end_index: usize) {
123            self.base.set_end_index(end_index);
124        }
125
126        fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
127            self.base.set_intervals(intervals);
128        }
129
130        fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
131            self.base.set_strides(strides);
132        }
133
134        fn set_shape(&mut self, shape: Shape) {
135            self.base.set_shape(shape);
136        }
137
138        fn set_prg(&mut self, prg: Vec<i64>) {
139            self.base.set_prg(prg);
140        }
141
142        fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
143            self.base.intervals()
144        }
145
146        fn strides(&self) -> &hpt_common::strides::strides::Strides {
147            self.base.strides()
148        }
149
150        fn shape(&self) -> &Shape {
151            self.base.shape()
152        }
153
154        fn layout(&self) -> &hpt_common::layout::layout::Layout {
155            self.base.layout()
156        }
157
158        fn broadcast_set_strides(&mut self, shape: &Shape) {
159            self.base.broadcast_set_strides(shape);
160        }
161
162        fn outer_loop_size(&self) -> usize {
163            self.base.outer_loop_size()
164        }
165
166        fn inner_loop_size(&self) -> usize {
167            self.base.inner_loop_size()
168        }
169
170        fn next(&mut self) {
171            self.base.next();
172        }
173
174        fn next_simd(&mut self) {
175            self.base.next_simd();
176        }
177
178        fn inner_loop_next(&mut self, index: usize) -> Self::Item {
179            unsafe {
180                self.base
181                    .ptr
182                    .get_ptr()
183                    .add(index * (self.base.last_stride as usize))
184                    .as_mut()
185                    .unwrap()
186            }
187        }
188
189        #[inline(always)]
190        fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
191            unsafe {
192                let ptr = self.base.ptr.get_ptr().add(index * T::Vec::SIZE) as *mut T::Vec;
193                #[cfg(feature = "bound_check")]
194                return MutVec::new(Pointer::new(ptr, T::Vec::SIZE as i64));
195                #[cfg(not(feature = "bound_check"))]
196                return MutVec::new(Pointer::new(ptr));
197            }
198        }
199
200        fn all_last_stride_one(&self) -> bool {
201            self.base.all_last_stride_one()
202        }
203
204        fn lanes(&self) -> Option<usize> {
205            self.base.lanes()
206        }
207    }
208}
209
210/// A parallel mutable strided iterator over tensor elements.
211///
212/// This struct provides mutable access to tensor elements with strided access patterns optimized
213/// for parallel processing. It leverages Rayon for concurrent execution, allowing efficient traversal
214/// and modification of tensor data based on their strides.
215pub struct ParStridedMut<'a, T> {
216    /// The base parallel strided iterator.
217    pub(crate) base: ParStrided<T>,
218    /// Phantom data to associate the lifetime `'a` with the struct.
219    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
220}
221
222impl<'a, T: CommonBounds> ParStridedHelper for ParStridedMut<'a, T> {
223    fn _set_last_strides(&mut self, stride: i64) {
224        self.base._set_last_strides(stride);
225    }
226
227    fn _set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
228        self.base._set_strides(strides);
229    }
230
231    fn _set_shape(&mut self, shape: Shape) {
232        self.base._set_shape(shape);
233    }
234
235    fn _layout(&self) -> &hpt_common::layout::layout::Layout {
236        self.base._layout()
237    }
238
239    fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
240        self.base._set_intervals(intervals);
241    }
242
243    fn _set_end_index(&mut self, end_index: usize) {
244        self.base._set_end_index(end_index);
245    }
246}
247
248impl<'a, T: CommonBounds> ShapeManipulator for ParStridedMut<'a, T> {
249    fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
250        par_reshape(self, shape)
251    }
252
253    fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
254        par_transpose(self, axes)
255    }
256
257    fn expand<S: Into<Shape>>(self, shape: S) -> Self {
258        par_expand(self, shape)
259    }
260}
261
262impl<'a, T: CommonBounds> ParStridedIteratorZip for ParStridedMut<'a, T> {}
263
264impl<'a, T: CommonBounds> ParStridedMut<'a, T> {
265    /// Creates a new `ParStridedMut` instance from a given tensor.
266    ///
267    /// This constructor initializes the `ParStridedMut` iterator by wrapping the provided tensor
268    /// in a `ParStrided` instance. It sets up the necessary data structures for parallel iteration.
269    ///
270    /// # Arguments
271    ///
272    /// * `tensor` - The tensor implementing the `TensorInfo<T>` trait to iterate over mutably.
273    ///
274    /// # Returns
275    ///
276    /// A new instance of `ParStridedMut` initialized with the provided tensor.
277    pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
278        ParStridedMut {
279            base: ParStrided::new(tensor),
280            phantom: std::marker::PhantomData,
281        }
282    }
283}
284
285impl<'a, T> ParallelIterator for ParStridedMut<'a, T>
286where
287    T: CommonBounds,
288{
289    type Item = &'a mut T;
290
291    fn drive_unindexed<C>(self, consumer: C) -> C::Result
292    where
293        C: UnindexedConsumer<Self::Item>,
294    {
295        bridge_unindexed(self, consumer)
296    }
297}
298
299impl<'a, T> UnindexedProducer for ParStridedMut<'a, T>
300where
301    T: CommonBounds,
302{
303    type Item = &'a mut T;
304
305    fn split(self) -> (Self, Option<Self>) {
306        let (a, b) = self.base.split();
307        (
308            ParStridedMut {
309                base: a,
310                phantom: std::marker::PhantomData,
311            },
312            b.map(|x| ParStridedMut {
313                base: x,
314                phantom: std::marker::PhantomData,
315            }),
316        )
317    }
318
319    fn fold_with<F>(self, folder: F) -> F
320    where
321        F: Folder<Self::Item>,
322    {
323        folder
324    }
325}
326
327impl<'a, T: 'a> IterGetSet for ParStridedMut<'a, T>
328where
329    T: CommonBounds,
330{
331    type Item = &'a mut T;
332
333    fn set_end_index(&mut self, end_index: usize) {
334        self.base.set_end_index(end_index);
335    }
336
337    fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
338        self.base.set_intervals(intervals);
339    }
340
341    fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
342        self.base.set_strides(strides);
343    }
344
345    fn set_shape(&mut self, shape: Shape) {
346        self.base.set_shape(shape);
347    }
348
349    fn set_prg(&mut self, prg: Vec<i64>) {
350        self.base.set_prg(prg);
351    }
352
353    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
354        self.base.intervals()
355    }
356
357    fn strides(&self) -> &hpt_common::strides::strides::Strides {
358        self.base.strides()
359    }
360
361    fn shape(&self) -> &Shape {
362        self.base.shape()
363    }
364
365    fn layout(&self) -> &hpt_common::layout::layout::Layout {
366        self.base.layout()
367    }
368
369    fn broadcast_set_strides(&mut self, shape: &Shape) {
370        self.base.broadcast_set_strides(shape);
371    }
372
373    fn outer_loop_size(&self) -> usize {
374        self.base.outer_loop_size()
375    }
376
377    fn inner_loop_size(&self) -> usize {
378        self.base.inner_loop_size()
379    }
380
381    fn next(&mut self) {
382        self.base.next();
383    }
384
385    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
386        unsafe {
387            self.base
388                .ptr
389                .get_ptr()
390                .add(index * (self.base.last_stride as usize))
391                .as_mut()
392                .unwrap()
393        }
394    }
395}