hpt_iterator/strided_map.rs
1use crate::iterator_traits::StridedIterator;
2use crate::{iterator_traits::IterGetSet, strided_map_mut::StridedMapMut};
3use hpt_traits::ops::creation::TensorCreator;
4use hpt_traits::tensor::{CommonBounds, TensorInfo};
5
6/// A module for strided map simd iterator.
7pub mod strided_map_simd {
8 use crate::iterator_traits::StridedIteratorSimd;
9 use crate::{
10 iterator_traits::IterGetSetSimd, strided_map_mut::strided_map_mut_simd::StridedMapMutSimd,
11 };
12 use crate::{CommonBounds, TensorInfo};
13 use hpt_traits::ops::creation::TensorCreator;
14 use hpt_types::dtype::TypeCommon;
15
16 /// # StridedMapSimd
17 /// A structure representing a SIMD (Single Instruction, Multiple Data) version of the strided map operation.
18 ///
19 /// This structure allows applying SIMD operations on tensors or arrays with strided memory layouts,
20 /// optimizing performance for certain types of computations.
21 #[derive(Clone)]
22 pub struct StridedMapSimd<'a, I, T: 'a, F, F2>
23 where
24 I: 'a + IterGetSetSimd<Item = T>,
25 {
26 /// An iterator implementing the `IterGetSetSimd` trait, providing access to the tensor elements.
27 pub(crate) iter: I,
28 /// A function to be applied to individual elements of the tensor
29 pub(crate) f: F,
30 /// A function to be applied to SIMD items of the tensor
31 pub(crate) f2: F2,
32 /// A marker for lifetimes to ensure proper memory safety
33 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
34 }
35
36 impl<'a, I: 'a + IterGetSetSimd<Item = T>, T: 'a, F, F2> StridedMapSimd<'a, I, T, F, F2> {
37 /// Collects the results of applying the SIMD operations into a new tensor of type `U`.
38 ///
39 /// This method applies two functions, `f` for individual elements and `f2` for SIMD elements,
40 /// and collects the results into a new tensor of type `U`.
41 ///
42 /// # Type Parameters
43 /// - `U`: The type of the resulting tensor, which must implement `TensorAlloc` and `TensorInfo`.
44 ///
45 /// # Returns
46 /// A new tensor of type `U` containing the results of applying the functions.
47 pub fn collect<U>(self) -> U
48 where
49 F: Fn(T) -> U::Meta + Sync + Send + 'a,
50 F2: Fn(
51 <I as IterGetSetSimd>::SimdItem,
52 ) -> <<U as TensorCreator>::Meta as TypeCommon>::Vec
53 + Sync
54 + Send
55 + 'a,
56 U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
57 <I as IterGetSetSimd>::Item: Send,
58 <U as TensorCreator>::Meta: CommonBounds,
59 <<U as TensorCreator>::Meta as TypeCommon>::Vec: Send,
60 {
61 let res = U::empty(self.iter.shape().clone()).unwrap();
62 let strided_mut = StridedMapMutSimd::new(res.clone());
63 let zip = strided_mut.zip(self.iter);
64 zip.for_each(
65 |(x, y)| {
66 *x = (self.f)(y);
67 },
68 |(x, y)| {
69 *x = (self.f2)(y);
70 },
71 );
72 res
73 }
74 }
75}
76
77/// # StridedMap
78///
79/// This struct is used to apply a function over a tensor or array that has a strided memory layout.
80/// The function `f` is applied to each element in the tensor as accessed through the strided iterator.
81///
82/// # Type Parameters
83/// - `'a`: The lifetime associated with the data being processed.
84/// - `I`: The iterator type that implements the `IterGetSet` trait, providing access to the elements.
85/// - `T`: The type of the elements in the tensor.
86/// - `F`: The function type that is applied to the tensor's elements.
87#[derive(Clone)]
88pub struct StridedMap<'a, I, T: 'a, F>
89where
90 I: 'a + IterGetSet<Item = T>,
91{
92 /// The iterator providing access to the tensor elements.
93 pub(crate) iter: I,
94 /// The function to be applied to each element of the tensor.
95 pub(crate) f: F,
96 /// Marker for the lifetime `'a`.
97 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
98}
99impl<'a, I: 'a + IterGetSet<Item = T>, T: 'a, F> StridedMap<'a, I, T, F> {
100 /// Collects the results of applying the function `f` to the elements of the tensor into a new tensor of type `U`.
101 ///
102 /// This method applies the function `f` to each element in the tensor and collects the results in a new tensor `U`.
103 ///
104 /// # Type Parameters
105 /// - `U`: The type of the resulting tensor, which must implement the `TensorAlloc` and `TensorInfo` traits.
106 ///
107 /// # Returns
108 /// Returns a new tensor of type `U` containing the results of applying the function `f`.
109 pub fn collect<U>(self) -> U
110 where
111 F: Fn(T) -> U::Meta + Sync + Send + 'a,
112 U: Clone + TensorInfo<U::Meta> + TensorCreator<Output = U>,
113 <I as IterGetSet>::Item: Send,
114 <U as TensorCreator>::Meta: CommonBounds,
115 {
116 let res = U::empty(self.iter.shape().clone()).unwrap();
117 let strided_mut = StridedMapMut::new(res.clone());
118 let zip = strided_mut.zip(self.iter);
119 zip.for_each(|(x, y)| {
120 *x = (self.f)(y);
121 });
122 res
123 }
124}