1use crate::prelude_dev::*;
2
3#[duplicate_item(
14 op op_f TensorOpAPI ;
15 [atan2 ] [atan2_f ] [TensorATan2API ];
16 [copysign ] [copysign_f ] [TensorCopySignAPI ];
17 [equal ] [equal_f ] [TensorEqualAPI ];
18 [floor_divide ] [floor_divide_f ] [TensorFloorDivideAPI ];
19 [greater ] [greater_f ] [TensorGreaterAPI ];
20 [greater_equal] [greater_equal_f] [TensorGreaterEqualAPI];
21 [hypot ] [hypot_f ] [TensorHypotAPI ];
22 [less ] [less_f ] [TensorLessAPI ];
23 [less_equal ] [less_equal_f ] [TensorLessEqualAPI ];
24 [log_add_exp ] [log_add_exp_f ] [TensorLogAddExpAPI ];
25 [maximum ] [maximum_f ] [TensorMaximumAPI ];
26 [minimum ] [minimum_f ] [TensorMinimumAPI ];
27 [not_equal ] [not_equal_f ] [TensorNotEqualAPI ];
28 [pow ] [pow_f ] [TensorPowAPI ];
29 [nextafter ] [nextafter_f ] [TensorNextAfterAPI ];
30)]
31pub trait TensorOpAPI<TRB> {
32 type Output;
33 fn op_f(self, b: TRB) -> Result<Self::Output>;
34 fn op(self, b: TRB) -> Self::Output
35 where
36 Self: Sized,
37 {
38 self.op_f(b).rstsr_unwrap()
39 }
40}
41
42#[duplicate_item(
43 op_f TensorOpAPI DeviceOpAPI ;
44 [atan2_f ] [TensorATan2API ] [DeviceATan2API ];
45 [copysign_f ] [TensorCopySignAPI ] [DeviceCopySignAPI ];
46 [equal_f ] [TensorEqualAPI ] [DeviceEqualAPI ];
47 [floor_divide_f ] [TensorFloorDivideAPI ] [DeviceFloorDivideAPI ];
48 [greater_f ] [TensorGreaterAPI ] [DeviceGreaterAPI ];
49 [greater_equal_f] [TensorGreaterEqualAPI] [DeviceGreaterEqualAPI];
50 [hypot_f ] [TensorHypotAPI ] [DeviceHypotAPI ];
51 [less_f ] [TensorLessAPI ] [DeviceLessAPI ];
52 [less_equal_f ] [TensorLessEqualAPI ] [DeviceLessEqualAPI ];
53 [log_add_exp_f ] [TensorLogAddExpAPI ] [DeviceLogAddExpAPI ];
54 [maximum_f ] [TensorMaximumAPI ] [DeviceMaximumAPI ];
55 [minimum_f ] [TensorMinimumAPI ] [DeviceMinimumAPI ];
56 [not_equal_f ] [TensorNotEqualAPI ] [DeviceNotEqualAPI ];
57 [pow_f ] [TensorPowAPI ] [DevicePowAPI ];
58 [nextafter_f ] [TensorNextAfterAPI ] [DeviceNextAfterAPI ];
59)]
60mod impl_trait_binary {
61 use super::*;
62
63 impl<RA, TA, DA, RB, TB, DB, B> TensorOpAPI<&TensorAny<RB, TB, B, DB>> for &TensorAny<RA, TA, B, DA>
64 where
65 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
66 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
67 DA: DimAPI + DimMaxAPI<DB>,
68 DB: DimAPI,
69 DA::Max: DimAPI,
70 B: DeviceOpAPI<TA, TB, DA::Max>,
71 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
72 {
73 type Output = Tensor<B::TOut, B, DA::Max>;
74
75 fn op_f(self, b: &TensorAny<RB, TB, B, DB>) -> Result<Self::Output> {
76 rstsr_assert!(self.device().same_device(b.device()), DeviceMismatch)?;
78
79 let la = self.layout();
81 let lb = b.layout();
82 let default_order = self.device().default_order();
83 let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
84 let lc_from_a = layout_for_array_copy(&la_b, TensorIterOrder::default())?;
85 let lc_from_b = layout_for_array_copy(&lb_b, TensorIterOrder::default())?;
86 let lc = if lc_from_a == lc_from_b {
87 lc_from_a
88 } else {
89 match self.device().default_order() {
90 RowMajor => la_b.shape().c(),
91 ColMajor => la_b.shape().f(),
92 }
93 };
94
95 let device = self.device();
97 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
98 device.op_mutc_refa_refb(storage_c.raw_mut(), &lc, self.raw(), &la_b, b.raw(), &lb_b)?;
99 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
100 Tensor::new_f(storage_c, lc)
101 }
102 }
103
104 #[duplicate_item(
105 ImplType TrA TrB ;
106 [TA, DA, TB, DB, B, R: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>] [&TensorAny<R, TA, B, DA> ] [TensorView<'_, TB, B, DB>];
107 [TA, DA, TB, DB, B, R: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>] [TensorView<'_, TA, B, DA>] [&TensorAny<R, TB, B, DB> ];
108 [TA, DA, TB, DB, B ] [TensorView<'_, TA, B, DA>] [TensorView<'_, TB, B, DB>];
109 )]
110 impl<ImplType> TensorOpAPI<TrB> for TrA
111 where
112 DA: DimAPI + DimMaxAPI<DB>,
113 DB: DimAPI,
114 DA::Max: DimAPI,
115 B: DeviceOpAPI<TA, TB, DA::Max>,
116 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
117 {
118 type Output = Tensor<B::TOut, B, DA::Max>;
119
120 fn op_f(self, b: TrB) -> Result<Self::Output> {
121 TensorOpAPI::op_f(&self.view(), &b.view())
122 }
123 }
124
125 impl<RA, TA, DA, TB, B> TensorOpAPI<TB> for &TensorAny<RA, TA, B, DA>
126 where
127 RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
128 DA: DimAPI,
129 B: DeviceOpAPI<TA, TB, DA>,
130 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
131 TB: num::Num,
132 {
133 type Output = Tensor<B::TOut, B, DA>;
134
135 fn op_f(self, b: TB) -> Result<Self::Output> {
136 let la = self.layout();
138 let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
139
140 let device = self.device();
142 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
143 device.op_mutc_refa_numb(storage_c.raw_mut(), &lc, self.raw(), la, b)?;
144 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
145 Tensor::new_f(storage_c, lc)
146 }
147 }
148
149 impl<TA, DA, TB, B> TensorOpAPI<TB> for TensorView<'_, TA, B, DA>
150 where
151 DA: DimAPI,
152 B: DeviceOpAPI<TA, TB, DA>,
153 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
154 TB: num::Num,
155 {
156 type Output = Tensor<B::TOut, B, DA>;
157
158 fn op_f(self, b: TB) -> Result<Self::Output> {
159 (&self).op_f(b)
160 }
161 }
162
163 impl<RB, TA, DB, TB, B> TensorOpAPI<&TensorAny<RB, TB, B, DB>> for TA
164 where
165 RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
166 DB: DimAPI,
167 B: DeviceOpAPI<TA, TB, DB>,
168 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
169 TA: num::Num,
170 {
171 type Output = Tensor<B::TOut, B, DB>;
172
173 fn op_f(self, b: &TensorAny<RB, TB, B, DB>) -> Result<Self::Output> {
174 let lb = b.layout();
176 let lc = layout_for_array_copy(lb, TensorIterOrder::default())?;
177
178 let device = b.device();
180 let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
181 device.op_mutc_numa_refb(storage_c.raw_mut(), &lc, self, b.raw(), lb)?;
182 let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
183 Tensor::new_f(storage_c, lc)
184 }
185 }
186
187 impl<TA, DB, TB, B> TensorOpAPI<TensorView<'_, TB, B, DB>> for TA
188 where
189 DB: DimAPI,
190 B: DeviceOpAPI<TA, TB, DB>,
191 B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
192 TA: num::Num,
193 {
194 type Output = Tensor<B::TOut, B, DB>;
195
196 fn op_f(self, b: TensorView<'_, TB, B, DB>) -> Result<Self::Output> {
197 TensorOpAPI::op_f(self, &b.view())
198 }
199 }
200}
201
202macro_rules! func_binary {
207 ($op: ident, $op_f: ident, $TensorOpAPI: ident, $DeviceOpAPI: ident, $($op2: ident, $op2_f: ident),*) => {
208 pub fn $op_f<TRA, TRB>(a: TRA, b: TRB) -> Result<TRA::Output>
209 where
210 TRA: $TensorOpAPI<TRB>,
211 {
212 a.$op_f(b)
213 }
214
215 pub fn $op<TRA, TRB>(a: TRA, b: TRB) -> TRA::Output
216 where
217 TRA: $TensorOpAPI<TRB>,
218 {
219 a.$op(b)
220 }
221
222 $(
223 pub fn $op2_f<TRA, TRB>(a: TRA, b: TRB) -> Result<TRA::Output>
224 where
225 TRA: $TensorOpAPI<TRB>,
226 {
227 a.$op_f(b)
228 }
229
230 pub fn $op2<TRA, TRB>(a: TRA, b: TRB) -> TRA::Output
231 where
232 TRA: $TensorOpAPI<TRB>,
233 {
234 a.$op(b)
235 }
236 )*
237 };
238}
239
240#[rustfmt::skip]
241mod func_binary {
242 use super::*;
243 func_binary!(atan2 , atan2_f , TensorATan2API , DeviceATan2API ,);
244 func_binary!(copysign , copysign_f , TensorCopySignAPI , DeviceCopySignAPI ,);
245 func_binary!(floor_divide , floor_divide_f , TensorFloorDivideAPI , DeviceFloorDivideAPI ,);
246 func_binary!(hypot , hypot_f , TensorHypotAPI , DeviceHypotAPI ,);
247 func_binary!(log_add_exp , log_add_exp_f , TensorLogAddExpAPI , DeviceLogAddExpAPI ,);
248 func_binary!(pow , pow_f , TensorPowAPI , DevicePowAPI ,);
249 func_binary!(maximum , maximum_f , TensorMaximumAPI , DeviceMaximumAPI , max, max_f);
250 func_binary!(minimum , minimum_f , TensorMinimumAPI , DeviceMinimumAPI , min, min_f);
251 func_binary!(equal , equal_f , TensorEqualAPI , DeviceEqualAPI , eq, eq_f, equal_than , equal_than_f );
252 func_binary!(less , less_f , TensorLessAPI , DeviceLessAPI , lt, lt_f, less_than , less_than_f );
253 func_binary!(greater , greater_f , TensorGreaterAPI , DeviceGreaterAPI , gt, gt_f, greater_than , greater_than_f );
254 func_binary!(less_equal , less_equal_f , TensorLessEqualAPI , DeviceLessEqualAPI , le, le_f, less_equal_to , less_equal_to_f );
255 func_binary!(greater_equal , greater_equal_f , TensorGreaterEqualAPI , DeviceGreaterEqualAPI , ge, ge_f, greater_equal_to, greater_equal_to_f);
256 func_binary!(not_equal , not_equal_f , TensorNotEqualAPI , DeviceNotEqualAPI , ne, ne_f, not_equal_to , not_equal_to_f );
257 func_binary!(nextafter , nextafter_f , TensorNextAfterAPI , DeviceNextAfterAPI ,);
258}
259
260pub use func_binary::*;
261
262#[cfg(test)]
265mod test {
266 use super::*;
267
268 #[test]
269 fn test_pow() {
270 #[cfg(not(feature = "col_major"))]
271 {
272 let a = arange(6u32).into_shape([2, 3]);
273 let b = arange(3u32);
274 let c = pow(&a, &b);
275 println!("{c:?}");
276 assert_eq!(c.reshape([6]).to_vec(), vec![1, 1, 4, 1, 4, 25]);
277
278 let a = arange(6.0).into_shape([2, 3]);
279
280 let b = arange(3.0);
281 let c = pow(&a, &b);
282 println!("{c:?}");
283 assert_eq!(c.reshape([6]).to_vec(), vec![1.0, 1.0, 4.0, 1.0, 4.0, 25.0]);
284
285 let b = arange(3);
286 let c = pow(&a, &b);
287 println!("{c:?}");
288 assert_eq!(c.reshape([6]).to_vec(), vec![1.0, 1.0, 4.0, 1.0, 4.0, 25.0]);
289 }
290 #[cfg(feature = "col_major")]
291 {
292 let a = arange(6u32).into_shape([3, 2]);
293 let b = arange(3u32);
294 let c = pow(&a, &b);
295 println!("{c:?}");
296 assert_eq!(c.reshape([6]).to_vec(), vec![1, 1, 4, 1, 4, 25]);
297
298 let a = arange(6.0).into_shape([3, 2]);
299
300 let b = arange(3.0);
301 let c = pow(&a, &b);
302 println!("{c:?}");
303 assert_eq!(c.reshape([6]).to_vec(), vec![1.0, 1.0, 4.0, 1.0, 4.0, 25.0]);
304
305 let b = arange(3);
306 let c = pow(&a, &b);
307 println!("{c:?}");
308 assert_eq!(c.reshape([6]).to_vec(), vec![1.0, 1.0, 4.0, 1.0, 4.0, 25.0]);
309 }
310 }
311
312 #[test]
313 fn test_floor_divide() {
314 #[cfg(not(feature = "col_major"))]
315 {
316 let a = arange(6u32).into_shape([2, 3]); let b = asarray(vec![1_i32, 2, 2]); let c = a.floor_divide(&b); println!("{c:?}");
320 assert_eq!(c.reshape([6]).to_vec(), vec![0_i64, 0, 1, 3, 2, 2]);
321
322 let a = arange(6.0).into_shape([2, 3]);
323
324 let b = asarray(vec![1.0, 2.0, 2.0]);
325 let c = a.floor_divide(&b);
326 println!("{c:?}");
327 assert_eq!(c.reshape([6]).to_vec(), vec![0.0, 0.0, 1.0, 3.0, 2.0, 2.0]);
328
329 let b = asarray(vec![0.0, 2.0, 2.0]);
330 let c = a.floor_divide_f(&b);
331 println!("{c:?}");
332 }
333 #[cfg(feature = "col_major")]
334 {
335 let a = arange(6u32).into_shape([3, 2]); let b = asarray(vec![1_i32, 2, 2]); let c = a.floor_divide(&b); println!("{c:?}");
340 assert_eq!(c.reshape([6]).to_vec(), vec![0_i64, 0, 1, 3, 2, 2]);
341
342 let a = arange(6.0).into_shape([3, 2]);
343
344 let b = asarray(vec![1.0, 2.0, 2.0]);
345 let c = a.floor_divide(&b);
346 println!("{c:?}");
347 assert_eq!(c.reshape([6]).to_vec(), vec![0.0, 0.0, 1.0, 3.0, 2.0, 2.0]);
348
349 let b = asarray(vec![0.0, 2.0, 2.0]);
350 let c = a.floor_divide_f(&b);
351 println!("{c:?}");
352 }
353 }
354
355 #[test]
356 fn test_ge_gt() {
357 let a = asarray(vec![1., 2., 3., 4., 5., 6.]);
358 let b = asarray(vec![1., 3., 2., 5., 5., 2.]);
359
360 let c = gt(a.view(), &b);
361 assert_eq!(c.raw(), &[false, false, true, false, false, true]);
362 let c = ge(a.view(), &b);
363 assert_eq!(c.raw(), &[true, false, true, false, true, true]);
364
365 let c_sum = c.sum();
366 assert_eq!(c_sum, 4);
367 }
368
369 #[test]
370 fn test_refa_numb() {
371 let a = asarray(vec![1., 3., 2., 5., 5., 2.]);
372 let b = a.greater_equal(3.0);
373 assert_eq!(b.raw(), &[false, true, false, true, true, false]);
374 let b = a.pow(2);
375 assert_eq!(b.raw(), &[1.0, 9.0, 4.0, 25.0, 25.0, 4.0]);
376 let b = 2.0.pow(a.view());
377 assert_eq!(b.raw(), &[2.0, 8.0, 4.0, 32.0, 32.0, 4.0]);
378 }
379}