1use crate::{
2 iterator_traits::{IterGetSet, ParStridedHelper, ParStridedIteratorZip, ShapeManipulator},
3 par_strided_fold::ParStridedFold,
4 par_strided_map::ParStridedMap,
5 shape_manipulate::{par_expand, par_reshape, par_transpose},
6};
7use hpt_common::{
8 axis::axis::Axis,
9 layout::layout::Layout,
10 shape::shape::Shape,
11 shape::shape_utils::{mt_intervals, try_pad_shape},
12 strides::strides::Strides,
13 strides::strides_utils::preprocess_strides,
14 utils::pointer::Pointer,
15};
16use hpt_traits::tensor::{CommonBounds, TensorInfo};
17use rayon::iter::{
18 plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
19 ParallelIterator,
20};
21use std::sync::Arc;
22
23pub mod par_strided_simd {
25 use hpt_types::vectors::traits::VecTrait;
26 use std::sync::Arc;
27
28 use hpt_common::{
29 axis::axis::Axis,
30 layout::layout::Layout,
31 shape::shape::Shape,
32 shape::shape_utils::{mt_intervals, try_pad_shape},
33 strides::strides::Strides,
34 strides::strides_utils::preprocess_strides,
35 utils::pointer::Pointer,
36 utils::simd_ref::MutVec,
37 };
38 use hpt_traits::{CommonBounds, TensorInfo};
39 use rayon::iter::{
40 plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
41 ParallelIterator,
42 };
43
44 use crate::{
45 iterator_traits::{
46 IterGetSetSimd, ParStridedHelper, ParStridedIteratorSimd, ParStridedIteratorSimdZip,
47 ShapeManipulator,
48 },
49 par_strided_map::par_strided_map_simd::ParStridedMapSimd,
50 shape_manipulate::{par_expand, par_reshape, par_transpose},
51 };
52
53 #[derive(Clone)]
57 pub struct ParStridedSimd<T: Send + Copy + Sync> {
58 pub(crate) ptr: Pointer<T>,
60 pub(crate) layout: Layout,
62 pub(crate) prg: Vec<i64>,
64 pub(crate) intervals: Arc<Vec<(usize, usize)>>,
66 pub(crate) start_index: usize,
68 pub(crate) end_index: usize,
70 pub(crate) last_stride: i64,
72 }
73 impl<T: CommonBounds> ParStridedSimd<T> {
74 pub fn shape(&self) -> &Shape {
76 self.layout.shape()
77 }
78
79 pub fn strides(&self) -> &Strides {
81 self.layout.strides()
82 }
83
84 pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
86 let inner_loop_size = *tensor.shape().last().unwrap() as usize;
87 let outer_loop_size = tensor.size() / inner_loop_size;
88 let num_threads;
89 if outer_loop_size < rayon::current_num_threads() {
90 num_threads = outer_loop_size;
91 } else {
92 num_threads = rayon::current_num_threads();
93 }
94 let intervals = mt_intervals(outer_loop_size, num_threads);
95 let len = intervals.len();
96 ParStridedSimd {
97 ptr: tensor.ptr(),
98 layout: tensor.layout().clone(),
99 prg: vec![],
100 intervals: Arc::new(intervals),
101 start_index: 0,
102 end_index: len,
103 last_stride: *tensor.strides().last().unwrap(),
104 }
105 }
106
107 pub fn strided_map_simd<'a, F, F2>(
109 self,
110 f: F,
111 vec_op: F2,
112 ) -> ParStridedMapSimd<'a, ParStridedSimd<T>, T, F, F2>
113 where
114 F: Fn((&mut T, <Self as IterGetSetSimd>::Item)) + Sync + Send + 'a,
115 <Self as IterGetSetSimd>::Item: Send,
116 F2: Send + Sync + Copy + Fn((MutVec<'_, T::Vec>, <Self as IterGetSetSimd>::SimdItem)),
117 {
118 {
119 ParStridedMapSimd {
120 iter: self,
121 f,
122 f2: vec_op,
123 phantom: std::marker::PhantomData,
124 }
125 }
126 }
127 }
128
129 impl<T: CommonBounds> ParStridedIteratorSimdZip for ParStridedSimd<T> {}
130 impl<T: CommonBounds> ParStridedIteratorSimd for ParStridedSimd<T> {}
131
132 impl<T: CommonBounds> IterGetSetSimd for ParStridedSimd<T>
133 where
134 T::Vec: Send,
135 {
136 type Item = T;
137
138 type SimdItem = T::Vec;
139
140 fn set_end_index(&mut self, end_index: usize) {
141 self.end_index = end_index;
142 }
143
144 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
145 self.intervals = intervals;
146 }
147
148 fn set_strides(&mut self, strides: Strides) {
149 self.layout.set_strides(strides);
150 }
151
152 fn set_shape(&mut self, shape: Shape) {
153 self.layout.set_shape(shape);
154 }
155
156 fn set_prg(&mut self, prg: Vec<i64>) {
157 self.prg = prg;
158 }
159
160 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
161 &self.intervals
162 }
163
164 fn strides(&self) -> &Strides {
165 self.layout.strides()
166 }
167
168 fn shape(&self) -> &Shape {
169 self.layout.shape()
170 }
171
172 fn layout(&self) -> &Layout {
173 &self.layout
174 }
175
176 fn broadcast_set_strides(&mut self, shape: &Shape) {
177 let self_shape = try_pad_shape(self.shape(), shape.len());
178 self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
179 self.last_stride = self.strides()[self.strides().len() - 1];
180 }
181
182 fn outer_loop_size(&self) -> usize {
183 self.intervals[self.start_index].1 - self.intervals[self.start_index].0
184 }
185
186 fn inner_loop_size(&self) -> usize {
187 self.shape().last().unwrap().clone() as usize
188 }
189
190 fn next(&mut self) {
191 for j in (0..(self.shape().len() as i64) - 1).rev() {
192 let j = j as usize;
193 if self.prg[j] < self.shape()[j] {
194 self.prg[j] += 1;
195 self.ptr.offset(self.strides()[j]);
196 break;
197 } else {
198 self.prg[j] = 0;
199 self.ptr.offset(-self.strides()[j] * self.shape()[j]);
200 }
201 }
202 }
203
204 fn next_simd(&mut self) {}
205
206 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
207 unsafe { *self.ptr.get_ptr().add(index * (self.last_stride as usize)) }
208 }
209
210 #[inline(always)]
211 fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
212 unsafe { T::Vec::from_ptr(self.ptr.get_ptr().add(index * T::Vec::SIZE)) }
213 }
214
215 fn all_last_stride_one(&self) -> bool {
216 self.last_stride == 1
217 }
218
219 fn lanes(&self) -> Option<usize> {
220 Some(T::Vec::SIZE)
221 }
222 }
223
224 impl<T> ParallelIterator for ParStridedSimd<T>
225 where
226 T: CommonBounds,
227 T::Vec: Send,
228 {
229 type Item = T;
230
231 fn drive_unindexed<C>(self, consumer: C) -> C::Result
232 where
233 C: UnindexedConsumer<Self::Item>,
234 {
235 bridge_unindexed(self, consumer)
236 }
237 }
238
239 impl<T> UnindexedProducer for ParStridedSimd<T>
240 where
241 T: CommonBounds,
242 T::Vec: Send,
243 {
244 type Item = T;
245
246 fn split(mut self) -> (Self, Option<Self>) {
247 if self.end_index - self.start_index <= 1 {
248 let mut curent_shape_prg: Vec<i64> = vec![0; self.shape().len()];
249 let mut amount =
250 self.intervals[self.start_index].0 * (*self.shape().last().unwrap() as usize);
251 for j in (0..self.shape().len()).rev() {
252 curent_shape_prg[j] = (amount as i64) % self.shape()[j];
253 amount /= self.shape()[j] as usize;
254 self.ptr += curent_shape_prg[j] * self.strides()[j];
255 }
256 self.prg = curent_shape_prg;
257 let mut new_shape = self.shape().to_vec();
258 new_shape.iter_mut().for_each(|x| {
259 *x -= 1;
260 });
261 self.last_stride = self.strides()[self.strides().len() - 1];
262 self.set_shape(Shape::from(new_shape));
263 return (self, None);
264 }
265 let _left_interval = &self.intervals[self.start_index..self.end_index];
266 let left = _left_interval.len() / 2;
267 let right = _left_interval.len() / 2 + (_left_interval.len() % 2);
268 (
269 ParStridedSimd {
270 ptr: self.ptr.clone(),
271 layout: self.layout.clone(),
272 prg: vec![],
273 intervals: self.intervals.clone(),
274 start_index: self.start_index,
275 end_index: self.start_index + left,
276 last_stride: self.last_stride,
277 },
278 Some(ParStridedSimd {
279 ptr: self.ptr.clone(),
280 layout: self.layout.clone(),
281 prg: vec![],
282 intervals: self.intervals.clone(),
283 start_index: self.start_index + left,
284 end_index: self.start_index + left + right,
285 last_stride: self.last_stride,
286 }),
287 )
288 }
289
290 fn fold_with<F>(self, folder: F) -> F
291 where
292 F: Folder<Self::Item>,
293 {
294 folder
295 }
296 }
297
298 impl<T: CommonBounds> ParStridedHelper for ParStridedSimd<T> {
299 fn _set_last_strides(&mut self, stride: i64) {
300 self.last_stride = stride;
301 }
302
303 fn _set_strides(&mut self, strides: Strides) {
304 self.layout.set_strides(strides);
305 }
306
307 fn _set_shape(&mut self, shape: Shape) {
308 self.layout.set_shape(shape);
309 }
310
311 fn _layout(&self) -> &Layout {
312 &self.layout
313 }
314
315 fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
316 self.intervals = intervals;
317 }
318
319 fn _set_end_index(&mut self, end_index: usize) {
320 self.end_index = end_index;
321 }
322 }
323
324 impl<T: CommonBounds> ShapeManipulator for ParStridedSimd<T>
325 where
326 T::Vec: Send,
327 {
328 fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
329 par_reshape(self, shape)
330 }
331
332 fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
333 par_transpose(self, axes)
334 }
335
336 fn expand<S: Into<Shape>>(self, shape: S) -> Self {
337 par_expand(self, shape)
338 }
339 }
340}
341
342#[derive(Clone)]
346pub struct ParStrided<T> {
347 pub(crate) ptr: Pointer<T>,
349 pub(crate) layout: Layout,
351 pub(crate) prg: Vec<i64>,
353 pub(crate) intervals: Arc<Vec<(usize, usize)>>,
355 pub(crate) start_index: usize,
357 pub(crate) end_index: usize,
359 pub(crate) last_stride: i64,
361}
362
363impl<T: CommonBounds> ParStrided<T> {
364 pub fn shape(&self) -> &Shape {
370 self.layout.shape()
371 }
372 pub fn strides(&self) -> &Strides {
378 self.layout.strides()
379 }
380 pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
394 let inner_loop_size = tensor.shape()[tensor.shape().len() - 1] as usize;
395 let outer_loop_size = tensor.size() / inner_loop_size;
396 let num_threads;
397 if outer_loop_size < rayon::current_num_threads() {
398 num_threads = outer_loop_size;
399 } else {
400 num_threads = rayon::current_num_threads();
401 }
402 let intervals = mt_intervals(outer_loop_size, num_threads);
403 let len = intervals.len();
404 ParStrided {
405 ptr: tensor.ptr(),
406 layout: tensor.layout().clone(),
407 prg: vec![],
408 intervals: Arc::new(intervals),
409 start_index: 0,
410 end_index: len,
411 last_stride: tensor.strides()[tensor.strides().len() - 1],
412 }
413 }
414 pub fn par_strided_fold<ID, F>(self, identity: ID, fold_op: F) -> ParStridedFold<Self, ID, F>
433 where
434 F: Fn(ID, T) -> ID + Sync + Send + Copy,
435 ID: Sync + Send + Copy,
436 {
437 ParStridedFold {
438 iter: self,
439 identity,
440 fold_op,
441 }
442 }
443 pub fn strided_map<'a, F, U>(self, f: F) -> ParStridedMap<'a, ParStrided<T>, T, F>
461 where
462 F: Fn((&mut U, T)) + Sync + Send + 'a,
463 U: CommonBounds,
464 {
465 ParStridedMap {
466 iter: self,
467 f,
468 phantom: std::marker::PhantomData,
469 }
470 }
471}
472
473impl<T: CommonBounds> ParStridedIteratorZip for ParStrided<T> {}
474
475impl<T: CommonBounds> IterGetSet for ParStrided<T> {
476 type Item = T;
477
478 fn set_end_index(&mut self, end_index: usize) {
479 self.end_index = end_index;
480 }
481
482 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
483 self.intervals = intervals;
484 }
485
486 fn set_strides(&mut self, strides: Strides) {
487 self.layout.set_strides(strides);
488 }
489
490 fn set_shape(&mut self, shape: Shape) {
491 self.layout.set_shape(shape);
492 }
493
494 fn set_prg(&mut self, prg: Vec<i64>) {
495 self.prg = prg;
496 }
497
498 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
499 &self.intervals
500 }
501
502 fn strides(&self) -> &Strides {
503 self.layout.strides()
504 }
505
506 fn shape(&self) -> &Shape {
507 self.layout.shape()
508 }
509
510 fn layout(&self) -> &Layout {
511 &self.layout
512 }
513
514 fn broadcast_set_strides(&mut self, shape: &Shape) {
515 let self_shape = try_pad_shape(self.shape(), shape.len());
516 self.set_strides(preprocess_strides(&self_shape, self.strides()).into());
517 self.last_stride = self.strides()[self.strides().len() - 1];
518 }
519
520 fn outer_loop_size(&self) -> usize {
521 self.intervals[self.start_index].1 - self.intervals[self.start_index].0
522 }
523
524 fn inner_loop_size(&self) -> usize {
525 self.shape().last().unwrap().clone() as usize
526 }
527
528 fn next(&mut self) {
529 for j in (0..(self.shape().len() as i64) - 1).rev() {
530 let j = j as usize;
531 if self.prg[j] < self.shape()[j] {
532 self.prg[j] += 1;
533 self.ptr.offset(self.strides()[j]);
534 break;
535 } else {
536 self.prg[j] = 0;
537 self.ptr.offset(-self.strides()[j] * self.shape()[j]);
538 }
539 }
540 }
541
542 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
543 unsafe { *self.ptr.get_ptr().add(index * (self.last_stride as usize)) }
544 }
545}
546
547impl<T> ParallelIterator for ParStrided<T>
548where
549 T: CommonBounds,
550 T::Vec: Send,
551{
552 type Item = T;
553
554 fn drive_unindexed<C>(self, consumer: C) -> C::Result
555 where
556 C: UnindexedConsumer<Self::Item>,
557 {
558 bridge_unindexed(self, consumer)
559 }
560}
561
562impl<T> UnindexedProducer for ParStrided<T>
563where
564 T: CommonBounds,
565 T::Vec: Send,
566{
567 type Item = T;
568
569 fn split(mut self) -> (Self, Option<Self>) {
570 if self.end_index - self.start_index <= 1 {
571 let mut curent_shape_prg: Vec<i64> = vec![0; self.shape().len()];
572 let mut amount =
573 self.intervals[self.start_index].0 * (*self.shape().last().unwrap() as usize);
574 let mut index = 0;
575 for j in (0..self.shape().len()).rev() {
576 curent_shape_prg[j] = (amount as i64) % self.shape()[j];
577 amount /= self.shape()[j] as usize;
578 index += curent_shape_prg[j] * self.strides()[j];
579 }
580 self.ptr.offset(index);
581 self.prg = curent_shape_prg;
582 let mut new_shape = self.shape().to_vec();
583 new_shape.iter_mut().for_each(|x| {
584 *x -= 1;
585 });
586 self.last_stride = self.strides()[self.strides().len() - 1];
587 self.set_shape(Shape::from(new_shape));
588 return (self, None);
589 }
590 let _left_interval = &self.intervals[self.start_index..self.end_index];
591 let left = _left_interval.len() / 2;
592 let right = _left_interval.len() / 2 + (_left_interval.len() % 2);
593 (
594 ParStrided {
595 ptr: self.ptr.clone(),
596 layout: self.layout.clone(),
597 prg: vec![],
598 intervals: self.intervals.clone(),
599 start_index: self.start_index,
600 end_index: self.start_index + left,
601 last_stride: self.last_stride,
602 },
603 Some(ParStrided {
604 ptr: self.ptr.clone(),
605 layout: self.layout.clone(),
606 prg: vec![],
607 intervals: self.intervals.clone(),
608 start_index: self.start_index + left,
609 end_index: self.start_index + left + right,
610 last_stride: self.last_stride,
611 }),
612 )
613 }
614
615 fn fold_with<F>(self, folder: F) -> F
616 where
617 F: Folder<Self::Item>,
618 {
619 folder
620 }
621}
622
623impl<T> ParStridedHelper for ParStrided<T> {
624 fn _set_last_strides(&mut self, last_stride: i64) {
625 self.last_stride = last_stride;
626 }
627
628 fn _set_strides(&mut self, strides: Strides) {
629 self.layout.set_strides(strides);
630 }
631
632 fn _set_shape(&mut self, shape: Shape) {
633 self.layout.set_shape(shape);
634 }
635
636 fn _layout(&self) -> &Layout {
637 &self.layout
638 }
639
640 fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
641 self.intervals = intervals;
642 }
643
644 fn _set_end_index(&mut self, end_index: usize) {
645 self.end_index = end_index;
646 }
647}
648
649impl<T: CommonBounds> ShapeManipulator for ParStrided<T> {
650 #[track_caller]
651 fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
652 par_reshape(self, shape)
653 }
654
655 fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self {
656 par_transpose(self, axes)
657 }
658
659 fn expand<S: Into<Shape>>(self, shape: S) -> Self {
660 par_expand(self, shape)
661 }
662}