1use hpt_common::{shape::shape::Shape, strides::strides::Strides};
2use hpt_traits::tensor::CommonBounds;
3use par_strided_zip_simd::ParStridedZipSimd;
4use rayon::iter::{
5 plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
6 ParallelIterator,
7};
8use std::sync::Arc;
9
10use crate::{
11 iterator_traits::{IterGetSet, IterGetSetSimd, ParStridedIteratorZip, ShapeManipulator},
12 par_strided_map::ParStridedMap,
13};
14
15pub mod par_strided_zip_simd {
17 use std::sync::Arc;
18
19 use crate::CommonBounds;
20 use hpt_common::{shape::shape::Shape, strides::strides::Strides, utils::simd_ref::MutVec};
21 use rayon::iter::{
22 plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
23 ParallelIterator,
24 };
25
26 use crate::{
27 iterator_traits::{
28 IterGetSetSimd, ParStridedIteratorSimd, ParStridedIteratorSimdZip, ShapeManipulator,
29 },
30 par_strided_map::par_strided_map_simd::ParStridedMapSimd,
31 };
32
33 #[derive(Clone)]
37 pub struct ParStridedZipSimd<'a, A: 'a, B: 'a> {
38 pub(crate) a: A,
40 pub(crate) b: B,
42 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
44 }
45
46 impl<'a, A, B> IterGetSetSimd for ParStridedZipSimd<'a, A, B>
47 where
48 A: IterGetSetSimd,
49 B: IterGetSetSimd,
50 {
51 type Item = (<A as IterGetSetSimd>::Item, <B as IterGetSetSimd>::Item);
52
53 type SimdItem = (
54 <A as IterGetSetSimd>::SimdItem,
55 <B as IterGetSetSimd>::SimdItem,
56 );
57
58 fn set_end_index(&mut self, end_index: usize) {
59 self.a.set_end_index(end_index);
60 self.b.set_end_index(end_index);
61 }
62
63 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
64 self.a.set_intervals(intervals.clone());
65 self.b.set_intervals(intervals);
66 }
67
68 fn set_strides(&mut self, last_stride: Strides) {
69 self.a.set_strides(last_stride.clone());
70 self.b.set_strides(last_stride);
71 }
72
73 fn set_shape(&mut self, shape: Shape) {
74 self.a.set_shape(shape.clone());
75 self.b.set_shape(shape);
76 }
77
78 fn set_prg(&mut self, prg: Vec<i64>) {
79 self.a.set_prg(prg.clone());
80 self.b.set_prg(prg);
81 }
82
83 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
84 self.a.intervals()
85 }
86
87 fn strides(&self) -> &Strides {
88 self.a.strides()
89 }
90
91 fn shape(&self) -> &Shape {
92 self.a.shape()
93 }
94
95 fn layout(&self) -> &hpt_common::layout::layout::Layout {
96 self.a.layout()
97 }
98
99 fn broadcast_set_strides(&mut self, shape: &Shape) {
100 self.a.broadcast_set_strides(shape);
101 self.b.broadcast_set_strides(shape);
102 }
103
104 fn outer_loop_size(&self) -> usize {
105 self.a.outer_loop_size()
106 }
107
108 fn inner_loop_size(&self) -> usize {
109 self.a.inner_loop_size()
110 }
111
112 fn next(&mut self) {
113 self.a.next();
114 self.b.next();
115 }
116
117 fn next_simd(&mut self) {
118 self.a.next_simd();
119 self.b.next_simd();
120 }
121
122 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
123 (self.a.inner_loop_next(index), self.b.inner_loop_next(index))
124 }
125
126 fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
127 (
128 self.a.inner_loop_next_simd(index),
129 self.b.inner_loop_next_simd(index),
130 )
131 }
132
133 fn all_last_stride_one(&self) -> bool {
134 self.a.all_last_stride_one() && self.b.all_last_stride_one()
135 }
136
137 fn lanes(&self) -> Option<usize> {
138 match (self.a.lanes(), self.b.lanes()) {
139 (Some(a), Some(b)) => {
140 if a == b {
141 Some(a)
142 } else {
143 None
144 }
145 }
146 _ => None,
147 }
148 }
149 }
150 impl<'a, A, B> ParStridedZipSimd<'a, A, B>
151 where
152 A: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator,
153 B: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator,
154 <A as IterGetSetSimd>::Item: Send,
155 <B as IterGetSetSimd>::Item: Send,
156 {
157 pub fn new(a: A, b: B) -> Self {
168 ParStridedZipSimd {
169 a,
170 b,
171 phantom: std::marker::PhantomData,
172 }
173 }
174 pub fn strided_map_simd<F, F2, T>(
188 self,
189 func: F,
190 func2: F2,
191 ) -> ParStridedMapSimd<'a, Self, <Self as IterGetSetSimd>::Item, F, F2>
192 where
193 F: Fn((&mut T, <Self as IterGetSetSimd>::Item)) + Sync + Send + 'a,
194 F2: Fn((MutVec<'_, T::Vec>, <Self as IterGetSetSimd>::SimdItem)) + Sync + Send + 'a,
195 T: CommonBounds,
196 <A as IterGetSetSimd>::Item: Send,
197 <B as IterGetSetSimd>::Item: Send,
198 T::Vec: Send,
199 A: ShapeManipulator,
200 B: ShapeManipulator,
201 {
202 ParStridedMapSimd {
203 iter: self,
204 f: func,
205 f2: func2,
206 phantom: std::marker::PhantomData,
207 }
208 }
209 }
210
211 impl<'a, A, B> ParStridedIteratorSimdZip for ParStridedZipSimd<'a, A, B>
212 where
213 A: UnindexedProducer + ParallelIterator + IterGetSetSimd,
214 B: UnindexedProducer + ParallelIterator + IterGetSetSimd,
215 {
216 }
217 impl<'a, A, B> ParStridedIteratorSimd for ParStridedZipSimd<'a, A, B>
218 where
219 A: UnindexedProducer + ParallelIterator + IterGetSetSimd,
220 B: UnindexedProducer + ParallelIterator + IterGetSetSimd,
221 <A as IterGetSetSimd>::Item: Send,
222 <B as IterGetSetSimd>::Item: Send,
223 {
224 }
225
226 impl<'a, A, B> UnindexedProducer for ParStridedZipSimd<'a, A, B>
227 where
228 A: UnindexedProducer + ParallelIterator + IterGetSetSimd,
229 B: UnindexedProducer + ParallelIterator + IterGetSetSimd,
230 {
231 type Item = <Self as IterGetSetSimd>::Item;
232
233 fn split(self) -> (Self, Option<Self>) {
234 let (left_a, right_a) = self.a.split();
235 let (left_b, right_b) = self.b.split();
236 if right_a.is_none() {
237 (
238 ParStridedZipSimd {
239 a: left_a,
240 b: left_b,
241 phantom: std::marker::PhantomData,
242 },
243 None,
244 )
245 } else {
246 (
247 ParStridedZipSimd {
248 a: left_a,
249 b: left_b,
250 phantom: std::marker::PhantomData,
251 },
252 Some(ParStridedZipSimd {
253 a: right_a.unwrap(),
254 b: right_b.unwrap(),
255 phantom: std::marker::PhantomData,
256 }),
257 )
258 }
259 }
260
261 fn fold_with<F>(mut self, mut folder: F) -> F
262 where
263 F: Folder<Self::Item>,
264 {
265 let outer_loop_size = self.outer_loop_size();
266 let inner_loop_size = self.inner_loop_size() + 1;
267 for _ in 0..outer_loop_size {
268 for idx in 0..inner_loop_size {
269 folder = folder.consume(self.inner_loop_next(idx));
270 }
271 self.next();
272 }
273 folder
274 }
275 }
276
277 impl<'a, A, B> ParallelIterator for ParStridedZipSimd<'a, A, B>
278 where
279 A: UnindexedProducer + ParallelIterator + IterGetSetSimd,
280 B: UnindexedProducer + ParallelIterator + IterGetSetSimd,
281 <A as IterGetSetSimd>::Item: Send,
282 <B as IterGetSetSimd>::Item: Send,
283 {
284 type Item = (<A as IterGetSetSimd>::Item, <B as IterGetSetSimd>::Item);
285
286 fn drive_unindexed<C>(self, consumer: C) -> C::Result
287 where
288 C: UnindexedConsumer<Self::Item>,
289 {
290 bridge_unindexed(self, consumer)
291 }
292 }
293}
294
295#[derive(Clone)]
305pub struct ParStridedZip<'a, A: 'a, B: 'a> {
306 pub(crate) a: A,
308 pub(crate) b: B,
310 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
312}
313
314impl<'a, A, B> IterGetSet for ParStridedZip<'a, A, B>
315where
316 A: IterGetSet,
317 B: IterGetSet,
318{
319 type Item = (<A as IterGetSet>::Item, <B as IterGetSet>::Item);
320
321 fn set_end_index(&mut self, end_index: usize) {
322 self.a.set_end_index(end_index);
323 self.b.set_end_index(end_index);
324 }
325
326 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
327 self.a.set_intervals(intervals.clone());
328 self.b.set_intervals(intervals);
329 }
330
331 fn set_strides(&mut self, last_stride: Strides) {
332 self.a.set_strides(last_stride.clone());
333 self.b.set_strides(last_stride);
334 }
335
336 fn set_shape(&mut self, shape: Shape) {
337 self.a.set_shape(shape.clone());
338 self.b.set_shape(shape);
339 }
340
341 fn set_prg(&mut self, prg: Vec<i64>) {
342 self.a.set_prg(prg.clone());
343 self.b.set_prg(prg);
344 }
345
346 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
347 self.a.intervals()
348 }
349
350 fn strides(&self) -> &Strides {
351 self.a.strides()
352 }
353
354 fn shape(&self) -> &Shape {
355 self.a.shape()
356 }
357
358 fn layout(&self) -> &hpt_common::layout::layout::Layout {
359 self.a.layout()
360 }
361
362 fn broadcast_set_strides(&mut self, shape: &Shape) {
363 self.a.broadcast_set_strides(shape);
364 self.b.broadcast_set_strides(shape);
365 }
366
367 fn outer_loop_size(&self) -> usize {
368 self.a.outer_loop_size()
369 }
370
371 fn inner_loop_size(&self) -> usize {
372 self.a.inner_loop_size()
373 }
374
375 fn next(&mut self) {
376 self.a.next();
377 self.b.next();
378 }
379
380 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
381 (self.a.inner_loop_next(index), self.b.inner_loop_next(index))
382 }
383}
384
385impl<'a, A, B> ParStridedZip<'a, A, B>
386where
387 A: UnindexedProducer + 'a + IterGetSet + ParallelIterator,
388 B: UnindexedProducer + 'a + IterGetSet + ParallelIterator,
389 <A as IterGetSet>::Item: Send,
390 <B as IterGetSet>::Item: Send,
391{
392 pub fn new(a: A, b: B) -> Self {
403 ParStridedZip {
404 a,
405 b,
406 phantom: std::marker::PhantomData,
407 }
408 }
409 pub fn strided_map<F, T>(
421 self,
422 func: F,
423 ) -> ParStridedMap<'a, Self, <Self as IterGetSet>::Item, F>
424 where
425 F: Fn((&mut T, <Self as IterGetSet>::Item)) + Sync + Send,
426 T: CommonBounds,
427 {
428 ParStridedMap {
429 iter: self,
430 f: func,
431 phantom: std::marker::PhantomData,
432 }
433 }
434}
435
436impl<'a, A, B> UnindexedProducer for ParStridedZip<'a, A, B>
437where
438 A: UnindexedProducer + ParallelIterator + IterGetSet,
439 B: UnindexedProducer + ParallelIterator + IterGetSet,
440{
441 type Item = <Self as IterGetSet>::Item;
442
443 fn split(self) -> (Self, Option<Self>) {
444 let (left_a, right_a) = self.a.split();
445 let (left_b, right_b) = self.b.split();
446 if right_a.is_none() {
447 (
448 ParStridedZip {
449 a: left_a,
450 b: left_b,
451 phantom: std::marker::PhantomData,
452 },
453 None,
454 )
455 } else {
456 (
457 ParStridedZip {
458 a: left_a,
459 b: left_b,
460 phantom: std::marker::PhantomData,
461 },
462 Some(ParStridedZip {
463 a: right_a.unwrap(),
464 b: right_b.unwrap(),
465 phantom: std::marker::PhantomData,
466 }),
467 )
468 }
469 }
470
471 fn fold_with<F>(mut self, mut folder: F) -> F
472 where
473 F: Folder<Self::Item>,
474 {
475 let outer_loop_size = self.outer_loop_size();
476 let inner_loop_size = self.inner_loop_size() + 1;
477 for _ in 0..outer_loop_size {
478 for idx in 0..inner_loop_size {
479 folder = folder.consume(self.inner_loop_next(idx));
480 }
481 self.next();
482 }
483 folder
484 }
485}
486
487impl<'a, A, B> ParallelIterator for ParStridedZip<'a, A, B>
488where
489 A: UnindexedProducer + ParallelIterator + IterGetSet,
490 B: UnindexedProducer + ParallelIterator + IterGetSet,
491 <A as IterGetSet>::Item: Send,
492 <B as IterGetSet>::Item: Send,
493{
494 type Item = (<A as IterGetSet>::Item, <B as IterGetSet>::Item);
495
496 fn drive_unindexed<C>(self, consumer: C) -> C::Result
497 where
498 C: UnindexedConsumer<Self::Item>,
499 {
500 bridge_unindexed(self, consumer)
501 }
502}
503
504impl<'a, A, B> ParStridedIteratorZip for ParStridedZip<'a, A, B>
505where
506 A: UnindexedProducer + ParallelIterator + IterGetSet,
507 B: UnindexedProducer + ParallelIterator + IterGetSet,
508 <A as IterGetSet>::Item: Send,
509 <B as IterGetSet>::Item: Send,
510{
511}
512
513impl<'a, A, B> ShapeManipulator for ParStridedZip<'a, A, B>
514where
515 A: UnindexedProducer + 'a + IterGetSet + ParallelIterator + ShapeManipulator,
516 B: UnindexedProducer + 'a + IterGetSet + ParallelIterator + ShapeManipulator,
517 <A as IterGetSet>::Item: Send,
518 <B as IterGetSet>::Item: Send,
519{
520 fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
521 let tmp: Shape = shape.into();
522 let a = self.a.reshape(tmp.clone());
523 let b = self.b.reshape(tmp);
524 ParStridedZip::new(a, b)
525 }
526
527 fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
528 let axes: hpt_common::axis::axis::Axis = axes.into();
529 let a = self.a.transpose(axes.clone());
530 let b = self.b.transpose(axes);
531 ParStridedZip::new(a, b)
532 }
533
534 fn expand<S: Into<Shape>>(self, shape: S) -> Self {
535 let tmp: Shape = shape.into();
536 let a = self.a.expand(tmp.clone());
537 let b = self.b.expand(tmp);
538 ParStridedZip::new(a, b)
539 }
540}
541
542impl<'a, A, B> ShapeManipulator for ParStridedZipSimd<'a, A, B>
543where
544 A: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator + ShapeManipulator,
545 B: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator + ShapeManipulator,
546 <A as IterGetSetSimd>::Item: Send,
547 <B as IterGetSetSimd>::Item: Send,
548{
549 fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
550 let tmp: Shape = shape.into();
551 let a = self.a.reshape(tmp.clone());
552 let b = self.b.reshape(tmp);
553 ParStridedZipSimd::new(a, b)
554 }
555
556 fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
557 let axes: hpt_common::axis::axis::Axis = axes.into();
558 let a = self.a.transpose(axes.clone());
559 let b = self.b.transpose(axes);
560 ParStridedZipSimd::new(a, b)
561 }
562
563 fn expand<S: Into<Shape>>(self, shape: S) -> Self {
564 let tmp: Shape = shape.into();
565 let a = self.a.expand(tmp.clone());
566 let b = self.b.expand(tmp);
567 ParStridedZipSimd::new(a, b)
568 }
569}