arrayfire/core/
arith.rs

1use super::array::Array;
2use super::data::{constant, tile, ConstGenerator};
3use super::defines::AfError;
4use super::dim4::Dim4;
5use super::error::HANDLE_ERROR;
6use super::util::{af_array, HasAfEnum, ImplicitPromote, IntegralType};
7use num::Zero;
8
9use libc::c_int;
10use num::Complex;
11use std::mem;
12use std::ops::Neg;
13use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Not, Rem, Shl, Shr, Sub};
14
15extern "C" {
16    fn af_add(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
17    fn af_sub(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
18    fn af_mul(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
19    fn af_div(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
20
21    fn af_lt(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
22    fn af_gt(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
23    fn af_le(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
24    fn af_ge(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
25    fn af_eq(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
26    fn af_or(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
27
28    fn af_neq(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
29    fn af_and(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
30    fn af_rem(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
31    fn af_mod(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
32
33    fn af_bitand(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
34    fn af_bitor(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
35    fn af_bitxor(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
36    fn af_bitshiftl(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
37    fn af_bitshiftr(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
38    fn af_minof(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
39    fn af_maxof(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
40    fn af_clamp(
41        out: *mut af_array,
42        inp: af_array,
43        lo: af_array,
44        hi: af_array,
45        batch: bool,
46    ) -> c_int;
47
48    fn af_not(out: *mut af_array, arr: af_array) -> c_int;
49    fn af_abs(out: *mut af_array, arr: af_array) -> c_int;
50    fn af_arg(out: *mut af_array, arr: af_array) -> c_int;
51    fn af_sign(out: *mut af_array, arr: af_array) -> c_int;
52    fn af_ceil(out: *mut af_array, arr: af_array) -> c_int;
53    fn af_round(out: *mut af_array, arr: af_array) -> c_int;
54    fn af_trunc(out: *mut af_array, arr: af_array) -> c_int;
55    fn af_floor(out: *mut af_array, arr: af_array) -> c_int;
56
57    fn af_hypot(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
58
59    fn af_sin(out: *mut af_array, arr: af_array) -> c_int;
60    fn af_cos(out: *mut af_array, arr: af_array) -> c_int;
61    fn af_tan(out: *mut af_array, arr: af_array) -> c_int;
62    fn af_asin(out: *mut af_array, arr: af_array) -> c_int;
63    fn af_acos(out: *mut af_array, arr: af_array) -> c_int;
64    fn af_atan(out: *mut af_array, arr: af_array) -> c_int;
65
66    fn af_atan2(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
67    fn af_cplx2(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
68    fn af_root(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
69    fn af_pow(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
70
71    fn af_cplx(out: *mut af_array, arr: af_array) -> c_int;
72    fn af_real(out: *mut af_array, arr: af_array) -> c_int;
73    fn af_imag(out: *mut af_array, arr: af_array) -> c_int;
74    fn af_conjg(out: *mut af_array, arr: af_array) -> c_int;
75    fn af_sinh(out: *mut af_array, arr: af_array) -> c_int;
76    fn af_cosh(out: *mut af_array, arr: af_array) -> c_int;
77    fn af_tanh(out: *mut af_array, arr: af_array) -> c_int;
78    fn af_asinh(out: *mut af_array, arr: af_array) -> c_int;
79    fn af_acosh(out: *mut af_array, arr: af_array) -> c_int;
80    fn af_atanh(out: *mut af_array, arr: af_array) -> c_int;
81    fn af_pow2(out: *mut af_array, arr: af_array) -> c_int;
82    fn af_exp(out: *mut af_array, arr: af_array) -> c_int;
83    fn af_sigmoid(out: *mut af_array, arr: af_array) -> c_int;
84    fn af_expm1(out: *mut af_array, arr: af_array) -> c_int;
85    fn af_erf(out: *mut af_array, arr: af_array) -> c_int;
86    fn af_erfc(out: *mut af_array, arr: af_array) -> c_int;
87    fn af_log(out: *mut af_array, arr: af_array) -> c_int;
88    fn af_log1p(out: *mut af_array, arr: af_array) -> c_int;
89    fn af_log10(out: *mut af_array, arr: af_array) -> c_int;
90    fn af_log2(out: *mut af_array, arr: af_array) -> c_int;
91    fn af_sqrt(out: *mut af_array, arr: af_array) -> c_int;
92    fn af_rsqrt(out: *mut af_array, arr: af_array) -> c_int;
93    fn af_cbrt(out: *mut af_array, arr: af_array) -> c_int;
94    fn af_factorial(out: *mut af_array, arr: af_array) -> c_int;
95    fn af_tgamma(out: *mut af_array, arr: af_array) -> c_int;
96    fn af_lgamma(out: *mut af_array, arr: af_array) -> c_int;
97    fn af_iszero(out: *mut af_array, arr: af_array) -> c_int;
98    fn af_isinf(out: *mut af_array, arr: af_array) -> c_int;
99    fn af_isnan(out: *mut af_array, arr: af_array) -> c_int;
100    fn af_bitnot(out: *mut af_array, arr: af_array) -> c_int;
101}
102
103/// Enables use of `!` on objects of type [Array](./struct.Array.html)
104impl<'f, T> Not for &'f Array<T>
105where
106    T: HasAfEnum,
107{
108    type Output = Array<T>;
109
110    fn not(self) -> Self::Output {
111        unsafe {
112            let mut temp: af_array = std::ptr::null_mut();
113            let err_val = af_not(&mut temp as *mut af_array, self.get());
114            HANDLE_ERROR(AfError::from(err_val));
115            temp.into()
116        }
117    }
118}
119
120macro_rules! unary_func {
121    [$doc_str: expr, $fn_name: ident, $ffi_fn: ident, $out_type: ident] => (
122        #[doc=$doc_str]
123        ///
124        /// This is an element wise unary operation.
125        pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> Array< T::$out_type >
126        where T::$out_type: HasAfEnum {
127            unsafe {
128                let mut temp: af_array = std::ptr::null_mut();
129                let err_val = $ffi_fn(&mut temp as *mut af_array, input.get());
130                HANDLE_ERROR(AfError::from(err_val));
131                temp.into()
132            }
133        }
134    )
135}
136
137unary_func!("Computes absolute value", abs, af_abs, AbsOutType);
138unary_func!("Computes phase value", arg, af_arg, ArgOutType);
139
140unary_func!(
141    "Truncate the values in an Array",
142    trunc,
143    af_trunc,
144    AbsOutType
145);
146unary_func!(
147    "Computes the sign of input Array values",
148    sign,
149    af_sign,
150    AbsOutType
151);
152unary_func!("Round the values in an Array", round, af_round, AbsOutType);
153unary_func!("Floor the values in an Array", floor, af_floor, AbsOutType);
154unary_func!("Ceil the values in an Array", ceil, af_ceil, AbsOutType);
155
156unary_func!("Compute sigmoid function", sigmoid, af_sigmoid, AbsOutType);
157unary_func!(
158    "Compute e raised to the power of value -1",
159    expm1,
160    af_expm1,
161    AbsOutType
162);
163unary_func!("Compute error function value", erf, af_erf, AbsOutType);
164unary_func!(
165    "Compute the complementary error function value",
166    erfc,
167    af_erfc,
168    AbsOutType
169);
170
171unary_func!("Compute logarithm base 10", log10, af_log10, AbsOutType);
172unary_func!(
173    "Compute the logarithm of input Array + 1",
174    log1p,
175    af_log1p,
176    AbsOutType
177);
178unary_func!("Compute logarithm base 2", log2, af_log2, AbsOutType);
179
180unary_func!("Compute the cube root", cbrt, af_cbrt, AbsOutType);
181unary_func!("Compute gamma function", tgamma, af_tgamma, AbsOutType);
182unary_func!(
183    "Compute the logarithm of absolute values of gamma function",
184    lgamma,
185    af_lgamma,
186    AbsOutType
187);
188
189unary_func!("Compute acosh", acosh, af_acosh, UnaryOutType);
190unary_func!("Compute acos", acos, af_acos, UnaryOutType);
191unary_func!("Compute asin", asin, af_asin, UnaryOutType);
192unary_func!("Compute asinh", asinh, af_asinh, UnaryOutType);
193unary_func!("Compute atan", atan, af_atan, UnaryOutType);
194unary_func!("Compute atanh", atanh, af_atanh, UnaryOutType);
195unary_func!("Compute cos", cos, af_cos, UnaryOutType);
196unary_func!("Compute cosh", cosh, af_cosh, UnaryOutType);
197unary_func!(
198    "Compute e raised to the power of value",
199    exp,
200    af_exp,
201    UnaryOutType
202);
203unary_func!("Compute the natural logarithm", log, af_log, UnaryOutType);
204unary_func!("Compute sin", sin, af_sin, UnaryOutType);
205unary_func!("Compute sinh", sinh, af_sinh, UnaryOutType);
206unary_func!("Compute the square root", sqrt, af_sqrt, UnaryOutType);
207unary_func!(
208    "Compute the reciprocal square root",
209    rsqrt,
210    af_rsqrt,
211    UnaryOutType
212);
213unary_func!("Compute tan", tan, af_tan, UnaryOutType);
214unary_func!("Compute tanh", tanh, af_tanh, UnaryOutType);
215
216unary_func!(
217    "Extract real values from a complex Array",
218    real,
219    af_real,
220    AbsOutType
221);
222unary_func!(
223    "Extract imaginary values from a complex Array",
224    imag,
225    af_imag,
226    AbsOutType
227);
228unary_func!(
229    "Create a complex Array from real Array",
230    cplx,
231    af_cplx,
232    ComplexOutType
233);
234unary_func!(
235    "Compute the complex conjugate",
236    conjg,
237    af_conjg,
238    ComplexOutType
239);
240unary_func!(
241    "Compute two raised to the power of value",
242    pow2,
243    af_pow2,
244    UnaryOutType
245);
246unary_func!(
247    "Compute the factorial",
248    factorial,
249    af_factorial,
250    UnaryOutType
251);
252
253macro_rules! unary_boolean_func {
254    [$doc_str: expr, $fn_name: ident, $ffi_fn: ident] => (
255        #[doc=$doc_str]
256        ///
257        /// This is an element wise unary operation.
258        pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> Array<bool> {
259            unsafe {
260                let mut temp: af_array = std::ptr::null_mut();
261                let err_val = $ffi_fn(&mut temp as *mut af_array, input.get());
262                HANDLE_ERROR(AfError::from(err_val));
263                temp.into()
264            }
265        }
266    )
267}
268
269unary_boolean_func!("Check if values are zero", iszero, af_iszero);
270unary_boolean_func!("Check if values are infinity", isinf, af_isinf);
271unary_boolean_func!("Check if values are NaN", isnan, af_isnan);
272
273macro_rules! binary_func {
274    ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
275        #[doc=$doc_str]
276        ///
277        /// This is an element wise binary operation.
278        ///
279        /// # Important Notes
280        ///
281        /// - If shape/dimensions of `lhs` and `rhs` are same, the value of `batch` parameter
282        ///   has no effect.
283        ///
284        /// - If shape/dimensions of `lhs` and `rhs` are different, the value of `batch` has
285        ///   to be set to `true`. In this case, the shapes of `lhs` and `rhs` have to satisfy the
286        ///   following criteria:
287        ///   - Same number of elements in `lhs` and `rhs` along a given dimension/axis
288        ///   - Only one element in `lhs` or `rhs` along a given dimension/axis
289        pub fn $fn_name<A, B>(lhs: &Array<A>, rhs: &Array<B>, batch: bool) -> Array<A::Output>
290        where
291            A: ImplicitPromote<B>,
292            B: ImplicitPromote<A>,
293        {
294            unsafe {
295                let mut temp: af_array = std::ptr::null_mut();
296                let err_val = $ffi_fn(
297                    &mut temp as *mut af_array, lhs.get(), rhs.get(), batch,
298                );
299                HANDLE_ERROR(AfError::from(err_val));
300                Into::<Array<A::Output>>::into(temp)
301            }
302        }
303    };
304}
305
306binary_func!(
307    "Elementwise AND(bit) operation of two Arrays",
308    bitand,
309    af_bitand
310);
311binary_func!(
312    "Elementwise OR(bit) operation of two Arrays",
313    bitor,
314    af_bitor
315);
316binary_func!(
317    "Elementwise XOR(bit) operation of two Arrays",
318    bitxor,
319    af_bitxor
320);
321binary_func!(
322    "Elementwise not equals comparison of two Arrays",
323    neq,
324    af_neq
325);
326binary_func!(
327    "Elementwise logical and operation of two Arrays",
328    and,
329    af_and
330);
331binary_func!("Elementwise logical or operation of two Arrays", or, af_or);
332binary_func!(
333    "Elementwise minimum operation of two Arrays",
334    minof,
335    af_minof
336);
337binary_func!(
338    "Elementwise maximum operation of two Arrays",
339    maxof,
340    af_maxof
341);
342binary_func!(
343    "Compute length of hypotenuse of two Arrays",
344    hypot,
345    af_hypot
346);
347
348/// Type Trait to convert to an [Array](./struct.Array.html)
349///
350/// Generic functions that overload the binary operations such as add, div, mul, rem, ge etc. are
351/// bound by this trait to allow combinations of scalar values and Array objects as parameters
352/// to those functions.
353///
354/// Internally, Convertable trait is implemented by following types.
355///
356/// - f32
357/// - f64
358/// - num::Complex\<f32\>
359/// - num::Complex\<f64\>
360/// - bool
361/// - i32
362/// - u32
363/// - u8
364/// - i64
365/// - u64
366/// - i16
367/// - u16
368///
369pub trait Convertable {
370    /// This type alias always points to `Self` which is the
371    /// type of [Array](./struct.Array.html) returned by the
372    /// trait method [convert](./trait.Convertable.html#tymethod.convert).
373    type OutType: HasAfEnum;
374
375    /// Get an Array of implementors type
376    fn convert(&self) -> Array<Self::OutType>;
377}
378
379impl<T> Convertable for T
380where
381    T: Clone + ConstGenerator<OutType = T>,
382{
383    type OutType = T;
384
385    fn convert(&self) -> Array<Self::OutType> {
386        constant(self.clone(), Dim4::new(&[1, 1, 1, 1]))
387    }
388}
389
390impl<T: HasAfEnum> Convertable for Array<T> {
391    type OutType = T;
392
393    fn convert(&self) -> Array<Self::OutType> {
394        self.clone()
395    }
396}
397
398macro_rules! overloaded_binary_func {
399    ($doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => {
400        fn $help_name<A, B>(lhs: &Array<A>, rhs: &Array<B>, batch: bool) -> Array<A::Output>
401        where
402            A: ImplicitPromote<B>,
403            B: ImplicitPromote<A>,
404        {
405            unsafe {
406                let mut temp: af_array = std::ptr::null_mut();
407                let err_val = $ffi_name(
408                    &mut temp as *mut af_array, lhs.get(), rhs.get(), batch,
409                );
410                HANDLE_ERROR(AfError::from(err_val));
411                temp.into()
412            }
413        }
414
415        #[doc=$doc_str]
416        ///
417        /// This is a binary elementwise operation.
418        ///
419        ///# Parameters
420        ///
421        /// - `arg1`is an argument that implements an internal trait `Convertable`.
422        /// - `arg2`is an argument that implements an internal trait `Convertable`.
423        /// - `batch` is an boolean that indicates if the current operation is an batch operation.
424        ///
425        /// Both parameters `arg1` and `arg2` can be either an Array or a value of rust integral
426        /// type.
427        ///
428        ///# Return Values
429        ///
430        /// An Array with results of the binary operation.
431        ///
432        ///# Important Notes
433        ///
434        /// - If shape/dimensions of `arg1` and `arg2` are same, the value of `batch` parameter
435        ///   has no effect.
436        ///
437        /// - If shape/dimensions of `arg1` and `arg2` are different, the value of `batch` has
438        ///   to be set to `true`. In this case, the shapes of `arg1` and `arg2` have to satisfy the
439        ///   following criteria:
440        ///   - Same number of elements in `arg1` and `arg2` along a given dimension/axis
441        ///   - Only one element in `arg1` or `arg2` along a given dimension/axis
442        ///
443        /// - The trait `Convertable` essentially translates to a scalar native type on rust or Array.
444        pub fn $fn_name<T, U>(
445            arg1: &T,
446            arg2: &U,
447            batch: bool,
448        ) -> Array<
449            <<T as Convertable>::OutType as ImplicitPromote<<U as Convertable>::OutType>>::Output,
450        >
451        where
452            T: Convertable,
453            U: Convertable,
454            <T as Convertable>::OutType: ImplicitPromote<<U as Convertable>::OutType>,
455            <U as Convertable>::OutType: ImplicitPromote<<T as Convertable>::OutType>,
456        {
457            let lhs = arg1.convert(); // Convert to Array<T>
458            let rhs = arg2.convert(); // Convert to Array<T>
459            match (lhs.is_scalar(), rhs.is_scalar()) {
460                (true, false) => {
461                    let l = tile(&lhs, rhs.dims());
462                    $help_name(&l, &rhs, batch)
463                }
464                (false, true) => {
465                    let r = tile(&rhs, lhs.dims());
466                    $help_name(&lhs, &r, batch)
467                }
468                _ => $help_name(&lhs, &rhs, batch),
469            }
470        }
471    };
472}
473
474overloaded_binary_func!("Addition of two Arrays", add, add_helper, af_add);
475overloaded_binary_func!("Subtraction of two Arrays", sub, sub_helper, af_sub);
476overloaded_binary_func!("Multiplication of two Arrays", mul, mul_helper, af_mul);
477overloaded_binary_func!("Division of two Arrays", div, div_helper, af_div);
478overloaded_binary_func!("Compute remainder from two Arrays", rem, rem_helper, af_rem);
479overloaded_binary_func!("Compute left shift", shiftl, shiftl_helper, af_bitshiftl);
480overloaded_binary_func!("Compute right shift", shiftr, shiftr_helper, af_bitshiftr);
481overloaded_binary_func!(
482    "Compute modulo of two Arrays",
483    modulo,
484    modulo_helper,
485    af_mod
486);
487overloaded_binary_func!(
488    "Calculate atan2 of two Arrays",
489    atan2,
490    atan2_helper,
491    af_atan2
492);
493overloaded_binary_func!(
494    "Create complex array from two Arrays",
495    cplx2,
496    cplx2_helper,
497    af_cplx2
498);
499overloaded_binary_func!("Compute root", root, root_helper, af_root);
500overloaded_binary_func!("Computer power", pow, pow_helper, af_pow);
501
502macro_rules! overloaded_compare_func {
503    ($doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => {
504        fn $help_name<A, B>(lhs: &Array<A>, rhs: &Array<B>, batch: bool) -> Array<bool>
505        where
506            A: ImplicitPromote<B>,
507            B: ImplicitPromote<A>,
508        {
509            unsafe {
510                let mut temp: af_array = std::ptr::null_mut();
511                let err_val = $ffi_name(
512                    &mut temp as *mut af_array, lhs.get(), rhs.get(), batch,
513                );
514                HANDLE_ERROR(AfError::from(err_val));
515                temp.into()
516            }
517        }
518
519        #[doc=$doc_str]
520        ///
521        /// This is a comparison operation.
522        ///
523        ///# Parameters
524        ///
525        /// - `arg1`is an argument that implements an internal trait `Convertable`.
526        /// - `arg2`is an argument that implements an internal trait `Convertable`.
527        /// - `batch` is an boolean that indicates if the current operation is an batch operation.
528        ///
529        /// Both parameters `arg1` and `arg2` can be either an Array or a value of rust integral
530        /// type.
531        ///
532        ///# Return Values
533        ///
534        /// An Array with results of the comparison operation a.k.a an Array of boolean values.
535        ///
536        ///# Important Notes
537        ///
538        /// - If shape/dimensions of `arg1` and `arg2` are same, the value of `batch` parameter
539        ///   has no effect.
540        ///
541        /// - If shape/dimensions of `arg1` and `arg2` are different, the value of `batch` has
542        ///   to be set to `true`. In this case, the shapes of `arg1` and `arg2` have to satisfy the
543        ///   following criteria:
544        ///   - Same number of elements in `arg1` and `arg2` along a given dimension/axis
545        ///   - Only one element in `arg1` or `arg2` along a given dimension/axis
546        ///
547        /// - The trait `Convertable` essentially translates to a scalar native type on rust or Array.
548        pub fn $fn_name<T, U>(
549            arg1: &T,
550            arg2: &U,
551            batch: bool,
552        ) -> Array<bool>
553        where
554            T: Convertable,
555            U: Convertable,
556            <T as Convertable>::OutType: ImplicitPromote<<U as Convertable>::OutType>,
557            <U as Convertable>::OutType: ImplicitPromote<<T as Convertable>::OutType>,
558        {
559            let lhs = arg1.convert(); // Convert to Array<T>
560            let rhs = arg2.convert(); // Convert to Array<T>
561            match (lhs.is_scalar(), rhs.is_scalar()) {
562                (true, false) => {
563                    let l = tile(&lhs, rhs.dims());
564                    $help_name(&l, &rhs, batch)
565                }
566                (false, true) => {
567                    let r = tile(&rhs, lhs.dims());
568                    $help_name(&lhs, &r, batch)
569                }
570                _ => $help_name(&lhs, &rhs, batch),
571            }
572        }
573    };
574}
575
576overloaded_compare_func!(
577    "Perform `less than` comparison operation",
578    lt,
579    lt_helper,
580    af_lt
581);
582overloaded_compare_func!(
583    "Perform `greater than` comparison operation",
584    gt,
585    gt_helper,
586    af_gt
587);
588overloaded_compare_func!(
589    "Perform `less than equals` comparison operation",
590    le,
591    le_helper,
592    af_le
593);
594overloaded_compare_func!(
595    "Perform `greater than equals` comparison operation",
596    ge,
597    ge_helper,
598    af_ge
599);
600overloaded_compare_func!(
601    "Perform `equals` comparison operation",
602    eq,
603    eq_helper,
604    af_eq
605);
606
607fn clamp_helper<X, Y>(
608    inp: &Array<X>,
609    lo: &Array<Y>,
610    hi: &Array<Y>,
611    batch: bool,
612) -> Array<<X as ImplicitPromote<Y>>::Output>
613where
614    X: ImplicitPromote<Y>,
615    Y: ImplicitPromote<X>,
616{
617    unsafe {
618        let mut temp: af_array = std::ptr::null_mut();
619        let err_val = af_clamp(
620            &mut temp as *mut af_array,
621            inp.get(),
622            lo.get(),
623            hi.get(),
624            batch,
625        );
626        HANDLE_ERROR(AfError::from(err_val));
627        temp.into()
628    }
629}
630
631/// Clamp the values of Array
632///
633/// # Parameters
634///
635/// - `arg1`is an argument that implements an internal trait `Convertable`.
636/// - `arg2`is an argument that implements an internal trait `Convertable`.
637/// - `batch` is an boolean that indicates if the current operation is an batch operation.
638///
639/// Both parameters `arg1` and `arg2` can be either an Array or a value of rust integral
640/// type.
641///
642/// # Return Values
643///
644/// An Array with results of the binary operation.
645///
646/// # Important Notes
647///
648/// - If shape/dimensions of `arg1` and `arg2` are same, the value of `batch` parameter
649///   has no effect.
650///
651/// - If shape/dimensions of `arg1` and `arg2` are different, the value of `batch` has
652///   to be set to `true`. In this case, the shapes of `arg1` and `arg2` have to satisfy the
653///   following criteria:
654///   - Same number of elements in `arg1` and `arg2` along a given dimension/axis
655///   - Only one element in `arg1` or `arg2` along a given dimension/axis
656///
657/// - The trait `Convertable` essentially translates to a scalar native type on rust or Array.
658pub fn clamp<T, C>(
659    input: &Array<T>,
660    arg1: &C,
661    arg2: &C,
662    batch: bool,
663) -> Array<<T as ImplicitPromote<<C as Convertable>::OutType>>::Output>
664where
665    T: ImplicitPromote<<C as Convertable>::OutType>,
666    C: Convertable,
667    <C as Convertable>::OutType: ImplicitPromote<T>,
668{
669    let lo = arg1.convert(); // Convert to Array<T>
670    let hi = arg2.convert(); // Convert to Array<T>
671    match (lo.is_scalar(), hi.is_scalar()) {
672        (true, false) => {
673            let l = tile(&lo, hi.dims());
674            clamp_helper(&input, &l, &hi, batch)
675        }
676        (false, true) => {
677            let r = tile(&hi, lo.dims());
678            clamp_helper(&input, &lo, &r, batch)
679        }
680        (true, true) => {
681            let l = tile(&lo, input.dims());
682            let r = tile(&hi, input.dims());
683            clamp_helper(&input, &l, &r, batch)
684        }
685        _ => clamp_helper(&input, &lo, &hi, batch),
686    }
687}
688
689macro_rules! arith_rhs_scalar_func {
690    ($op_name:ident, $fn_name: ident) => {
691        // Implement (&Array<T> op_name rust_type)
692        impl<'f, T, U> $op_name<U> for &'f Array<T>
693        where
694            T: ImplicitPromote<U>,
695            U: ImplicitPromote<T> + Clone + ConstGenerator<OutType = U>,
696        {
697            type Output = Array<<T as ImplicitPromote<U>>::Output>;
698
699            fn $fn_name(self, rhs: U) -> Self::Output {
700                let temp = rhs.clone();
701                $fn_name(self, &temp, false)
702            }
703        }
704
705        // Implement (Array<T> op_name rust_type)
706        impl<T, U> $op_name<U> for Array<T>
707        where
708            T: ImplicitPromote<U>,
709            U: ImplicitPromote<T> + Clone + ConstGenerator<OutType = U>,
710        {
711            type Output = Array<<T as ImplicitPromote<U>>::Output>;
712
713            fn $fn_name(self, rhs: U) -> Self::Output {
714                let temp = rhs.clone();
715                $fn_name(&self, &temp, false)
716            }
717        }
718    };
719}
720
721macro_rules! arith_lhs_scalar_func {
722    ($rust_type: ty, $op_name: ident, $fn_name: ident) => {
723        // Implement (rust_type op_name &Array<T>)
724        impl<'f, T> $op_name<&'f Array<T>> for $rust_type
725        where
726            T: ImplicitPromote<$rust_type>,
727            $rust_type: ImplicitPromote<T>,
728        {
729            type Output = Array<<$rust_type as ImplicitPromote<T>>::Output>;
730
731            fn $fn_name(self, rhs: &'f Array<T>) -> Self::Output {
732                $fn_name(&self, rhs, false)
733            }
734        }
735
736        // Implement (rust_type op_name Array<T>)
737        impl<T> $op_name<Array<T>> for $rust_type
738        where
739            T: ImplicitPromote<$rust_type>,
740            $rust_type: ImplicitPromote<T>,
741        {
742            type Output = Array<<$rust_type as ImplicitPromote<T>>::Output>;
743
744            fn $fn_name(self, rhs: Array<T>) -> Self::Output {
745                $fn_name(&self, &rhs, false)
746            }
747        }
748    };
749}
750
751arith_rhs_scalar_func!(Add, add);
752arith_rhs_scalar_func!(Sub, sub);
753arith_rhs_scalar_func!(Mul, mul);
754arith_rhs_scalar_func!(Div, div);
755
756macro_rules! arith_scalar_spec {
757    ($ty_name:ty) => {
758        arith_lhs_scalar_func!($ty_name, Add, add);
759        arith_lhs_scalar_func!($ty_name, Sub, sub);
760        arith_lhs_scalar_func!($ty_name, Mul, mul);
761        arith_lhs_scalar_func!($ty_name, Div, div);
762    };
763}
764
765arith_scalar_spec!(Complex<f64>);
766arith_scalar_spec!(Complex<f32>);
767arith_scalar_spec!(f64);
768arith_scalar_spec!(f32);
769arith_scalar_spec!(u64);
770arith_scalar_spec!(i64);
771arith_scalar_spec!(u32);
772arith_scalar_spec!(i32);
773arith_scalar_spec!(u8);
774
775macro_rules! arith_func {
776    ($op_name:ident, $fn_name:ident, $delegate:ident) => {
777        impl<A, B> $op_name<Array<B>> for Array<A>
778        where
779            A: ImplicitPromote<B>,
780            B: ImplicitPromote<A>,
781        {
782            type Output = Array<<A as ImplicitPromote<B>>::Output>;
783
784            fn $fn_name(self, rhs: Array<B>) -> Self::Output {
785                $delegate(&self, &rhs, false)
786            }
787        }
788
789        impl<'a, A, B> $op_name<&'a Array<B>> for Array<A>
790        where
791            A: ImplicitPromote<B>,
792            B: ImplicitPromote<A>,
793        {
794            type Output = Array<<A as ImplicitPromote<B>>::Output>;
795
796            fn $fn_name(self, rhs: &'a Array<B>) -> Self::Output {
797                $delegate(&self, rhs, false)
798            }
799        }
800
801        impl<'a, A, B> $op_name<Array<B>> for &'a Array<A>
802        where
803            A: ImplicitPromote<B>,
804            B: ImplicitPromote<A>,
805        {
806            type Output = Array<<A as ImplicitPromote<B>>::Output>;
807
808            fn $fn_name(self, rhs: Array<B>) -> Self::Output {
809                $delegate(self, &rhs, false)
810            }
811        }
812
813        impl<'a, 'b, A, B> $op_name<&'a Array<B>> for &'b Array<A>
814        where
815            A: ImplicitPromote<B>,
816            B: ImplicitPromote<A>,
817        {
818            type Output = Array<<A as ImplicitPromote<B>>::Output>;
819
820            fn $fn_name(self, rhs: &'a Array<B>) -> Self::Output {
821                $delegate(self, rhs, false)
822            }
823        }
824    };
825}
826
827arith_func!(Add, add, add);
828arith_func!(Sub, sub, sub);
829arith_func!(Mul, mul, mul);
830arith_func!(Div, div, div);
831arith_func!(Rem, rem, rem);
832arith_func!(Shl, shl, shiftl);
833arith_func!(Shr, shr, shiftr);
834arith_func!(BitAnd, bitand, bitand);
835arith_func!(BitOr, bitor, bitor);
836arith_func!(BitXor, bitxor, bitxor);
837
838macro_rules! bitshift_scalar_func {
839    ($rust_type: ty, $trait_name: ident, $op_name: ident) => {
840        impl<T> $trait_name<$rust_type> for Array<T>
841        where
842            T: ImplicitPromote<$rust_type>,
843            $rust_type: ImplicitPromote<T>,
844        {
845            type Output = Array<<T as ImplicitPromote<$rust_type>>::Output>;
846
847            fn $op_name(self, rhs: $rust_type) -> Self::Output {
848                let op2 = constant(rhs, self.dims());
849                self.$op_name(op2)
850            }
851        }
852        impl<'f, T> $trait_name<$rust_type> for &'f Array<T>
853        where
854            T: ImplicitPromote<$rust_type>,
855            $rust_type: ImplicitPromote<T>,
856        {
857            type Output = Array<<T as ImplicitPromote<$rust_type>>::Output>;
858
859            fn $op_name(self, rhs: $rust_type) -> Self::Output {
860                let op2 = constant(rhs, self.dims());
861                self.$op_name(op2)
862            }
863        }
864    };
865}
866
867macro_rules! shift_spec {
868    ($trait_name: ident, $op_name: ident) => {
869        bitshift_scalar_func!(u64, $trait_name, $op_name);
870        bitshift_scalar_func!(u32, $trait_name, $op_name);
871        bitshift_scalar_func!(u16, $trait_name, $op_name);
872        bitshift_scalar_func!(u8, $trait_name, $op_name);
873    };
874}
875
876shift_spec!(Shl, shl);
877shift_spec!(Shr, shr);
878
879#[cfg(op_assign)]
880mod op_assign {
881
882    use super::*;
883    use crate::core::{assign_gen, Array, Indexer, Seq};
884    use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
885    use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign, ShlAssign, ShrAssign};
886
887    macro_rules! arith_assign_func {
888        ($op_name:ident, $fn_name:ident, $func: ident) => {
889            impl<A, B> $op_name<Array<B>> for Array<A>
890            where
891                A: ImplicitPromote<B>,
892                B: ImplicitPromote<A>,
893            {
894                fn $fn_name(&mut self, rhs: Array<B>) {
895                    let tmp_seq = Seq::<f32>::default();
896                    let mut idxrs = Indexer::default();
897                    for n in 0..self.numdims() {
898                        idxrs.set_index(&tmp_seq, n, Some(false));
899                    }
900                    let opres = $func(self as &Array<A>, &rhs, false).cast::<A>();
901                    assign_gen(self, &idxrs, &opres);
902                }
903            }
904        };
905    }
906
907    arith_assign_func!(AddAssign, add_assign, add);
908    arith_assign_func!(SubAssign, sub_assign, sub);
909    arith_assign_func!(MulAssign, mul_assign, mul);
910    arith_assign_func!(DivAssign, div_assign, div);
911    arith_assign_func!(RemAssign, rem_assign, rem);
912    arith_assign_func!(ShlAssign, shl_assign, shiftl);
913    arith_assign_func!(ShrAssign, shr_assign, shiftr);
914
915    macro_rules! shift_assign_func {
916        ($rust_type:ty, $trait_name:ident, $op_name:ident, $func:ident) => {
917            impl<T> $trait_name<$rust_type> for Array<T>
918            where
919                $rust_type: ImplicitPromote<T>,
920                T: ImplicitPromote<$rust_type, Output = T>,
921            {
922                fn $op_name(&mut self, rhs: $rust_type) {
923                    let mut temp = $func(self, &rhs, false);
924                    mem::swap(self, &mut temp);
925                }
926            }
927        };
928    }
929
930    macro_rules! shift_assign_spec {
931        ($trait_name: ident, $op_name: ident, $func:ident) => {
932            shift_assign_func!(u64, $trait_name, $op_name, $func);
933            shift_assign_func!(u32, $trait_name, $op_name, $func);
934            shift_assign_func!(u16, $trait_name, $op_name, $func);
935            shift_assign_func!(u8, $trait_name, $op_name, $func);
936        };
937    }
938
939    shift_assign_spec!(ShlAssign, shl_assign, shiftl);
940    shift_assign_spec!(ShrAssign, shr_assign, shiftr);
941
942    macro_rules! bit_assign_func {
943        ($op_name:ident, $fn_name:ident, $func: ident) => {
944            impl<A, B> $op_name<Array<B>> for Array<A>
945            where
946                A: ImplicitPromote<B>,
947                B: ImplicitPromote<A>,
948            {
949                fn $fn_name(&mut self, rhs: Array<B>) {
950                    let tmp_seq = Seq::<f32>::default();
951                    let mut idxrs = Indexer::default();
952                    for n in 0..self.numdims() {
953                        idxrs.set_index(&tmp_seq, n, Some(false));
954                    }
955                    let opres = $func(self as &Array<A>, &rhs, false).cast::<A>();
956                    assign_gen(self, &idxrs, &opres);
957                }
958            }
959        };
960    }
961
962    bit_assign_func!(BitAndAssign, bitand_assign, bitand);
963    bit_assign_func!(BitOrAssign, bitor_assign, bitor);
964    bit_assign_func!(BitXorAssign, bitxor_assign, bitxor);
965}
966
967///Implement negation trait for Array
968impl<T> Neg for Array<T>
969where
970    T: Zero + ConstGenerator<OutType = T>,
971{
972    type Output = Array<T>;
973
974    fn neg(self) -> Self::Output {
975        let cnst = constant(T::zero(), self.dims());
976        sub(&cnst, &self, true)
977    }
978}
979
980/// Perform bitwise complement on all values of Array
981pub fn bitnot<T: HasAfEnum>(input: &Array<T>) -> Array<T>
982where
983    T: HasAfEnum + IntegralType,
984{
985    unsafe {
986        let mut temp: af_array = std::ptr::null_mut();
987        let err_val = af_bitnot(&mut temp as *mut af_array, input.get());
988        HANDLE_ERROR(AfError::from(err_val));
989        temp.into()
990    }
991}