1use std::sync::Arc;
2
3use hpt_common::{shape::shape::Shape, strides::strides::Strides};
4use hpt_traits::tensor::{CommonBounds, TensorInfo};
5use rayon::iter::{
6    plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
7    ParallelIterator,
8};
9
10use crate::{
11    iterator_traits::IterGetSet, par_strided_mut::ParStridedMut, par_strided_zip::ParStridedZip,
12};
13
14pub mod par_strided_map_mut_simd {
16    use std::sync::Arc;
17
18    use hpt_common::{shape::shape::Shape, strides::strides::Strides, utils::simd_ref::MutVec};
19    use hpt_traits::{CommonBounds, TensorInfo};
20    use hpt_types::dtype::TypeCommon;
21    use rayon::iter::{
22        plumbing::{bridge_unindexed, Folder, UnindexedConsumer, UnindexedProducer},
23        ParallelIterator,
24    };
25
26    use crate::{
27        iterator_traits::IterGetSetSimd,
28        par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
29        par_strided_zip::par_strided_zip_simd::ParStridedZipSimd,
30    };
31
32    pub struct ParStridedMapMutSimd<'a, T>
36    where
37        T: TypeCommon + Send + Copy + Sync,
38    {
39        pub(crate) base: ParStridedMutSimd<'a, T>,
41        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
43    }
44
45    impl<'a, T> ParStridedMapMutSimd<'a, T>
46    where
47        T: CommonBounds,
48        T::Vec: Send,
49    {
50        pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
68            ParStridedMapMutSimd {
69                base: ParStridedMutSimd::new(res_tensor),
70                phantom: std::marker::PhantomData,
71            }
72        }
73        pub fn zip<C>(self, other: C) -> ParStridedZipSimd<'a, Self, C>
96        where
97            C: UnindexedProducer + 'a + IterGetSetSimd + ParallelIterator,
98            <C as IterGetSetSimd>::Item: Send,
99            T::Vec: Send,
100        {
101            ParStridedZipSimd::new(self, other)
102        }
103    }
104
105    impl<'a, T> ParallelIterator for ParStridedMapMutSimd<'a, T>
106    where
107        T: 'a + CommonBounds,
108        T::Vec: Send,
109    {
110        type Item = &'a mut T;
111
112        fn drive_unindexed<C>(self, consumer: C) -> C::Result
113        where
114            C: UnindexedConsumer<Self::Item>,
115        {
116            bridge_unindexed(self, consumer)
117        }
118    }
119
120    impl<'a, T> UnindexedProducer for ParStridedMapMutSimd<'a, T>
121    where
122        T: 'a + CommonBounds,
123        T::Vec: Send,
124    {
125        type Item = &'a mut T;
126
127        fn split(self) -> (Self, Option<Self>) {
128            let (a, b) = self.base.split();
129            (
130                ParStridedMapMutSimd {
131                    base: a,
132                    phantom: std::marker::PhantomData,
133                },
134                b.map(|x| ParStridedMapMutSimd {
135                    base: x,
136                    phantom: std::marker::PhantomData,
137                }),
138            )
139        }
140
141        fn fold_with<F>(self, folder: F) -> F
142        where
143            F: Folder<Self::Item>,
144        {
145            folder
146        }
147    }
148
149    impl<'a, T: 'a + CommonBounds> IterGetSetSimd for ParStridedMapMutSimd<'a, T>
150    where
151        T::Vec: Send,
152    {
153        type Item = &'a mut T;
154
155        type SimdItem = MutVec<'a, T::Vec>;
156
157        fn set_end_index(&mut self, end_index: usize) {
158            self.base.set_end_index(end_index);
159        }
160
161        fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
162            self.base.set_intervals(intervals);
163        }
164
165        fn set_strides(&mut self, strides: Strides) {
166            self.base.set_strides(strides);
167        }
168
169        fn set_shape(&mut self, shape: Shape) {
170            self.base.set_shape(shape);
171        }
172
173        fn set_prg(&mut self, prg: Vec<i64>) {
174            self.base.set_prg(prg);
175        }
176
177        fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
178            self.base.intervals()
179        }
180
181        fn strides(&self) -> &Strides {
182            self.base.strides()
183        }
184
185        fn shape(&self) -> &Shape {
186            self.base.shape()
187        }
188
189        fn layout(&self) -> &hpt_common::layout::layout::Layout {
190            self.base.layout()
191        }
192
193        fn broadcast_set_strides(&mut self, shape: &Shape) {
194            self.base.broadcast_set_strides(shape);
195        }
196
197        fn outer_loop_size(&self) -> usize {
198            self.base.outer_loop_size()
199        }
200
201        fn inner_loop_size(&self) -> usize {
202            self.base.inner_loop_size()
203        }
204
205        fn next(&mut self) {
206            self.base.next();
207        }
208
209        fn next_simd(&mut self) {
210            self.base.next_simd();
211        }
212
213        fn inner_loop_next(&mut self, index: usize) -> Self::Item {
214            self.base.inner_loop_next(index)
215        }
216
217        fn inner_loop_next_simd(&mut self, index: usize) -> Self::SimdItem {
218            self.base.inner_loop_next_simd(index)
219        }
220
221        fn all_last_stride_one(&self) -> bool {
222            self.base.all_last_stride_one()
223        }
224
225        fn lanes(&self) -> Option<usize> {
226            self.base.lanes()
227        }
228    }
229}
230
231pub struct ParStridedMapMut<'a, T>
235where
236    T: Copy,
237{
238    pub(crate) base: ParStridedMut<'a, T>,
240    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
242}
243
244impl<'a, T> ParStridedMapMut<'a, T>
245where
246    T: CommonBounds,
247{
248    pub fn new<U: TensorInfo<T>>(res_tensor: U) -> Self {
265        ParStridedMapMut {
266            base: ParStridedMut::new(res_tensor),
267            phantom: std::marker::PhantomData,
268        }
269    }
270    pub fn zip<C>(self, other: C) -> ParStridedZip<'a, Self, C>
292    where
293        C: UnindexedProducer + 'a + IterGetSet + ParallelIterator,
294        <C as IterGetSet>::Item: Send,
295    {
296        ParStridedZip::new(self, other)
297    }
298}
299
300impl<'a, T> ParallelIterator for ParStridedMapMut<'a, T>
301where
302    T: 'a + CommonBounds,
303{
304    type Item = &'a mut T;
305
306    fn drive_unindexed<C>(self, consumer: C) -> C::Result
307    where
308        C: UnindexedConsumer<Self::Item>,
309    {
310        bridge_unindexed(self, consumer)
311    }
312}
313
314impl<'a, T> UnindexedProducer for ParStridedMapMut<'a, T>
315where
316    T: 'a + CommonBounds,
317{
318    type Item = &'a mut T;
319
320    fn split(self) -> (Self, Option<Self>) {
321        let (a, b) = self.base.split();
322        (
323            ParStridedMapMut {
324                base: a,
325                phantom: std::marker::PhantomData,
326            },
327            b.map(|x| ParStridedMapMut {
328                base: x,
329                phantom: std::marker::PhantomData,
330            }),
331        )
332    }
333
334    fn fold_with<F>(self, folder: F) -> F
335    where
336        F: Folder<Self::Item>,
337    {
338        folder
339    }
340}
341
342impl<'a, T: 'a + CommonBounds> IterGetSet for ParStridedMapMut<'a, T> {
343    type Item = &'a mut T;
344
345    fn set_end_index(&mut self, end_index: usize) {
346        self.base.set_end_index(end_index);
347    }
348
349    fn set_intervals(&mut self, intervals: Arc<Vec<(usize, usize)>>) {
350        self.base.set_intervals(intervals);
351    }
352
353    fn set_strides(&mut self, strides: Strides) {
354        self.base.set_strides(strides);
355    }
356
357    fn set_shape(&mut self, shape: Shape) {
358        self.base.set_shape(shape);
359    }
360
361    fn set_prg(&mut self, prg: Vec<i64>) {
362        self.base.set_prg(prg);
363    }
364
365    fn intervals(&self) -> &Arc<Vec<(usize, usize)>> {
366        self.base.intervals()
367    }
368
369    fn strides(&self) -> &Strides {
370        self.base.strides()
371    }
372    fn shape(&self) -> &Shape {
373        self.base.shape()
374    }
375    fn layout(&self) -> &hpt_common::layout::layout::Layout {
376        self.base.layout()
377    }
378    fn broadcast_set_strides(&mut self, shape: &Shape) {
379        self.base.broadcast_set_strides(shape);
380    }
381
382    fn outer_loop_size(&self) -> usize {
383        self.base.outer_loop_size()
384    }
385
386    fn inner_loop_size(&self) -> usize {
387        self.base.inner_loop_size()
388    }
389
390    fn next(&mut self) {
391        self.base.next();
392    }
393
394    fn inner_loop_next(&mut self, index: usize) -> Self::Item {
395        self.base.inner_loop_next(index)
396    }
397}