1use hpt_common::{shape::shape::Shape, strides::strides::Strides};
2use std::sync::Arc;
3
4use crate::iterator_traits::{IterGetSet, StridedIterator, StridedIteratorMap, StridedIteratorZip};
5
6pub mod strided_zip_simd {
8 use hpt_common::{shape::shape::Shape, strides::strides::Strides};
9
10 use crate::iterator_traits::{IterGetSetSimd, StridedIteratorSimd, StridedSimdIteratorZip};
11 use std::sync::Arc;
12
13 #[derive(Clone)]
30 pub struct StridedZipSimd<'a, A: 'a, B: 'a> {
31 pub(crate) a: A,
33 pub(crate) b: B,
35 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
37 }
38
39 impl<'a, A, B> IterGetSetSimd for StridedZipSimd<'a, A, B>
40 where
41 A: IterGetSetSimd,
42 B: IterGetSetSimd,
43 {
44 type Item = (<A as IterGetSetSimd>::Item, <B as IterGetSetSimd>::Item);
45
46 type SimdItem = (
47 <A as IterGetSetSimd>::SimdItem,
48 <B as IterGetSetSimd>::SimdItem,
49 );
50
51 fn set_end_index(&mut self, _: usize) {
52 panic!("single thread strided zip does not support set_intervals");
53 }
54
55 fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
56 panic!("single thread strided zip does not support set_intervals");
57 }
58
59 fn set_strides(&mut self, last_stride: Strides) {
60 self.a.set_strides(last_stride.clone());
61 self.b.set_strides(last_stride);
62 }
63
64 fn set_shape(&mut self, shape: Shape) {
65 self.a.set_shape(shape.clone());
66 self.b.set_shape(shape);
67 }
68
69 fn set_prg(&mut self, prg: Vec<i64>) {
70 self.a.set_prg(prg.clone());
71 self.b.set_prg(prg);
72 }
73
74 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
75 panic!("single thread strided zip does not support intervals");
76 }
77
78 fn strides(&self) -> &Strides {
79 self.a.strides()
80 }
81
82 fn shape(&self) -> &Shape {
83 self.a.shape()
84 }
85
86 fn layout(&self) -> &hpt_common::layout::layout::Layout {
87 self.a.layout()
88 }
89
90 fn broadcast_set_strides(&mut self, shape: &Shape) {
91 self.a.broadcast_set_strides(shape);
92 self.b.broadcast_set_strides(shape);
93 }
94
95 fn outer_loop_size(&self) -> usize {
96 self.a.outer_loop_size()
97 }
98
99 fn inner_loop_size(&self) -> usize {
100 self.a.inner_loop_size()
101 }
102
103 fn next(&mut self) {
104 self.a.next();
105 self.b.next();
106 }
107 fn next_simd(&mut self) {
108 todo!()
109 }
110 #[inline(always)]
111 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
112 (self.a.inner_loop_next(index), self.b.inner_loop_next(index))
113 }
114 fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
115 (
116 self.a.inner_loop_next_simd(index),
117 self.b.inner_loop_next_simd(index),
118 )
119 }
120 fn all_last_stride_one(&self) -> bool {
121 self.a.all_last_stride_one() && self.b.all_last_stride_one()
122 }
123
124 fn lanes(&self) -> Option<usize> {
125 match (self.a.lanes(), self.b.lanes()) {
126 (Some(a), Some(b)) => {
127 if a == b {
128 Some(a)
129 } else {
130 None
131 }
132 }
133 _ => None,
134 }
135 }
136 }
137
138 impl<'a, A, B> StridedZipSimd<'a, A, B>
139 where
140 A: 'a + IterGetSetSimd,
141 B: 'a + IterGetSetSimd,
142 <A as IterGetSetSimd>::Item: Send,
143 <B as IterGetSetSimd>::Item: Send,
144 {
145 pub fn new(a: A, b: B) -> Self {
156 StridedZipSimd {
157 a,
158 b,
159 phantom: std::marker::PhantomData,
160 }
161 }
162 }
163
164 impl<'a, A, B> StridedIteratorSimd for StridedZipSimd<'a, A, B>
165 where
166 A: IterGetSetSimd,
167 B: IterGetSetSimd,
168 {
169 }
170 impl<'a, A, B> StridedSimdIteratorZip for StridedZipSimd<'a, A, B>
171 where
172 A: IterGetSetSimd,
173 B: IterGetSetSimd,
174 {
175 }
176}
177
178#[derive(Clone)]
180pub struct StridedZip<'a, A: 'a, B: 'a> {
181 pub(crate) a: A,
183 pub(crate) b: B,
185 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
187}
188
189impl<'a, A, B> IterGetSet for StridedZip<'a, A, B>
190where
191 A: IterGetSet,
192 B: IterGetSet,
193{
194 type Item = (<A as IterGetSet>::Item, <B as IterGetSet>::Item);
195
196 fn set_end_index(&mut self, _: usize) {
197 panic!("single thread strided zip does not support set_intervals");
198 }
199
200 fn set_intervals(&mut self, _: Arc<Vec<(usize, usize)>>) {
201 panic!("single thread strided zip does not support set_intervals");
202 }
203
204 fn set_strides(&mut self, last_stride: Strides) {
205 self.a.set_strides(last_stride.clone());
206 self.b.set_strides(last_stride);
207 }
208
209 fn set_shape(&mut self, shape: Shape) {
210 self.a.set_shape(shape.clone());
211 self.b.set_shape(shape);
212 }
213
214 fn set_prg(&mut self, prg: Vec<i64>) {
215 self.a.set_prg(prg.clone());
216 self.b.set_prg(prg);
217 }
218
219 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
220 panic!("single thread strided zip does not support intervals");
221 }
222
223 fn strides(&self) -> &Strides {
224 self.a.strides()
225 }
226
227 fn shape(&self) -> &Shape {
228 self.a.shape()
229 }
230
231 fn layout(&self) -> &hpt_common::layout::layout::Layout {
232 self.a.layout()
233 }
234
235 fn broadcast_set_strides(&mut self, shape: &Shape) {
236 self.a.broadcast_set_strides(shape);
237 self.b.broadcast_set_strides(shape);
238 }
239
240 fn outer_loop_size(&self) -> usize {
241 self.a.outer_loop_size()
242 }
243
244 fn inner_loop_size(&self) -> usize {
245 self.a.inner_loop_size()
246 }
247
248 fn next(&mut self) {
249 self.a.next();
250 self.b.next();
251 }
252
253 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
254 (self.a.inner_loop_next(index), self.b.inner_loop_next(index))
255 }
256}
257
258impl<'a, A, B> StridedZip<'a, A, B>
259where
260 A: 'a + IterGetSet,
261 B: 'a + IterGetSet,
262 <A as IterGetSet>::Item: Send,
263 <B as IterGetSet>::Item: Send,
264{
265 pub fn new(a: A, b: B) -> Self {
276 StridedZip {
277 a,
278 b,
279 phantom: std::marker::PhantomData,
280 }
281 }
282}
283
284impl<'a, A, B> StridedIteratorZip for StridedZip<'a, A, B> {}
285impl<'a, A, B> StridedIteratorMap for StridedZip<'a, A, B> {}
286impl<'a, A, B> StridedIterator for StridedZip<'a, A, B>
287where
288 A: IterGetSet,
289 B: IterGetSet,
290{
291}