hpt_iterator/
par_strided_map.rs

1use hpt_common::shape::shape::Shape;
2use hpt_traits::tensor::{CommonBounds, TensorAlloc, TensorInfo};
3use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
4
5use crate::{
6    iterator_traits::{IterGetSet, ShapeManipulator},
7    par_strided_map_mut::ParStridedMapMut,
8    par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
9};
10
11/// A module for parallel strided map iterator.
12pub mod par_strided_map_simd {
13    use hpt_common::utils::simd_ref::MutVec;
14    use hpt_traits::{CommonBounds, TensorAlloc, TensorInfo};
15    use hpt_types::dtype::TypeCommon;
16    use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
17
18    use crate::{
19        iterator_traits::{IterGetSetSimd, ParStridedIteratorSimdZip, ShapeManipulator},
20        par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
21        with_simd::WithSimd,
22    };
23    /// A parallel SIMD-optimized map iterator over tensor elements.
24    ///
25    /// This struct allows for applying two separate functions (`f` and `f2`) to elements of a tensor
26    /// in a SIMD-optimized and parallel manner.
27    #[derive(Clone)]
28    pub struct ParStridedMapSimd<'a, I, T: 'a, F, F2>
29    where
30        I: UnindexedProducer<Item = T>
31            + 'a
32            + IterGetSetSimd<Item = T>
33            + ParallelIterator
34            + ShapeManipulator,
35    {
36        /// The underlying parallel SIMD-optimized strided iterator.
37        pub(crate) iter: I,
38        /// The first function to apply to each item.
39        pub(crate) f: F,
40        /// The second function to apply to SIMD items.
41        pub(crate) f2: F2,
42        /// Phantom data to associate the lifetime `'a` with the struct.
43        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
44    }
45
46    impl<
47            'a,
48            I: UnindexedProducer<Item = T>
49                + 'a
50                + IterGetSetSimd<Item = T>
51                + ParallelIterator
52                + ShapeManipulator,
53            T: 'a,
54            F,
55            F2,
56        > ParStridedMapSimd<'a, I, T, F, F2>
57    {
58        /// Collects the results of the map operation into a new tensor.
59        ///
60        /// This method applies the provided functions `f` and `f2` to each element and SIMD item
61        /// respectively, accumulating the results into a new tensor of type `U`.
62        ///
63        /// # Type Parameters
64        ///
65        /// * `U` - The type of the tensor to collect the results into. Must implement `TensorAlloc`, `Clone`,
66        ///         and `TensorInfo`.
67        ///
68        /// # Arguments
69        ///
70        /// * `self` - The `ParStridedMapSimd` iterator instance.
71        ///
72        /// # Returns
73        ///
74        /// A new tensor of type `U` containing the results of the map operation.
75        pub fn collect<U>(self) -> U
76        where
77            F: Fn((&mut <U as TensorAlloc>::Meta, <I as IterGetSetSimd>::Item)) + Sync + Send + 'a,
78            U: Clone + TensorInfo<U::Meta> + TensorAlloc,
79            <I as IterGetSetSimd>::Item: Send,
80            <U as TensorAlloc>::Meta: CommonBounds,
81            F2: Send
82                + Sync
83                + Copy
84                + Fn(
85                    (
86                        MutVec<'_, <<U as TensorAlloc>::Meta as TypeCommon>::Vec>,
87                        <I as IterGetSetSimd>::SimdItem,
88                    ),
89                ),
90        {
91            let res = U::_empty(self.iter.shape().clone()).unwrap();
92            let par_strided = ParStridedMutSimd::new(res.clone());
93            let zip = par_strided.zip(self.iter);
94            let with_simd = WithSimd {
95                base: zip,
96                vec_op: self.f2,
97            };
98            with_simd.for_each(|x| {
99                (self.f)(x);
100            });
101            res
102        }
103    }
104}
105
106/// A parallel map iterator over tensor elements.
107///
108/// This struct allows for applying a single function (`f`) to elements of a tensor
109#[derive(Clone)]
110pub struct ParStridedMap<'a, I, T: 'a, F>
111where
112    I: UnindexedProducer<Item = T> + 'a + IterGetSet<Item = T> + ParallelIterator,
113{
114    /// The underlying parallel strided iterator.
115    pub(crate) iter: I,
116    /// The function to apply to each item.
117    pub(crate) f: F,
118    /// Phantom data to associate the lifetime `'a` with the struct.
119    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
120}
121
122impl<
123        'a,
124        I: UnindexedProducer<Item = T> + 'a + IterGetSet<Item = T> + ParallelIterator,
125        T: 'a,
126        F,
127    > ParStridedMap<'a, I, T, F>
128{
129    /// Collects the results of the map operation into a new tensor.
130    ///
131    /// This method applies the provided function `f` to each element of the tensor,
132    /// accumulating the results into a new tensor of type `U`.
133    ///
134    /// # Type Parameters
135    ///
136    /// * `U` - The type of the tensor to collect the results into. Must implement `TensorAlloc`, `Clone`,
137    ///         and `TensorInfo`.
138    ///
139    /// # Arguments
140    ///
141    /// * `self` - The `ParStridedMap` iterator instance.
142    ///
143    /// # Returns
144    ///
145    /// A new tensor of type `U` containing the results of the map operation.
146    pub fn collect<U>(self) -> U
147    where
148        F: Fn((&mut U::Meta, T)) + Sync + Send,
149        U: Clone + TensorInfo<U::Meta> + TensorAlloc,
150        <I as IterGetSet>::Item: Send,
151        <U as TensorAlloc>::Meta: CommonBounds,
152    {
153        let res = U::_empty(self.iter.shape().clone()).unwrap();
154        let strided_mut = ParStridedMapMut::new(res.clone());
155        let zip = strided_mut.zip(self.iter);
156        zip.for_each(|(x, y)| {
157            (self.f)((x, y));
158        });
159        res
160    }
161}
162
163impl<'a, T: CommonBounds> ShapeManipulator for ParStridedMutSimd<'a, T> {
164    fn reshape<S: Into<hpt_common::shape::shape::Shape>>(self, shape: S) -> Self {
165        let shape: Shape = shape.into();
166        let new_base = self.base.reshape(shape);
167        ParStridedMutSimd {
168            base: new_base,
169            phantom: self.phantom,
170        }
171    }
172
173    fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
174        let axes = axes.into();
175        let new_base = self.base.transpose(axes);
176        ParStridedMutSimd {
177            base: new_base,
178            phantom: self.phantom,
179        }
180    }
181
182    fn expand<S: Into<hpt_common::shape::shape::Shape>>(self, shape: S) -> Self {
183        let shape: Shape = shape.into();
184        let new_base = self.base.expand(shape);
185        ParStridedMutSimd {
186            base: new_base,
187            phantom: self.phantom,
188        }
189    }
190}