1use std::sync::Arc;
2
3use hpt_common::{shape::shape::Shape, strides::strides::Strides};
4use hpt_traits::tensor::{CommonBounds, TensorInfo};
5use rayon::iter::{
6 plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
7 ParallelIterator,
8};
9
10use crate::{
11 iterator_traits::IterGetSet, par_strided_mut::ParStridedMut, par_strided_zip::ParStridedZip,
12};
13
14pub mod par_strided_map_mut_simd {
16 use std::sync::Arc;
17
18 use crate::{CommonBounds, TensorInfo};
19 use hpt_common::{shape::shape::Shape, strides::strides::Strides, utils::simd_ref::MutVec};
20 use hpt_types::dtype::TypeCommon;
21 use rayon::iter::{
22 plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
23 ParallelIterator,
24 };
25
26 use crate::{
27 iterator_traits::IterGetSetSimd,
28 par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
29 par_strided_zip::par_strided_zip_simd::ParStridedZipSimd,
30 };
31
32 pub struct ParStridedMapMutSimd<'a, T>
36 where
37 T: TypeCommon + Send + Copy + Sync,
38 {
39 pub(crate) base: ParStridedMutSimd<'a, T>,
41 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
43 }
44
45 impl<'a, T> ParStridedMapMutSimd<'a, T>
46 where
47 T: CommonBounds,
48 T::Vec: Send,
49 {
50 pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
68 ParStridedMapMutSimd {
69 base: ParStridedMutSimd::new(res_tensor),
70 phantom: std::marker::PhantomData,
71 }
72 }
73 pub fn zip<C>(self, other: C) -> ParStridedZipSimd<'a, Self, C>
96 where
97 C: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator,
98 <C as IterGetSetSimd>::Item: Send,
99 T::Vec: Send,
100 {
101 ParStridedZipSimd::new(self, other)
102 }
103 }
104
105 impl<'a, T> ParallelIterator for ParStridedMapMutSimd<'a, T>
106 where
107 T: 'a + CommonBounds,
108 T::Vec: Send,
109 {
110 type Item = &'a mut T;
111
112 fn drive_unindexed<C>(self, consumer: C) -> C::Result
113 where
114 C: UnindexedConsumer<Self::Item>,
115 {
116 bridge_unindexed(self, consumer)
117 }
118 }
119
120 impl<'a, T> UnindexedProducer for ParStridedMapMutSimd<'a, T>
121 where
122 T: 'a + CommonBounds,
123 T::Vec: Send,
124 {
125 type Item = &'a mut T;
126
127 fn split(self) -> (Self, Option<Self>) {
128 let (a, b) = self.base.split();
129 (
130 ParStridedMapMutSimd {
131 base: a,
132 phantom: std::marker::PhantomData,
133 },
134 b.map(|x| ParStridedMapMutSimd {
135 base: x,
136 phantom: std::marker::PhantomData,
137 }),
138 )
139 }
140
141 fn fold_with<F>(self, folder: F) -> F
142 where
143 F: Folder<Self::Item>,
144 {
145 folder
146 }
147 }
148
149 impl<'a, T: 'a + CommonBounds> IterGetSetSimd for ParStridedMapMutSimd<'a, T>
150 where
151 T::Vec: Send,
152 {
153 type Item = &'a mut T;
154
155 type SimdItem = MutVec<'a, T::Vec>;
156
157 fn set_end_index(&mut self, end_index: usize) {
158 self.base.set_end_index(end_index);
159 }
160
161 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
162 self.base.set_intervals(intervals);
163 }
164
165 fn set_strides(&mut self, strides: Strides) {
166 self.base.set_strides(strides);
167 }
168
169 fn set_shape(&mut self, shape: Shape) {
170 self.base.set_shape(shape);
171 }
172
173 fn set_prg(&mut self, prg: Vec<i64>) {
174 self.base.set_prg(prg);
175 }
176
177 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
178 self.base.intervals()
179 }
180
181 fn strides(&self) -> &Strides {
182 self.base.strides()
183 }
184
185 fn shape(&self) -> &Shape {
186 self.base.shape()
187 }
188
189 fn layout(&self) -> &hpt_common::layout::layout::Layout {
190 self.base.layout()
191 }
192
193 fn broadcast_set_strides(&mut self, shape: &Shape) {
194 self.base.broadcast_set_strides(shape);
195 }
196
197 fn outer_loop_size(&self) -> usize {
198 self.base.outer_loop_size()
199 }
200
201 fn inner_loop_size(&self) -> usize {
202 self.base.inner_loop_size()
203 }
204
205 fn next(&mut self) {
206 self.base.next();
207 }
208
209 fn next_simd(&mut self) {
210 self.base.next_simd();
211 }
212
213 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
214 self.base.inner_loop_next(index)
215 }
216
217 fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
218 self.base.inner_loop_next_simd(index)
219 }
220
221 fn all_last_stride_one(&self) -> bool {
222 self.base.all_last_stride_one()
223 }
224
225 fn lanes(&self) -> Option<usize> {
226 self.base.lanes()
227 }
228 }
229}
230
231pub struct ParStridedMapMut<'a, T>
235where
236 T: Copy,
237{
238 pub(crate) base: ParStridedMut<'a, T>,
240 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
242}
243
244impl<'a, T> ParStridedMapMut<'a, T>
245where
246 T: CommonBounds,
247{
248 pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
265 ParStridedMapMut {
266 base: ParStridedMut::new(res_tensor),
267 phantom: std::marker::PhantomData,
268 }
269 }
270 pub fn zip<C>(self, other: C) -> ParStridedZip<'a, Self, C>
292 where
293 C: UnindexedProducer + 'a + IterGetSet + ParallelIterator,
294 <C as IterGetSet>::Item: Send,
295 {
296 ParStridedZip::new(self, other)
297 }
298}
299
300impl<'a, T> ParallelIterator for ParStridedMapMut<'a, T>
301where
302 T: 'a + CommonBounds,
303{
304 type Item = &'a mut T;
305
306 fn drive_unindexed<C>(self, consumer: C) -> C::Result
307 where
308 C: UnindexedConsumer<Self::Item>,
309 {
310 bridge_unindexed(self, consumer)
311 }
312}
313
314impl<'a, T> UnindexedProducer for ParStridedMapMut<'a, T>
315where
316 T: 'a + CommonBounds,
317{
318 type Item = &'a mut T;
319
320 fn split(self) -> (Self, Option<Self>) {
321 let (a, b) = self.base.split();
322 (
323 ParStridedMapMut {
324 base: a,
325 phantom: std::marker::PhantomData,
326 },
327 b.map(|x| ParStridedMapMut {
328 base: x,
329 phantom: std::marker::PhantomData,
330 }),
331 )
332 }
333
334 fn fold_with<F>(self, folder: F) -> F
335 where
336 F: Folder<Self::Item>,
337 {
338 folder
339 }
340}
341
342impl<'a, T: 'a + CommonBounds> IterGetSet for ParStridedMapMut<'a, T> {
343 type Item = &'a mut T;
344
345 fn set_end_index(&mut self, end_index: usize) {
346 self.base.set_end_index(end_index);
347 }
348
349 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
350 self.base.set_intervals(intervals);
351 }
352
353 fn set_strides(&mut self, strides: Strides) {
354 self.base.set_strides(strides);
355 }
356
357 fn set_shape(&mut self, shape: Shape) {
358 self.base.set_shape(shape);
359 }
360
361 fn set_prg(&mut self, prg: Vec<i64>) {
362 self.base.set_prg(prg);
363 }
364
365 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
366 self.base.intervals()
367 }
368
369 fn strides(&self) -> &Strides {
370 self.base.strides()
371 }
372 fn shape(&self) -> &Shape {
373 self.base.shape()
374 }
375 fn layout(&self) -> &hpt_common::layout::layout::Layout {
376 self.base.layout()
377 }
378 fn broadcast_set_strides(&mut self, shape: &Shape) {
379 self.base.broadcast_set_strides(shape);
380 }
381
382 fn outer_loop_size(&self) -> usize {
383 self.base.outer_loop_size()
384 }
385
386 fn inner_loop_size(&self) -> usize {
387 self.base.inner_loop_size()
388 }
389
390 fn next(&mut self) {
391 self.base.next();
392 }
393
394 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
395 self.base.inner_loop_next(index)
396 }
397}