hpt_iterator/iterator_traits.rs
1use std::sync::Arc;
2
3use hpt_common::{
4 axis::axis::Axis,
5 layout::layout::Layout,
6 shape::shape::Shape,
7 shape::shape_utils::{mt_intervals, predict_broadcast_shape},
8 strides::strides::Strides,
9};
10use hpt_traits::CommonBounds;
11use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
12
13use crate::{
14 par_strided_zip::{par_strided_zip_simd::ParStridedZipSimd, ParStridedZip},
15 strided_map::StridedMap,
16 strided_zip::{strided_zip_simd::StridedZipSimd, StridedZip},
17 with_simd::WithSimd,
18};
19
20/// A trait for getting and setting values from an iterator.
21pub trait IterGetSet {
22 /// The type of the iterator's elements.
23 type Item;
24 /// set the end index of the iterator, this is used when rayon perform data splitting
25 fn set_end_index(&mut self, end_index: usize);
26 /// set the chunk intervals of the iterator, we chunk the outer loop
27 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>);
28 /// set the strides for the iterator, we call this method normally when we do broadcasting
29 fn set_strides(&mut self, strides: Strides);
30 /// set the shape for the iterator, we call this method normally when we do broadcasting
31 fn set_shape(&mut self, shape: Shape);
32 /// set the loop progress for the iterator
33 fn set_prg(&mut self, prg: Vec<i64>);
34 /// get the intervals of the iterator
35 fn intervals(&self) -> &Arc<Vec<(usize, usize)>>;
36 /// get the strides of the iterator
37 fn strides(&self) -> &Strides;
38 /// get the shape of the iterator
39 fn shape(&self) -> &Shape;
40 /// get the layout of the iterator
41 fn layout(&self) -> &Layout;
42 /// set the strides for all the iterators
43 fn broadcast_set_strides(&mut self, shape: &Shape);
44 /// get the outer loop size
45 fn outer_loop_size(&self) -> usize;
46 /// get the inner loop size
47 fn inner_loop_size(&self) -> usize;
48 /// update the loop progress
49 fn next(&mut self);
50 /// get the next element of the inner loop
51 fn inner_loop_next(&mut self, index: usize) -> Self::Item;
52}
53
54/// A trait for getting and setting values from an simd iterator
55pub trait IterGetSetSimd {
56 /// The type of the iterator's elements.
57 type Item;
58 /// The type of the iterator's simd elements.
59 type SimdItem;
60 /// set the end index of the iterator, this is used when rayon perform data splitting
61 fn set_end_index(&mut self, end_index: usize);
62 /// set the chunk intervals of the iterator, we chunk the outer loop
63 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>);
64 /// set the strides for the iterator, we call this method normally when we do broadcasting
65 fn set_strides(&mut self, last_stride: Strides);
66 /// set the shape for the iterator, we call this method normally when we do broadcasting
67 fn set_shape(&mut self, shape: Shape);
68 /// set the loop progress for the iterator
69 fn set_prg(&mut self, prg: Vec<i64>);
70 /// get the intervals of the iterator
71 fn intervals(&self) -> &Arc<Vec<(usize, usize)>>;
72 /// get the strides of the iterator
73 fn strides(&self) -> &Strides;
74 /// get the shape of the iterator
75 fn shape(&self) -> &Shape;
76 /// get the layout of the iterator
77 fn layout(&self) -> &Layout;
78 /// set the strides for all the iterators
79 fn broadcast_set_strides(&mut self, shape: &Shape);
80 /// get the outer loop size
81 fn outer_loop_size(&self) -> usize;
82 /// get the inner loop size
83 fn inner_loop_size(&self) -> usize;
84 /// update the loop progress, this is called when we don't do simd iteration
85 fn next(&mut self);
86 /// update the loop progress, this is called when we do simd iteration
87 fn next_simd(&mut self);
88 /// get the next element of the inner loop
89 fn inner_loop_next(&mut self, index: usize) -> Self::Item;
90 /// get the next vector of the inner loop, this is called when we do simd iteration
91 fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem;
92 /// check if all iterators' last stride is one, only when all iterators' last stride is one, we can do simd iteration
93 fn all_last_stride_one(&self) -> bool;
94 /// get the simd vector size, if any of the iterator returned different vector size, it will return None
95 fn lanes(&self) -> Option<usize>;
96}
97
98/// A trait for performing shape manipulation on an iterator.
99pub trait ShapeManipulator {
100 /// reshape the iterator, we can change the iteration behavior by changing the shape
101 fn reshape<S: Into<Shape>>(self, shape: S) -> Self;
102 /// transpose the iterator, we can change the iteration behavior by changing the axes
103 fn transpose<AXIS: Into<Axis>>(self, axes: AXIS) -> Self;
104 /// expand the iterator, we can change the iteration behavior by changing the shape
105 fn expand<S: Into<Shape>>(self, shape: S) -> Self;
106}
107
108/// A trait for performing single thread iteration over an iterator.
109pub trait StridedIterator: IterGetSet
110where
111 Self: Sized,
112{
113 /// perform scalar iteration, this method is for single thread iterator
114 fn for_each<F>(mut self, func: F)
115 where
116 F: Fn(Self::Item),
117 {
118 let outer_loop_size = self.outer_loop_size();
119 let inner_loop_size = self.inner_loop_size(); // we don't need to add 1 as we didn't subtract shape by 1
120 self.set_prg(vec![0; self.shape().len()]);
121 for _ in 0..outer_loop_size {
122 for idx in 0..inner_loop_size {
123 func(self.inner_loop_next(idx));
124 }
125 self.next();
126 }
127 }
128 /// perform scalar iteration with init, this method is for single thread iterator
129 fn for_each_init<F, INIT, T>(mut self, init: INIT, func: F)
130 where
131 F: Fn(&mut T, Self::Item),
132 INIT: Fn() -> T,
133 {
134 let outer_loop_size = self.outer_loop_size();
135 let inner_loop_size = self.inner_loop_size();
136 self.set_prg(vec![0; self.shape().len()]);
137 let mut init = init();
138 for _ in 0..outer_loop_size {
139 for idx in 0..inner_loop_size {
140 func(&mut init, self.inner_loop_next(idx));
141 }
142 self.next();
143 }
144 }
145}
146
147/// A trait to zip two iterators together.
148pub trait StridedIteratorZip: Sized {
149 /// Combines this iterator with another iterator, enabling simultaneous iteration.
150 ///
151 /// This method zips together `self` and `other` into a `StridedZip` iterator, allowing for synchronized
152 ///
153 /// iteration over both iterators. This is particularly useful for operations that require processing
154 ///
155 /// elements from two tensors in parallel, such as element-wise arithmetic operations.
156 ///
157 /// # Arguments
158 ///
159 /// * `other` - The other iterator to zip with. It must implement the `IterGetSet` trait, and
160 /// its associated `Item` type must be `Send`.
161 ///
162 /// # Returns
163 ///
164 /// A `StridedZip` instance that encapsulates both `self` and `other`, allowing for synchronized
165 ///
166 /// iteration over their elements.
167 ///
168 /// # Panics
169 ///
170 /// This method will panic if the shapes of `self` and `other` cannot be broadcasted together.
171 #[track_caller]
172 fn zip<'a, C>(self, other: C) -> StridedZip<'a, Self, C>
173 where
174 C: IterGetSet + ShapeManipulator,
175 Self: IterGetSet + ShapeManipulator,
176 <C as IterGetSet>::Item: Send,
177 <Self as IterGetSet>::Item: Send,
178 {
179 let new_shape = predict_broadcast_shape(&self.shape(), &other.shape())
180 .expect("Cannot broadcast shapes");
181
182 let mut a = self.reshape(new_shape.clone());
183 let mut b = other.reshape(new_shape.clone());
184
185 a.set_shape(new_shape.clone());
186 b.set_shape(new_shape.clone());
187 StridedZip::new(a, b)
188 }
189}
190
191/// A trait to zip two parallel iterators together.
192pub trait ParStridedIteratorZip: Sized + IterGetSet {
193 /// Combines this iterator with another iterator, enabling simultaneous parallel iteration.
194 ///
195 /// This method performs shape broadcasting between `self` and `other` to ensure that both iterators
196 /// iterate over tensors with compatible shapes. It adjusts the strides and shapes of both iterators
197 /// to match the broadcasted shape and then returns a `ParStridedZip` that allows for synchronized
198 /// parallel iteration over both iterators.
199 ///
200 /// # Arguments
201 ///
202 /// * `other` - The other iterator to zip with. It must implement the `IterGetSet`, `UnindexedProducer`,
203 /// and `ParallelIterator` traits, and its associated `Item` type must be `Send`.
204 ///
205 /// # Returns
206 ///
207 /// A `ParStridedZip` instance that zips together `self` and `other`, enabling synchronized
208 /// parallel iteration over their elements.
209 ///
210 /// # Panics
211 ///
212 /// This method will panic if the shapes of `self` and `other` cannot be broadcasted together.
213 /// Ensure that the shapes are compatible before calling this method.
214 #[track_caller]
215 fn zip<'a, C>(mut self, mut other: C) -> ParStridedZip<'a, Self, C>
216 where
217 C: UnindexedProducer + 'a + IterGetSet + ParallelIterator + ShapeManipulator,
218 <C as IterGetSet>::Item: Send,
219 Self: UnindexedProducer + ParallelIterator + ShapeManipulator,
220 <Self as IterGetSet>::Item: Send,
221 {
222 let new_shape = predict_broadcast_shape(&self.shape(), &other.shape())
223 .expect("Cannot broadcast shapes");
224
225 let inner_loop_size = new_shape[new_shape.len() - 1] as usize;
226 let outer_loop_size = (new_shape.size() as usize) / inner_loop_size;
227
228 let num_threads;
229 if outer_loop_size < rayon::current_num_threads() {
230 num_threads = outer_loop_size;
231 } else {
232 num_threads = rayon::current_num_threads();
233 }
234 let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
235 let len = intervals.len();
236 self.set_intervals(intervals.clone());
237 self.set_end_index(len);
238 other.set_intervals(intervals.clone());
239 other.set_end_index(len);
240
241 let mut a = self.reshape(new_shape.clone());
242 let mut b = other.reshape(new_shape.clone());
243
244 a.set_shape(new_shape.clone());
245 b.set_shape(new_shape.clone());
246
247 ParStridedZip::new(a, b)
248 }
249}
250
251/// A trait to zip two parallel iterators together.
252pub trait ParStridedIteratorSimdZip: Sized + IterGetSetSimd {
253 /// Combines this `ParStridedZipSimd` iterator with another SIMD-optimized iterator, enabling simultaneous parallel iteration.
254 ///
255 /// This method performs shape broadcasting between `self` and `other` to ensure that both iterators
256 /// iterate over tensors with compatible shapes. It calculates the appropriate iteration intervals based
257 /// on the new broadcasted shape and configures both iterators accordingly. Finally, it returns a new
258 /// `ParStridedZipSimd` instance that allows for synchronized parallel iteration over the combined iterators.
259 ///
260 /// # Arguments
261 ///
262 /// * `other` - The third iterator to zip with. It must implement the `IterGetSetSimd`, `UnindexedProducer`,
263 /// `ShapeManipulator`, and `ParallelIterator` traits,
264 /// and its associated `Item` type must be `Send`.
265 ///
266 /// # Returns
267 ///
268 /// A new `ParStridedZipSimd` instance that combines `self` and `other` for synchronized parallel iteration over all three iterators.
269 ///
270 /// # Panics
271 ///
272 /// This method will panic if the shapes of `self` and `other` cannot be broadcasted together.
273 /// Ensure that the shapes are compatible before calling this method.
274 #[track_caller]
275 fn zip<'a, C>(mut self, mut other: C) -> ParStridedZipSimd<'a, Self, C>
276 where
277 C: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator + ShapeManipulator,
278 <C as IterGetSetSimd>::Item: Send,
279 Self: UnindexedProducer + ParallelIterator + ShapeManipulator,
280 <Self as IterGetSetSimd>::Item: Send,
281 {
282 let new_shape = predict_broadcast_shape(&self.shape(), &other.shape())
283 .expect("Cannot broadcast shapes");
284
285 let inner_loop_size = new_shape[new_shape.len() - 1] as usize;
286 let outer_loop_size = (new_shape.size() as usize) / inner_loop_size;
287
288 let num_threads;
289 if outer_loop_size < rayon::current_num_threads() {
290 num_threads = outer_loop_size;
291 } else {
292 num_threads = rayon::current_num_threads();
293 }
294 let intervals = Arc::new(mt_intervals(outer_loop_size, num_threads));
295 let len = intervals.len();
296 self.set_intervals(intervals.clone());
297 self.set_end_index(len);
298 other.set_intervals(intervals.clone());
299 other.set_end_index(len);
300
301 let mut a = self.reshape(new_shape.clone());
302 let mut b = other.reshape(new_shape.clone());
303
304 a.set_shape(new_shape.clone());
305 b.set_shape(new_shape.clone());
306
307 ParStridedZipSimd::new(a, b)
308 }
309}
310
311/// A trait to zip two simd iterators together.
312pub trait StridedSimdIteratorZip: Sized {
313 /// Combines this iterator with another SIMD-optimized iterator, enabling simultaneous iteration.
314 ///
315 /// This method performs shape broadcasting between `self` and `other` to ensure that both iterators
316 /// iterate over tensors with compatible shapes. It adjusts the strides and shapes of both iterators
317 /// to match the broadcasted shape and then returns a `StridedZipSimd` that allows for synchronized
318 /// iteration over both iterators.
319 ///
320 /// # Arguments
321 ///
322 /// * `other` - The other iterator to zip with. It must implement the `IterGetSetSimd`, `UnindexedProducer`,
323 /// and `ParallelIterator` traits, and its associated `Item` type must be `Send`.
324 ///
325 /// # Returns
326 ///
327 /// A `StridedZipSimd` instance that zips together `self` and `other`, enabling synchronized
328 /// iteration over their elements.
329 #[track_caller]
330 fn zip<'a, C>(mut self, mut other: C) -> StridedZipSimd<'a, Self, C>
331 where
332 C: 'a + IterGetSetSimd,
333 <C as IterGetSetSimd>::Item: Send,
334 Self: IterGetSetSimd,
335 <Self as IterGetSetSimd>::Item: Send,
336 {
337 let new_shape =
338 predict_broadcast_shape(self.shape(), other.shape()).expect("Cannot broadcast shapes");
339
340 other.broadcast_set_strides(&new_shape);
341 self.broadcast_set_strides(&new_shape);
342
343 other.set_shape(new_shape.clone());
344 self.set_shape(new_shape.clone());
345
346 StridedZipSimd::new(self, other)
347 }
348}
349
350/// A trait for performing single thread simd iteration over an iterator.
351pub trait StridedIteratorSimd
352where
353 Self: Sized + IterGetSetSimd,
354{
355 /// perform simd iteration, this method is for single thread simd iterator
356 fn for_each<F, F2>(mut self, op: F, vec_op: F2)
357 where
358 F: Fn(Self::Item),
359 F2: Fn(Self::SimdItem),
360 {
361 let outer_loop_size = self.outer_loop_size();
362 let inner_loop_size = self.inner_loop_size(); // we don't need to add 1 as we didn't subtract shape by 1
363 self.set_prg(vec![0; self.shape().len()]);
364 match (self.all_last_stride_one(), self.lanes()) {
365 (true, Some(vec_size)) => {
366 let remain = inner_loop_size % vec_size;
367 let inner = inner_loop_size - remain;
368 let n = inner / vec_size;
369 let unroll = n % 4;
370 if remain > 0 {
371 if unroll == 0 {
372 for _ in 0..outer_loop_size {
373 for idx in 0..n / 4 {
374 vec_op(self.inner_loop_next_simd(idx * 4));
375 vec_op(self.inner_loop_next_simd(idx * 4 + 1));
376 vec_op(self.inner_loop_next_simd(idx * 4 + 2));
377 vec_op(self.inner_loop_next_simd(idx * 4 + 3));
378 }
379 for idx in inner..inner_loop_size {
380 op(self.inner_loop_next(idx));
381 }
382 self.next();
383 }
384 } else {
385 for _ in 0..outer_loop_size {
386 for idx in 0..n {
387 vec_op(self.inner_loop_next_simd(idx));
388 }
389 for idx in inner..inner_loop_size {
390 op(self.inner_loop_next(idx));
391 }
392 self.next();
393 }
394 }
395 } else {
396 if unroll == 0 {
397 for _ in 0..outer_loop_size {
398 for idx in 0..n / 4 {
399 vec_op(self.inner_loop_next_simd(idx * 4));
400 vec_op(self.inner_loop_next_simd(idx * 4 + 1));
401 vec_op(self.inner_loop_next_simd(idx * 4 + 2));
402 vec_op(self.inner_loop_next_simd(idx * 4 + 3));
403 }
404 self.next();
405 }
406 } else {
407 for _ in 0..outer_loop_size {
408 for idx in 0..n {
409 vec_op(self.inner_loop_next_simd(idx));
410 }
411 self.next();
412 }
413 }
414 }
415 }
416 _ => {
417 for _ in 0..outer_loop_size {
418 for idx in 0..inner_loop_size {
419 op(self.inner_loop_next(idx));
420 }
421 self.next();
422 }
423 }
424 }
425 }
426 /// perform simd iteration with init, this method is for single thread simd iterator
427 fn for_each_init<F, INIT, T>(mut self, init: INIT, func: F)
428 where
429 F: Fn(&mut T, Self::Item),
430 INIT: Fn() -> T,
431 {
432 let outer_loop_size = self.outer_loop_size();
433 let inner_loop_size = self.inner_loop_size();
434 let mut init = init();
435 for _ in 0..outer_loop_size {
436 for idx in 0..inner_loop_size {
437 func(&mut init, self.inner_loop_next(idx));
438 }
439 self.next();
440 }
441 }
442}
443
444/// A trait for performing single thread simd iteration over an iterator.
445pub trait ParStridedIteratorSimd
446where
447 Self: Sized + UnindexedProducer + IterGetSetSimd + ParallelIterator,
448{
449 /// perform simd iteration, this method is for single thread simd iterator
450 fn for_each<F, F2>(self, op: F, vec_op: F2)
451 where
452 F: Fn(<Self as IterGetSetSimd>::Item) + Sync,
453 F2: Fn(<Self as IterGetSetSimd>::SimdItem) + Sync + Send + Copy,
454 <Self as IterGetSetSimd>::SimdItem: Send,
455 <Self as IterGetSetSimd>::Item: Send,
456 {
457 let with_simd = WithSimd { base: self, vec_op };
458 with_simd.for_each(|x| {
459 op(x);
460 });
461 }
462}
463
464/// A trait to map a function on the elements of an iterator.
465pub trait StridedIteratorMap: Sized {
466 /// Transforms the strided iterators by applying a provided function to their items.
467 ///
468 /// This method allows for element-wise operations on the zipped iterators by applying `func` to each item.
469 ///
470 /// # Type Parameters
471 ///
472 /// * `'a` - The lifetime associated with the iterators.
473 /// * `F` - The function to apply to each item.
474 /// * `U` - The output type after applying the function.
475 ///
476 /// # Arguments
477 ///
478 /// * `f` - A function that takes an item from the zipped iterator and returns a transformed value.
479 ///
480 /// # Returns
481 ///
482 /// A `StridedMap` instance that applies the provided function during iteration.
483 fn map<'a, T, F, U>(self, f: F) -> StridedMap<'a, Self, T, F>
484 where
485 F: Fn(T) -> U + Sync + Send + 'a,
486 U: CommonBounds,
487 Self: IterGetSet<Item = T>,
488 {
489 StridedMap {
490 iter: self,
491 f,
492 phantom: std::marker::PhantomData,
493 }
494 }
495}
496
497pub(crate) trait StridedHelper {
498 fn _set_last_strides(&mut self, stride: i64);
499 fn _set_strides(&mut self, strides: Strides);
500 fn _set_shape(&mut self, shape: Shape);
501 fn _layout(&self) -> &Layout;
502}
503
504pub(crate) trait ParStridedHelper {
505 fn _set_last_strides(&mut self, stride: i64);
506 fn _set_strides(&mut self, strides: Strides);
507 fn _set_shape(&mut self, shape: Shape);
508 fn _layout(&self) -> &Layout;
509 fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>);
510 fn _set_end_index(&mut self, end_index: usize);
511}