hpt_iterator/strided_map.rs
1use crate::iterator_traits::StridedIterator;
2use crate::{iterator_traits::IterGetSet, strided_map_mut::StridedMapMut};
3use hpt_traits::tensor::{CommonBounds, TensorAlloc, TensorInfo};
4
5/// A module for strided map simd iterator.
6pub mod strided_map_simd {
7 use crate::iterator_traits::StridedIteratorSimd;
8 use crate::{
9 iterator_traits::IterGetSetSimd, strided_map_mut::strided_map_mut_simd::StridedMapMutSimd,
10 };
11 use hpt_traits::{CommonBounds, TensorAlloc, TensorInfo};
12 use hpt_types::dtype::TypeCommon;
13
14 /// # StridedMapSimd
15 /// A structure representing a SIMD (Single Instruction, Multiple Data) version of the strided map operation.
16 ///
17 /// This structure allows applying SIMD operations on tensors or arrays with strided memory layouts,
18 /// optimizing performance for certain types of computations.
19 #[derive(Clone)]
20 pub struct StridedMapSimd<'a, I, T: 'a, F, F2>
21 where
22 I: 'a + IterGetSetSimd<Item = T>,
23 {
24 /// An iterator implementing the `IterGetSetSimd` trait, providing access to the tensor elements.
25 pub(crate) iter: I,
26 /// A function to be applied to individual elements of the tensor
27 pub(crate) f: F,
28 /// A function to be applied to SIMD items of the tensor
29 pub(crate) f2: F2,
30 /// A marker for lifetimes to ensure proper memory safety
31 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
32 }
33
34 impl<'a, I: 'a + IterGetSetSimd<Item = T>, T: 'a, F, F2> StridedMapSimd<'a, I, T, F, F2> {
35 /// Collects the results of applying the SIMD operations into a new tensor of type `U`.
36 ///
37 /// This method applies two functions, `f` for individual elements and `f2` for SIMD elements,
38 /// and collects the results into a new tensor of type `U`.
39 ///
40 /// # Type Parameters
41 /// - `U`: The type of the resulting tensor, which must implement `TensorAlloc` and `TensorInfo`.
42 ///
43 /// # Returns
44 /// A new tensor of type `U` containing the results of applying the functions.
45 pub fn collect<U>(self) -> U
46 where
47 F: Fn(T) -> U::Meta + Sync + Send + 'a,
48 F2: Fn(<I as IterGetSetSimd>::SimdItem) -> <<U as TensorAlloc>::Meta as TypeCommon>::Vec
49 + Sync
50 + Send
51 + 'a,
52 U: Clone + TensorInfo<U::Meta> + TensorAlloc,
53 <I as IterGetSetSimd>::Item: Send,
54 <U as TensorAlloc>::Meta: CommonBounds,
55 <<U as TensorAlloc>::Meta as TypeCommon>::Vec: Send,
56 {
57 let res = U::_empty(self.iter.shape().clone()).unwrap();
58 let strided_mut = StridedMapMutSimd::new(res.clone());
59 let zip = strided_mut.zip(self.iter);
60 zip.for_each(
61 |(x, y)| {
62 *x = (self.f)(y);
63 },
64 |(x, y)| {
65 *x = (self.f2)(y);
66 },
67 );
68 res
69 }
70 }
71}
72
73/// # StridedMap
74///
75/// This struct is used to apply a function over a tensor or array that has a strided memory layout.
76/// The function `f` is applied to each element in the tensor as accessed through the strided iterator.
77///
78/// # Type Parameters
79/// - `'a`: The lifetime associated with the data being processed.
80/// - `I`: The iterator type that implements the `IterGetSet` trait, providing access to the elements.
81/// - `T`: The type of the elements in the tensor.
82/// - `F`: The function type that is applied to the tensor's elements.
83#[derive(Clone)]
84pub struct StridedMap<'a, I, T: 'a, F>
85where
86 I: 'a + IterGetSet<Item = T>,
87{
88 /// The iterator providing access to the tensor elements.
89 pub(crate) iter: I,
90 /// The function to be applied to each element of the tensor.
91 pub(crate) f: F,
92 /// Marker for the lifetime `'a`.
93 pub(crate) phantom: std::marker::PhantomData<&'a ()>,
94}
95impl<'a, I: 'a + IterGetSet<Item = T>, T: 'a, F> StridedMap<'a, I, T, F> {
96 /// Collects the results of applying the function `f` to the elements of the tensor into a new tensor of type `U`.
97 ///
98 /// This method applies the function `f` to each element in the tensor and collects the results in a new tensor `U`.
99 ///
100 /// # Type Parameters
101 /// - `U`: The type of the resulting tensor, which must implement the `TensorAlloc` and `TensorInfo` traits.
102 ///
103 /// # Returns
104 /// Returns a new tensor of type `U` containing the results of applying the function `f`.
105 pub fn collect<U>(self) -> U
106 where
107 F: Fn(T) -> U::Meta + Sync + Send + 'a,
108 U: Clone + TensorInfo<U::Meta> + TensorAlloc,
109 <I as IterGetSet>::Item: Send,
110 <U as TensorAlloc>::Meta: CommonBounds,
111 {
112 let res = U::_empty(self.iter.shape().clone()).unwrap();
113 let strided_mut = StridedMapMut::new(res.clone());
114 let zip = strided_mut.zip(self.iter);
115 zip.for_each(|(x, y)| {
116 *x = (self.f)(y);
117 });
118 res
119 }
120}