hpt_iterator/
par_strided_map.rs1use hpt_common::shape::shape::Shape;
2use hpt_traits::tensor::{CommonBounds, TensorAlloc, TensorInfo};
3use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
4
5use crate::{
6 iterator_traits::{IterGetSet, ShapeManipulator},
7 par_strided_map_mut::ParStridedMapMut,
8 par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
9};
10
11pub mod par_strided_map_simd {
13 use hpt_common::utils::simd_ref::MutVec;
14 use hpt_traits::{CommonBounds, TensorAlloc, TensorInfo};
15 use hpt_types::dtype::TypeCommon;
16 use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
17
18 use crate::{
19 iterator_traits::{IterGetSetSimd, ParStridedIteratorSimdZip, ShapeManipulator},
20 par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
21 with_simd::WithSimd,
22 };
23 #[derive(Clone)]
28 pub struct ParStridedMapSimd<'a, I, T: 'a, F, F2>
29 where
30 I: UnindexedProducer<Item = T>
31 + 'a
32 + IterGetSetSimd<Item = T>
33 + ParallelIterator
34 + ShapeManipulator,
35 {
36 pub(crate) iter: I,
38 pub(crate) f: F,
40 pub(crate) f2: F2,
42 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
44 }
45
46 impl<
47 'a,
48 I: UnindexedProducer<Item = T>
49 + 'a
50 + IterGetSetSimd<Item = T>
51 + ParallelIterator
52 + ShapeManipulator,
53 T: 'a,
54 F,
55 F2,
56 > ParStridedMapSimd<'a, I, T, F, F2>
57 {
58 pub fn collect<U>(self) -> U
76 where
77 F: Fn((&mut <U as TensorAlloc>::Meta, <I as IterGetSetSimd>::Item)) + Sync + Send + 'a,
78 U: Clone + TensorInfo<U::Meta> + TensorAlloc,
79 <I as IterGetSetSimd>::Item: Send,
80 <U as TensorAlloc>::Meta: CommonBounds,
81 F2: Send
82 + Sync
83 + Copy
84 + Fn(
85 (
86 MutVec<'_, <<U as TensorAlloc>::Meta as TypeCommon>::Vec>,
87 <I as IterGetSetSimd>::SimdItem,
88 ),
89 ),
90 {
91 let res = U::_empty(self.iter.shape().clone()).unwrap();
92 let par_strided = ParStridedMutSimd::new(res.clone());
93 let zip = par_strided.zip(self.iter);
94 let with_simd = WithSimd {
95 base: zip,
96 vec_op: self.f2,
97 };
98 with_simd.for_each(|x| {
99 (self.f)(x);
100 });
101 res
102 }
103 }
104}
105
106#[derive(Clone)]
110pub struct ParStridedMap<'a, I, T: 'a, F>
111where
112 I: UnindexedProducer<Item = T> + 'a + IterGetSet<Item = T> + ParallelIterator,
113{
114 pub(crate) iter: I,
116 pub(crate) f: F,
118 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
120}
121
122impl<
123 'a,
124 I: UnindexedProducer<Item = T> + 'a + IterGetSet<Item = T> + ParallelIterator,
125 T: 'a,
126 F,
127 > ParStridedMap<'a, I, T, F>
128{
129 pub fn collect<U>(self) -> U
147 where
148 F: Fn((&mut U::Meta, T)) + Sync + Send,
149 U: Clone + TensorInfo<U::Meta> + TensorAlloc,
150 <I as IterGetSet>::Item: Send,
151 <U as TensorAlloc>::Meta: CommonBounds,
152 {
153 let res = U::_empty(self.iter.shape().clone()).unwrap();
154 let strided_mut = ParStridedMapMut::new(res.clone());
155 let zip = strided_mut.zip(self.iter);
156 zip.for_each(|(x, y)| {
157 (self.f)((x, y));
158 });
159 res
160 }
161}
162
163impl<'a, T: CommonBounds> ShapeManipulator for ParStridedMutSimd<'a, T> {
164 fn reshape<S: Into<hpt_common::shape::shape::Shape>>(self, shape: S) -> Self {
165 let shape: Shape = shape.into();
166 let new_base = self.base.reshape(shape);
167 ParStridedMutSimd {
168 base: new_base,
169 phantom: self.phantom,
170 }
171 }
172
173 fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
174 let axes = axes.into();
175 let new_base = self.base.transpose(axes);
176 ParStridedMutSimd {
177 base: new_base,
178 phantom: self.phantom,
179 }
180 }
181
182 fn expand<S: Into<hpt_common::shape::shape::Shape>>(self, shape: S) -> Self {
183 let shape: Shape = shape.into();
184 let new_base = self.base.expand(shape);
185 ParStridedMutSimd {
186 base: new_base,
187 phantom: self.phantom,
188 }
189 }
190}