hpt_iterator/
strided_mut.rs

1use crate::{
2    iterator_traits::{IterGetSet, StridedIterator, StridedIteratorZip},
3    strided::Strided,
4    strided_zip::StridedZip,
5};
6use hpt_common::{shape::shape::Shape, shape::shape_utils::predict_broadcast_shape};
7use hpt_traits::tensor::{CommonBounds, TensorInfo};
8use std::sync::Arc;
9
10/// Module containing SIMD-optimized implementations for strided mutability.
11pub mod simd_imports {
12    use crate::{
13        iterator_traits::{IterGetSetSimd, StridedIteratorSimd, StridedSimdIteratorZip},
14        strided::strided_simd::StridedSimd,
15    };
16    use crate::{CommonBounds, TensorInfo};
17    use hpt_common::shape::shape::Shape;
18    use hpt_types::dtype::TypeCommon;
19    use hpt_types::vectors::traits::VecTrait;
20    use std::sync::Arc;
21
22    /// A SIMD-optimized mutable strided iterator over tensor elements.
23    ///
24    /// This struct provides mutable access to tensor elements with SIMD optimizations.
25    pub struct StridedMutSimd<'a, T: TypeCommon> {
26        /// The underlying SIMD-optimized strided iterator.
27        pub(crate) base: StridedSimd<T>,
28        /// The stride for the last dimension, used for inner loop element access.
29        pub(crate) last_stride: i64,
30        /// Phantom data to associate the lifetime `'a` with the struct.
31        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
32    }
33
34    impl<'a, T: CommonBounds> StridedMutSimd<'a, T> {
35        /// Creates a new `StridedMutSimd` instance from a tensor.
36        ///
37        /// This constructor initializes a `StridedMutSimd` iterator by creating a base `StridedSimd`
38        /// from the provided tensor. It also retrieves the last stride from the base iterator to
39        /// configure the strided access pattern. The `PhantomData` marker is used to associate
40        /// the iterator with the tensor's data type `T` without holding any actual data.
41        ///
42        /// # Arguments
43        ///
44        /// * `tensor` - An instance that implements the `TensorInfo<T>` trait, representing the tensor
45        ///   to be iterated over. This tensor provides the necessary information about the tensor's shape,
46        ///   strides, and data layout.
47        ///
48        /// # Returns
49        ///
50        /// A new instance of `StridedMutSimd`
51        pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
52            let base = StridedSimd::new(tensor);
53            let last_stride = base.last_stride;
54            StridedMutSimd {
55                base,
56                last_stride,
57                phantom: std::marker::PhantomData,
58            }
59        }
60    }
61
62    impl<'a, T: 'a> IterGetSetSimd for StridedMutSimd<'a, T>
63    where
64        T: CommonBounds,
65    {
66        type Item = &'a mut T;
67        type SimdItem = &'a mut T::Vec;
68
69        fn set_end_index(&mut self, end_index: usize) {
70            self.base.set_end_index(end_index);
71        }
72
73        fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
74            self.base.set_intervals(intervals);
75        }
76
77        fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
78            self.base.set_strides(strides);
79        }
80
81        fn set_shape(&mut self, shape: Shape) {
82            self.base.set_shape(shape);
83        }
84
85        fn set_prg(&mut self, prg: Vec<i64>) {
86            self.base.set_prg(prg);
87        }
88
89        fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
90            self.base.intervals()
91        }
92
93        fn strides(&self) -> &hpt_common::strides::strides::Strides {
94            self.base.strides()
95        }
96
97        fn shape(&self) -> &Shape {
98            self.base.shape()
99        }
100
101        fn layout(&self) -> &hpt_common::layout::layout::Layout {
102            self.base.layout()
103        }
104
105        fn broadcast_set_strides(&mut self, shape: &Shape) {
106            self.base.broadcast_set_strides(shape);
107        }
108
109        fn outer_loop_size(&self) -> usize {
110            self.base.outer_loop_size()
111        }
112        fn inner_loop_size(&self) -> usize {
113            self.base.inner_loop_size()
114        }
115
116        fn next(&mut self) {
117            self.base.next();
118        }
119        fn next_simd(&mut self) {
120            todo!()
121        }
122        #[inline(always)]
123        fn inner_loop_next(&mut self, index: usize) -> Self::Item {
124            unsafe {
125                &mut *self
126                    .base
127                    .ptr
128                    .ptr
129                    .offset((index as isize) * (self.last_stride as isize))
130            }
131        }
132        fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
133            let vector = unsafe { self.base.ptr.ptr.add(index * T::Vec::SIZE) };
134            unsafe { std::mem::transmute(vector) }
135        }
136        fn all_last_stride_one(&self) -> bool {
137            self.base.all_last_stride_one()
138        }
139
140        fn lanes(&self) -> Option<usize> {
141            self.base.lanes()
142        }
143    }
144    impl<'a, T> StridedIteratorSimd for StridedMutSimd<'a, T> where T: CommonBounds {}
145    impl<'a, T> StridedSimdIteratorZip for StridedMutSimd<'a, T> where T: CommonBounds {}
146}
147
148/// A mutable strided iterator over tensor elements.
149///
150/// This struct provides mutable access to tensor elements with strided access patterns in `single thread`.
151pub struct StridedMut<'a, T> {
152    /// The underlying `single thread` strided iterator handling the iteration logic.
153    pub(crate) base: Strided<T>,
154    /// Phantom data to associate the lifetime `'a` with the struct.
155    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
156}
157
158impl<'a, T: CommonBounds> StridedMut<'a, T> {
159    /// Creates a new `StridedMut` instance from a given tensor.
160    ///
161    /// # Arguments
162    ///
163    /// * `tensor` - The tensor implementing the `TensorInfo<T>` trait to iterate over.
164    ///
165    /// # Returns
166    ///
167    /// A new instance of `StridedMut` initialized with the provided tensor.
168    pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
169        StridedMut {
170            base: Strided::new(tensor),
171            phantom: std::marker::PhantomData,
172        }
173    }
174    /// Combines this `StridedMut` iterator with another iterator, enabling simultaneous iteration.
175    ///
176    /// This method performs shape broadcasting between `self` and `other` to ensure that both iterators
177    /// iterate over tensors with compatible shapes. It adjusts the strides and shapes of both iterators
178    /// to match the broadcasted shape and then returns a `StridedZip` that allows for synchronized
179    /// iteration over both iterators.
180    ///
181    /// # Arguments
182    ///
183    /// * `other` - The other iterator to zip with. It must implement the `IterGetSet` trait, and
184    ///             its associated `Item` type must be `Send`.
185    ///
186    /// # Returns
187    ///
188    /// A `StridedZip` instance that zips together `self` and `other`, enabling synchronized
189    /// iteration over their elements.
190    ///
191    /// # Panics
192    ///
193    /// This method will panic if the shapes of `self` and `other` cannot be broadcasted together.
194    /// Ensure that the shapes are compatible before calling this method.
195    #[track_caller]
196    pub fn zip<C>(mut self, mut other: C) -> StridedZip<'a, Self, C>
197    where
198        C: 'a + IterGetSet,
199        <C as IterGetSet>::Item: Send,
200    {
201        let new_shape = match predict_broadcast_shape(self.shape(), other.shape()) {
202            Ok(s) => s,
203            Err(err) => {
204                panic!("{}", err);
205            }
206        };
207
208        other.broadcast_set_strides(&new_shape);
209        self.broadcast_set_strides(&new_shape);
210
211        other.set_shape(new_shape.clone());
212        self.set_shape(new_shape.clone());
213
214        StridedZip::new(self, other)
215    }
216}
217
218impl<'a, T: CommonBounds> StridedIterator for StridedMut<'a, T> {}
219impl<'a, T: CommonBounds> StridedIteratorZip for StridedMut<'a, T> {}
220
221impl<'a, T: 'a> IterGetSet for StridedMut<'a, T>
222where
223    T: CommonBounds,
224{
225    type Item = &'a mut T;
226
227    fn set_end_index(&mut self, end_index: usize) {
228        self.base.set_end_index(end_index);
229    }
230
231    fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
232        self.base.set_intervals(intervals);
233    }
234
235    fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
236        self.base.set_strides(strides);
237    }
238
239    fn set_shape(&mut self, shape: Shape) {
240        self.base.set_shape(shape);
241    }
242
243    fn set_prg(&mut self, prg: Vec<i64>) {
244        self.base.set_prg(prg);
245    }
246
247    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
248        self.base.intervals()
249    }
250
251    fn strides(&self) -> &hpt_common::strides::strides::Strides {
252        self.base.strides()
253    }
254
255    fn shape(&self) -> &Shape {
256        self.base.shape()
257    }
258
259    fn layout(&self) -> &hpt_common::layout::layout::Layout {
260        self.base.layout()
261    }
262
263    fn broadcast_set_strides(&mut self, shape: &Shape) {
264        self.base.broadcast_set_strides(shape);
265    }
266
267    fn outer_loop_size(&self) -> usize {
268        self.base.outer_loop_size()
269    }
270
271    fn inner_loop_size(&self) -> usize {
272        self.base.inner_loop_size()
273    }
274
275    fn next(&mut self) {
276        self.base.next();
277    }
278
279    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
280        unsafe {
281            self.base
282                .ptr
283                .get_ptr()
284                .add(index * (self.base.last_stride as usize))
285                .as_mut()
286                .unwrap()
287        }
288    }
289}