hpt_iterator/
par_strided_map.rs

1use hpt_common::shape::shape::Shape;
2use hpt_traits::{
3    ops::creation::TensorCreator,
4    tensor::{CommonBounds, TensorInfo},
5};
6use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
7
8use crate::{
9    iterator_traits::{IterGetSet, ShapeManipulator},
10    par_strided_map_mut::ParStridedMapMut,
11    par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
12};
13
14/// A module for parallel strided map iterator.
15pub mod par_strided_map_simd {
16    use crate::{CommonBounds, TensorInfo};
17    use hpt_common::utils::simd_ref::MutVec;
18    use hpt_traits::ops::creation::TensorCreator;
19    use hpt_types::dtype::TypeCommon;
20    use rayon::iter::{plumbing::UnindexedProducer, ParallelIterator};
21
22    use crate::{
23        iterator_traits::{IterGetSetSimd, ParStridedIteratorSimdZip, ShapeManipulator},
24        par_strided_mut::par_strided_map_mut_simd::ParStridedMutSimd,
25        with_simd::WithSimd,
26    };
27    /// A parallel SIMD-optimized map iterator over tensor elements.
28    ///
29    /// This struct allows for applying two separate functions (`f` and `f2`) to elements of a tensor
30    /// in a SIMD-optimized and parallel manner.
31    #[derive(Clone)]
32    pub struct ParStridedMapSimd<'a, I, T: 'a, F, F2>
33    where
34        I: UnindexedProducer<Item = T>
35            + 'a
36            + IterGetSetSimd<Item = T>
37            + ParallelIterator
38            + ShapeManipulator,
39    {
40        /// The underlying parallel SIMD-optimized strided iterator.
41        pub(crate) iter: I,
42        /// The first function to apply to each item.
43        pub(crate) f: F,
44        /// The second function to apply to SIMD items.
45        pub(crate) f2: F2,
46        /// Phantom data to associate the lifetime `'a` with the struct.
47        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
48    }
49
50    impl<
51            'a,
52            I: UnindexedProducer<Item = T>
53                + 'a
54                + IterGetSetSimd<Item = T>
55                + ParallelIterator
56                + ShapeManipulator,
57            T: 'a,
58            F,
59            F2,
60        > ParStridedMapSimd<'a, I, T, F, F2>
61    {
62        /// Collects the results of the map operation into a new tensor.
63        ///
64        /// This method applies the provided functions `f` and `f2` to each element and SIMD item
65        /// respectively, accumulating the results into a new tensor of type `U`.
66        ///
67        /// # Type Parameters
68        ///
69        /// * `U` - The type of the tensor to collect the results into. Must implement `TensorAlloc`, `Clone`,
70        ///         and `TensorInfo`.
71        ///
72        /// # Arguments
73        ///
74        /// * `self` - The `ParStridedMapSimd` iterator instance.
75        ///
76        /// # Returns
77        ///
78        /// A new tensor of type `U` containing the results of the map operation.
79        pub fn collect<U>(self) -> U
80        where
81            F: Fn((&mut <U as TensorCreator>::Meta, <I as IterGetSetSimd>::Item))
82                + Sync
83                + Send
84                + 'a,
85            U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
86            <I as IterGetSetSimd>::Item: Send,
87            <U as TensorCreator>::Meta: CommonBounds,
88            F2: Send
89                + Sync
90                + Copy
91                + Fn(
92                    (
93                        MutVec<'_, <<U as TensorCreator>::Meta as TypeCommon>::Vec>,
94                        <I as IterGetSetSimd>::SimdItem,
95                    ),
96                ),
97        {
98            let res = U::empty(self.iter.shape().clone()).unwrap();
99            let par_strided = ParStridedMutSimd::new(res.clone());
100            let zip = par_strided.zip(self.iter);
101            let with_simd = WithSimd {
102                base: zip,
103                vec_op: self.f2,
104            };
105            with_simd.for_each(|x| {
106                (self.f)(x);
107            });
108            res
109        }
110    }
111}
112
113/// A parallel map iterator over tensor elements.
114///
115/// This struct allows for applying a single function (`f`) to elements of a tensor
116#[derive(Clone)]
117pub struct ParStridedMap<'a, I, T: 'a, F>
118where
119    I: UnindexedProducer<Item = T> + 'a + IterGetSet<Item = T> + ParallelIterator,
120{
121    /// The underlying parallel strided iterator.
122    pub(crate) iter: I,
123    /// The function to apply to each item.
124    pub(crate) f: F,
125    /// Phantom data to associate the lifetime `'a` with the struct.
126    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
127}
128
129impl<
130        'a,
131        I: UnindexedProducer<Item = T> + 'a + IterGetSet<Item = T> + ParallelIterator,
132        T: 'a,
133        F,
134    > ParStridedMap<'a, I, T, F>
135{
136    /// Collects the results of the map operation into a new tensor.
137    ///
138    /// This method applies the provided function `f` to each element of the tensor,
139    /// accumulating the results into a new tensor of type `U`.
140    ///
141    /// # Type Parameters
142    ///
143    /// * `U` - The type of the tensor to collect the results into. Must implement `TensorAlloc`, `Clone`,
144    ///         and `TensorInfo`.
145    ///
146    /// # Arguments
147    ///
148    /// * `self` - The `ParStridedMap` iterator instance.
149    ///
150    /// # Returns
151    ///
152    /// A new tensor of type `U` containing the results of the map operation.
153    pub fn collect<U>(self) -> U
154    where
155        F: Fn((&mut U::Meta, T)) + Sync + Send,
156        U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
157        <I as IterGetSet>::Item: Send,
158        <U as TensorCreator>::Meta: CommonBounds,
159    {
160        let res = U::empty(self.iter.shape().clone()).unwrap();
161        let strided_mut = ParStridedMapMut::new(res.clone());
162        let zip = strided_mut.zip(self.iter);
163        zip.for_each(|(x, y)| {
164            (self.f)((x, y));
165        });
166        res
167    }
168}
169
170impl<'a, T: CommonBounds> ShapeManipulator for ParStridedMutSimd<'a, T> {
171    fn reshape<S: Into<hpt_common::shape::shape::Shape>>(self, shape: S) -> Self {
172        let shape: Shape = shape.into();
173        let new_base = self.base.reshape(shape);
174        ParStridedMutSimd {
175            base: new_base,
176            phantom: self.phantom,
177        }
178    }
179
180    fn transpose<AXIS: Into<hpt_common::axis::axis::Axis>>(self, axes: AXIS) -> Self {
181        let axes = axes.into();
182        let new_base = self.base.transpose(axes);
183        ParStridedMutSimd {
184            base: new_base,
185            phantom: self.phantom,
186        }
187    }
188
189    fn expand<S: Into<hpt_common::shape::shape::Shape>>(self, shape: S) -> Self {
190        let shape: Shape = shape.into();
191        let new_base = self.base.expand(shape);
192        ParStridedMutSimd {
193            base: new_base,
194            phantom: self.phantom,
195        }
196    }
197}