hpt_iterator/
strided_map_mut.rs

1use hpt_common::{shape::shape::Shape, strides::strides::Strides};
2use hpt_traits::tensor::{CommonBounds, TensorInfo};
3use hpt_types::dtype::TypeCommon;
4use std::sync::Arc;
5
6use crate::{
7    iterator_traits::{IterGetSet, StridedIterator},
8    par_strided_mut::ParStridedMut,
9    strided_zip::StridedZip,
10};
11
12/// A module for mutable mapped strided iterator.
13pub mod strided_map_mut_simd {
14    use std::sync::Arc;
15
16    use crate::{CommonBounds, TensorInfo};
17    use hpt_common::{shape::shape::Shape, strides::strides::Strides};
18    use hpt_types::dtype::TypeCommon;
19
20    use crate::{
21        iterator_traits::{IterGetSetSimd, StridedIteratorSimd},
22        par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
23        strided_zip::strided_zip_simd::StridedZipSimd,
24    };
25
26    /// A SIMD-optimized mutable mapped strided iterator over tensor elements.
27    ///
28    /// This struct provides mutable access to tensor elements with SIMD optimizations,
29    /// allowing for efficient parallel processing of tensor data.
30    pub struct StridedMapMutSimd<'a, T>
31    where
32        T: Copy + TypeCommon + Send + Sync,
33    {
34        /// The underlying parallel SIMD-optimized strided iterator.
35        pub(crate) base: ParStridedMutSimd<'a, T>,
36        /// Phantom data to associate the lifetime `'a` with the struct.
37        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
38    }
39    impl<'a, T> StridedMapMutSimd<'a, T>
40    where
41        T: CommonBounds,
42        T::Vec: Send,
43    {
44        /// Creates a new `StridedMapMutSimd` instance from a given tensor.
45        ///
46        /// # Arguments
47        ///
48        /// * `res_tensor` - The tensor implementing the `TensorInfo<T>` trait to iterate over mutably.
49        ///
50        /// # Returns
51        ///
52        /// A new instance of `StridedMapMutSimd` initialized with the provided tensor.
53        pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
54            StridedMapMutSimd {
55                base: ParStridedMutSimd::new(res_tensor),
56                phantom: std::marker::PhantomData,
57            }
58        }
59        /// Combines this `StridedMapMutSimd` iterator with another SIMD-optimized iterator, enabling simultaneous iteration.
60        ///
61        /// This method zips together `self` and `other` into a `StridedZipSimd` iterator, allowing for synchronized
62        /// iteration over both iterators. This is particularly useful for operations that require processing
63        /// elements from two tensors in parallel.
64        ///
65        /// # Arguments
66        ///
67        /// * `other` - The other iterator to zip with. It must implement the `IterGetSetSimd` trait, and
68        ///             its associated `Item` type must be `Send`.
69        ///
70        /// # Returns
71        ///
72        /// A `StridedZipSimd` instance that zips together `self` and `other`, enabling synchronized
73        /// iteration over their elements.
74        pub(crate) fn zip<C>(self, other: C) -> StridedZipSimd<'a, Self, C>
75        where
76            C: 'a + IterGetSetSimd,
77            <C as IterGetSetSimd>::Item: Send,
78        {
79            StridedZipSimd::new(self, other)
80        }
81    }
82    impl<'a, T> StridedIteratorSimd for StridedMapMutSimd<'a, T> where T: 'a + CommonBounds {}
83    impl<'a, T: 'a + CommonBounds> IterGetSetSimd for StridedMapMutSimd<'a, T>
84    where
85        T::Vec: Send,
86    {
87        type Item = &'a mut T;
88        type SimdItem = &'a mut T::Vec;
89
90        fn set_end_index(&mut self, _: usize) {}
91
92        fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {}
93
94        fn set_strides(&mut self, strides: Strides) {
95            self.base.set_strides(strides);
96        }
97
98        fn set_shape(&mut self, shape: Shape) {
99            self.base.set_shape(shape);
100        }
101
102        fn set_prg(&mut self, prg: Vec<i64>) {
103            self.base.set_prg(prg);
104        }
105
106        fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
107            self.base.intervals()
108        }
109
110        fn strides(&self) -> &Strides {
111            self.base.strides()
112        }
113
114        fn shape(&self) -> &Shape {
115            self.base.shape()
116        }
117
118        fn layout(&self) -> &hpt_common::layout::layout::Layout {
119            self.base.layout()
120        }
121
122        fn broadcast_set_strides(&mut self, shape: &Shape) {
123            self.base.broadcast_set_strides(shape);
124        }
125
126        fn outer_loop_size(&self) -> usize {
127            self.base.outer_loop_size()
128        }
129
130        fn inner_loop_size(&self) -> usize {
131            self.base.inner_loop_size()
132        }
133
134        fn next(&mut self) {
135            self.base.next();
136        }
137
138        fn next_simd(&mut self) {
139            todo!()
140        }
141
142        fn inner_loop_next(&mut self, index: usize) -> Self::Item {
143            self.base.inner_loop_next(index)
144        }
145
146        fn inner_loop_next_simd(&mut self, _: usize) -> Self::SimdItem {
147            todo!()
148        }
149
150        fn all_last_stride_one(&self) -> bool {
151            todo!()
152        }
153
154        fn lanes(&self) -> Option<usize> {
155            todo!()
156        }
157    }
158}
159
160/// A `non` SIMD-optimized mutable mapped strided iterator over tensor elements.
161///
162/// This struct provides mutable access to tensor elements,
163pub struct StridedMapMut<'a, T>
164where
165    T: Copy + TypeCommon,
166{
167    /// The underlying parallel mutable strided iterator.
168    pub(crate) base: ParStridedMut<'a, T>,
169    /// Phantom data to associate the lifetime `'a` with the struct.
170    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
171}
172
173impl<'a, T> StridedMapMut<'a, T>
174where
175    T: CommonBounds,
176    T::Vec: Send,
177{
178    /// Creates a new `StridedMapMut` instance from a given tensor.
179    ///
180    /// # Arguments
181    ///
182    /// * `res_tensor` - The tensor implementing the `TensorInfo<T>` trait to iterate over mutably.
183    ///
184    /// # Returns
185    ///
186    /// A new instance of `StridedMapMut` initialized with the provided tensor.
187    pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
188        StridedMapMut {
189            base: ParStridedMut::new(res_tensor),
190            phantom: std::marker::PhantomData,
191        }
192    }
193
194    /// Combines this `StridedMapMut` iterator with another iterator, enabling simultaneous iteration.
195    ///
196    /// This method zips together `self` and `other` into a `StridedZip` iterator, allowing for synchronized
197    /// iteration over both iterators. This is particularly useful for operations that require processing
198    /// elements from two tensors in parallel, such as element-wise arithmetic operations.
199    ///
200    /// # Arguments
201    ///
202    /// * `other` - The other iterator to zip with. It must implement the `IterGetSet` trait, and
203    ///             its associated `Item` type must be `Send`.
204    ///
205    /// # Returns
206    ///
207    /// A `StridedZip` instance that encapsulates both `self` and `other`, allowing for synchronized
208    /// iteration over their elements.
209    pub fn zip<C>(self, other: C) -> StridedZip<'a, Self, C>
210    where
211        C: 'a + IterGetSet,
212        <C as IterGetSet>::Item: Send,
213    {
214        StridedZip::new(self, other)
215    }
216}
217
218impl<'a, T> StridedIterator for StridedMapMut<'a, T> where T: 'a + CommonBounds {}
219
220impl<'a, T: 'a + CommonBounds> IterGetSet for StridedMapMut<'a, T>
221where
222    T::Vec: Send,
223{
224    type Item = &'a mut T;
225
226    fn set_end_index(&mut self, _: usize) {}
227
228    fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {}
229
230    fn set_strides(&mut self, strides: Strides) {
231        self.base.set_strides(strides);
232    }
233
234    fn set_shape(&mut self, shape: Shape) {
235        self.base.set_shape(shape);
236    }
237
238    fn set_prg(&mut self, prg: Vec<i64>) {
239        self.base.set_prg(prg);
240    }
241
242    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
243        self.base.intervals()
244    }
245
246    fn strides(&self) -> &Strides {
247        self.base.strides()
248    }
249
250    fn shape(&self) -> &Shape {
251        self.base.shape()
252    }
253
254    fn layout(&self) -> &hpt_common::layout::layout::Layout {
255        self.base.layout()
256    }
257
258    fn broadcast_set_strides(&mut self, shape: &Shape) {
259        self.base.broadcast_set_strides(shape);
260    }
261
262    fn outer_loop_size(&self) -> usize {
263        self.base.outer_loop_size()
264    }
265
266    fn inner_loop_size(&self) -> usize {
267        self.base.inner_loop_size()
268    }
269
270    fn next(&mut self) {
271        self.base.next();
272    }
273
274    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
275        self.base.inner_loop_next(index)
276    }
277}