rstsr_core/tensor/operators/
op_unary_common.rs1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4macro_rules! trait_unary {
22 ($op: ident, $op_f: ident, $TensorOpAPI: ident) => {
23 pub trait $TensorOpAPI {
24 type Output;
25 fn $op_f(self) -> Result<Self::Output>;
26 fn $op(self) -> Self::Output
27 where
28 Self: Sized,
29 {
30 self.$op_f().rstsr_unwrap()
31 }
32 }
33
34 pub fn $op_f<TRA, TRB>(a: TRA) -> Result<TRB>
35 where
36 TRA: $TensorOpAPI<Output = TRB>,
37 {
38 TRA::$op_f(a)
39 }
40
41 pub fn $op<TRA, TRB>(a: TRA) -> TRB
42 where
43 TRA: $TensorOpAPI<Output = TRB>,
44 {
45 TRA::$op(a)
46 }
47 };
48}
49
50#[rustfmt::skip]
51#[allow(clippy::wrong_self_convention)]
52mod trait_unary {
53 use super::*;
54 trait_unary!(acos , acos_f , TensorAcosAPI );
55 trait_unary!(acosh , acosh_f , TensorAcoshAPI );
56 trait_unary!(asin , asin_f , TensorAsinAPI );
57 trait_unary!(asinh , asinh_f , TensorAsinhAPI );
58 trait_unary!(atan , atan_f , TensorAtanAPI );
59 trait_unary!(atanh , atanh_f , TensorAtanhAPI );
60 trait_unary!(ceil , ceil_f , TensorCeilAPI );
61 trait_unary!(conj , conj_f , TensorConjAPI );
62 trait_unary!(cos , cos_f , TensorCosAPI );
63 trait_unary!(cosh , cosh_f , TensorCoshAPI );
64 trait_unary!(exp , exp_f , TensorExpAPI );
65 trait_unary!(expm1 , expm1_f , TensorExpm1API );
66 trait_unary!(floor , floor_f , TensorFloorAPI );
67 trait_unary!(inv , inv_f , TensorInvAPI );
68 trait_unary!(log , log_f , TensorLogAPI );
69 trait_unary!(log1p , log1p_f , TensorLog1pAPI );
70 trait_unary!(log2 , log2_f , TensorLog2API );
71 trait_unary!(log10 , log10_f , TensorLog10API );
72 trait_unary!(reciprocal, reciprocal_f , TensorReciprocalAPI);
73 trait_unary!(round , round_f , TensorRoundAPI );
74 trait_unary!(signbit , signbit_f , TensorSignBitAPI );
75 trait_unary!(sin , sin_f , TensorSinAPI );
76 trait_unary!(sinh , sinh_f , TensorSinhAPI );
77 trait_unary!(square , square_f , TensorSquareAPI );
78 trait_unary!(sqrt , sqrt_f , TensorSqrtAPI );
79 trait_unary!(tan , tan_f , TensorTanAPI );
80 trait_unary!(tanh , tanh_f , TensorTanhAPI );
81 trait_unary!(trunc , trunc_f , TensorTruncAPI );
82 trait_unary!(is_finite , is_finite_f , TensorIsFiniteAPI );
83 trait_unary!(is_inf , is_inf_f , TensorIsInfAPI );
84 trait_unary!(is_nan , is_nan_f , TensorIsNanAPI );
85
86 trait_unary!(abs , abs_f , TensorAbsAPI );
87 trait_unary!(real , real_f , TensorRealAPI );
88 trait_unary!(imag , imag_f , TensorImagAPI );
89 trait_unary!(sign , sign_f , TensorSignAPI );
90}
91
92pub use trait_unary::*;
93
94#[duplicate_item(
99 op_f TensorOpAPI OpAPI ;
100 [acos_f ] [TensorAcosAPI ] [OpAcosAPI ];
101 [acosh_f ] [TensorAcoshAPI ] [OpAcoshAPI ];
102 [asin_f ] [TensorAsinAPI ] [OpAsinAPI ];
103 [asinh_f ] [TensorAsinhAPI ] [OpAsinhAPI ];
104 [atan_f ] [TensorAtanAPI ] [OpAtanAPI ];
105 [atanh_f ] [TensorAtanhAPI ] [OpAtanhAPI ];
106 [ceil_f ] [TensorCeilAPI ] [OpCeilAPI ];
107 [conj_f ] [TensorConjAPI ] [OpConjAPI ];
108 [cos_f ] [TensorCosAPI ] [OpCosAPI ];
109 [cosh_f ] [TensorCoshAPI ] [OpCoshAPI ];
110 [exp_f ] [TensorExpAPI ] [OpExpAPI ];
111 [expm1_f ] [TensorExpm1API ] [OpExpm1API ];
112 [floor_f ] [TensorFloorAPI ] [OpFloorAPI ];
113 [inv_f ] [TensorInvAPI ] [OpInvAPI ];
114 [is_finite_f ] [TensorIsFiniteAPI ] [OpIsFiniteAPI ];
115 [is_inf_f ] [TensorIsInfAPI ] [OpIsInfAPI ];
116 [is_nan_f ] [TensorIsNanAPI ] [OpIsNanAPI ];
117 [log_f ] [TensorLogAPI ] [OpLogAPI ];
118 [log1p_f ] [TensorLog1pAPI ] [OpLog1pAPI ];
119 [log2_f ] [TensorLog2API ] [OpLog2API ];
120 [log10_f ] [TensorLog10API ] [OpLog10API ];
121 [reciprocal_f] [TensorReciprocalAPI] [OpReciprocalAPI];
122 [round_f ] [TensorRoundAPI ] [OpRoundAPI ];
123 [signbit_f ] [TensorSignBitAPI ] [OpSignBitAPI ];
124 [sin_f ] [TensorSinAPI ] [OpSinAPI ];
125 [sinh_f ] [TensorSinhAPI ] [OpSinhAPI ];
126 [square_f ] [TensorSquareAPI ] [OpSquareAPI ];
127 [sqrt_f ] [TensorSqrtAPI ] [OpSqrtAPI ];
128 [tan_f ] [TensorTanAPI ] [OpTanAPI ];
129 [tanh_f ] [TensorTanhAPI ] [OpTanhAPI ];
130 [trunc_f ] [TensorTruncAPI ] [OpTruncAPI ];
131 [abs_f ] [TensorAbsAPI ] [OpAbsAPI ];
132 [imag_f ] [TensorImagAPI ] [OpImagAPI ];
133 [real_f ] [TensorRealAPI ] [OpRealAPI ];
134 [sign_f ] [TensorSignAPI ] [OpSignAPI ];
135)]
136mod impl_tensor_unary_common {
137 use super::*;
138
139 impl<R, T, B, D> TensorOpAPI for &TensorAny<R, T, B, D>
141 where
142 D: DimAPI,
143 R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
144 B: DeviceAPI<T> + DeviceAPI<B::TOut>,
145 B: OpAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
146 {
147 type Output = Tensor<B::TOut, B, D>;
148 fn op_f(self) -> Result<Self::Output> {
149 let lb = self.layout();
150 let device = self.device();
152 let la = layout_for_array_copy(lb, TensorIterOrder::K)?;
153 let mut storage_a = device.uninit_impl(la.bounds_index()?.1)?;
154 device.op_muta_refb(storage_a.raw_mut(), &la, self.raw(), lb)?;
156 let storage_a = unsafe { B::assume_init_impl(storage_a) }?;
157 return Tensor::new_f(storage_a, la);
158 }
159 }
160
161 impl<T, B, D> TensorOpAPI for TensorView<'_, T, B, D>
163 where
164 D: DimAPI,
165 B: DeviceAPI<T> + DeviceAPI<B::TOut>,
166 B: OpAPI<T, D> + DeviceCreationAnyAPI<B::TOut>,
167 {
168 type Output = Tensor<B::TOut, B, D>;
169 fn op_f(self) -> Result<Self::Output> {
170 TensorOpAPI::op_f(&self)
171 }
172 }
173
174 impl<T, B, D> TensorOpAPI for Tensor<T, B, D>
176 where
177 D: DimAPI,
178 B: DeviceAPI<T>,
179 B: OpAPI<T, D, TOut = T> + DeviceCreationAnyAPI<T>,
180 {
181 type Output = Tensor<T, B, D>;
182 fn op_f(mut self) -> Result<Self::Output> {
183 let layout = self.layout().clone();
184 let device = self.device().clone();
185 let self_raw_mut = unsafe {
187 transmute::<&mut <B as DeviceRawAPI<T>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<T>>>::Raw>(
188 self.raw_mut(),
189 )
190 };
191 device.op_muta(self_raw_mut, &layout)?;
192 return Ok(self);
193 }
194 }
195}
196
197#[cfg(test)]
200mod test {
201 use super::*;
202
203 #[test]
204 fn test_same_type() {
205 let a = arange(6.0).into_shape([2, 3]).into_owned();
206 let b = sin(&a);
207 println!("{b:}");
208 let b = a.view().sin();
209 println!("{b:}");
210
211 let ptr_a = a.raw().as_ptr();
212 let b = a.sin();
213 let ptr_b = b.raw().as_ptr();
214 assert_eq!(ptr_a, ptr_b);
215 }
216
217 #[test]
218 fn test_sign() {
219 use num::complex::c64;
220 let a = linspace((c64(1.0, 2.0), c64(5.0, 6.0), 6)).into_shape([2, 3]);
221 let b = (&a).sign();
222 let vec_b = b.reshape([6]).to_vec();
223 let b_abs_sum = vec_b.iter().map(|x| x.norm()).sum::<f64>();
224 println!("{b:}");
225 assert!(b_abs_sum - 6.0 < 1e-6);
226 }
227
228 #[test]
229 fn test_abs() {
230 use num::complex::c32;
231 let a = linspace((c32(1.0, 2.0), c32(5.0, 6.0), 6)).into_shape([2, 3]);
232 let ptr_a = a.raw().as_ptr();
233 let b = a.abs();
234 let ptr_b = b.raw().as_ptr();
235 println!("{b:}");
236 println!("{ptr_a:?}");
237 println!("{ptr_b:?}");
238 println!("{a:}");
240
241 let a = linspace((-3.0f64, 3.0f64, 6)).into_shape([2, 3]);
242 let ptr_a = a.raw().as_ptr();
243 let b = a.abs();
244 let ptr_b = b.raw().as_ptr();
245 println!("{b:}");
246 assert_eq!(ptr_a, ptr_b);
247 }
250
251 #[test]
252 fn test_hetrogeneous_type() {
253 use num::complex::c32;
254 let a = linspace((c32(1.0, 2.0), c32(5.0, 6.0), 6)).into_shape([2, 3]);
255 let b = (&a).imag();
256 println!("{b:}");
257 }
258
259 #[test]
260 fn test_cpuserial() {
261 let a = linspace((1.0, 5.0, 5, &DeviceCpuSerial::default()));
262 let b = a.sin();
263 println!("{b:}");
264 }
265}