hpt_iterator/
strided_map.rs

1use crate::iterator_traits::StridedIterator;
2use crate::{iterator_traits::IterGetSet, strided_map_mut::StridedMapMut};
3use hpt_traits::tensor::{CommonBounds, TensorAlloc, TensorInfo};
4
5/// A module for strided map simd iterator.
6pub mod strided_map_simd {
7    use crate::iterator_traits::StridedIteratorSimd;
8    use crate::{
9        iterator_traits::IterGetSetSimd, strided_map_mut::strided_map_mut_simd::StridedMapMutSimd,
10    };
11    use hpt_traits::{CommonBounds, TensorAlloc, TensorInfo};
12    use hpt_types::dtype::TypeCommon;
13
14    /// # StridedMapSimd
15    /// A structure representing a SIMD (Single Instruction, Multiple Data) version of the strided map operation.
16    ///
17    /// This structure allows applying SIMD operations on tensors or arrays with strided memory layouts,
18    /// optimizing performance for certain types of computations.
19    #[derive(Clone)]
20    pub struct StridedMapSimd<'a, I, T: 'a, F, F2>
21    where
22        I: 'a + IterGetSetSimd<Item = T>,
23    {
24        /// An iterator implementing the `IterGetSetSimd` trait, providing access to the tensor elements.
25        pub(crate) iter: I,
26        /// A function to be applied to individual elements of the tensor
27        pub(crate) f: F,
28        /// A function to be applied to SIMD items of the tensor
29        pub(crate) f2: F2,
30        /// A marker for lifetimes to ensure proper memory safety
31        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
32    }
33
34    impl<'a, I: 'a + IterGetSetSimd<Item = T>, T: 'a, F, F2> StridedMapSimd<'a, I, T, F, F2> {
35        /// Collects the results of applying the SIMD operations into a new tensor of type `U`.
36        ///
37        /// This method applies two functions, `f` for individual elements and `f2` for SIMD elements,
38        /// and collects the results into a new tensor of type `U`.
39        ///
40        /// # Type Parameters
41        /// - `U`: The type of the resulting tensor, which must implement `TensorAlloc` and `TensorInfo`.
42        ///
43        /// # Returns
44        /// A new tensor of type `U` containing the results of applying the functions.
45        pub fn collect<U>(self) -> U
46        where
47            F: Fn(T) -> U::Meta + Sync + Send + 'a,
48            F2: Fn(<I as IterGetSetSimd>::SimdItem) -> <<U as TensorAlloc>::Meta as TypeCommon>::Vec
49                + Sync
50                + Send
51                + 'a,
52            U: Clone + TensorInfo<U::Meta> + TensorAlloc,
53            <I as IterGetSetSimd>::Item: Send,
54            <U as TensorAlloc>::Meta: CommonBounds,
55            <<U as TensorAlloc>::Meta as TypeCommon>::Vec: Send,
56        {
57            let res = U::_empty(self.iter.shape().clone()).unwrap();
58            let strided_mut = StridedMapMutSimd::new(res.clone());
59            let zip = strided_mut.zip(self.iter);
60            zip.for_each(
61                |(x, y)| {
62                    *x = (self.f)(y);
63                },
64                |(x, y)| {
65                    *x = (self.f2)(y);
66                },
67            );
68            res
69        }
70    }
71}
72
73/// # StridedMap
74///
75/// This struct is used to apply a function over a tensor or array that has a strided memory layout.
76/// The function `f` is applied to each element in the tensor as accessed through the strided iterator.
77///
78/// # Type Parameters
79/// - `'a`: The lifetime associated with the data being processed.
80/// - `I`: The iterator type that implements the `IterGetSet` trait, providing access to the elements.
81/// - `T`: The type of the elements in the tensor.
82/// - `F`: The function type that is applied to the tensor's elements.
83#[derive(Clone)]
84pub struct StridedMap<'a, I, T: 'a, F>
85where
86    I: 'a + IterGetSet<Item = T>,
87{
88    /// The iterator providing access to the tensor elements.
89    pub(crate) iter: I,
90    /// The function to be applied to each element of the tensor.
91    pub(crate) f: F,
92    /// Marker for the lifetime `'a`.
93    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
94}
95impl<'a, I: 'a + IterGetSet<Item = T>, T: 'a, F> StridedMap<'a, I, T, F> {
96    /// Collects the results of applying the function `f` to the elements of the tensor into a new tensor of type `U`.
97    ///
98    /// This method applies the function `f` to each element in the tensor and collects the results in a new tensor `U`.
99    ///
100    /// # Type Parameters
101    /// - `U`: The type of the resulting tensor, which must implement the `TensorAlloc` and `TensorInfo` traits.
102    ///
103    /// # Returns
104    /// Returns a new tensor of type `U` containing the results of applying the function `f`.
105    pub fn collect<U>(self) -> U
106    where
107        F: Fn(T) -> U::Meta + Sync + Send + 'a,
108        U: Clone + TensorInfo<U::Meta> + TensorAlloc,
109        <I as IterGetSet>::Item: Send,
110        <U as TensorAlloc>::Meta: CommonBounds,
111    {
112        let res = U::_empty(self.iter.shape().clone()).unwrap();
113        let strided_mut = StridedMapMut::new(res.clone());
114        let zip = strided_mut.zip(self.iter);
115        zip.for_each(|(x, y)| {
116            *x = (self.f)(y);
117        });
118        res
119    }
120}