hpt_iterator/
par_strided_map.rs1use hpt_common::shape::shape::Shape;
2use hpt_traits::{
3 ops::creation::TensorCreator,
4 tensor::{CommonBounds, TensorInfo},
5};
6use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
7
8use crate::{
9 iterator_traits::{IterGetSet, ShapeManipulator},
10 par_strided_map_mut::ParStridedMapMut,
11 par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
12};
13
14pub mod par_strided_map_simd {
16 use crate::{CommonBounds, TensorInfo};
17 use hpt_common::utils::simd_ref::MutVec;
18 use hpt_traits::ops::creation::TensorCreator;
19 use hpt_types::dtype::TypeCommon;
20 use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
21
22 use crate::{
23 iterator_traits::{IterGetSetSimd, ParStridedIteratorSimdZip, ShapeManipulator},
24 par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
25 with_simd::WithSimd,
26 };
27 #[derive(Clone)]
32 pub struct ParStridedMapSimd<'a, I, T: 'a, F, F2>
33 where
34 I: UnindexedProducer<Item = T>
35 + 'a
36 + IterGetSetSimd<Item = T>
37 + ParallelIterator
38 + ShapeManipulator,
39 {
40 pub(crate) iter: I,
42 pub(crate) f: F,
44 pub(crate) f2: F2,
46 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
48 }
49
50 impl<
51 'a,
52 I: UnindexedProducer<Item = T>
53 + 'a
54 + IterGetSetSimd<Item = T>
55 + ParallelIterator
56 + ShapeManipulator,
57 T: 'a,
58 F,
59 F2,
60 > ParStridedMapSimd<'a, I, T, F, F2>
61 {
62 pub fn collect<U>(self) -> U
80 where
81 F: Fn((&mut <U as TensorCreator>::Meta, <I as IterGetSetSimd>::Item))
82 + Sync
83 + Send
84 + 'a,
85 U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
86 <I as IterGetSetSimd>::Item: Send,
87 <U as TensorCreator>::Meta: CommonBounds,
88 F2: Send
89 + Sync
90 + Copy
91 + Fn(
92 (
93 MutVec<'_, <<U as TensorCreator>::Meta as TypeCommon>::Vec>,
94 <I as IterGetSetSimd>::SimdItem,
95 ),
96 ),
97 {
98 let res = U::empty(self.iter.shape().clone()).unwrap();
99 let par_strided = ParStridedMutSimd::new(res.clone());
100 let zip = par_strided.zip(self.iter);
101 let with_simd = WithSimd {
102 base: zip,
103 vec_op: self.f2,
104 };
105 with_simd.for_each(|x| {
106 (self.f)(x);
107 });
108 res
109 }
110 }
111}
112
113#[derive(Clone)]
117pub struct ParStridedMap<'a, I, T: 'a, F>
118where
119 I: UnindexedProducer<Item = T> + 'a + IterGetSet<Item = T> + ParallelIterator,
120{
121 pub(crate) iter: I,
123 pub(crate) f: F,
125 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
127}
128
129impl<
130 'a,
131 I: UnindexedProducer<Item = T> + 'a + IterGetSet<Item = T> + ParallelIterator,
132 T: 'a,
133 F,
134 > ParStridedMap<'a, I, T, F>
135{
136 pub fn collect<U>(self) -> U
154 where
155 F: Fn((&mut U::Meta, T)) + Sync + Send,
156 U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
157 <I as IterGetSet>::Item: Send,
158 <U as TensorCreator>::Meta: CommonBounds,
159 {
160 let res = U::empty(self.iter.shape().clone()).unwrap();
161 let strided_mut = ParStridedMapMut::new(res.clone());
162 let zip = strided_mut.zip(self.iter);
163 zip.for_each(|(x, y)| {
164 (self.f)((x, y));
165 });
166 res
167 }
168}
169
170impl<'a, T: CommonBounds> ShapeManipulator for ParStridedMutSimd<'a, T> {
171 fn reshape<S: Into<hpt_common::shape::shape::Shape>>(self, shape: S) -> Self {
172 let shape: Shape = shape.into();
173 let new_base = self.base.reshape(shape);
174 ParStridedMutSimd {
175 base: new_base,
176 phantom: self.phantom,
177 }
178 }
179
180 fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
181 let axes = axes.into();
182 let new_base = self.base.transpose(axes);
183 ParStridedMutSimd {
184 base: new_base,
185 phantom: self.phantom,
186 }
187 }
188
189 fn expand<S: Into<hpt_common::shape::shape::Shape>>(self, shape: S) -> Self {
190 let shape: Shape = shape.into();
191 let new_base = self.base.expand(shape);
192 ParStridedMutSimd {
193 base: new_base,
194 phantom: self.phantom,
195 }
196 }
197}