hpt_iterator/
strided_map.rs

1use crate::iterator_traits::StridedIterator;
2use crate::{iterator_traits::IterGetSet, strided_map_mut::StridedMapMut};
3use hpt_traits::ops::creation::TensorCreator;
4use hpt_traits::tensor::{CommonBounds, TensorInfo};
5
6/// A module for strided map simd iterator.
7pub mod strided_map_simd {
8    use crate::iterator_traits::StridedIteratorSimd;
9    use crate::{
10        iterator_traits::IterGetSetSimd, strided_map_mut::strided_map_mut_simd::StridedMapMutSimd,
11    };
12    use crate::{CommonBounds, TensorInfo};
13    use hpt_traits::ops::creation::TensorCreator;
14    use hpt_types::dtype::TypeCommon;
15
16    /// # StridedMapSimd
17    /// A structure representing a SIMD (Single Instruction, Multiple Data) version of the strided map operation.
18    ///
19    /// This structure allows applying SIMD operations on tensors or arrays with strided memory layouts,
20    /// optimizing performance for certain types of computations.
21    #[derive(Clone)]
22    pub struct StridedMapSimd<'a, I, T: 'a, F, F2>
23    where
24        I: 'a + IterGetSetSimd<Item = T>,
25    {
26        /// An iterator implementing the `IterGetSetSimd` trait, providing access to the tensor elements.
27        pub(crate) iter: I,
28        /// A function to be applied to individual elements of the tensor
29        pub(crate) f: F,
30        /// A function to be applied to SIMD items of the tensor
31        pub(crate) f2: F2,
32        /// A marker for lifetimes to ensure proper memory safety
33        pub(crate) phantom: std::marker::PhantomData<&'a ()>,
34    }
35
36    impl<'a, I: 'a + IterGetSetSimd<Item = T>, T: 'a, F, F2> StridedMapSimd<'a, I, T, F, F2> {
37        /// Collects the results of applying the SIMD operations into a new tensor of type `U`.
38        ///
39        /// This method applies two functions, `f` for individual elements and `f2` for SIMD elements,
40        /// and collects the results into a new tensor of type `U`.
41        ///
42        /// # Type Parameters
43        /// - `U`: The type of the resulting tensor, which must implement `TensorAlloc` and `TensorInfo`.
44        ///
45        /// # Returns
46        /// A new tensor of type `U` containing the results of applying the functions.
47        pub fn collect<U>(self) -> U
48        where
49            F: Fn(T) -> U::Meta + Sync + Send + 'a,
50            F2: Fn(
51                    <I as IterGetSetSimd>::SimdItem,
52                ) -> <<U as TensorCreator>::Meta as TypeCommon>::Vec
53                + Sync
54                + Send
55                + 'a,
56            U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
57            <I as IterGetSetSimd>::Item: Send,
58            <U as TensorCreator>::Meta: CommonBounds,
59            <<U as TensorCreator>::Meta as TypeCommon>::Vec: Send,
60        {
61            let res = U::empty(self.iter.shape().clone()).unwrap();
62            let strided_mut = StridedMapMutSimd::new(res.clone());
63            let zip = strided_mut.zip(self.iter);
64            zip.for_each(
65                |(x, y)| {
66                    *x = (self.f)(y);
67                },
68                |(x, y)| {
69                    *x = (self.f2)(y);
70                },
71            );
72            res
73        }
74    }
75}
76
77/// # StridedMap
78///
79/// This struct is used to apply a function over a tensor or array that has a strided memory layout.
80/// The function `f` is applied to each element in the tensor as accessed through the strided iterator.
81///
82/// # Type Parameters
83/// - `'a`: The lifetime associated with the data being processed.
84/// - `I`: The iterator type that implements the `IterGetSet` trait, providing access to the elements.
85/// - `T`: The type of the elements in the tensor.
86/// - `F`: The function type that is applied to the tensor's elements.
87#[derive(Clone)]
88pub struct StridedMap<'a, I, T: 'a, F>
89where
90    I: 'a + IterGetSet<Item = T>,
91{
92    /// The iterator providing access to the tensor elements.
93    pub(crate) iter: I,
94    /// The function to be applied to each element of the tensor.
95    pub(crate) f: F,
96    /// Marker for the lifetime `'a`.
97    pub(crate) phantom: std::marker::PhantomData<&'a ()>,
98}
99impl<'a, I: 'a + IterGetSet<Item = T>, T: 'a, F> StridedMap<'a, I, T, F> {
100    /// Collects the results of applying the function `f` to the elements of the tensor into a new tensor of type `U`.
101    ///
102    /// This method applies the function `f` to each element in the tensor and collects the results in a new tensor `U`.
103    ///
104    /// # Type Parameters
105    /// - `U`: The type of the resulting tensor, which must implement the `TensorAlloc` and `TensorInfo` traits.
106    ///
107    /// # Returns
108    /// Returns a new tensor of type `U` containing the results of applying the function `f`.
109    pub fn collect<U>(self) -> U
110    where
111        F: Fn(T) -> U::Meta + Sync + Send + 'a,
112        U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
113        <I as IterGetSet>::Item: Send,
114        <U as TensorCreator>::Meta: CommonBounds,
115    {
116        let res = U::empty(self.iter.shape().clone()).unwrap();
117        let strided_mut = StridedMapMut::new(res.clone());
118        let zip = strided_mut.zip(self.iter);
119        zip.for_each(|(x, y)| {
120            *x = (self.f)(y);
121        });
122        res
123    }
124}