hpt_iterator/
strided.rs

1use crate::{
2    iterator_traits::{
3        IterGetSet, ShapeManipulator, StridedHelper, StridedIterator, StridedIteratorMap,
4        StridedIteratorZip,
5    },
6    shape_manipulate::{expand, reshape, transpose},
7};
8use hpt_common::{
9    axis::axis::Axis, layout::layout::Layout, shape::shape::Shape,
10    shape::shape_utils::try_pad_shape, strides::strides::Strides,
11    strides::strides_utils::preprocess_strides, utils::pointer::Pointer,
12};
13use hpt_traits::tensor::{CommonBounds, TensorInfo};
14use std::sync::Arc;
15
16/// A module for single-threaded strided simd iterator.
17pub mod strided_simd {
18    use hpt_common::{
19        axis::axis::Axis, layout::layout::Layout, shape::shape::Shape,
20        shape::shape_utils::try_pad_shape, strides::strides::Strides,
21        strides::strides_utils::preprocess_strides, utils::pointer::Pointer,
22    };
23    use hpt_traits::{CommonBounds, TensorInfo};
24    use hpt_types::dtype::TypeCommon;
25    use hpt_types::vectors::traits::VecTrait;
26    use std::sync::Arc;
27
28    use crate::iterator_traits::{
29        IterGetSetSimd, ShapeManipulator, StridedIteratorMap, StridedIteratorSimd,
30        StridedSimdIteratorZip,
31    };
32
33    use super::{expand, reshape, transpose, StridedHelper};
34
35    /// A single thread SIMD-optimized strided iterator
36    #[derive(Clone)]
37    pub struct StridedSimd<T: TypeCommon> {
38        /// A pointer to the tensor's data.
39        pub(crate) ptr: Pointer<T>,
40        /// The layout of the tensor, including shape and strides.
41        pub(crate) layout: Layout,
42        /// The loop progress of the iterator.
43        pub(crate) prg: Vec<i64>,
44        /// The stride for the last dimension.
45        pub(crate) last_stride: i64,
46    }
47
48    impl<T: CommonBounds> StridedSimd<T> {
49        /// Retrieves the shape of the tensor.
50        ///
51        /// # Returns
52        ///
53        /// A reference to the `Shape` struct representing the tensor's dimensions.
54        pub fn shape(&self) -> &Shape {
55            self.layout.shape()
56        }
57        /// Retrieves the strides of the tensor.
58        ///
59        /// # Returns
60        ///
61        /// A reference to the `Strides` struct representing the tensor's stride information.
62        pub fn strides(&self) -> &Strides {
63            self.layout.strides()
64        }
65        /// Creates a new `StridedSimd` instance from a given tensor.
66        ///
67        /// # Arguments
68        ///
69        /// * `tensor` - The tensor implementing the `TensorInfo<T>` trait to iterate over mutably.
70        ///
71        /// # Returns
72        ///
73        /// A new instance of `StridedSimd` initialized with the provided tensor.
74        pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
75            StridedSimd {
76                ptr: tensor.ptr(),
77                layout: tensor.layout().clone(),
78                prg: vec![],
79                last_stride: *tensor.strides().last().unwrap_or(&0),
80            }
81        }
82    }
83
84    impl<T: CommonBounds> IterGetSetSimd for StridedSimd<T> {
85        type Item = T;
86        type SimdItem = T::Vec;
87
88        fn set_end_index(&mut self, _: usize) {
89            panic!("single thread iterator does not support set_end_index");
90        }
91
92        fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
93            panic!("single thread iterator does not support set_intervals");
94        }
95
96        fn set_strides(&mut self, strides: Strides) {
97            self.layout.set_strides(strides);
98        }
99
100        fn set_shape(&mut self, shape: Shape) {
101            self.layout.set_shape(shape);
102        }
103
104        fn set_prg(&mut self, prg: Vec<i64>) {
105            self.prg = prg;
106        }
107
108        fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
109            panic!("single thread iterator does not support intervals");
110        }
111
112        fn strides(&self) -> &Strides {
113            self.layout.strides()
114        }
115
116        fn shape(&self) -> &Shape {
117            self.layout.shape()
118        }
119
120        fn layout(&self) -> &Layout {
121            &self.layout
122        }
123
124        fn broadcast_set_strides(&mut self, shape: &Shape) {
125            let self_shape = try_pad_shape(self.shape(), shape.len());
126            self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
127            self.last_stride = self.strides()[self.strides().len() - 1];
128        }
129
130        fn outer_loop_size(&self) -> usize {
131            (self.shape().iter().product::<i64>() as usize) / self.inner_loop_size()
132        }
133        fn inner_loop_size(&self) -> usize {
134            self.shape().last().unwrap().clone() as usize
135        }
136
137        fn next(&mut self) {
138            for j in (0..(self.shape().len() as i64) - 1).rev() {
139                let j = j as usize;
140                if self.prg[j] < self.shape()[j] - 1 {
141                    self.prg[j] += 1;
142                    self.ptr.offset(self.strides()[j]);
143                    break;
144                } else {
145                    self.prg[j] = 0;
146                    self.ptr.offset(-self.strides()[j] * (self.shape()[j] - 1));
147                }
148            }
149        }
150        fn next_simd(&mut self) {
151            todo!()
152        }
153        #[inline(always)]
154        fn inner_loop_next(&mut self, index: usize) -> Self::Item {
155            unsafe {
156                *self
157                    .ptr
158                    .ptr
159                    .offset((index as isize) * (self.last_stride as isize))
160            }
161        }
162        fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
163            unsafe { Self::SimdItem::from_ptr(self.ptr.get_ptr().add(index * T::Vec::SIZE)) }
164        }
165        fn all_last_stride_one(&self) -> bool {
166            self.last_stride == 1
167        }
168
169        fn lanes(&self) -> Option<usize> {
170            Some(T::Vec::SIZE)
171        }
172    }
173
174    impl<T: CommonBounds> ShapeManipulator for StridedSimd<T> {
175        #[track_caller]
176        fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
177            reshape(self, shape)
178        }
179
180        fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
181            transpose(self, axes)
182        }
183
184        fn expand<S: Into<Shape>>(self, shape: S) -> Self {
185            expand(self, shape)
186        }
187    }
188
189    impl<T: CommonBounds> StridedHelper for StridedSimd<T> {
190        fn _set_last_strides(&mut self, stride: i64) {
191            self.last_stride = stride;
192        }
193        fn _set_strides(&mut self, strides: Strides) {
194            self.layout.set_strides(strides);
195        }
196        fn _set_shape(&mut self, shape: Shape) {
197            self.layout.set_shape(shape);
198        }
199        fn _layout(&self) -> &Layout {
200            &self.layout
201        }
202    }
203    impl<T: CommonBounds> StridedIteratorMap for StridedSimd<T> {}
204    impl<T: CommonBounds> StridedSimdIteratorZip for StridedSimd<T> {}
205    impl<T> StridedIteratorSimd for StridedSimd<T> where T: CommonBounds {}
206}
207
208/// A single-threaded strided iterator over tensor elements.
209#[derive(Clone)]
210pub struct Strided<T> {
211    /// A pointer points to the tensor's data.
212    pub(crate) ptr: Pointer<T>,
213    /// The layout of the tensor, including shape and strides.
214    pub(crate) layout: Layout,
215    /// The loop progress of the iterator.
216    pub(crate) prg: Vec<i64>,
217    /// The stride for the last dimension.
218    pub(crate) last_stride: i64,
219}
220
221impl<T: CommonBounds> Strided<T> {
222    /// Retrieves the shape of the tensor.
223    ///
224    /// # Returns
225    ///
226    /// A reference to the `Shape` struct representing the tensor's dimensions.
227    pub fn shape(&self) -> &Shape {
228        self.layout.shape()
229    }
230    /// Retrieves the strides of the tensor.
231    ///
232    /// # Returns
233    ///
234    /// A reference to the `Strides` struct representing the tensor's stride information.
235    pub fn strides(&self) -> &Strides {
236        self.layout.strides()
237    }
238    /// Creates a new `Strided` instance from a given tensor.
239    ///
240    /// # Arguments
241    ///
242    /// * `tensor` - The tensor implementing the `TensorInfo<T>` trait to iterate over mutably.
243    ///
244    /// # Returns
245    ///
246    /// A new instance of `Strided` initialized with the provided tensor.
247    pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
248        Strided {
249            ptr: tensor.ptr(),
250            layout: tensor.layout().clone(),
251            prg: vec![],
252            last_stride: *tensor.strides().last().unwrap_or(&0),
253        }
254    }
255}
256
257impl<T> StridedIteratorMap for Strided<T> {}
258impl<T> StridedIteratorZip for Strided<T> {}
259
260impl<T: CommonBounds> IterGetSet for Strided<T> {
261    type Item = T;
262
263    fn set_end_index(&mut self, _: usize) {
264        panic!("single thread iterator does not support set_end_index");
265    }
266
267    fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
268        panic!("single thread iterator does not support set_intervals");
269    }
270
271    fn set_strides(&mut self, strides: Strides) {
272        self.layout.set_strides(strides);
273    }
274
275    fn set_shape(&mut self, shape: Shape) {
276        self.layout.set_shape(shape);
277    }
278
279    fn set_prg(&mut self, prg: Vec<i64>) {
280        self.prg = prg;
281    }
282
283    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
284        panic!("single thread iterator does not support intervals");
285    }
286    fn strides(&self) -> &Strides {
287        self.layout.strides()
288    }
289
290    fn shape(&self) -> &Shape {
291        self.layout.shape()
292    }
293
294    fn layout(&self) -> &Layout {
295        &self.layout
296    }
297
298    fn broadcast_set_strides(&mut self, shape: &Shape) {
299        let self_shape = try_pad_shape(self.shape(), shape.len());
300        self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
301        self.last_stride = self.strides()[self.strides().len() - 1];
302    }
303
304    fn outer_loop_size(&self) -> usize {
305        (self.shape().iter().product::<i64>() as usize) / self.inner_loop_size()
306    }
307
308    fn inner_loop_size(&self) -> usize {
309        self.shape().last().unwrap().clone() as usize
310    }
311
312    fn next(&mut self) {
313        for j in (0..(self.shape().len() as i64) - 1).rev() {
314            let j = j as usize;
315            if self.prg[j] < self.shape()[j] - 1 {
316                self.prg[j] += 1;
317                self.ptr.offset(self.strides()[j]);
318                break;
319            } else {
320                self.prg[j] = 0;
321                self.ptr.offset(-self.strides()[j] * (self.shape()[j] - 1));
322            }
323        }
324    }
325
326    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
327        unsafe { *self.ptr.get_ptr().add(index * (self.last_stride as usize)) }
328    }
329}
330
331impl<T: CommonBounds> ShapeManipulator for Strided<T> {
332    #[track_caller]
333    fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
334        reshape(self, shape)
335    }
336
337    fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
338        transpose(self, axes)
339    }
340
341    fn expand<S: Into<Shape>>(self, shape: S) -> Self {
342        expand(self, shape)
343    }
344}
345
346impl<T: CommonBounds> StridedIterator for Strided<T> {}
347
348impl<T> StridedHelper for Strided<T> {
349    fn _set_last_strides(&mut self, stride: i64) {
350        self.last_stride = stride;
351    }
352    fn _set_strides(&mut self, strides: Strides) {
353        self.layout.set_strides(strides);
354    }
355    fn _set_shape(&mut self, shape: Shape) {
356        self.layout.set_shape(shape);
357    }
358    fn _layout(&self) -> &Layout {
359        &self.layout
360    }
361}