rstsr_core/tensor/operators/
op_binary_arithmetic.rs

1use crate::prelude_dev::*;
2use core::mem::transmute;
3
4/* #region binary operation function and traits */
5
6#[duplicate_item(
7    op       op_f       TensorOpAPI     ;
8   [add   ] [add_f   ] [TensorAddAPI   ];
9   [sub   ] [sub_f   ] [TensorSubAPI   ];
10   [mul   ] [mul_f   ] [TensorMulAPI   ];
11   [div   ] [div_f   ] [TensorDivAPI   ];
12   [rem   ] [rem_f   ] [TensorRemAPI   ];
13   [bitor ] [bitor_f ] [TensorBitOrAPI ];
14   [bitand] [bitand_f] [TensorBitAndAPI];
15   [bitxor] [bitxor_f] [TensorBitXorAPI];
16   [shl   ] [shl_f   ] [TensorShlAPI   ];
17   [shr   ] [shr_f   ] [TensorShrAPI   ];
18)]
19pub trait TensorOpAPI<TrB> {
20    type Output;
21    fn op_f(a: Self, b: TrB) -> Result<Self::Output>;
22    fn op(a: Self, b: TrB) -> Self::Output
23    where
24        Self: Sized,
25    {
26        Self::op_f(a, b).rstsr_unwrap()
27    }
28}
29
30#[duplicate_item(
31    op       op_f       TensorOpAPI     ;
32   [add   ] [add_f   ] [TensorAddAPI   ];
33   [sub   ] [sub_f   ] [TensorSubAPI   ];
34   [mul   ] [mul_f   ] [TensorMulAPI   ];
35   [div   ] [div_f   ] [TensorDivAPI   ];
36   [rem   ] [rem_f   ] [TensorRemAPI   ];
37   [bitor ] [bitor_f ] [TensorBitOrAPI ];
38   [bitand] [bitand_f] [TensorBitAndAPI];
39   [bitxor] [bitxor_f] [TensorBitXorAPI];
40   [shl   ] [shl_f   ] [TensorShlAPI   ];
41   [shr   ] [shr_f   ] [TensorShrAPI   ];
42)]
43pub fn op_f<TrA, TrB>(a: TrA, b: TrB) -> Result<TrA::Output>
44where
45    TrA: TensorOpAPI<TrB>,
46{
47    TrA::op_f(a, b)
48}
49
50#[duplicate_item(
51    op       op_f       TensorOpAPI     ;
52   [add   ] [add_f   ] [TensorAddAPI   ];
53   [sub   ] [sub_f   ] [TensorSubAPI   ];
54   [mul   ] [mul_f   ] [TensorMulAPI   ];
55   [div   ] [div_f   ] [TensorDivAPI   ];
56   [rem   ] [rem_f   ] [TensorRemAPI   ];
57   [bitor ] [bitor_f ] [TensorBitOrAPI ];
58   [bitand] [bitand_f] [TensorBitAndAPI];
59   [bitxor] [bitxor_f] [TensorBitXorAPI];
60   [shl   ] [shl_f   ] [TensorShlAPI   ];
61   [shr   ] [shr_f   ] [TensorShrAPI   ];
62)]
63pub fn op<TrA, TrB>(a: TrA, b: TrB) -> TrA::Output
64where
65    TrA: TensorOpAPI<TrB>,
66{
67    TrA::op(a, b)
68}
69
70#[duplicate_item(
71    op       op_f       TensorOpAPI     ;
72   [add   ] [add_f   ] [TensorAddAPI   ];
73   [sub   ] [sub_f   ] [TensorSubAPI   ];
74   [mul   ] [mul_f   ] [TensorMulAPI   ];
75   [div   ] [div_f   ] [TensorDivAPI   ];
76   [rem   ] [rem_f   ] [TensorRemAPI   ];
77   [bitor ] [bitor_f ] [TensorBitOrAPI ];
78   [bitand] [bitand_f] [TensorBitAndAPI];
79   [bitxor] [bitxor_f] [TensorBitXorAPI];
80   [shl   ] [shl_f   ] [TensorShlAPI   ];
81   [shr   ] [shr_f   ] [TensorShrAPI   ];
82)]
83impl<S, D> TensorBase<S, D>
84where
85    D: DimAPI,
86{
87    pub fn op_f<TrB>(&self, b: TrB) -> Result<<&Self as TensorOpAPI<TrB>>::Output>
88    where
89        for<'a> &'a Self: TensorOpAPI<TrB>,
90    {
91        <&Self as TensorOpAPI<TrB>>::op_f(self, b)
92    }
93
94    pub fn op<TrB>(&self, b: TrB) -> <&Self as TensorOpAPI<TrB>>::Output
95    where
96        for<'a> &'a Self: TensorOpAPI<TrB>,
97    {
98        <&Self as TensorOpAPI<TrB>>::op(self, b)
99    }
100}
101
102/* #endregion */
103
104/* #region binary core ops implementation */
105
106#[duplicate_item(
107    op       DeviceOpAPI       TensorOpAPI       Op     ;
108   [add   ] [DeviceAddAPI   ] [TensorAddAPI   ] [Add   ];
109   [sub   ] [DeviceSubAPI   ] [TensorSubAPI   ] [Sub   ];
110   [mul   ] [DeviceMulAPI   ] [TensorMulAPI   ] [Mul   ];
111   [div   ] [DeviceDivAPI   ] [TensorDivAPI   ] [Div   ];
112// [rem   ] [DeviceRemAPI   ] [TensorRemAPI   ] [Rem   ];
113   [bitor ] [DeviceBitOrAPI ] [TensorBitOrAPI ] [BitOr ];
114   [bitand] [DeviceBitAndAPI] [TensorBitAndAPI] [BitAnd];
115   [bitxor] [DeviceBitXorAPI] [TensorBitXorAPI] [BitXor];
116   [shl   ] [DeviceShlAPI   ] [TensorShlAPI   ] [Shl   ];
117   [shr   ] [DeviceShrAPI   ] [TensorShrAPI   ] [Shr   ];
118)]
119mod impl_core_ops {
120    use super::*;
121
122    impl<SA, DA, TrB> Op<TrB> for &TensorBase<SA, DA>
123    where
124        DA: DimAPI,
125        Self: TensorOpAPI<TrB>,
126    {
127        type Output = <Self as TensorOpAPI<TrB>>::Output;
128        fn op(self, b: TrB) -> Self::Output {
129            TensorOpAPI::op(self, b)
130        }
131    }
132
133    #[duplicate_item(
134        TrA; [TensorView<'_, TA, B, DA>]; [Tensor<TA, B, DA>]; [TensorCow<'_, TA, B, DA>];
135    )]
136    impl<TA, DA, B, TrB> Op<TrB> for TrA
137    where
138        DA: DimAPI,
139        B: DeviceAPI<TA>,
140        Self: TensorOpAPI<TrB>,
141    {
142        type Output = <Self as TensorOpAPI<TrB>>::Output;
143        fn op(self, b: TrB) -> Self::Output {
144            TensorOpAPI::op(self, b)
145        }
146    }
147}
148
149/* #endregion */
150
151/* #region binary implementation */
152
153#[duplicate_item(
154    op_f       DeviceOpAPI       TensorOpAPI       Op     ;
155   [add_f   ] [DeviceAddAPI   ] [TensorAddAPI   ] [Add   ];
156   [sub_f   ] [DeviceSubAPI   ] [TensorSubAPI   ] [Sub   ];
157   [mul_f   ] [DeviceMulAPI   ] [TensorMulAPI   ] [Mul   ];
158   [div_f   ] [DeviceDivAPI   ] [TensorDivAPI   ] [Div   ];
159   [rem_f   ] [DeviceRemAPI   ] [TensorRemAPI   ] [Rem   ];
160   [bitor_f ] [DeviceBitOrAPI ] [TensorBitOrAPI ] [BitOr ];
161   [bitand_f] [DeviceBitAndAPI] [TensorBitAndAPI] [BitAnd];
162   [bitxor_f] [DeviceBitXorAPI] [TensorBitXorAPI] [BitXor];
163   [shl_f   ] [DeviceShlAPI   ] [TensorShlAPI   ] [Shl   ];
164   [shr_f   ] [DeviceShrAPI   ] [TensorShrAPI   ] [Shr   ];
165)]
166mod impl_binary_arithmetic_ref {
167    use super::*;
168
169    #[doc(hidden)]
170    impl<RA, RB, TA, TB, TC, DA, DB, DC, B> TensorOpAPI<&TensorAny<RB, TB, B, DB>> for &TensorAny<RA, TA, B, DA>
171    where
172        // tensor types
173        RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
174        RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
175        // data constraints
176        DA: DimAPI,
177        DB: DimAPI,
178        DC: DimAPI,
179        B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC> + DeviceRawAPI<MaybeUninit<TC>>,
180        B: DeviceCreationAnyAPI<TC>,
181        // broadcast constraints
182        DA: DimMaxAPI<DB, Max = DC>,
183        // operation constraints
184        TA: Op<TB, Output = TC>,
185        B: DeviceOpAPI<TA, TB, TC, DC>,
186    {
187        type Output = Tensor<TC, B, DC>;
188        fn op_f(a: Self, b: &TensorAny<RB, TB, B, DB>) -> Result<Self::Output> {
189            // get tensor views
190            let a = a.view();
191            let b = b.view();
192            // check device and layout
193            rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
194            let la = a.layout();
195            let lb = b.layout();
196            let default_order = a.device().default_order();
197            let (la_b, lb_b) = broadcast_layout(la, lb, default_order)?;
198            // generate output layout
199            let lc_from_a = layout_for_array_copy(&la_b, TensorIterOrder::default())?;
200            let lc_from_b = layout_for_array_copy(&lb_b, TensorIterOrder::default())?;
201            let lc = if lc_from_a == lc_from_b {
202                lc_from_a
203            } else {
204                match a.device().default_order() {
205                    RowMajor => la_b.shape().c(),
206                    ColMajor => la_b.shape().f(),
207                }
208            };
209            // generate empty c
210            let device = a.device();
211            let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
212            // add provided by device
213            device.op_mutc_refa_refb(storage_c.raw_mut(), &lc, a.raw(), &la_b, b.raw(), &lb_b)?;
214            // return tensor
215            let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
216            Tensor::new_f(storage_c, lc)
217        }
218    }
219
220    #[doc(hidden)]
221    #[duplicate_item(
222        RType                                             TrA                         TrB                         a_inner   b_inner ;
223       [R: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>] [&TensorAny<R, TA, B, DA> ] [TensorView<'_, TB, B, DB>] [ a     ] [&b     ];
224       [R: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>] [TensorView<'_, TA, B, DA>] [&TensorAny<R, TB, B, DB> ] [&a     ] [ b     ];
225       [                                               ] [TensorView<'_, TA, B, DA>] [TensorView<'_, TB, B, DB>] [&a     ] [&b     ];
226    )]
227    impl<TA, TB, TC, DA, DB, DC, B, RType> TensorOpAPI<TrB> for TrA
228    where
229        // data constraints
230        DA: DimAPI,
231        DB: DimAPI,
232        DC: DimAPI,
233        B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
234        B: DeviceCreationAnyAPI<TC>,
235        // broadcast constraints
236        DA: DimMaxAPI<DB, Max = DC>,
237        // operation constraints
238        TA: Op<TB, Output = TC>,
239        B: DeviceOpAPI<TA, TB, TC, DC>,
240    {
241        type Output = Tensor<TC, B, DC>;
242        fn op_f(a: Self, b: TrB) -> Result<Self::Output> {
243            TensorOpAPI::op_f(a_inner, b_inner)
244        }
245    }
246}
247
248#[duplicate_item(
249    op_f       DeviceOpAPI      TensorOpAPI        Op       DeviceLConsumeAPI         DeviceRConsumeAPI       ;
250   [add_f   ] [DeviceAddAPI   ] [TensorAddAPI   ] [Add   ] [DeviceLConsumeAddAPI   ] [DeviceRConsumeAddAPI   ];
251   [sub_f   ] [DeviceSubAPI   ] [TensorSubAPI   ] [Sub   ] [DeviceLConsumeSubAPI   ] [DeviceRConsumeSubAPI   ];
252   [mul_f   ] [DeviceMulAPI   ] [TensorMulAPI   ] [Mul   ] [DeviceLConsumeMulAPI   ] [DeviceRConsumeMulAPI   ];
253   [div_f   ] [DeviceDivAPI   ] [TensorDivAPI   ] [Div   ] [DeviceLConsumeDivAPI   ] [DeviceRConsumeDivAPI   ];
254   [rem_f   ] [DeviceRemAPI   ] [TensorRemAPI   ] [Rem   ] [DeviceLConsumeRemAPI   ] [DeviceRConsumeRemAPI   ];
255   [bitor_f ] [DeviceBitOrAPI ] [TensorBitOrAPI ] [BitOr ] [DeviceLConsumeBitOrAPI ] [DeviceRConsumeBitOrAPI ];
256   [bitand_f] [DeviceBitAndAPI] [TensorBitAndAPI] [BitAnd] [DeviceLConsumeBitAndAPI] [DeviceRConsumeBitAndAPI];
257   [bitxor_f] [DeviceBitXorAPI] [TensorBitXorAPI] [BitXor] [DeviceLConsumeBitXorAPI] [DeviceRConsumeBitXorAPI];
258   [shl_f   ] [DeviceShlAPI   ] [TensorShlAPI   ] [Shl   ] [DeviceLConsumeShlAPI   ] [DeviceRConsumeShlAPI   ];
259   [shr_f   ] [DeviceShrAPI   ] [TensorShrAPI   ] [Shr   ] [DeviceLConsumeShrAPI   ] [DeviceRConsumeShrAPI   ];
260)]
261mod impl_binary_lr_consume {
262    use super::*;
263
264    #[doc(hidden)]
265    impl<RB, TA, TB, DA, DB, DC, B> TensorOpAPI<&TensorAny<RB, TB, B, DB>> for Tensor<TA, B, DA>
266    where
267        // tensor types
268        RB: DataAPI<Data = <B as DeviceRawAPI<TB>>::Raw>,
269        // data constraints
270        DA: DimAPI,
271        DB: DimAPI,
272        DC: DimAPI,
273        B: DeviceAPI<TA> + DeviceAPI<TB>,
274        B: DeviceCreationAnyAPI<TA>,
275        // broadcast constraints
276        DA: DimMaxAPI<DB, Max = DC>,
277        DC: DimIntoAPI<DA>,
278        DA: DimIntoAPI<DC>,
279        // operation constraints
280        TA: Op<TB, Output = TA>,
281        B: DeviceOpAPI<TA, TB, TA, DC>,
282        B: DeviceLConsumeAPI<TA, TB, DA>,
283    {
284        type Output = Tensor<TA, B, DC>;
285        fn op_f(a: Self, b: &TensorAny<RB, TB, B, DB>) -> Result<Self::Output> {
286            rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
287            let device = a.device().clone();
288            let la = a.layout();
289            let lb = b.layout();
290            let default_order = a.device().default_order();
291            let broadcast_result = broadcast_layout_to_first(la, lb, default_order);
292            if a.layout().is_broadcasted() || broadcast_result.is_err() {
293                // not broadcastable for output a
294                TensorOpAPI::op_f(&a, b)
295            } else {
296                // check broadcast layouts
297                let (la_b, lb_b) = broadcast_result?;
298                if la_b != *la {
299                    // output shape of c is not the same to input owned a
300                    TensorOpAPI::op_f(&a, b)
301                } else {
302                    // reuse a as c
303                    let (mut storage_a, _) = a.into_raw_parts();
304                    device.op_muta_refb(storage_a.raw_mut(), &la_b, b.raw(), &lb_b)?;
305                    let c = unsafe { Tensor::new_unchecked(storage_a, la_b) };
306                    c.into_dim_f::<DC>()
307                }
308            }
309        }
310    }
311
312    #[doc(hidden)]
313    impl<RA, TA, TB, DA, DB, DC, B> TensorOpAPI<Tensor<TB, B, DB>> for &TensorAny<RA, TA, B, DA>
314    where
315        // tensor
316        // types
317        RA: DataAPI<Data = <B as DeviceRawAPI<TA>>::Raw>,
318        // data constraints
319        DA: DimAPI,
320        DB: DimAPI,
321        DC: DimAPI,
322        B: DeviceAPI<TA> + DeviceAPI<TB>,
323        B: DeviceCreationAnyAPI<TB>,
324        // broadcast constraints
325        DA: DimMaxAPI<DB, Max = DC>,
326        DB: DimMaxAPI<DA, Max = DC>,
327        DC: DimIntoAPI<DB>,
328        DB: DimIntoAPI<DC>,
329        // operation constraints
330        TA: Op<TB, Output = TB>,
331        B: DeviceOpAPI<TA, TB, TB, DC>,
332        B: DeviceRConsumeAPI<TA, TB, DB>,
333    {
334        type Output = Tensor<TB, B, DC>;
335        fn op_f(a: Self, b: Tensor<TB, B, DB>) -> Result<Self::Output> {
336            rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
337            let device = b.device().clone();
338            let la = a.layout();
339            let lb = b.layout();
340            let default_order = b.device().default_order();
341            let broadcast_result = broadcast_layout_to_first(lb, la, default_order);
342            if b.layout().is_broadcasted() || broadcast_result.is_err() {
343                // not broadcastable for output a
344                TensorOpAPI::op_f(a, &b)
345            } else {
346                // check broadcast layouts
347                let (lb_b, la_b) = broadcast_result?;
348                if lb_b != *lb {
349                    // output shape of c is not the same to input owned b
350                    TensorOpAPI::op_f(a, &b)
351                } else {
352                    // reuse b as c
353                    let (mut storage_b, _) = b.into_raw_parts();
354                    device.op_muta_refb(storage_b.raw_mut(), &lb_b, a.raw(), &la_b)?;
355                    let c = unsafe { Tensor::new_unchecked(storage_b, lb_b) };
356                    c.into_dim_f::<DC>()
357                }
358            }
359        }
360    }
361
362    #[doc(hidden)]
363    impl<'b, TA, TB, DA, DB, DC, B> TensorOpAPI<TensorView<'b, TB, B, DB>> for Tensor<TA, B, DA>
364    where
365        // data constraints
366        DA: DimAPI,
367        DB: DimAPI,
368        DC: DimAPI,
369        B: DeviceAPI<TA> + DeviceAPI<TB>,
370        B: DeviceCreationAnyAPI<TA>,
371        // broadcast constraints
372        DA: DimMaxAPI<DB, Max = DC>,
373        DC: DimIntoAPI<DA>,
374        DA: DimIntoAPI<DC>,
375        // operation constraints
376        TA: Op<TB, Output = TA>,
377        B: DeviceOpAPI<TA, TB, TA, DC>,
378        B: DeviceLConsumeAPI<TA, TB, DA>,
379    {
380        type Output = Tensor<TA, B, DC>;
381        fn op_f(a: Self, b: TensorView<'b, TB, B, DB>) -> Result<Self::Output> {
382            TensorOpAPI::op_f(a, &b)
383        }
384    }
385
386    #[doc(hidden)]
387    impl<TA, TB, DA, DB, DC, B> TensorOpAPI<Tensor<TB, B, DB>> for TensorView<'_, TA, B, DA>
388    where
389        // data constraints
390        DA: DimAPI,
391        DB: DimAPI,
392        DC: DimAPI,
393        B: DeviceAPI<TA> + DeviceAPI<TB>,
394        B: DeviceCreationAnyAPI<TB>,
395        // broadcast constraints
396        DA: DimMaxAPI<DB, Max = DC>,
397        DB: DimMaxAPI<DA, Max = DC>,
398        DC: DimIntoAPI<DB>,
399        DB: DimIntoAPI<DC>,
400        // operation constraints
401        TA: Op<TB, Output = TB>,
402        B: DeviceOpAPI<TA, TB, TB, DC>,
403        B: DeviceRConsumeAPI<TA, TB, DB>,
404    {
405        type Output = Tensor<TB, B, DC>;
406        fn op_f(a: Self, b: Tensor<TB, B, DB>) -> Result<Self::Output> {
407            TensorOpAPI::op_f(&a, b)
408        }
409    }
410
411    #[doc(hidden)]
412    impl<T, DA, DB, DC, B> TensorOpAPI<Tensor<T, B, DB>> for Tensor<T, B, DA>
413    where
414        // data constraints
415        DA: DimAPI,
416        DB: DimAPI,
417        DC: DimAPI,
418        B: DeviceAPI<T>,
419        B: DeviceCreationAnyAPI<T>,
420        // broadcast constraints
421        DA: DimMaxAPI<DB, Max = DC> + DimIntoAPI<DC>,
422        DB: DimMaxAPI<DA, Max = DC> + DimIntoAPI<DC>,
423        DC: DimIntoAPI<DA> + DimIntoAPI<DB>,
424        // operation constraints
425        T: Op<T, Output = T>,
426        B: DeviceOpAPI<T, T, T, DC>,
427        B: DeviceLConsumeAPI<T, T, DA>,
428        B: DeviceRConsumeAPI<T, T, DB>,
429    {
430        type Output = Tensor<T, B, DC>;
431        fn op_f(a: Self, b: Tensor<T, B, DB>) -> Result<Self::Output> {
432            rstsr_assert!(a.device().same_device(b.device()), DeviceMismatch)?;
433            let la = a.layout();
434            let lb = b.layout();
435            let default_order = a.device().default_order();
436            let broadcast_result = broadcast_layout_to_first(la, lb, default_order);
437            if !a.layout().is_broadcasted() && broadcast_result.is_ok() {
438                let (la_b, _) = broadcast_result?;
439                if la_b == *la {
440                    return TensorOpAPI::op_f(a, &b);
441                }
442            }
443            let broadcast_result = broadcast_layout_to_first(lb, la, default_order);
444            if !b.layout().is_broadcasted() && broadcast_result.is_ok() {
445                let (lb_b, _) = broadcast_result?;
446                if lb_b == *lb {
447                    return TensorOpAPI::op_f(&a, b);
448                }
449            }
450            return TensorOpAPI::op_f(&a, &b);
451        }
452    }
453
454    // For TensorCow, currently use the most strict implementation, that requires
455    // all types involved to be the same.
456
457    #[doc(hidden)]
458    #[duplicate_item(
459        RType                                            TrB                      ;
460       [R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>] [&TensorAny<R, T, B, DB> ];
461       [                                              ] [TensorView<'_, T, B, DB>];
462       [                                              ] [Tensor<T, B, DB>        ];
463    )]
464    impl<T, DA, DB, DC, B, RType> TensorOpAPI<TrB> for TensorCow<'_, T, B, DA>
465    where
466        // data constraints
467        DA: DimAPI,
468        DB: DimAPI,
469        DC: DimAPI,
470        B: DeviceAPI<T>,
471        B: DeviceCreationAnyAPI<T>,
472        // broadcast constraints
473        DA: DimMaxAPI<DB, Max = DC> + DimIntoAPI<DC>,
474        DB: DimMaxAPI<DA, Max = DC> + DimIntoAPI<DC>,
475        DC: DimIntoAPI<DA> + DimIntoAPI<DB>,
476        // operation constraints
477        T: Op<T, Output = T>,
478        B: DeviceOpAPI<T, T, T, DC>,
479        B: DeviceLConsumeAPI<T, T, DA>,
480        B: DeviceRConsumeAPI<T, T, DB>,
481        // cow constraints
482        T: Clone,
483        <B as DeviceRawAPI<T>>::Raw: Clone,
484        B: OpAssignAPI<T, DA>,
485    {
486        type Output = Tensor<T, B, DC>;
487        fn op_f(a: Self, b: TrB) -> Result<Self::Output> {
488            match a.is_owned() {
489                true => TensorOpAPI::op_f(a.into_owned(), b),
490                false => TensorOpAPI::op_f(a.view(), b),
491            }
492        }
493    }
494
495    #[doc(hidden)]
496    #[duplicate_item(
497        RType                                            TrA                      ;
498       [R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>] [&TensorAny<R, T, B, DA> ];
499       [                                              ] [TensorView<'_, T, B, DA>];
500       [                                              ] [Tensor<T, B, DA>        ];
501    )]
502    impl<T, DA, DB, DC, B, RType> TensorOpAPI<TensorCow<'_, T, B, DB>> for TrA
503    where
504        // data constraints
505        DA: DimAPI,
506        DB: DimAPI,
507        DC: DimAPI,
508        B: DeviceAPI<T>,
509        B: DeviceCreationAnyAPI<T>,
510        // broadcast constraints
511        DA: DimMaxAPI<DB, Max = DC> + DimIntoAPI<DC>,
512        DB: DimMaxAPI<DA, Max = DC> + DimIntoAPI<DC>,
513        DC: DimIntoAPI<DA> + DimIntoAPI<DB>,
514        // operation constraints
515        T: Op<T, Output = T>,
516        B: DeviceOpAPI<T, T, T, DC>,
517        B: DeviceLConsumeAPI<T, T, DA>,
518        B: DeviceRConsumeAPI<T, T, DB>,
519        // cow constraints
520        T: Clone,
521        <B as DeviceRawAPI<T>>::Raw: Clone,
522        B: OpAssignAPI<T, DB>,
523    {
524        type Output = Tensor<T, B, DC>;
525        fn op_f(a: Self, b: TensorCow<'_, T, B, DB>) -> Result<Self::Output> {
526            match b.is_owned() {
527                true => TensorOpAPI::op_f(a, b.into_owned()),
528                false => TensorOpAPI::op_f(a, b.view()),
529            }
530        }
531    }
532
533    impl<T, DA, DB, DC, B> TensorOpAPI<TensorCow<'_, T, B, DB>> for TensorCow<'_, T, B, DA>
534    where
535        // data constraints
536        DA: DimAPI,
537        DB: DimAPI,
538        DC: DimAPI,
539        B: DeviceAPI<T>,
540        B: DeviceCreationAnyAPI<T>,
541        // broadcast constraints
542        DA: DimMaxAPI<DB, Max = DC> + DimIntoAPI<DC>,
543        DB: DimMaxAPI<DA, Max = DC> + DimIntoAPI<DC>,
544        DC: DimIntoAPI<DA> + DimIntoAPI<DB>,
545        // operation constraints
546        T: Op<T, Output = T>,
547        B: DeviceOpAPI<T, T, T, DC>,
548        B: DeviceLConsumeAPI<T, T, DA>,
549        B: DeviceRConsumeAPI<T, T, DB>,
550        // cow constraints
551        T: Clone,
552        <B as DeviceRawAPI<T>>::Raw: Clone,
553        B: OpAssignAPI<T, DA> + OpAssignAPI<T, DB>,
554    {
555        type Output = Tensor<T, B, DC>;
556        fn op_f(a: Self, b: TensorCow<'_, T, B, DB>) -> Result<Self::Output> {
557            match (a.is_owned(), b.is_owned()) {
558                (true, true) => TensorOpAPI::op_f(a.into_owned(), b.into_owned()),
559                (true, false) => TensorOpAPI::op_f(a.into_owned(), b.view()),
560                (false, true) => TensorOpAPI::op_f(a.view(), b.into_owned()),
561                (false, false) => TensorOpAPI::op_f(a.view(), b.view()),
562            }
563        }
564    }
565}
566
567/* #endregion */
568
569/* #region binary with output implementation */
570
571#[duplicate_item(
572    op                   op_f                   DeviceOpAPI       Op     ;
573   [add_with_output   ] [add_with_output_f   ] [DeviceAddAPI   ] [Add   ];
574   [sub_with_output   ] [sub_with_output_f   ] [DeviceSubAPI   ] [Sub   ];
575   [mul_with_output   ] [mul_with_output_f   ] [DeviceMulAPI   ] [Mul   ];
576   [div_with_output   ] [div_with_output_f   ] [DeviceDivAPI   ] [Div   ];
577   [rem_with_output   ] [rem_with_output_f   ] [DeviceRemAPI   ] [Rem   ];
578   [bitor_with_output ] [bitor_with_output_f ] [DeviceBitOrAPI ] [BitOr ];
579   [bitand_with_output] [bitand_with_output_f] [DeviceBitAndAPI] [BitAnd];
580   [bitxor_with_output] [bitxor_with_output_f] [DeviceBitXorAPI] [BitXor];
581   [shl_with_output   ] [shl_with_output_f   ] [DeviceShlAPI   ] [Shl   ];
582   [shr_with_output   ] [shr_with_output_f   ] [DeviceShrAPI   ] [Shr   ];
583)]
584pub fn op_f<TrA, TrB, TrC, TA, TB, TC, DA, DB, DC, B>(a: TrA, b: TrB, mut c: TrC) -> Result<()>
585where
586    // tensor types
587    TrA: TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
588    TrB: TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
589    TrC: TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
590    // data constraints
591    DA: DimAPI,
592    DB: DimAPI,
593    DC: DimAPI,
594    B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
595    // broadcast constraints
596    DC: DimMaxAPI<DA, Max = DC> + DimMaxAPI<DB, Max = DC>,
597    // operation constraints
598    TA: Op<TB, Output = TC>,
599    B: DeviceOpAPI<TA, TB, TC, DC>,
600{
601    // get tensor views
602    let a = a.view();
603    let b = b.view();
604    let mut c = c.view_mut();
605    // check device
606    rstsr_assert!(c.device().same_device(a.device()), DeviceMismatch)?;
607    rstsr_assert!(c.device().same_device(b.device()), DeviceMismatch)?;
608    let lc = c.layout();
609    let la = a.layout();
610    let lb = b.layout();
611    let default_order = c.device().default_order();
612    // all layouts should be broadcastable to lc
613    // we can first generate broadcasted shape, then check this
614    let (lc_b, la_b) = broadcast_layout_to_first(lc, la, default_order)?;
615    rstsr_assert_eq!(lc_b, *lc, InvalidLayout)?;
616    let (lc_b, lb_b) = broadcast_layout_to_first(lc, lb, default_order)?;
617    rstsr_assert_eq!(lc_b, *lc, InvalidLayout)?;
618    // op provided by device
619    let device = c.device().clone();
620    // REVIEWME: transmute &Raw<T> to &MaybeUninit<Raw<T>>
621    let c_raw_mut = unsafe {
622        transmute::<&mut <B as DeviceRawAPI<TC>>::Raw, &mut <B as DeviceRawAPI<MaybeUninit<TC>>>::Raw>(c.raw_mut())
623    };
624    device.op_mutc_refa_refb(c_raw_mut, &lc_b, a.raw(), &la_b, b.raw(), &lb_b)
625}
626
627#[duplicate_item(
628        op                   op_f                   DeviceOpAPI       Op     ;
629       [add_with_output   ] [add_with_output_f   ] [DeviceAddAPI   ] [Add   ];
630       [sub_with_output   ] [sub_with_output_f   ] [DeviceSubAPI   ] [Sub   ];
631       [mul_with_output   ] [mul_with_output_f   ] [DeviceMulAPI   ] [Mul   ];
632       [div_with_output   ] [div_with_output_f   ] [DeviceDivAPI   ] [Div   ];
633       [rem_with_output   ] [rem_with_output_f   ] [DeviceRemAPI   ] [Rem   ];
634       [bitor_with_output ] [bitor_with_output_f ] [DeviceBitOrAPI ] [BitOr ];
635       [bitand_with_output] [bitand_with_output_f] [DeviceBitAndAPI] [BitAnd];
636       [bitxor_with_output] [bitxor_with_output_f] [DeviceBitXorAPI] [BitXor];
637       [shl_with_output   ] [shl_with_output_f   ] [DeviceShlAPI   ] [Shl   ];
638       [shr_with_output   ] [shr_with_output_f   ] [DeviceShrAPI   ] [Shr   ];
639    )]
640pub fn op<TrA, TrB, TrC, TA, TB, TC, DA, DB, DC, B>(a: TrA, b: TrB, c: TrC)
641where
642    // tensor types
643    TrA: TensorViewAPI<Type = TA, Backend = B, Dim = DA>,
644    TrB: TensorViewAPI<Type = TB, Backend = B, Dim = DB>,
645    TrC: TensorViewMutAPI<Type = TC, Backend = B, Dim = DC>,
646    // data constraints
647    DA: DimAPI,
648    DB: DimAPI,
649    DC: DimAPI,
650    B: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<TC>,
651    // broadcast constraints
652    DC: DimMaxAPI<DA, Max = DC> + DimMaxAPI<DB, Max = DC>,
653    // operation constraints
654    TA: Op<TB, Output = TC>,
655    B: DeviceOpAPI<TA, TB, TC, DC>,
656{
657    op_f(a, b, c).rstsr_unwrap()
658}
659
660/* #endregion */
661
662/* #region binary with scalar, num a op tsr b */
663
664macro_rules! impl_arithmetic_scalar_lhs {
665    ($ty: ty, $op: ident, $op_f: ident, $Op: ident, $DeviceOpAPI: ident, $TensorOpAPI: ident, $DeviceRConsumeOpAPI: ident) => {
666        #[doc(hidden)]
667        impl<T, R, D, B> $TensorOpAPI<&TensorAny<R, T, B, D>> for $ty
668        where
669            T: From<$ty> + $Op<T, Output = T>,
670            R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
671            D: DimAPI,
672            B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
673            B: $DeviceOpAPI<T, T, T, D>,
674        {
675            type Output = Tensor<T, B, D>;
676            fn $op_f(a: Self, b: &TensorAny<R, T, B, D>) -> Result<Self::Output> {
677                let a = T::from(a);
678                let device = b.device();
679                let lb = b.layout();
680                let lc = layout_for_array_copy(lb, TensorIterOrder::default())?;
681                let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
682                device.op_mutc_numa_refb(storage_c.raw_mut(), &lc, a, b.raw(), lb)?;
683                let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
684                Tensor::new_f(storage_c, lc)
685            }
686        }
687
688        #[doc(hidden)]
689        impl<T, R, D, B> $Op<&TensorAny<R, T, B, D>> for $ty
690        where
691            T: From<$ty> + $Op<T, Output = T>,
692            R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
693            D: DimAPI,
694            B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
695            B: $DeviceOpAPI<T, T, T, D>,
696        {
697            type Output = Tensor<T, B, D>;
698            fn $op(self, rhs: &TensorAny<R, T, B, D>) -> Self::Output {
699                $TensorOpAPI::$op_f(self, rhs).rstsr_unwrap()
700            }
701        }
702
703        #[doc(hidden)]
704        impl<T, B, D> $TensorOpAPI<TensorView<'_, T, B, D>> for $ty
705        where
706            T: From<$ty> + $Op<T, Output = T>,
707            D: DimAPI,
708            B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
709            B: $DeviceOpAPI<T, T, T, D>,
710        {
711            type Output = Tensor<T, B, D>;
712            fn $op_f(a: Self, b: TensorView<'_, T, B, D>) -> Result<Self::Output> {
713                let a = T::from(a);
714                let device = b.device();
715                let lb = b.layout();
716                let lc = layout_for_array_copy(lb, TensorIterOrder::default())?;
717                let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
718                device.op_mutc_numa_refb(storage_c.raw_mut(), &lc, a, b.raw(), lb)?;
719                let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
720                Tensor::new_f(storage_c, lc)
721            }
722        }
723
724        #[doc(hidden)]
725        impl<T, B, D> $Op<TensorView<'_, T, B, D>> for $ty
726        where
727            T: From<$ty> + $Op<T, Output = T>,
728            D: DimAPI,
729            B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
730            B: $DeviceOpAPI<T, T, T, D>,
731        {
732            type Output = Tensor<T, B, D>;
733            fn $op(self, rhs: TensorView<'_, T, B, D>) -> Self::Output {
734                $TensorOpAPI::$op_f(self, rhs).rstsr_unwrap()
735            }
736        }
737
738        #[doc(hidden)]
739        impl<T, B, D> $TensorOpAPI<Tensor<T, B, D>> for $ty
740        where
741            T: From<$ty> + $Op<T, Output = T>,
742            D: DimAPI,
743            B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
744            B: $DeviceRConsumeOpAPI<T, T, D>,
745        {
746            type Output = Tensor<T, B, D>;
747            fn $op_f(a: Self, mut b: Tensor<T, B, D>) -> Result<Self::Output> {
748                let a = T::from(a);
749                let device = b.device().clone();
750                let lb = b.layout().clone();
751                device.op_muta_numb(b.raw_mut(), &lb, a)?;
752                return Ok(b);
753            }
754        }
755
756        #[doc(hidden)]
757        impl<T, B, D> $Op<Tensor<T, B, D>> for $ty
758        where
759            T: From<$ty> + $Op<T, Output = T>,
760            D: DimAPI,
761            B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
762            B: $DeviceRConsumeOpAPI<T, T, D>,
763        {
764            type Output = Tensor<T, B, D>;
765            fn $op(self, rhs: Tensor<T, B, D>) -> Self::Output {
766                $TensorOpAPI::$op_f(self, rhs).rstsr_unwrap()
767            }
768        }
769
770        #[doc(hidden)]
771        impl<T, B, D> $TensorOpAPI<TensorCow<'_, T, B, D>> for $ty
772        where
773            T: From<$ty> + $Op<T, Output = T>,
774            D: DimAPI,
775            B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
776            B: $DeviceRConsumeOpAPI<T, T, D> + $DeviceOpAPI<T, T, T, D>,
777            // cow constraints
778            T: Clone,
779            <B as DeviceRawAPI<T>>::Raw: Clone,
780            B: OpAssignAPI<T, D>,
781        {
782            type Output = Tensor<T, B, D>;
783            fn $op_f(a: Self, b: TensorCow<'_, T, B, D>) -> Result<Self::Output> {
784                match b.is_owned() {
785                    true => $TensorOpAPI::$op_f(a, b.into_owned()),
786                    false => $TensorOpAPI::$op_f(a, b.view()),
787                }
788            }
789        }
790
791        #[doc(hidden)]
792        impl<T, B, D> $Op<TensorCow<'_, T, B, D>> for $ty
793        where
794            T: From<$ty> + $Op<T, Output = T>,
795            D: DimAPI,
796            B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
797            B: $DeviceRConsumeOpAPI<T, T, D> + $DeviceOpAPI<T, T, T, D>,
798            // cow constraints
799            T: Clone,
800            <B as DeviceRawAPI<T>>::Raw: Clone,
801            B: OpAssignAPI<T, D>,
802        {
803            type Output = Tensor<T, B, D>;
804            fn $op(self, rhs: TensorCow<'_, T, B, D>) -> Self::Output {
805                $TensorOpAPI::$op_f(self, rhs).rstsr_unwrap()
806            }
807        }
808    };
809}
810
811#[rustfmt::skip]
812macro_rules! impl_arithmetic_scalar_lhs_all {
813    ($ty: ty) => {
814        impl_arithmetic_scalar_lhs!($ty, add   , add_f   , Add   , DeviceAddAPI   , TensorAddAPI   , DeviceRConsumeAddAPI   );
815        impl_arithmetic_scalar_lhs!($ty, sub   , sub_f   , Sub   , DeviceSubAPI   , TensorSubAPI   , DeviceRConsumeSubAPI   );
816        impl_arithmetic_scalar_lhs!($ty, mul   , mul_f   , Mul   , DeviceMulAPI   , TensorMulAPI   , DeviceRConsumeMulAPI   );
817        impl_arithmetic_scalar_lhs!($ty, div   , div_f   , Div   , DeviceDivAPI   , TensorDivAPI   , DeviceRConsumeDivAPI   );
818        impl_arithmetic_scalar_lhs!($ty, rem   , rem_f   , Rem   , DeviceRemAPI   , TensorRemAPI   , DeviceRConsumeRemAPI   );
819        impl_arithmetic_scalar_lhs!($ty, bitor , bitor_f , BitOr , DeviceBitOrAPI , TensorBitOrAPI , DeviceRConsumeBitOrAPI );
820        impl_arithmetic_scalar_lhs!($ty, bitand, bitand_f, BitAnd, DeviceBitAndAPI, TensorBitAndAPI, DeviceRConsumeBitAndAPI);
821        impl_arithmetic_scalar_lhs!($ty, bitxor, bitxor_f, BitXor, DeviceBitXorAPI, TensorBitXorAPI, DeviceRConsumeBitXorAPI);
822        impl_arithmetic_scalar_lhs!($ty, shl   , shl_f   , Shl   , DeviceShlAPI   , TensorShlAPI   , DeviceRConsumeShlAPI   );
823        impl_arithmetic_scalar_lhs!($ty, shr   , shr_f   , Shr   , DeviceShrAPI   , TensorShrAPI   , DeviceRConsumeShrAPI   );
824    };
825}
826
827#[rustfmt::skip]
828macro_rules! impl_arithmetic_scalar_lhs_bool {
829    ($ty: ty) => {
830        impl_arithmetic_scalar_lhs!($ty, bitor , bitor_f , BitOr , DeviceBitOrAPI , TensorBitOrAPI , DeviceRConsumeBitOrAPI );
831        impl_arithmetic_scalar_lhs!($ty, bitand, bitand_f, BitAnd, DeviceBitAndAPI, TensorBitAndAPI, DeviceRConsumeBitAndAPI);
832        impl_arithmetic_scalar_lhs!($ty, bitxor, bitxor_f, BitXor, DeviceBitXorAPI, TensorBitXorAPI, DeviceRConsumeBitXorAPI);
833    };
834}
835
836#[rustfmt::skip]
837macro_rules! impl_arithmetic_scalar_lhs_float {
838    ($ty: ty) => {
839        impl_arithmetic_scalar_lhs!($ty, add , add_f , Add   , DeviceAddAPI   , TensorAddAPI   , DeviceRConsumeAddAPI   );
840        impl_arithmetic_scalar_lhs!($ty, sub , sub_f , Sub   , DeviceSubAPI   , TensorSubAPI   , DeviceRConsumeSubAPI   );
841        impl_arithmetic_scalar_lhs!($ty, mul , mul_f , Mul   , DeviceMulAPI   , TensorMulAPI   , DeviceRConsumeMulAPI   );
842        impl_arithmetic_scalar_lhs!($ty, div , div_f , Div   , DeviceDivAPI   , TensorDivAPI   , DeviceRConsumeDivAPI   );
843    };
844}
845
846mod impl_arithmetic_scalar_lhs {
847    use super::*;
848    use half::{bf16, f16};
849    use num::complex::Complex;
850    impl_arithmetic_scalar_lhs_all!(i8);
851    impl_arithmetic_scalar_lhs_all!(u8);
852    impl_arithmetic_scalar_lhs_all!(i16);
853    impl_arithmetic_scalar_lhs_all!(u16);
854    impl_arithmetic_scalar_lhs_all!(i32);
855    impl_arithmetic_scalar_lhs_all!(u32);
856    impl_arithmetic_scalar_lhs_all!(i64);
857    impl_arithmetic_scalar_lhs_all!(u64);
858    impl_arithmetic_scalar_lhs_all!(i128);
859    impl_arithmetic_scalar_lhs_all!(u128);
860    impl_arithmetic_scalar_lhs_all!(isize);
861    impl_arithmetic_scalar_lhs_all!(usize);
862
863    impl_arithmetic_scalar_lhs_bool!(bool);
864
865    impl_arithmetic_scalar_lhs_float!(bf16);
866    impl_arithmetic_scalar_lhs_float!(f16);
867    impl_arithmetic_scalar_lhs_float!(f32);
868    impl_arithmetic_scalar_lhs_float!(f64);
869    impl_arithmetic_scalar_lhs_float!(Complex<bf16>);
870    impl_arithmetic_scalar_lhs_float!(Complex<f16>);
871    impl_arithmetic_scalar_lhs_float!(Complex<f32>);
872    impl_arithmetic_scalar_lhs_float!(Complex<f64>);
873}
874
875/* #endregion */
876
877/* #region binary with scalar, tsr a op num b */
878
879// for this case, core::ops::* is not required to be re-implemented
880// see macro_rule `impl_core_ops`
881
882#[duplicate_item(
883    op_f       Op       DeviceOpAPI       TensorOpAPI       DeviceLConsumeOpAPI     ;
884   [add_f   ] [Add   ] [DeviceAddAPI   ] [TensorAddAPI   ] [DeviceLConsumeAddAPI   ];
885   [sub_f   ] [Sub   ] [DeviceSubAPI   ] [TensorSubAPI   ] [DeviceLConsumeSubAPI   ];
886   [mul_f   ] [Mul   ] [DeviceMulAPI   ] [TensorMulAPI   ] [DeviceLConsumeMulAPI   ];
887   [div_f   ] [Div   ] [DeviceDivAPI   ] [TensorDivAPI   ] [DeviceLConsumeDivAPI   ];
888   [rem_f   ] [Rem   ] [DeviceRemAPI   ] [TensorRemAPI   ] [DeviceLConsumeRemAPI   ];
889   [bitor_f ] [BitOr ] [DeviceBitOrAPI ] [TensorBitOrAPI ] [DeviceLConsumeBitOrAPI ];
890   [bitand_f] [BitAnd] [DeviceBitAndAPI] [TensorBitAndAPI] [DeviceLConsumeBitAndAPI];
891   [bitxor_f] [BitXor] [DeviceBitXorAPI] [TensorBitXorAPI] [DeviceLConsumeBitXorAPI];
892   [shl_f   ] [Shl   ] [DeviceShlAPI   ] [TensorShlAPI   ] [DeviceLConsumeShlAPI   ];
893   [shr_f   ] [Shr   ] [DeviceShrAPI   ] [TensorShrAPI   ] [DeviceLConsumeShrAPI   ];
894)]
895mod impl_arithmetic_scalar_rhs {
896    use super::*;
897
898    #[doc(hidden)]
899    impl<T, TB, R, D, B> TensorOpAPI<TB> for &TensorAny<R, T, B, D>
900    where
901        T: From<TB> + Op<T, Output = T>,
902        R: DataAPI<Data = <B as DeviceRawAPI<T>>::Raw>,
903        D: DimAPI,
904        B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
905        B: DeviceOpAPI<T, T, T, D>,
906        // this constraint prohibits confliting impl to TensorBase<RB, D>
907        TB: num::Num,
908    {
909        type Output = Tensor<T, B, D>;
910        fn op_f(a: Self, b: TB) -> Result<Self::Output> {
911            let b = T::from(b);
912            let device = a.device();
913            let la = a.layout();
914            let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
915            let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
916            device.op_mutc_refa_numb(storage_c.raw_mut(), &lc, a.raw(), la, b)?;
917            let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
918            Tensor::new_f(storage_c, lc)
919        }
920    }
921
922    #[doc(hidden)]
923    impl<T, TB, D, B> TensorOpAPI<TB> for TensorView<'_, T, B, D>
924    where
925        T: From<TB> + Op<T, Output = T>,
926        D: DimAPI,
927        B: DeviceAPI<T> + DeviceCreationAnyAPI<T>,
928        B: DeviceOpAPI<T, T, T, D>,
929        // this constraint prohibits confliting impl to TensorBase<RB, D>
930        TB: num::Num,
931    {
932        type Output = Tensor<T, B, D>;
933        fn op_f(a: Self, b: TB) -> Result<Self::Output> {
934            let b = T::from(b);
935            let device = a.device();
936            let la = a.layout();
937            let lc = layout_for_array_copy(la, TensorIterOrder::default())?;
938            let mut storage_c = device.uninit_impl(lc.bounds_index()?.1)?;
939            device.op_mutc_refa_numb(storage_c.raw_mut(), &lc, a.raw(), la, b)?;
940            let storage_c = unsafe { B::assume_init_impl(storage_c) }?;
941            Tensor::new_f(storage_c, lc)
942        }
943    }
944
945    #[doc(hidden)]
946    impl<T, TB, D, B> TensorOpAPI<TB> for Tensor<T, B, D>
947    where
948        T: From<TB> + Op<T, Output = T>,
949        D: DimAPI,
950        B: DeviceAPI<T>,
951        B: DeviceLConsumeOpAPI<T, T, D>,
952        // this constraint prohibits confliting impl to TensorBase<RB, D>
953        TB: num::Num,
954    {
955        type Output = Tensor<T, B, D>;
956        fn op_f(mut a: Self, b: TB) -> Result<Self::Output> {
957            let b = T::from(b);
958            let device = a.device().clone();
959            let la = a.layout().clone();
960            device.op_muta_numb(a.raw_mut(), &la, b)?;
961            return Ok(a);
962        }
963    }
964
965    #[doc(hidden)]
966    impl<T, TB, D, B> TensorOpAPI<TB> for TensorCow<'_, T, B, D>
967    where
968        T: From<TB> + Op<T, Output = T>,
969        D: DimAPI,
970        B: DeviceAPI<T>,
971        B: DeviceLConsumeOpAPI<T, T, D> + DeviceOpAPI<T, T, T, D>,
972        // this constraint prohibits confliting impl to TensorBase<RB, D>
973        TB: num::Num,
974        // cow constraints
975        T: Clone,
976        <B as DeviceRawAPI<T>>::Raw: Clone,
977        B: DeviceCreationAnyAPI<T> + OpAssignAPI<T, D>,
978    {
979        type Output = Tensor<T, B, D>;
980        fn op_f(a: Self, b: TB) -> Result<Self::Output> {
981            match a.is_owned() {
982                true => TensorOpAPI::op_f(a.into_owned(), b),
983                false => TensorOpAPI::op_f(a.view(), b),
984            }
985        }
986    }
987}
988
989/* #endregion */
990
991/* #region test */
992
993#[cfg(test)]
994mod test {
995    use super::*;
996
997    #[test]
998    #[cfg(not(feature = "col_major"))]
999    fn test_add_row_major() {
1000        // contiguous
1001        let a = linspace((1.0, 5.0, 5));
1002        let b = linspace((2.0, 10.0, 5));
1003        let c = &a + &b;
1004        let c_ref = vec![3., 6., 9., 12., 15.].into();
1005        assert!(allclose_f64(&c, &c_ref));
1006
1007        let a = linspace((1.0, 5.0, 5));
1008        let b = linspace((2.0, 10.0, 5));
1009        let c = add(&a, &b);
1010        let c_ref = vec![3., 6., 9., 12., 15.].into();
1011        assert!(allclose_f64(&c, &c_ref));
1012
1013        // broadcast
1014        // [2, 3] + [3]
1015        let a = linspace((1.0, 6.0, 6)).into_shape_assume_contig([2, 3]);
1016        let b = linspace((2.0, 6.0, 3));
1017        let c = &a + &b;
1018        let c_ref = vec![3., 6., 9., 6., 9., 12.].into();
1019        assert!(allclose_f64(&c, &c_ref));
1020
1021        // broadcast
1022        // [1, 2, 3] + [5, 1, 2, 1]
1023        // a = np.linspace(1, 6, 6).reshape(1, 2, 3)
1024        // b = np.linspace(1, 10, 10).reshape(5, 1, 2, 1)
1025        let a = linspace((1.0, 6.0, 6));
1026        let a = a.into_shape_assume_contig([1, 2, 3]);
1027        let b = linspace((1.0, 10.0, 10));
1028        let b = b.into_shape_assume_contig([5, 1, 2, 1]);
1029        let c = &a + &b;
1030        let c_ref = vec![
1031            2., 3., 4., 6., 7., 8., 4., 5., 6., 8., 9., 10., 6., 7., 8., 10., 11., 12., 8., 9., 10., 12., 13., 14.,
1032            10., 11., 12., 14., 15., 16.,
1033        ];
1034        let c_ref = c_ref.into();
1035        assert!(allclose_f64(&c, &c_ref));
1036
1037        // transposed
1038        let a = linspace((1.0, 9.0, 9));
1039        let a = a.into_shape_assume_contig([3, 3]);
1040        let b = linspace((2.0, 18.0, 9));
1041        let b = b.into_shape_assume_contig([3, 3]).into_reverse_axes();
1042        let c = &a + &b;
1043        let c_ref = vec![3., 10., 17., 8., 15., 22., 13., 20., 27.].into();
1044        assert!(allclose_f64(&c, &c_ref));
1045
1046        // negative strides
1047        let a = linspace((1.0, 5.0, 5));
1048        let b = linspace((2.0, 10.0, 5));
1049        let a = a.flip(0);
1050        let c = &a + &b;
1051        let c_ref = vec![7., 8., 9., 10., 11.].into();
1052        assert!(allclose_f64(&c, &c_ref));
1053
1054        let a = linspace((1.0, 5.0, 5));
1055        let b = linspace((2.0, 10.0, 5));
1056        let b = b.flip(0);
1057        let c = &a + &b;
1058        let c_ref = vec![11., 10., 9., 8., 7.].into();
1059        assert!(allclose_f64(&c, &c_ref));
1060
1061        // view
1062        let a = linspace((1.0, 5.0, 5));
1063        let b = linspace((2.0, 10.0, 5));
1064        let c = a.view() + &b;
1065        let c_ref = vec![3., 6., 9., 12., 15.].into();
1066        assert!(allclose_f64(&c, &c_ref));
1067
1068        let a = linspace((1.0, 5.0, 5));
1069        let b = linspace((2.0, 10.0, 5));
1070        let c = &a + b.view();
1071        let c_ref = vec![3., 6., 9., 12., 15.].into();
1072        assert!(allclose_f64(&c, &c_ref));
1073    }
1074
1075    #[test]
1076    #[cfg(feature = "col_major")]
1077    fn test_add_col_major() {
1078        // contiguous
1079        let a = linspace((1.0, 5.0, 5));
1080        let b = linspace((2.0, 10.0, 5));
1081        let c = &a + &b;
1082        let c_ref = vec![3., 6., 9., 12., 15.].into();
1083        assert!(allclose_f64(&c, &c_ref));
1084
1085        let a = linspace((1.0, 5.0, 5));
1086        let b = linspace((2.0, 10.0, 5));
1087        let c = add(&a, &b);
1088        let c_ref = vec![3., 6., 9., 12., 15.].into();
1089        assert!(allclose_f64(&c, &c_ref));
1090
1091        // broadcast
1092        // [3, 2] + [3]
1093        let a = linspace((1.0, 6.0, 6)).into_shape_assume_contig([3, 2]);
1094        let b = linspace((2.0, 6.0, 3));
1095        let c = &a + &b;
1096        let c_ref = vec![3., 6., 9., 6., 9., 12.];
1097        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1098
1099        // broadcast
1100        // [3, 2, 1] + [1, 2, 1, 5]
1101        let a = linspace((1.0, 6.0, 6));
1102        let a = a.into_shape_assume_contig([3, 2, 1]);
1103        let b = linspace((1.0, 10.0, 10));
1104        let b = b.into_shape_assume_contig([1, 2, 1, 5]);
1105        let c = &a + &b;
1106        let c_ref = vec![
1107            2., 3., 4., 6., 7., 8., 4., 5., 6., 8., 9., 10., 6., 7., 8., 10., 11., 12., 8., 9., 10., 12., 13., 14.,
1108            10., 11., 12., 14., 15., 16.,
1109        ];
1110        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1111
1112        // transposed
1113        let a = linspace((1.0, 9.0, 9));
1114        let a = a.into_shape_assume_contig([3, 3]);
1115        let b = linspace((2.0, 18.0, 9));
1116        let b = b.into_shape_assume_contig([3, 3]).into_reverse_axes();
1117        let c = &a + &b;
1118        let c_ref = vec![3., 10., 17., 8., 15., 22., 13., 20., 27.];
1119        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1120
1121        // negative strides
1122        let a = linspace((1.0, 5.0, 5));
1123        let b = linspace((2.0, 10.0, 5));
1124        let a = a.flip(0);
1125        let c = &a + &b;
1126        let c_ref = vec![7., 8., 9., 10., 11.];
1127        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1128
1129        let a = linspace((1.0, 5.0, 5));
1130        let b = linspace((2.0, 10.0, 5));
1131        let b = b.flip(0);
1132        let c = &a + &b;
1133        let c_ref = vec![11., 10., 9., 8., 7.];
1134        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1135
1136        // view
1137        let a = linspace((1.0, 5.0, 5));
1138        let b = linspace((2.0, 10.0, 5));
1139        let c = a.view() + &b;
1140        let c_ref = vec![3., 6., 9., 12., 15.];
1141        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1142
1143        let a = linspace((1.0, 5.0, 5));
1144        let b = linspace((2.0, 10.0, 5));
1145        let c = &a + b.view();
1146        let c_ref = vec![3., 6., 9., 12., 15.];
1147        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1148    }
1149
1150    #[test]
1151    fn test_sub() {
1152        // contiguous
1153        let a = linspace((1.0, 5.0, 5));
1154        let b = linspace((2.0, 10.0, 5));
1155        let c = &a - &b;
1156        let c_ref = vec![-1., -2., -3., -4., -5.].into();
1157        assert!(allclose_f64(&c, &c_ref));
1158    }
1159
1160    #[test]
1161    fn test_mul() {
1162        // contiguous
1163        let a = linspace((1.0, 5.0, 5));
1164        let b = linspace((2.0, 10.0, 5));
1165        let c = &a * &b;
1166        let c_ref = vec![2., 8., 18., 32., 50.].into();
1167        assert!(allclose_f64(&c, &c_ref));
1168    }
1169
1170    #[test]
1171    #[cfg(not(feature = "col_major"))]
1172    fn test_add_consume_row_major() {
1173        // a + &b, same shape
1174        let a = linspace((1.0, 5.0, 5));
1175        let b = linspace((2.0, 10.0, 5));
1176        let a_ptr = a.raw().as_ptr();
1177        let c = a + &b;
1178        let c_ptr = c.raw().as_ptr();
1179        let c_ref = vec![3., 6., 9., 12., 15.].into();
1180        assert!(allclose_f64(&c, &c_ref));
1181        assert_eq!(a_ptr, c_ptr);
1182        // a + &b, broadcastable
1183        let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1184        let b = linspace((2.0, 10.0, 5));
1185        let a_ptr = a.raw().as_ptr();
1186        let c = a + &b;
1187        let c_ptr = c.raw().as_ptr();
1188        let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.].into();
1189        assert!(allclose_f64(&c, &c_ref));
1190        assert_eq!(a_ptr, c_ptr);
1191        // a + &b, non-broadcastable
1192        let a = linspace((2.0, 10.0, 5));
1193        let b = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1194        let a_ptr = a.raw().as_ptr();
1195        let c = a + &b;
1196        let c_ptr = c.raw().as_ptr();
1197        let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.].into();
1198        assert!(allclose_f64(&c, &c_ref));
1199        assert_ne!(a_ptr, c_ptr);
1200        // &a + b
1201        let a = linspace((1.0, 5.0, 5));
1202        let b = linspace((2.0, 10.0, 5));
1203        let b_ptr = b.raw().as_ptr();
1204        let c = &a + b;
1205        let c_ptr = c.raw().as_ptr();
1206        let c_ref = vec![3., 6., 9., 12., 15.].into();
1207        assert!(allclose_f64(&c, &c_ref));
1208        assert_eq!(b_ptr, c_ptr);
1209        // &a + b, non-broadcastable
1210        let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1211        let b = linspace((2.0, 10.0, 5));
1212        let b_ptr = b.raw().as_ptr();
1213        let c = &a + b;
1214        let c_ptr = c.raw().as_ptr();
1215        let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.].into();
1216        assert!(allclose_f64(&c, &c_ref));
1217        assert_ne!(b_ptr, c_ptr);
1218        // a + b, same shape
1219        let a = linspace((1.0, 5.0, 5));
1220        let b = linspace((2.0, 10.0, 5));
1221        let a_ptr = a.raw().as_ptr();
1222        let c = a + b;
1223        let c_ptr = c.raw().as_ptr();
1224        let c_ref = vec![3., 6., 9., 12., 15.].into();
1225        assert!(allclose_f64(&c, &c_ref));
1226        assert_eq!(a_ptr, c_ptr);
1227    }
1228
1229    #[test]
1230    #[cfg(feature = "col_major")]
1231    fn test_add_consume_col_major() {
1232        // a + &b, same shape
1233        let a = linspace((1.0, 5.0, 5));
1234        let b = linspace((2.0, 10.0, 5));
1235        let a_ptr = a.raw().as_ptr();
1236        let c = a + &b;
1237        let c_ptr = c.raw().as_ptr();
1238        let c_ref = vec![3., 6., 9., 12., 15.];
1239        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1240        assert_eq!(a_ptr, c_ptr);
1241        // a + &b, broadcastable
1242        let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1243        let b = linspace((2.0, 10.0, 5));
1244        let a_ptr = a.raw().as_ptr();
1245        let c = a + &b;
1246        let c_ptr = c.raw().as_ptr();
1247        let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.];
1248        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1249        assert_eq!(a_ptr, c_ptr);
1250        // a + &b, non-broadcastable
1251        let a = linspace((2.0, 10.0, 5));
1252        let b = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1253        let a_ptr = a.raw().as_ptr();
1254        let c = a + &b;
1255        let c_ptr = c.raw().as_ptr();
1256        let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.];
1257        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1258        assert_ne!(a_ptr, c_ptr);
1259        // &a + b
1260        let a = linspace((1.0, 5.0, 5));
1261        let b = linspace((2.0, 10.0, 5));
1262        let b_ptr = b.raw().as_ptr();
1263        let c = &a + b;
1264        let c_ptr = c.raw().as_ptr();
1265        let c_ref = vec![3., 6., 9., 12., 15.];
1266        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1267        assert_eq!(b_ptr, c_ptr);
1268        // &a + b, non-broadcastable
1269        let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1270        let b = linspace((2.0, 10.0, 5));
1271        let b_ptr = b.raw().as_ptr();
1272        let c = &a + b;
1273        let c_ptr = c.raw().as_ptr();
1274        let c_ref = vec![3., 6., 9., 12., 15., 8., 11., 14., 17., 20.];
1275        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1276        assert_ne!(b_ptr, c_ptr);
1277        // a + b, same shape
1278        let a = linspace((1.0, 5.0, 5));
1279        let b = linspace((2.0, 10.0, 5));
1280        let a_ptr = a.raw().as_ptr();
1281        let c = a + b;
1282        let c_ptr = c.raw().as_ptr();
1283        let c_ref = vec![3., 6., 9., 12., 15.];
1284        assert!(allclose_f64(&c.raw().into(), &c_ref.into()));
1285        assert_eq!(a_ptr, c_ptr);
1286    }
1287
1288    #[test]
1289    fn test_sub_consume() {
1290        // &a - b
1291        let a = linspace((1.0, 5.0, 5));
1292        let b = linspace((2.0, 10.0, 5));
1293        let b_ptr = b.raw().as_ptr();
1294        let c = &a - b;
1295        let c_ptr = c.raw().as_ptr();
1296        let c_ref = vec![-1., -2., -3., -4., -5.].into();
1297        assert!(allclose_f64(&c, &c_ref));
1298        assert_eq!(b_ptr, c_ptr);
1299        // a - &b
1300        let a = linspace((1.0, 5.0, 5));
1301        let b = linspace((2.0, 10.0, 5));
1302        let a_ptr = a.raw().as_ptr();
1303        let c = a - b.view();
1304        let c_ptr = c.raw().as_ptr();
1305        let c_ref = vec![-1., -2., -3., -4., -5.].into();
1306        assert!(allclose_f64(&c, &c_ref));
1307        assert_eq!(a_ptr, c_ptr);
1308        // a - b
1309        let a = linspace((1.0, 5.0, 5));
1310        let b = linspace((2.0, 10.0, 5));
1311        let a_ptr = a.raw().as_ptr();
1312        let c = a - b;
1313        let c_ptr = c.raw().as_ptr();
1314        let c_ref = vec![-1., -2., -3., -4., -5.].into();
1315        assert!(allclose_f64(&c, &c_ref));
1316        assert_eq!(a_ptr, c_ptr);
1317    }
1318}
1319
1320#[cfg(test)]
1321mod test_with_output {
1322    use super::*;
1323
1324    #[test]
1325    fn test_op_binary_with_output() {
1326        #[cfg(not(feature = "col_major"))]
1327        {
1328            let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1329            let b = linspace((2.0, 10.0, 5)).into_layout([5].c());
1330            let mut c = linspace((1.0, 10.0, 10)).into_shape_assume_contig([2, 5]);
1331            let c_view = c.view_mut();
1332            add_with_output(&a, b, c_view);
1333            println!("{c:?}");
1334        }
1335        #[cfg(feature = "col_major")]
1336        {
1337            let a = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1338            let b = linspace((2.0, 10.0, 5)).into_layout([5].c());
1339            let mut c = linspace((1.0, 10.0, 10)).into_shape_assume_contig([5, 2]);
1340            let c_view = c.view_mut();
1341            add_with_output(&a, b, c_view);
1342            println!("{c:?}");
1343        }
1344    }
1345}
1346
1347#[cfg(test)]
1348mod tests_with_scalar {
1349    use super::*;
1350
1351    #[test]
1352    fn test_add() {
1353        // b - &a
1354        let a = linspace((1.0, 5.0, 5));
1355        let b = 1;
1356        let c = b - &a;
1357        let c_ref = vec![0., -1., -2., -3., -4.].into();
1358        assert!(allclose_f64(&c, &c_ref));
1359
1360        // &a - b
1361        let a = linspace((1.0, 5.0, 5));
1362        let b = 1;
1363        let c = &a - b;
1364        let c_ref = vec![0., 1., 2., 3., 4.].into();
1365        assert!(allclose_f64(&c, &c_ref));
1366
1367        // b * a
1368        let a = linspace((1.0, 5.0, 5));
1369        let a_ptr = a.raw().as_ptr();
1370        let b = 2;
1371        let c: Tensor<_> = -b * a;
1372        let c_ref = vec![-2., -4., -6., -8., -10.].into();
1373        assert!(allclose_f64(&c, &c_ref));
1374        let c_ptr = c.raw().as_ptr();
1375        assert_eq!(a_ptr, c_ptr);
1376    }
1377
1378    #[test]
1379    fn test_scalar_consequent() {
1380        let a = linspace((1.0, 5.0, 5));
1381        let mut c = linspace((1.0, 5.0, 5));
1382        // TODO: currently `let b = 2 * a` will give compiler error
1383        // type must be known at this point
1384        // I'm not sure why this happens, maybe rust's type inference problem?
1385        let b = a * 2;
1386        *&mut c.i_mut(1) += b.i(1);
1387        println!("{c:?}");
1388    }
1389
1390    #[test]
1391    fn test_cow() {
1392        let a = linspace((1.0, 24.0, 24)).into_shape((2, 3, 4));
1393        let a_cow_view = a.reshape((2, 3, 4));
1394        let a_cow_owned = a.view().into_swapaxes(-1, -2).change_shape((2, 3, 4));
1395        let ptr_a_cow_owned = a_cow_owned.raw().as_ptr();
1396        assert!(!a_cow_view.is_owned());
1397        assert!(a_cow_owned.is_owned());
1398
1399        let b = a.reshape((2, 3, 4)) + a_cow_view;
1400        let ptr_b = b.raw().as_ptr();
1401        println!("{b:?}");
1402        assert_ne!(ptr_a_cow_owned, ptr_b);
1403
1404        let b = a.reshape((2, 3, 4)) + a_cow_owned;
1405        let ptr_b = b.raw().as_ptr();
1406        println!("{b:?}");
1407        assert_eq!(ptr_a_cow_owned, ptr_b);
1408
1409        let b = a.reshape((2, 3, 4)) * 2.0;
1410        println!("{b:?}");
1411    }
1412}
1413
1414/* #endregion */