rstsr_openblas/rayon_auto_impl/
op_ternary_common.rs

1use crate::prelude_dev::*;
2use num::complex::ComplexFloat;
3use num::{pow::Pow, Float};
4use rstsr_dtype_traits::{DTypeIntoFloatAPI, DTypePromoteAPI, ExtFloat, ExtReal};
5
6// output with special promotion
7#[duplicate_item(
8     DeviceOpAPI             TraitT           func_inner;
9    [DeviceATan2API       ] [Float         ] [Float::atan2(a, b)            ];
10    [DeviceCopySignAPI    ] [Float         ] [Float::copysign(a, b)         ];
11    [DeviceHypotAPI       ] [Float         ] [Float::hypot(a, b)            ];
12    [DeviceNextAfterAPI   ] [ExtFloat      ] [ExtFloat::ext_nextafter(a, b) ];
13    [DeviceLogAddExpAPI   ] [ComplexFloat  ] [(a.exp() + b.exp()).ln()      ];
14)]
15impl<TA, TB, D> DeviceOpAPI<TA, TB, D> for DeviceRayonAutoImpl
16where
17    TA: Clone + Send + Sync + DTypePromoteAPI<TB, Res: DTypeIntoFloatAPI<FloatType: TraitT + Send + Sync>>,
18    TB: Clone + Send + Sync,
19    D: DimAPI,
20{
21    type TOut = <TA::Res as DTypeIntoFloatAPI>::FloatType;
22
23    fn op_mutc_refa_refb(
24        &self,
25        c: &mut Vec<MaybeUninit<Self::TOut>>,
26        lc: &Layout<D>,
27        a: &Vec<TA>,
28        la: &Layout<D>,
29        b: &Vec<TB>,
30        lb: &Layout<D>,
31    ) -> Result<()> {
32        let mut func = |c: &mut MaybeUninit<Self::TOut>, a: &TA, b: &TB| {
33            let (a, b) = TA::promote_pair(a.clone(), b.clone());
34            let (a, b) = (a.into_float(), b.into_float());
35            c.write(func_inner);
36        };
37        self.op_mutc_refa_refb_func(c, lc, a, la, b, lb, &mut func)
38    }
39
40    fn op_mutc_refa_numb(
41        &self,
42        c: &mut Vec<MaybeUninit<Self::TOut>>,
43        lc: &Layout<D>,
44        a: &Vec<TA>,
45        la: &Layout<D>,
46        b: TB,
47    ) -> Result<()> {
48        let mut func = |c: &mut MaybeUninit<Self::TOut>, a: &TA, b: &TB| {
49            let (a, b) = TA::promote_pair(a.clone(), b.clone());
50            let (a, b) = (a.into_float(), b.into_float());
51            c.write(func_inner);
52        };
53        self.op_mutc_refa_numb_func(c, lc, a, la, b, &mut func)
54    }
55
56    fn op_mutc_numa_refb(
57        &self,
58        c: &mut Vec<MaybeUninit<Self::TOut>>,
59        lc: &Layout<D>,
60        a: TA,
61        b: &Vec<TB>,
62        lb: &Layout<D>,
63    ) -> Result<()> {
64        let mut func = |c: &mut MaybeUninit<Self::TOut>, a: &TA, b: &TB| {
65            let (a, b) = TA::promote_pair(a.clone(), b.clone());
66            let (a, b) = (a.into_float(), b.into_float());
67            c.write(func_inner);
68        };
69        self.op_mutc_numa_refb_func(c, lc, a, b, lb, &mut func)
70    }
71}
72
73// general promotion
74#[duplicate_item(
75     DeviceOpAPI             TO        TraitT           func_inner;
76    [DeviceMaximumAPI     ] [TA::Res] [ExtReal       ] [ExtReal::ext_max(a, b)         ];
77    [DeviceMinimumAPI     ] [TA::Res] [ExtReal       ] [ExtReal::ext_min(a, b)         ];
78    [DeviceFloorDivideAPI ] [TA::Res] [ExtReal       ] [ExtReal::ext_floor_divide(a, b)];
79    [DeviceEqualAPI       ] [bool   ] [PartialEq     ] [a == b                         ];
80    [DeviceGreaterAPI     ] [bool   ] [PartialOrd    ] [a > b                          ];
81    [DeviceGreaterEqualAPI] [bool   ] [PartialOrd    ] [a >= b                         ];
82    [DeviceLessAPI        ] [bool   ] [PartialOrd    ] [a < b                          ];
83    [DeviceLessEqualAPI   ] [bool   ] [PartialOrd    ] [a <= b                         ];
84)]
85impl<TA, TB, D> DeviceOpAPI<TA, TB, D> for DeviceRayonAutoImpl
86where
87    TA: Clone + Send + Sync + DTypePromoteAPI<TB, Res: TraitT + Send + Sync>,
88    TB: Clone + Send + Sync,
89    D: DimAPI,
90{
91    type TOut = TO;
92
93    fn op_mutc_refa_refb(
94        &self,
95        c: &mut Vec<MaybeUninit<Self::TOut>>,
96        lc: &Layout<D>,
97        a: &Vec<TA>,
98        la: &Layout<D>,
99        b: &Vec<TB>,
100        lb: &Layout<D>,
101    ) -> Result<()> {
102        let mut func = |c: &mut MaybeUninit<Self::TOut>, a: &TA, b: &TB| {
103            let (a, b) = TA::promote_pair(a.clone(), b.clone());
104            c.write(func_inner);
105        };
106        self.op_mutc_refa_refb_func(c, lc, a, la, b, lb, &mut func)
107    }
108
109    fn op_mutc_refa_numb(
110        &self,
111        c: &mut Vec<MaybeUninit<Self::TOut>>,
112        lc: &Layout<D>,
113        a: &Vec<TA>,
114        la: &Layout<D>,
115        b: TB,
116    ) -> Result<()> {
117        let mut func = |c: &mut MaybeUninit<Self::TOut>, a: &TA, b: &TB| {
118            let (a, b) = TA::promote_pair(a.clone(), b.clone());
119            c.write(func_inner);
120        };
121        self.op_mutc_refa_numb_func(c, lc, a, la, b, &mut func)
122    }
123
124    fn op_mutc_numa_refb(
125        &self,
126        c: &mut Vec<MaybeUninit<Self::TOut>>,
127        lc: &Layout<D>,
128        a: TA,
129        b: &Vec<TB>,
130        lb: &Layout<D>,
131    ) -> Result<()> {
132        let mut func = |c: &mut MaybeUninit<Self::TOut>, a: &TA, b: &TB| {
133            let (a, b) = TA::promote_pair(a.clone(), b.clone());
134            c.write(func_inner);
135        };
136        self.op_mutc_numa_refb_func(c, lc, a, b, lb, &mut func)
137    }
138}
139
140// Special case for pow
141impl<TA, TB, D> DevicePowAPI<TA, TB, D> for DeviceRayonAutoImpl
142where
143    TA: Clone + Send + Sync,
144    TB: Clone + Send + Sync,
145    TA: Pow<TB>,
146    TA::Output: Clone + Send + Sync,
147    D: DimAPI,
148{
149    type TOut = TA::Output;
150
151    fn op_mutc_refa_refb(
152        &self,
153        c: &mut Vec<MaybeUninit<Self::TOut>>,
154        lc: &Layout<D>,
155        a: &Vec<TA>,
156        la: &Layout<D>,
157        b: &Vec<TB>,
158        lb: &Layout<D>,
159    ) -> Result<()> {
160        self.op_mutc_refa_refb_func(c, lc, a, la, b, lb, &mut |c, a, b| {
161            c.write(a.clone().pow(b.clone()));
162        })
163    }
164
165    fn op_mutc_refa_numb(
166        &self,
167        c: &mut <Self as DeviceRawAPI<MaybeUninit<Self::TOut>>>::Raw,
168        lc: &Layout<D>,
169        a: &<Self as DeviceRawAPI<TA>>::Raw,
170        la: &Layout<D>,
171        b: TB,
172    ) -> Result<()> {
173        self.op_mutc_refa_numb_func(c, lc, a, la, b, &mut |c, a, b| {
174            c.write(a.clone().pow(b.clone()));
175        })
176    }
177
178    fn op_mutc_numa_refb(
179        &self,
180        c: &mut <Self as DeviceRawAPI<MaybeUninit<Self::TOut>>>::Raw,
181        lc: &Layout<D>,
182        a: TA,
183        b: &<Self as DeviceRawAPI<TB>>::Raw,
184        lb: &Layout<D>,
185    ) -> Result<()> {
186        self.op_mutc_numa_refb_func(c, lc, a, b, lb, &mut |c, a, b| {
187            c.write(a.clone().pow(b.clone()));
188        })
189    }
190}