Skip to main content

rstsr_core/tensor/operators/
op_unary_common.rs

1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4/* Structure of implementation
5
6Most unary functions are of the same type. However, there are some exceptions, and some of them are very common used functions.
7
8- `same type`: Input and Output are of the same type. They can be implemented in an inplace manner.
9- `boolean output`: Output is boolean. Not able for inplace operation.
10- `Imag, Real, Abs`:
11    - complex: generalized, not for inplace.
12    - real: specialized, for inplace.
13- `Sign`:
14    - complex: generalized, for inplace.
15    - real: specialized, for inplace.
16
17*/
18
19/* #region tensor traits */
20
21macro_rules! trait_unary {
22    ($op: ident, $op_f: ident, $TensorOpAPI: ident) => {
23        pub trait $TensorOpAPI {
24            type Output;
25            fn $op_f(self) -> Result<Self::Output>;
26            fn $op(self) -> Self::Output
27            where
28                Self: Sized,
29            {
30                self.$op_f().rstsr_unwrap()
31            }
32        }
33
34        pub fn $op_f<TRA, TRB>(a: TRA) -> Result<TRB>
35        where
36            TRA: $TensorOpAPI<Output = TRB>,
37        {
38            TRA::$op_f(a)
39        }
40
41        pub fn $op<TRA, TRB>(a: TRA) -> TRB
42        where
43            TRA: $TensorOpAPI<Output = TRB>,
44        {
45            TRA::$op(a)
46        }
47    };
48}
49
50#[rustfmt::skip]
51#[allow(clippy::wrong_self_convention)]
52mod trait_unary {
53    use super::*;
54    trait_unary!(acos      , acos_f      , TensorAcosAPI       );
55    trait_unary!(acosh     , acosh_f     , TensorAcoshAPI      );
56    trait_unary!(asin      , asin_f      , TensorAsinAPI       );
57    trait_unary!(asinh     , asinh_f     , TensorAsinhAPI      );
58    trait_unary!(atan      , atan_f      , TensorAtanAPI       );
59    trait_unary!(atanh     , atanh_f     , TensorAtanhAPI      );
60    trait_unary!(ceil      , ceil_f      , TensorCeilAPI       );
61    trait_unary!(conj      , conj_f      , TensorConjAPI       );
62    trait_unary!(cos       , cos_f       , TensorCosAPI        );
63    trait_unary!(cosh      , cosh_f      , TensorCoshAPI       );
64    trait_unary!(exp       , exp_f       , TensorExpAPI        );
65    trait_unary!(expm1     , expm1_f     , TensorExpm1API      );
66    trait_unary!(floor     , floor_f     , TensorFloorAPI      );
67    trait_unary!(inv       , inv_f       , TensorInvAPI        );
68    trait_unary!(log       , log_f       , TensorLogAPI        );
69    trait_unary!(log1p     , log1p_f     , TensorLog1pAPI      );
70    trait_unary!(log2      , log2_f      , TensorLog2API       );
71    trait_unary!(log10     , log10_f     , TensorLog10API      );
72    trait_unary!(reciprocal, reciprocal_f , TensorReciprocalAPI);
73    trait_unary!(round     , round_f     , TensorRoundAPI      );
74    trait_unary!(signbit   , signbit_f   , TensorSignBitAPI    );
75    trait_unary!(sin       , sin_f       , TensorSinAPI        );
76    trait_unary!(sinh      , sinh_f      , TensorSinhAPI       );
77    trait_unary!(square    , square_f    , TensorSquareAPI     );
78    trait_unary!(sqrt      , sqrt_f      , TensorSqrtAPI       );
79    trait_unary!(tan       , tan_f       , TensorTanAPI        );
80    trait_unary!(tanh      , tanh_f      , TensorTanhAPI       );
81    trait_unary!(trunc     , trunc_f     , TensorTruncAPI      );
82    trait_unary!(is_finite , is_finite_f , TensorIsFiniteAPI   );
83    trait_unary!(is_inf    , is_inf_f    , TensorIsInfAPI      );
84    trait_unary!(is_nan    , is_nan_f    , TensorIsNanAPI      );
85
86    trait_unary!(abs  , abs_f  , TensorAbsAPI  );
87    trait_unary!(real , real_f , TensorRealAPI );
88    trait_unary!(imag , imag_f , TensorImagAPI );
89    trait_unary!(sign , sign_f , TensorSignAPI );
90}
91
92pub use trait_unary::*;
93
94/* #endregion */
95
96/* #region impl tensor unary common */
97
98#[duplicate_item(
99    op_f           TensorOpAPI             OpAPI           ;
100   [acos_f      ] [TensorAcosAPI      ] [OpAcosAPI      ];
101   [acosh_f     ] [TensorAcoshAPI     ] [OpAcoshAPI     ];
102   [asin_f      ] [TensorAsinAPI      ] [OpAsinAPI      ];
103   [asinh_f     ] [TensorAsinhAPI     ] [OpAsinhAPI     ];
104   [atan_f      ] [TensorAtanAPI      ] [OpAtanAPI      ];
105   [atanh_f     ] [TensorAtanhAPI     ] [OpAtanhAPI     ];
106   [ceil_f      ] [TensorCeilAPI      ] [OpCeilAPI      ];
107   [conj_f      ] [TensorConjAPI      ] [OpConjAPI      ];
108   [cos_f       ] [TensorCosAPI       ] [OpCosAPI       ];
109   [cosh_f      ] [TensorCoshAPI      ] [OpCoshAPI      ];
110   [exp_f       ] [TensorExpAPI       ] [OpExpAPI       ];
111   [expm1_f     ] [TensorExpm1API     ] [OpExpm1API     ];
112   [floor_f     ] [TensorFloorAPI     ] [OpFloorAPI     ];
113   [inv_f       ] [TensorInvAPI       ] [OpInvAPI       ];
114   [is_finite_f ] [TensorIsFiniteAPI  ] [OpIsFiniteAPI  ];
115   [is_inf_f    ] [TensorIsInfAPI     ] [OpIsInfAPI     ];
116   [is_nan_f    ] [TensorIsNanAPI     ] [OpIsNanAPI     ];
117   [log_f       ] [TensorLogAPI       ] [OpLogAPI       ];
118   [log1p_f     ] [TensorLog1pAPI     ] [OpLog1pAPI     ];
119   [log2_f      ] [TensorLog2API      ] [OpLog2API      ];
120   [log10_f     ] [TensorLog10API     ] [OpLog10API     ];
121   [reciprocal_f] [TensorReciprocalAPI] [OpReciprocalAPI];
122   [round_f     ] [TensorRoundAPI     ] [OpRoundAPI     ];
123   [signbit_f   ] [TensorSignBitAPI   ] [OpSignBitAPI   ];
124   [sin_f       ] [TensorSinAPI       ] [OpSinAPI       ];
125   [sinh_f      ] [TensorSinhAPI      ] [OpSinhAPI      ];
126   [square_f    ] [TensorSquareAPI    ] [OpSquareAPI    ];
127   [sqrt_f      ] [TensorSqrtAPI      ] [OpSqrtAPI      ];
128   [tan_f       ] [TensorTanAPI       ] [OpTanAPI       ];
129   [tanh_f      ] [TensorTanhAPI      ] [OpTanhAPI      ];
130   [trunc_f     ] [TensorTruncAPI     ] [OpTruncAPI     ];
131   [abs_f       ] [TensorAbsAPI       ] [OpAbsAPI       ];
132   [imag_f      ] [TensorImagAPI      ] [OpImagAPI      ];
133   [real_f      ] [TensorRealAPI      ] [OpRealAPI      ];
134   [sign_f      ] [TensorSignAPI      ] [OpSignAPI      ];
135)]
136mod impl_tensor_unary_common {
137    use super::*;
138
139    // any types allowed
140    impl<R, T, B, D> TensorOpAPI for &TensorAny<R, T, B, D>
141    where
142        D: DimAPI,
143        R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
144        B: DeviceAPI<T> + DeviceAPI<B::TOut>,
145        B: OpAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
146    {
147        type Output = Tensor<B::TOut, B, D>;
148        fn op_f(self) -> Result<Self::Output> {
149            let lb = self.layout();
150            // generate empty output tensor
151            let device = self.device();
152            let la = layout_for_array_copy(lb, TensorIterOrder::K)?;
153            let mut storage_a = device.uninit_impl(la.bounds_index()?.1)?;
154            // compute and return
155            device.op_muta_refb(storage_a.raw_mut(), &la, self.raw(), lb)?;
156            let storage_a = unsafe { B::assume_init_impl(storage_a) }?;
157            return Tensor::new_f(storage_a, la);
158        }
159    }
160
161    // any types allowed
162    impl<T, B, D> TensorOpAPI for TensorView<'_, T, B, D>
163    where
164        D: DimAPI,
165        B: DeviceAPI<T> + DeviceAPI<B::TOut>,
166        B: OpAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
167    {
168        type Output = Tensor<B::TOut, B, D>;
169        fn op_f(self) -> Result<Self::Output> {
170            TensorOpAPI::op_f(&self)
171        }
172    }
173
174    // same types allowed
175    impl<T, B, D> TensorOpAPI for Tensor<T, B, D>
176    where
177        D: DimAPI,
178        B: DeviceAPI<T>,
179        B: OpAPI<T, D, TOut = T> + DeviceCreationAnyAPI<T>,
180    {
181        type Output = Tensor<T, B, D>;
182        fn op_f(mut self) -> Result<Self::Output> {
183            let layout = self.layout().clone();
184            let device = self.device().clone();
185            // generate empty output tensor
186            let self_raw_mut = unsafe {
187                transmute::<&mut <B as DeviceRawAPI<T>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<T>>>::Raw>(
188                    self.raw_mut(),
189                )
190            };
191            device.op_muta(self_raw_mut, &layout)?;
192            return Ok(self);
193        }
194    }
195}
196
197/* #endregion */
198
199#[cfg(test)]
200mod test {
201    use super::*;
202
203    #[test]
204    fn test_same_type() {
205        let a = arange(6.0).into_shape([2, 3]).into_owned();
206        let b = sin(&a);
207        println!("{b:}");
208        let b = a.view().sin();
209        println!("{b:}");
210
211        let ptr_a = a.raw().as_ptr();
212        let b = a.sin();
213        let ptr_b = b.raw().as_ptr();
214        assert_eq!(ptr_a, ptr_b);
215    }
216
217    #[test]
218    fn test_sign() {
219        use num::complex::c64;
220        let a = linspace((c64(1.0, 2.0), c64(5.0, 6.0), 6)).into_shape([2, 3]);
221        let b = (&a).sign();
222        let vec_b = b.reshape([6]).to_vec();
223        let b_abs_sum = vec_b.iter().map(|x| x.norm()).sum::<f64>();
224        println!("{b:}");
225        assert!(b_abs_sum - 6.0 < 1e-6);
226    }
227
228    #[test]
229    fn test_abs() {
230        use num::complex::c32;
231        let a = linspace((c32(1.0, 2.0), c32(5.0, 6.0), 6)).into_shape([2, 3]);
232        let ptr_a = a.raw().as_ptr();
233        let b = a.abs();
234        let ptr_b = b.raw().as_ptr();
235        println!("{b:}");
236        println!("{ptr_a:?}");
237        println!("{ptr_b:?}");
238        // for complex case, only abs(&a) is valid
239        println!("{a:}");
240
241        let a = linspace((-3.0f64, 3.0f64, 6)).into_shape([2, 3]);
242        let ptr_a = a.raw().as_ptr();
243        let b = a.abs();
244        let ptr_b = b.raw().as_ptr();
245        println!("{b:}");
246        assert_eq!(ptr_a, ptr_b);
247        // for f64 case, `a.abs()` will try to consume variable `a`
248        // println!("{:?}", a);
249    }
250
251    #[test]
252    fn test_hetrogeneous_type() {
253        use num::complex::c32;
254        let a = linspace((c32(1.0, 2.0), c32(5.0, 6.0), 6)).into_shape([2, 3]);
255        let b = (&a).imag();
256        println!("{b:}");
257    }
258
259    #[test]
260    fn test_cpuserial() {
261        let a = linspace((1.0, 5.0, 5, &DeviceCpuSerial::default()));
262        let b = a.sin();
263        println!("{b:}");
264    }
265}