rstsr_core/tensor/operators/
op_binary_common.rs

1use crate::prelude_dev::*;
2
3/* Structure of implementation
4
5Exception functions:
6- floor_divide: integer and float are different, so we need to implement two functions
7- pow: different input types occurs
8
9*/
10
11/* #region tensor traits */
12
13#[duplicate_item(
14    op              op_f              TensorOpAPI           ;
15   [atan2        ] [atan2_f        ] [TensorATan2API       ];
16   [copysign     ] [copysign_f     ] [TensorCopySignAPI    ];
17   [equal        ] [equal_f        ] [TensorEqualAPI       ];
18   [floor_divide ] [floor_divide_f ] [TensorFloorDivideAPI ];
19   [greater      ] [greater_f      ] [TensorGreaterAPI     ];
20   [greater_equal] [greater_equal_f] [TensorGreaterEqualAPI];
21   [hypot        ] [hypot_f        ] [TensorHypotAPI       ];
22   [less         ] [less_f         ] [TensorLessAPI        ];
23   [less_equal   ] [less_equal_f   ] [TensorLessEqualAPI   ];
24   [log_add_exp  ] [log_add_exp_f  ] [TensorLogAddExpAPI   ];
25   [maximum      ] [maximum_f      ] [TensorMaximumAPI     ];
26   [minimum      ] [minimum_f      ] [TensorMinimumAPI     ];
27   [not_equal    ] [not_equal_f    ] [TensorNotEqualAPI    ];
28   [pow          ] [pow_f          ] [TensorPowAPI         ];
29   [nextafter    ] [nextafter_f    ] [TensorNextAfterAPI   ];
30)]
31pub trait TensorOpAPI<TRB> {
32    type Output;
33    fn op_f(self, b: TRB) -> Result<Self::Output>;
34    fn op(self, b: TRB) -> Self::Output
35    where
36        Self: Sized,
37    {
38        self.op_f(b).rstsr_unwrap()
39    }
40}
41
42#[duplicate_item(
43    op_f              TensorOpAPI             DeviceOpAPI           ;
44   [atan2_f        ] [TensorATan2API       ] [DeviceATan2API       ];
45   [copysign_f     ] [TensorCopySignAPI    ] [DeviceCopySignAPI    ];
46   [equal_f        ] [TensorEqualAPI       ] [DeviceEqualAPI       ];
47   [floor_divide_f ] [TensorFloorDivideAPI ] [DeviceFloorDivideAPI ];
48   [greater_f      ] [TensorGreaterAPI     ] [DeviceGreaterAPI     ];
49   [greater_equal_f] [TensorGreaterEqualAPI] [DeviceGreaterEqualAPI];
50   [hypot_f        ] [TensorHypotAPI       ] [DeviceHypotAPI       ];
51   [less_f         ] [TensorLessAPI        ] [DeviceLessAPI        ];
52   [less_equal_f   ] [TensorLessEqualAPI   ] [DeviceLessEqualAPI   ];
53   [log_add_exp_f  ] [TensorLogAddExpAPI   ] [DeviceLogAddExpAPI   ];
54   [maximum_f      ] [TensorMaximumAPI     ] [DeviceMaximumAPI     ];
55   [minimum_f      ] [TensorMinimumAPI     ] [DeviceMinimumAPI     ];
56   [not_equal_f    ] [TensorNotEqualAPI    ] [DeviceNotEqualAPI    ];
57   [pow_f          ] [TensorPowAPI         ] [DevicePowAPI         ];
58   [nextafter_f    ] [TensorNextAfterAPI   ] [DeviceNextAfterAPI   ];
59)]
60mod impl_trait_binary {
61    use super::*;
62
63    impl<RA, TA, DA, RB, TB, DB, B> TensorOpAPI<&TensorAny<RB, TB, B, DB>> for &TensorAny<RA, TA, B, DA>
64    where
65        RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
66        RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
67        DA: DimAPI + DimMaxAPI<DB>,
68        DB: DimAPI,
69        DA::Max: DimAPI,
70        B: DeviceOpAPI<TA, TB, DA::Max>,
71        B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
72    {
73        type Output = Tensor<B::TOut, B, DA::Max>;
74
75        fn op_f(self, b: &TensorAny<RB, TB, B, DB>) -> Result<Self::Output> {
76            // check device
77            rstsr_assert!(self.device().same_device(b.device()), DeviceMismatch)?;
78
79            // check and broadcast layout
80            let la = self.layout();
81            let lb = b.layout();
82            let default_order = self.device().default_order();
83            let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
84            let lc_from_a = layout_for_array_copy(&la_b, TensorIterOrder::default())?;
85            let lc_from_b = layout_for_array_copy(&lb_b, TensorIterOrder::default())?;
86            let lc = if lc_from_a == lc_from_b {
87                lc_from_a
88            } else {
89                match self.device().default_order() {
90                    RowMajor => la_b.shape().c(),
91                    ColMajor => la_b.shape().f(),
92                }
93            };
94
95            // perform operation and return
96            let device = self.device();
97            let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
98            device.op_mutc_refa_refb(storage_c.raw_mut(), &lc, self.raw(), &la_b, b.raw(), &lb_b)?;
99            let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
100            Tensor::new_f(storage_c, lc)
101        }
102    }
103
104    #[duplicate_item(
105        ImplType                                                             TrA                         TrB                       ;
106       [TA, DA, TB, DB, B, R: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>] [&TensorAny<R, TA, B, DA> ] [TensorView<'_, TB, B, DB>];
107       [TA, DA, TB, DB, B, R: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>] [TensorView<'_, TA, B, DA>] [&TensorAny<R, TB, B, DB> ];
108       [TA, DA, TB, DB, B                                                 ] [TensorView<'_, TA, B, DA>] [TensorView<'_, TB, B, DB>];
109    )]
110    impl<ImplType> TensorOpAPI<TrB> for TrA
111    where
112        DA: DimAPI + DimMaxAPI<DB>,
113        DB: DimAPI,
114        DA::Max: DimAPI,
115        B: DeviceOpAPI<TA, TB, DA::Max>,
116        B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
117    {
118        type Output = Tensor<B::TOut, B, DA::Max>;
119
120        fn op_f(self, b: TrB) -> Result<Self::Output> {
121            TensorOpAPI::op_f(&self.view(), &b.view())
122        }
123    }
124
125    impl<RA, TA, DA, TB, B> TensorOpAPI<TB> for &TensorAny<RA, TA, B, DA>
126    where
127        RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
128        DA: DimAPI,
129        B: DeviceOpAPI<TA, TB, DA>,
130        B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
131        TB: num::Num,
132    {
133        type Output = Tensor<B::TOut, B, DA>;
134
135        fn op_f(self, b: TB) -> Result<Self::Output> {
136            // check and broadcast layout
137            let la = self.layout();
138            let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
139
140            // perform operation and return
141            let device = self.device();
142            let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
143            device.op_mutc_refa_numb(storage_c.raw_mut(), &lc, self.raw(), la, b)?;
144            let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
145            Tensor::new_f(storage_c, lc)
146        }
147    }
148
149    impl<TA, DA, TB, B> TensorOpAPI<TB> for TensorView<'_, TA, B, DA>
150    where
151        DA: DimAPI,
152        B: DeviceOpAPI<TA, TB, DA>,
153        B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
154        TB: num::Num,
155    {
156        type Output = Tensor<B::TOut, B, DA>;
157
158        fn op_f(self, b: TB) -> Result<Self::Output> {
159            (&self).op_f(b)
160        }
161    }
162
163    impl<RB, TA, DB, TB, B> TensorOpAPI<&TensorAny<RB, TB, B, DB>> for TA
164    where
165        RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
166        DB: DimAPI,
167        B: DeviceOpAPI<TA, TB, DB>,
168        B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
169        TA: num::Num,
170    {
171        type Output = Tensor<B::TOut, B, DB>;
172
173        fn op_f(self, b: &TensorAny<RB, TB, B, DB>) -> Result<Self::Output> {
174            // check and broadcast layout
175            let lb = b.layout();
176            let lc = layout_for_array_copy(lb, TensorIterOrder::default())?;
177
178            // perform operation and return
179            let device = b.device();
180            let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
181            device.op_mutc_numa_refb(storage_c.raw_mut(), &lc, self, b.raw(), lb)?;
182            let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
183            Tensor::new_f(storage_c, lc)
184        }
185    }
186
187    impl<TA, DB, TB, B> TensorOpAPI<TensorView<'_, TB, B, DB>> for TA
188    where
189        DB: DimAPI,
190        B: DeviceOpAPI<TA, TB, DB>,
191        B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<B::TOut> + DeviceCreationAnyAPI<B::TOut>,
192        TA: num::Num,
193    {
194        type Output = Tensor<B::TOut, B, DB>;
195
196        fn op_f(self, b: TensorView<'_, TB, B, DB>) -> Result<Self::Output> {
197            TensorOpAPI::op_f(self, &b.view())
198        }
199    }
200}
201
202/* #endregion */
203
204/* #region function impl */
205
206macro_rules! func_binary {
207    ($op: ident, $op_f: ident, $TensorOpAPI: ident, $DeviceOpAPI: ident, $($op2: ident, $op2_f: ident),*) => {
208        pub fn $op_f<TRA, TRB>(a: TRA, b: TRB) -> Result<TRA::Output>
209        where
210            TRA: $TensorOpAPI<TRB>,
211        {
212            a.$op_f(b)
213        }
214
215        pub fn $op<TRA, TRB>(a: TRA, b: TRB) -> TRA::Output
216        where
217            TRA: $TensorOpAPI<TRB>,
218        {
219            a.$op(b)
220        }
221
222        $(
223            pub fn $op2_f<TRA, TRB>(a: TRA, b: TRB) -> Result<TRA::Output>
224            where
225                TRA: $TensorOpAPI<TRB>,
226            {
227                a.$op_f(b)
228            }
229
230            pub fn $op2<TRA, TRB>(a: TRA, b: TRB) -> TRA::Output
231            where
232                TRA: $TensorOpAPI<TRB>,
233            {
234                a.$op(b)
235            }
236        )*
237    };
238}
239
240#[rustfmt::skip]
241mod func_binary {
242    use super::*;
243    func_binary!(atan2         , atan2_f           , TensorATan2API            , DeviceATan2API            ,);
244    func_binary!(copysign      , copysign_f        , TensorCopySignAPI         , DeviceCopySignAPI         ,);
245    func_binary!(floor_divide  , floor_divide_f    , TensorFloorDivideAPI      , DeviceFloorDivideAPI      ,);
246    func_binary!(hypot         , hypot_f           , TensorHypotAPI            , DeviceHypotAPI            ,);
247    func_binary!(log_add_exp   , log_add_exp_f     , TensorLogAddExpAPI        , DeviceLogAddExpAPI        ,);
248    func_binary!(pow           , pow_f             , TensorPowAPI              , DevicePowAPI              ,);
249    func_binary!(maximum       , maximum_f         , TensorMaximumAPI          , DeviceMaximumAPI          , max, max_f);
250    func_binary!(minimum       , minimum_f         , TensorMinimumAPI          , DeviceMinimumAPI          , min, min_f);
251    func_binary!(equal         , equal_f           , TensorEqualAPI            , DeviceEqualAPI            , eq, eq_f, equal_than      , equal_than_f      );
252    func_binary!(less          , less_f            , TensorLessAPI             , DeviceLessAPI             , lt, lt_f, less_than       , less_than_f       );
253    func_binary!(greater       , greater_f         , TensorGreaterAPI          , DeviceGreaterAPI          , gt, gt_f, greater_than    , greater_than_f    );
254    func_binary!(less_equal    , less_equal_f      , TensorLessEqualAPI        , DeviceLessEqualAPI        , le, le_f, less_equal_to   , less_equal_to_f   );
255    func_binary!(greater_equal , greater_equal_f   , TensorGreaterEqualAPI     , DeviceGreaterEqualAPI     , ge, ge_f, greater_equal_to, greater_equal_to_f);
256    func_binary!(not_equal     , not_equal_f       , TensorNotEqualAPI         , DeviceNotEqualAPI         , ne, ne_f, not_equal_to    , not_equal_to_f    );
257    func_binary!(nextafter     , nextafter_f       , TensorNextAfterAPI        , DeviceNextAfterAPI        ,);
258}
259
260pub use func_binary::*;
261
262/* #endregion */
263
264#[cfg(test)]
265mod test {
266    use super::*;
267
268    #[test]
269    fn test_pow() {
270        #[cfg(not(feature = "col_major"))]
271        {
272            let a = arange(6u32).into_shape([2, 3]);
273            let b = arange(3u32);
274            let c = pow(&a, &b);
275            println!("{c:?}");
276            assert_eq!(c.reshape([6]).to_vec(), vec![1, 1, 4, 1, 4, 25]);
277
278            let a = arange(6.0).into_shape([2, 3]);
279
280            let b = arange(3.0);
281            let c = pow(&a, &b);
282            println!("{c:?}");
283            assert_eq!(c.reshape([6]).to_vec(), vec![1.0, 1.0, 4.0, 1.0, 4.0, 25.0]);
284
285            let b = arange(3);
286            let c = pow(&a, &b);
287            println!("{c:?}");
288            assert_eq!(c.reshape([6]).to_vec(), vec![1.0, 1.0, 4.0, 1.0, 4.0, 25.0]);
289        }
290        #[cfg(feature = "col_major")]
291        {
292            let a = arange(6u32).into_shape([3, 2]);
293            let b = arange(3u32);
294            let c = pow(&a, &b);
295            println!("{c:?}");
296            assert_eq!(c.reshape([6]).to_vec(), vec![1, 1, 4, 1, 4, 25]);
297
298            let a = arange(6.0).into_shape([3, 2]);
299
300            let b = arange(3.0);
301            let c = pow(&a, &b);
302            println!("{c:?}");
303            assert_eq!(c.reshape([6]).to_vec(), vec![1.0, 1.0, 4.0, 1.0, 4.0, 25.0]);
304
305            let b = arange(3);
306            let c = pow(&a, &b);
307            println!("{c:?}");
308            assert_eq!(c.reshape([6]).to_vec(), vec![1.0, 1.0, 4.0, 1.0, 4.0, 25.0]);
309        }
310    }
311
312    #[test]
313    fn test_floor_divide() {
314        #[cfg(not(feature = "col_major"))]
315        {
316            let a = arange(6u32).into_shape([2, 3]); // u32
317            let b = asarray(vec![1_i32, 2, 2]); // i32
318            let c = a.floor_divide(&b); // i64
319            println!("{c:?}");
320            assert_eq!(c.reshape([6]).to_vec(), vec![0_i64, 0, 1, 3, 2, 2]);
321
322            let a = arange(6.0).into_shape([2, 3]);
323
324            let b = asarray(vec![1.0, 2.0, 2.0]);
325            let c = a.floor_divide(&b);
326            println!("{c:?}");
327            assert_eq!(c.reshape([6]).to_vec(), vec![0.0, 0.0, 1.0, 3.0, 2.0, 2.0]);
328
329            let b = asarray(vec![0.0, 2.0, 2.0]);
330            let c = a.floor_divide_f(&b);
331            println!("{c:?}");
332        }
333        #[cfg(feature = "col_major")]
334        {
335            // [3, 2] + [3]
336            let a = arange(6u32).into_shape([3, 2]); // u32
337            let b = asarray(vec![1_i32, 2, 2]); // i32
338            let c = a.floor_divide(&b); // i64
339            println!("{c:?}");
340            assert_eq!(c.reshape([6]).to_vec(), vec![0_i64, 0, 1, 3, 2, 2]);
341
342            let a = arange(6.0).into_shape([3, 2]);
343
344            let b = asarray(vec![1.0, 2.0, 2.0]);
345            let c = a.floor_divide(&b);
346            println!("{c:?}");
347            assert_eq!(c.reshape([6]).to_vec(), vec![0.0, 0.0, 1.0, 3.0, 2.0, 2.0]);
348
349            let b = asarray(vec![0.0, 2.0, 2.0]);
350            let c = a.floor_divide_f(&b);
351            println!("{c:?}");
352        }
353    }
354
355    #[test]
356    fn test_ge_gt() {
357        let a = asarray(vec![1., 2., 3., 4., 5., 6.]);
358        let b = asarray(vec![1., 3., 2., 5., 5., 2.]);
359
360        let c = gt(a.view(), &b);
361        assert_eq!(c.raw(), &[false, false, true, false, false, true]);
362        let c = ge(a.view(), &b);
363        assert_eq!(c.raw(), &[true, false, true, false, true, true]);
364
365        let c_sum = c.sum();
366        assert_eq!(c_sum, 4);
367    }
368
369    #[test]
370    fn test_refa_numb() {
371        let a = asarray(vec![1., 3., 2., 5., 5., 2.]);
372        let b = a.greater_equal(3.0);
373        assert_eq!(b.raw(), &[false, true, false, true, true, false]);
374        let b = a.pow(2);
375        assert_eq!(b.raw(), &[1.0, 9.0, 4.0, 25.0, 25.0, 4.0]);
376        let b = 2.0.pow(a.view());
377        assert_eq!(b.raw(), &[2.0, 8.0, 4.0, 32.0, 32.0, 4.0]);
378    }
379}