1use crate::{
2 iterator_traits::{IterGetSet, ParStridedHelper, ParStridedIteratorZip, ShapeManipulator},
3 par_strided::ParStrided,
4 shape_manipulate::{par_expand, par_reshape, par_transpose},
5};
6use hpt_common::shape::shape::Shape;
7use hpt_traits::tensor::{CommonBounds, TensorInfo};
8use rayon::iter::{
9 plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
10 ParallelIterator,
11};
12use std::sync::Arc;
13
14pub mod par_strided_map_mut_simd {
16 use crate::{
17 iterator_traits::{IterGetSetSimd, ParStridedIteratorSimd, ParStridedIteratorSimdZip},
18 par_strided::par_strided_simd::ParStridedSimd,
19 };
20 use crate::{CommonBounds, TensorInfo};
21 use hpt_common::{shape::shape::Shape, utils::pointer::Pointer, utils::simd_ref::MutVec};
22 use hpt_types::dtype::TypeCommon;
23 use hpt_types::traits::VecTrait;
24 use rayon::iter::{
25 plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
26 ParallelIterator,
27 };
28 use std::sync::Arc;
29 pub struct ParStridedMutSimd<'a, T: TypeCommon + Send + Copy + Sync> {
36 pub(crate) base: ParStridedSimd<T>,
38 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
39 }
40
41 impl<'a, T: CommonBounds> ParStridedMutSimd<'a, T> {
42 pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
56 ParStridedMutSimd {
57 base: ParStridedSimd::new(tensor),
58 phantom: std::marker::PhantomData,
59 }
60 }
61 }
62
63 impl<'a, T: CommonBounds> ParStridedIteratorSimdZip for ParStridedMutSimd<'a, T> {}
64 impl<'a, T: CommonBounds> ParStridedIteratorSimd for ParStridedMutSimd<'a, T> {}
65
66 impl<'a, T> ParallelIterator for ParStridedMutSimd<'a, T>
67 where
68 T: CommonBounds,
69 T::Vec: Send,
70 {
71 type Item = &'a mut T;
72
73 fn drive_unindexed<C>(self, consumer: C) -> C::Result
74 where
75 C: UnindexedConsumer<Self::Item>,
76 {
77 bridge_unindexed(self, consumer)
78 }
79 }
80
81 impl<'a, T> UnindexedProducer for ParStridedMutSimd<'a, T>
82 where
83 T: CommonBounds,
84 T::Vec: Send,
85 {
86 type Item = &'a mut T;
87
88 fn split(self) -> (Self, Option<Self>) {
89 let (a, b) = self.base.split();
90 (
91 ParStridedMutSimd {
92 base: a,
93 phantom: std::marker::PhantomData,
94 },
95 b.map(|x| ParStridedMutSimd {
96 base: x,
97 phantom: std::marker::PhantomData,
98 }),
99 )
100 }
101
102 fn fold_with<F>(self, folder: F) -> F
103 where
104 F: Folder<Self::Item>,
105 {
106 folder
107 }
108 }
109
110 impl<'a, T: 'a> IterGetSetSimd for ParStridedMutSimd<'a, T>
111 where
112 T: CommonBounds,
113 T::Vec: Send,
114 {
115 type Item = &'a mut T;
116
117 type SimdItem
118 = MutVec<'a, T::Vec>
119 where
120 Self: 'a;
121
122 fn set_end_index(&mut self, end_index: usize) {
123 self.base.set_end_index(end_index);
124 }
125
126 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
127 self.base.set_intervals(intervals);
128 }
129
130 fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
131 self.base.set_strides(strides);
132 }
133
134 fn set_shape(&mut self, shape: Shape) {
135 self.base.set_shape(shape);
136 }
137
138 fn set_prg(&mut self, prg: Vec<i64>) {
139 self.base.set_prg(prg);
140 }
141
142 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
143 self.base.intervals()
144 }
145
146 fn strides(&self) -> &hpt_common::strides::strides::Strides {
147 self.base.strides()
148 }
149
150 fn shape(&self) -> &Shape {
151 self.base.shape()
152 }
153
154 fn layout(&self) -> &hpt_common::layout::layout::Layout {
155 self.base.layout()
156 }
157
158 fn broadcast_set_strides(&mut self, shape: &Shape) {
159 self.base.broadcast_set_strides(shape);
160 }
161
162 fn outer_loop_size(&self) -> usize {
163 self.base.outer_loop_size()
164 }
165
166 fn inner_loop_size(&self) -> usize {
167 self.base.inner_loop_size()
168 }
169
170 fn next(&mut self) {
171 self.base.next();
172 }
173
174 fn next_simd(&mut self) {
175 self.base.next_simd();
176 }
177
178 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
179 unsafe {
180 self.base
181 .ptr
182 .get_ptr()
183 .add(index * (self.base.last_stride as usize))
184 .as_mut()
185 .unwrap()
186 }
187 }
188
189 #[inline(always)]
190 fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
191 unsafe {
192 let ptr = self.base.ptr.get_ptr().add(index * T::Vec::SIZE) as *mut T::Vec;
193 #[cfg(feature = "bound_check")]
194 return MutVec::new(Pointer::new(ptr, T::Vec::SIZE as i64));
195 #[cfg(not(feature = "bound_check"))]
196 return MutVec::new(Pointer::new(ptr));
197 }
198 }
199
200 fn all_last_stride_one(&self) -> bool {
201 self.base.all_last_stride_one()
202 }
203
204 fn lanes(&self) -> Option<usize> {
205 self.base.lanes()
206 }
207 }
208}
209
210pub struct ParStridedMut<'a, T> {
216 pub(crate) base: ParStrided<T>,
218 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
220}
221
222impl<'a, T: CommonBounds> ParStridedHelper for ParStridedMut<'a, T> {
223 fn _set_last_strides(&mut self, stride: i64) {
224 self.base._set_last_strides(stride);
225 }
226
227 fn _set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
228 self.base._set_strides(strides);
229 }
230
231 fn _set_shape(&mut self, shape: Shape) {
232 self.base._set_shape(shape);
233 }
234
235 fn _layout(&self) -> &hpt_common::layout::layout::Layout {
236 self.base._layout()
237 }
238
239 fn _set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
240 self.base._set_intervals(intervals);
241 }
242
243 fn _set_end_index(&mut self, end_index: usize) {
244 self.base._set_end_index(end_index);
245 }
246}
247
248impl<'a, T: CommonBounds> ShapeManipulator for ParStridedMut<'a, T> {
249 fn reshape<S: Into<Shape>>(self, shape: S) -> Self {
250 par_reshape(self, shape)
251 }
252
253 fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
254 par_transpose(self, axes)
255 }
256
257 fn expand<S: Into<Shape>>(self, shape: S) -> Self {
258 par_expand(self, shape)
259 }
260}
261
262impl<'a, T: CommonBounds> ParStridedIteratorZip for ParStridedMut<'a, T> {}
263
264impl<'a, T: CommonBounds> ParStridedMut<'a, T> {
265 pub fn new<U: TensorInfo<T>>(tensor: U) -> Self {
278 ParStridedMut {
279 base: ParStrided::new(tensor),
280 phantom: std::marker::PhantomData,
281 }
282 }
283}
284
285impl<'a, T> ParallelIterator for ParStridedMut<'a, T>
286where
287 T: CommonBounds,
288{
289 type Item = &'a mut T;
290
291 fn drive_unindexed<C>(self, consumer: C) -> C::Result
292 where
293 C: UnindexedConsumer<Self::Item>,
294 {
295 bridge_unindexed(self, consumer)
296 }
297}
298
299impl<'a, T> UnindexedProducer for ParStridedMut<'a, T>
300where
301 T: CommonBounds,
302{
303 type Item = &'a mut T;
304
305 fn split(self) -> (Self, Option<Self>) {
306 let (a, b) = self.base.split();
307 (
308 ParStridedMut {
309 base: a,
310 phantom: std::marker::PhantomData,
311 },
312 b.map(|x| ParStridedMut {
313 base: x,
314 phantom: std::marker::PhantomData,
315 }),
316 )
317 }
318
319 fn fold_with<F>(self, folder: F) -> F
320 where
321 F: Folder<Self::Item>,
322 {
323 folder
324 }
325}
326
327impl<'a, T: 'a> IterGetSet for ParStridedMut<'a, T>
328where
329 T: CommonBounds,
330{
331 type Item = &'a mut T;
332
333 fn set_end_index(&mut self, end_index: usize) {
334 self.base.set_end_index(end_index);
335 }
336
337 fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
338 self.base.set_intervals(intervals);
339 }
340
341 fn set_strides(&mut self, strides: hpt_common::strides::strides::Strides) {
342 self.base.set_strides(strides);
343 }
344
345 fn set_shape(&mut self, shape: Shape) {
346 self.base.set_shape(shape);
347 }
348
349 fn set_prg(&mut self, prg: Vec<i64>) {
350 self.base.set_prg(prg);
351 }
352
353 fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
354 self.base.intervals()
355 }
356
357 fn strides(&self) -> &hpt_common::strides::strides::Strides {
358 self.base.strides()
359 }
360
361 fn shape(&self) -> &Shape {
362 self.base.shape()
363 }
364
365 fn layout(&self) -> &hpt_common::layout::layout::Layout {
366 self.base.layout()
367 }
368
369 fn broadcast_set_strides(&mut self, shape: &Shape) {
370 self.base.broadcast_set_strides(shape);
371 }
372
373 fn outer_loop_size(&self) -> usize {
374 self.base.outer_loop_size()
375 }
376
377 fn inner_loop_size(&self) -> usize {
378 self.base.inner_loop_size()
379 }
380
381 fn next(&mut self) {
382 self.base.next();
383 }
384
385 fn inner_loop_next(&mut self, index: usize) -> Self::Item {
386 unsafe {
387 self.base
388 .ptr
389 .get_ptr()
390 .add(index * (self.base.last_stride as usize))
391 .as_mut()
392 .unwrap()
393 }
394 }
395}