hpt_iterator/
strided_zip.rs

1use hpt_common::{shape::shape::Shape, strides::strides::Strides};
2use std::sync::Arc;
3
4use crate::iterator_traits::{IterGetSet, StridedIterator, StridedIteratorMap, StridedIteratorZip};
5
6/// A module for zipped strided simd iterator.
7pub mod strided_zip_simd {
8    use hpt_common::{shape::shape::Shape, strides::strides::Strides};
9
10    use crate::iterator_traits::{IterGetSetSimd, StridedIteratorSimd, StridedSimdIteratorZip};
11    use std::sync::Arc;
12
13    /// A single thread SIMD-optimized zipped iterator combining two iterators over tensor elements.
14    ///
15    /// # Example
16    /// use hpt::tensor::Tensor;
17    /// use hpt::StridedIteratorSimd;
18    /// use hpt::TensorIterator;
19    /// let a = Tensor::<f64>::new([0.0, 1.0, 2.0, 3.0]);
20    /// a.iter_simd().zip(a.iter_simd()).for_each(
21    ///     |(x, y)| {
22    ///         println!("{} {}", x, y);
23    ///     },
24    ///     |(x, y)| {
25    ///         println!("{:?} {:?}", x, y);
26    ///     },
27    /// );
28    /// ```
29    #[derive(Clone)]
30    pub struct StridedZipSimd<'a, A: 'a, B: 'a> {
31        /// The first iterator to be zipped.
32        pub(crate) a: A,
33        /// The second iterator to be zipped.
34        pub(crate) b: B,
35        /// Phantom data to associate the lifetime `'a` with the struct.
36        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
37    }
38
39    impl<'a, A, B> IterGetSetSimd for StridedZipSimd<'a, A, B>
40    where
41        A: IterGetSetSimd,
42        B: IterGetSetSimd,
43    {
44        type Item = (<A as IterGetSetSimd>::Item, <B as IterGetSetSimd>::Item);
45
46        type SimdItem = (
47            <A as IterGetSetSimd>::SimdItem,
48            <B as IterGetSetSimd>::SimdItem,
49        );
50
51        fn set_end_index(&mut self, _: usize) {
52            panic!("single thread strided zip does not support set_intervals");
53        }
54
55        fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
56            panic!("single thread strided zip does not support set_intervals");
57        }
58
59        fn set_strides(&mut self, last_stride: Strides) {
60            self.a.set_strides(last_stride.clone());
61            self.b.set_strides(last_stride);
62        }
63
64        fn set_shape(&mut self, shape: Shape) {
65            self.a.set_shape(shape.clone());
66            self.b.set_shape(shape);
67        }
68
69        fn set_prg(&mut self, prg: Vec<i64>) {
70            self.a.set_prg(prg.clone());
71            self.b.set_prg(prg);
72        }
73
74        fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
75            panic!("single thread strided zip does not support intervals");
76        }
77
78        fn strides(&self) -> &Strides {
79            self.a.strides()
80        }
81
82        fn shape(&self) -> &Shape {
83            self.a.shape()
84        }
85
86        fn layout(&self) -> &hpt_common::layout::layout::Layout {
87            self.a.layout()
88        }
89
90        fn broadcast_set_strides(&mut self, shape: &Shape) {
91            self.a.broadcast_set_strides(shape);
92            self.b.broadcast_set_strides(shape);
93        }
94
95        fn outer_loop_size(&self) -> usize {
96            self.a.outer_loop_size()
97        }
98
99        fn inner_loop_size(&self) -> usize {
100            self.a.inner_loop_size()
101        }
102
103        fn next(&mut self) {
104            self.a.next();
105            self.b.next();
106        }
107        fn next_simd(&mut self) {
108            todo!()
109        }
110        #[inline(always)]
111        fn inner_loop_next(&mut self, index: usize) -> Self::Item {
112            (self.a.inner_loop_next(index), self.b.inner_loop_next(index))
113        }
114        fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
115            (
116                self.a.inner_loop_next_simd(index),
117                self.b.inner_loop_next_simd(index),
118            )
119        }
120        fn all_last_stride_one(&self) -> bool {
121            self.a.all_last_stride_one() && self.b.all_last_stride_one()
122        }
123
124        fn lanes(&self) -> Option<usize> {
125            match (self.a.lanes(), self.b.lanes()) {
126                (Some(a), Some(b)) => {
127                    if a == b {
128                        Some(a)
129                    } else {
130                        None
131                    }
132                }
133                _ => None,
134            }
135        }
136    }
137
138    impl<'a, A, B> StridedZipSimd<'a, A, B>
139    where
140        A: 'a + IterGetSetSimd,
141        B: 'a + IterGetSetSimd,
142        <A as IterGetSetSimd>::Item: Send,
143        <B as IterGetSetSimd>::Item: Send,
144    {
145        /// Creates a new `StridedZipSimd` instance by zipping two SIMD-optimized iterators.
146        ///
147        /// # Arguments
148        ///
149        /// * `a` - The first iterator to zip.
150        /// * `b` - The second iterator to zip.
151        ///
152        /// # Returns
153        ///
154        /// A new `StridedZipSimd` instance that combines both iterators for synchronized iteration.
155        pub fn new(a: A, b: B) -> Self {
156            StridedZipSimd {
157                a,
158                b,
159                phantom: std::marker::PhantomData,
160            }
161        }
162    }
163
164    impl<'a, A, B> StridedIteratorSimd for StridedZipSimd<'a, A, B>
165    where
166        A: IterGetSetSimd,
167        B: IterGetSetSimd,
168    {
169    }
170    impl<'a, A, B> StridedSimdIteratorZip for StridedZipSimd<'a, A, B>
171    where
172        A: IterGetSetSimd,
173        B: IterGetSetSimd,
174    {
175    }
176}
177
178/// A single thread `non` SIMD-optimized zipped iterator combining two iterators over tensor elements.
179#[derive(Clone)]
180pub struct StridedZip<'a, A: 'a, B: 'a> {
181    /// The first iterator to be zipped.
182    pub(crate) a: A,
183    /// The second iterator to be zipped.
184    pub(crate) b: B,
185    /// Phantom data to associate the lifetime `'a` with the struct.
186    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
187}
188
189impl<'a, A, B> IterGetSet for StridedZip<'a, A, B>
190where
191    A: IterGetSet,
192    B: IterGetSet,
193{
194    type Item = (<A as IterGetSet>::Item, <B as IterGetSet>::Item);
195
196    fn set_end_index(&mut self, _: usize) {
197        panic!("single thread strided zip does not support set_intervals");
198    }
199
200    fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
201        panic!("single thread strided zip does not support set_intervals");
202    }
203
204    fn set_strides(&mut self, last_stride: Strides) {
205        self.a.set_strides(last_stride.clone());
206        self.b.set_strides(last_stride);
207    }
208
209    fn set_shape(&mut self, shape: Shape) {
210        self.a.set_shape(shape.clone());
211        self.b.set_shape(shape);
212    }
213
214    fn set_prg(&mut self, prg: Vec<i64>) {
215        self.a.set_prg(prg.clone());
216        self.b.set_prg(prg);
217    }
218
219    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
220        panic!("single thread strided zip does not support intervals");
221    }
222
223    fn strides(&self) -> &Strides {
224        self.a.strides()
225    }
226
227    fn shape(&self) -> &Shape {
228        self.a.shape()
229    }
230
231    fn layout(&self) -> &hpt_common::layout::layout::Layout {
232        self.a.layout()
233    }
234
235    fn broadcast_set_strides(&mut self, shape: &Shape) {
236        self.a.broadcast_set_strides(shape);
237        self.b.broadcast_set_strides(shape);
238    }
239
240    fn outer_loop_size(&self) -> usize {
241        self.a.outer_loop_size()
242    }
243
244    fn inner_loop_size(&self) -> usize {
245        self.a.inner_loop_size()
246    }
247
248    fn next(&mut self) {
249        self.a.next();
250        self.b.next();
251    }
252
253    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
254        (self.a.inner_loop_next(index), self.b.inner_loop_next(index))
255    }
256}
257
258impl<'a, A, B> StridedZip<'a, A, B>
259where
260    A: 'a + IterGetSet,
261    B: 'a + IterGetSet,
262    <A as IterGetSet>::Item: Send,
263    <B as IterGetSet>::Item: Send,
264{
265    /// Creates a new `StridedZip` instance by zipping two iterators.
266    ///
267    /// # Arguments
268    ///
269    /// * `a` - The first iterator to zip.
270    /// * `b` - The second iterator to zip.
271    ///
272    /// # Returns
273    ///
274    /// A new `StridedZip` instance that combines both iterators for synchronized iteration.
275    pub fn new(a: A, b: B) -> Self {
276        StridedZip {
277            a,
278            b,
279            phantom: std::marker::PhantomData,
280        }
281    }
282}
283
284impl<'a, A, B> StridedIteratorZip for StridedZip<'a, A, B> {}
285impl<'a, A, B> StridedIteratorMap for StridedZip<'a, A, B> {}
286impl<'a, A, B> StridedIterator for StridedZip<'a, A, B>
287where
288    A: IterGetSet,
289    B: IterGetSet,
290{
291}