rstsr_openblas/rayon_auto_impl/
op_binary_common.rs

1use crate::prelude_dev::*;
2use core::ops::Div;
3use num::complex::ComplexFloat;
4use num::{Float, Signed};
5use rstsr_dtype_traits::{DTypeIntoFloatAPI, ExtNum};
6
7// TODO: log1p
8
9/* #region same type */
10
11#[duplicate_item(
12     DeviceOpAPI           NumTrait       func_inner;
13    [DeviceAcosAPI      ] [ComplexFloat] [b.acos()  ];
14    [DeviceAcoshAPI     ] [ComplexFloat] [b.acosh() ];
15    [DeviceAsinAPI      ] [ComplexFloat] [b.asin()  ];
16    [DeviceAsinhAPI     ] [ComplexFloat] [b.asinh() ];
17    [DeviceAtanAPI      ] [ComplexFloat] [b.atan()  ];
18    [DeviceAtanhAPI     ] [ComplexFloat] [b.atanh() ];
19    [DeviceCeilAPI      ] [Float       ] [b.ceil()  ];
20    [DeviceConjAPI      ] [ComplexFloat] [b.conj()  ];
21    [DeviceCosAPI       ] [ComplexFloat] [b.cos()   ];
22    [DeviceCoshAPI      ] [ComplexFloat] [b.cosh()  ];
23    [DeviceExpAPI       ] [ComplexFloat] [b.exp()   ];
24    [DeviceExpm1API     ] [Float       ] [b.exp_m1()];
25    [DeviceFloorAPI     ] [Float       ] [b.floor() ];
26    [DeviceInvAPI       ] [ComplexFloat] [b.recip() ];
27    [DeviceLogAPI       ] [ComplexFloat] [b.ln()    ];
28    [DeviceLog2API      ] [ComplexFloat] [b.log2()  ];
29    [DeviceLog10API     ] [ComplexFloat] [b.log10() ];
30    [DeviceReciprocalAPI] [ComplexFloat] [b.recip() ];
31    [DeviceRoundAPI     ] [Float       ] [b.round() ];
32    [DeviceSinAPI       ] [ComplexFloat] [b.sin()   ];
33    [DeviceSinhAPI      ] [ComplexFloat] [b.sinh()  ];
34    [DeviceSqrtAPI      ] [ComplexFloat] [b.sqrt()  ];
35    [DeviceTanAPI       ] [ComplexFloat] [b.tan()   ];
36    [DeviceTanhAPI      ] [ComplexFloat] [b.tanh()  ];
37    [DeviceTruncAPI     ] [Float       ] [b.trunc() ];
38)]
39impl<T, D> DeviceOpAPI<T, D> for DeviceRayonAutoImpl
40where
41    T: Clone + Send + Sync + DTypeIntoFloatAPI<FloatType: NumTrait + Send + Sync>,
42    D: DimAPI,
43{
44    type TOut = T::FloatType;
45
46    fn op_muta_refb(
47        &self,
48        a: &mut Vec<MaybeUninit<Self::TOut>>,
49        la: &Layout<D>,
50        b: &Vec<T>,
51        lb: &Layout<D>,
52    ) -> Result<()> {
53        let mut func = |a: &mut MaybeUninit<Self::TOut>, b: &T| {
54            let b = b.clone().into_float();
55            a.write(func_inner);
56        };
57        self.op_muta_refb_func(a, la, b, lb, &mut func)
58    }
59
60    fn op_muta(&self, a: &mut Vec<MaybeUninit<Self::TOut>>, la: &Layout<D>) -> Result<()> {
61        let mut func = |a: &mut MaybeUninit<Self::TOut>| {
62            let b = unsafe { a.assume_init_read() };
63            a.write(func_inner);
64        };
65        self.op_muta_func(a, la, &mut func)
66    }
67}
68
69impl<T, D> DeviceSquareAPI<T, D> for DeviceRayonAutoImpl
70where
71    T: Clone + Send + Sync + Mul<Output = T>,
72    D: DimAPI,
73{
74    type TOut = T;
75
76    fn op_muta_refb(&self, a: &mut Vec<MaybeUninit<T>>, la: &Layout<D>, b: &Vec<T>, lb: &Layout<D>) -> Result<()> {
77        let mut func = |a: &mut MaybeUninit<T>, b: &T| {
78            a.write(b.clone() * b.clone());
79        };
80        self.op_muta_refb_func(a, la, b, lb, &mut func)
81    }
82
83    fn op_muta(&self, a: &mut Vec<MaybeUninit<T>>, la: &Layout<D>) -> Result<()> {
84        let mut func = |a: &mut MaybeUninit<T>| {
85            let b = unsafe { a.assume_init_read() };
86            a.write(b.clone() * b);
87        };
88        self.op_muta_func(a, la, &mut func)
89    }
90}
91
92/* #endregion */
93
94/* #region boolean output */
95
96#[duplicate_item(
97     DeviceOpAPI         NumTrait       func                         ;
98    [DeviceSignBitAPI ] [Signed      ] [|a, b| { a.write(b.is_positive()); } ];
99    [DeviceIsFiniteAPI] [ComplexFloat] [|a, b| { a.write(b.is_finite()  ); } ];
100    [DeviceIsInfAPI   ] [ComplexFloat] [|a, b| { a.write(b.is_infinite()); } ];
101    [DeviceIsNanAPI   ] [ComplexFloat] [|a, b| { a.write(b.is_nan()     ); } ];
102)]
103impl<T, D> DeviceOpAPI<T, D> for DeviceRayonAutoImpl
104where
105    T: Clone + NumTrait + Send + Sync,
106    D: DimAPI,
107{
108    type TOut = bool;
109
110    fn op_muta_refb(&self, a: &mut Vec<MaybeUninit<bool>>, la: &Layout<D>, b: &Vec<T>, lb: &Layout<D>) -> Result<()> {
111        self.op_muta_refb_func(a, la, b, lb, &mut func)
112    }
113
114    fn op_muta(&self, _a: &mut Vec<MaybeUninit<bool>>, _la: &Layout<D>) -> Result<()> {
115        let type_b = core::any::type_name::<T>();
116        unreachable!("{:?} is not supported in this function.", type_b);
117    }
118}
119
120/* #endregion */
121
122/* #region complex specific implementation */
123
124impl<T, D> DeviceAbsAPI<T, D> for DeviceRayonAutoImpl
125where
126    T: ExtNum + Send + Sync,
127    T::AbsOut: Send + Sync,
128    D: DimAPI,
129{
130    type TOut = T::AbsOut;
131
132    fn op_muta_refb(
133        &self,
134        a: &mut Vec<MaybeUninit<T::AbsOut>>,
135        la: &Layout<D>,
136        b: &Vec<T>,
137        lb: &Layout<D>,
138    ) -> Result<()> {
139        self.op_muta_refb_func(a, la, b, lb, &mut |a, b| {
140            a.write(b.clone().ext_abs());
141        })
142    }
143
144    fn op_muta(&self, a: &mut Vec<MaybeUninit<T::AbsOut>>, la: &Layout<D>) -> Result<()> {
145        if T::ABS_UNCHANGED {
146            return Ok(());
147        } else if T::ABS_SAME_TYPE {
148            return self.op_muta_func(a, la, &mut |a| unsafe {
149                a.write(a.assume_init_read().ext_abs());
150            });
151        } else {
152            let type_b = core::any::type_name::<T>();
153            unreachable!("{:?} is not supported in this function.", type_b);
154        }
155    }
156}
157
158impl<T, D> DeviceImagAPI<T, D> for DeviceRayonAutoImpl
159where
160    T: ExtNum + Send + Sync,
161    T::AbsOut: Send + Sync,
162    D: DimAPI,
163{
164    type TOut = T::AbsOut;
165
166    fn op_muta_refb(
167        &self,
168        a: &mut Vec<MaybeUninit<T::AbsOut>>,
169        la: &Layout<D>,
170        b: &Vec<T>,
171        lb: &Layout<D>,
172    ) -> Result<()> {
173        self.op_muta_refb_func(a, la, b, lb, &mut |a, b| {
174            a.write(b.clone().ext_imag());
175        })
176    }
177
178    fn op_muta(&self, a: &mut Vec<MaybeUninit<T::AbsOut>>, la: &Layout<D>) -> Result<()> {
179        if T::ABS_SAME_TYPE {
180            return self.op_muta_func(a, la, &mut |a| unsafe {
181                a.write(a.assume_init_read().ext_imag());
182            });
183        } else {
184            let type_b = core::any::type_name::<T>();
185            unreachable!("{:?} is not supported in this function.", type_b);
186        }
187    }
188}
189
190impl<T, D> DeviceRealAPI<T, D> for DeviceRayonAutoImpl
191where
192    T: ExtNum + Send + Sync,
193    T::AbsOut: Send + Sync,
194    D: DimAPI,
195{
196    type TOut = T::AbsOut;
197
198    fn op_muta_refb(
199        &self,
200        a: &mut Vec<MaybeUninit<T::AbsOut>>,
201        la: &Layout<D>,
202        b: &Vec<T>,
203        lb: &Layout<D>,
204    ) -> Result<()> {
205        self.op_muta_refb_func(a, la, b, lb, &mut |a, b| {
206            a.write(b.clone().ext_real());
207        })
208    }
209
210    fn op_muta(&self, _a: &mut Vec<MaybeUninit<T::AbsOut>>, _la: &Layout<D>) -> Result<()> {
211        if T::ABS_SAME_TYPE {
212            return Ok(());
213        } else {
214            let type_b = core::any::type_name::<T>();
215            unreachable!("{:?} is not supported in this function.", type_b);
216        }
217    }
218}
219
220impl<T, D> DeviceSignAPI<T, D> for DeviceRayonAutoImpl
221where
222    T: Clone + Send + Sync + ComplexFloat + Div<T::Real, Output = T>,
223    D: DimAPI,
224{
225    type TOut = T;
226
227    fn op_muta_refb(&self, a: &mut Vec<MaybeUninit<T>>, la: &Layout<D>, b: &Vec<T>, lb: &Layout<D>) -> Result<()> {
228        self.op_muta_refb_func(a, la, b, lb, &mut |a, b| {
229            a.write(*b / b.abs());
230        })
231    }
232
233    fn op_muta(&self, a: &mut Vec<MaybeUninit<T>>, la: &Layout<D>) -> Result<()> {
234        self.op_muta_func(a, la, &mut |a| unsafe {
235            a.write(a.assume_init_read() / a.assume_init_read().abs());
236        })
237    }
238}
239
240/* #endregion */