1use super::array::Array;
2use super::data::{constant, tile, ConstGenerator};
3use super::defines::AfError;
4use super::dim4::Dim4;
5use super::error::HANDLE_ERROR;
6use super::util::{af_array, HasAfEnum, ImplicitPromote, IntegralType};
7use num::Zero;
8
9use libc::c_int;
10use num::Complex;
11use std::mem;
12use std::ops::Neg;
13use std::ops::{Add, BitAnd, BitOr, BitXor, Div, Mul, Not, Rem, Shl, Shr, Sub};
14
15extern "C" {
16 fn af_add(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
17 fn af_sub(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
18 fn af_mul(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
19 fn af_div(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
20
21 fn af_lt(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
22 fn af_gt(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
23 fn af_le(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
24 fn af_ge(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
25 fn af_eq(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
26 fn af_or(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
27
28 fn af_neq(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
29 fn af_and(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
30 fn af_rem(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
31 fn af_mod(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
32
33 fn af_bitand(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
34 fn af_bitor(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
35 fn af_bitxor(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
36 fn af_bitshiftl(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
37 fn af_bitshiftr(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
38 fn af_minof(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
39 fn af_maxof(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
40 fn af_clamp(
41 out: *mut af_array,
42 inp: af_array,
43 lo: af_array,
44 hi: af_array,
45 batch: bool,
46 ) -> c_int;
47
48 fn af_not(out: *mut af_array, arr: af_array) -> c_int;
49 fn af_abs(out: *mut af_array, arr: af_array) -> c_int;
50 fn af_arg(out: *mut af_array, arr: af_array) -> c_int;
51 fn af_sign(out: *mut af_array, arr: af_array) -> c_int;
52 fn af_ceil(out: *mut af_array, arr: af_array) -> c_int;
53 fn af_round(out: *mut af_array, arr: af_array) -> c_int;
54 fn af_trunc(out: *mut af_array, arr: af_array) -> c_int;
55 fn af_floor(out: *mut af_array, arr: af_array) -> c_int;
56
57 fn af_hypot(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
58
59 fn af_sin(out: *mut af_array, arr: af_array) -> c_int;
60 fn af_cos(out: *mut af_array, arr: af_array) -> c_int;
61 fn af_tan(out: *mut af_array, arr: af_array) -> c_int;
62 fn af_asin(out: *mut af_array, arr: af_array) -> c_int;
63 fn af_acos(out: *mut af_array, arr: af_array) -> c_int;
64 fn af_atan(out: *mut af_array, arr: af_array) -> c_int;
65
66 fn af_atan2(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
67 fn af_cplx2(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
68 fn af_root(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
69 fn af_pow(out: *mut af_array, lhs: af_array, rhs: af_array, batch: bool) -> c_int;
70
71 fn af_cplx(out: *mut af_array, arr: af_array) -> c_int;
72 fn af_real(out: *mut af_array, arr: af_array) -> c_int;
73 fn af_imag(out: *mut af_array, arr: af_array) -> c_int;
74 fn af_conjg(out: *mut af_array, arr: af_array) -> c_int;
75 fn af_sinh(out: *mut af_array, arr: af_array) -> c_int;
76 fn af_cosh(out: *mut af_array, arr: af_array) -> c_int;
77 fn af_tanh(out: *mut af_array, arr: af_array) -> c_int;
78 fn af_asinh(out: *mut af_array, arr: af_array) -> c_int;
79 fn af_acosh(out: *mut af_array, arr: af_array) -> c_int;
80 fn af_atanh(out: *mut af_array, arr: af_array) -> c_int;
81 fn af_pow2(out: *mut af_array, arr: af_array) -> c_int;
82 fn af_exp(out: *mut af_array, arr: af_array) -> c_int;
83 fn af_sigmoid(out: *mut af_array, arr: af_array) -> c_int;
84 fn af_expm1(out: *mut af_array, arr: af_array) -> c_int;
85 fn af_erf(out: *mut af_array, arr: af_array) -> c_int;
86 fn af_erfc(out: *mut af_array, arr: af_array) -> c_int;
87 fn af_log(out: *mut af_array, arr: af_array) -> c_int;
88 fn af_log1p(out: *mut af_array, arr: af_array) -> c_int;
89 fn af_log10(out: *mut af_array, arr: af_array) -> c_int;
90 fn af_log2(out: *mut af_array, arr: af_array) -> c_int;
91 fn af_sqrt(out: *mut af_array, arr: af_array) -> c_int;
92 fn af_rsqrt(out: *mut af_array, arr: af_array) -> c_int;
93 fn af_cbrt(out: *mut af_array, arr: af_array) -> c_int;
94 fn af_factorial(out: *mut af_array, arr: af_array) -> c_int;
95 fn af_tgamma(out: *mut af_array, arr: af_array) -> c_int;
96 fn af_lgamma(out: *mut af_array, arr: af_array) -> c_int;
97 fn af_iszero(out: *mut af_array, arr: af_array) -> c_int;
98 fn af_isinf(out: *mut af_array, arr: af_array) -> c_int;
99 fn af_isnan(out: *mut af_array, arr: af_array) -> c_int;
100 fn af_bitnot(out: *mut af_array, arr: af_array) -> c_int;
101}
102
103impl<'f, T> Not for &'f Array<T>
105where
106 T: HasAfEnum,
107{
108 type Output = Array<T>;
109
110 fn not(self) -> Self::Output {
111 unsafe {
112 let mut temp: af_array = std::ptr::null_mut();
113 let err_val = af_not(&mut temp as *mut af_array, self.get());
114 HANDLE_ERROR(AfError::from(err_val));
115 temp.into()
116 }
117 }
118}
119
120macro_rules! unary_func {
121 [$doc_str: expr, $fn_name: ident, $ffi_fn: ident, $out_type: ident] => (
122 #[doc=$doc_str]
123 pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> Array< T::$out_type >
126 where T::$out_type: HasAfEnum {
127 unsafe {
128 let mut temp: af_array = std::ptr::null_mut();
129 let err_val = $ffi_fn(&mut temp as *mut af_array, input.get());
130 HANDLE_ERROR(AfError::from(err_val));
131 temp.into()
132 }
133 }
134 )
135}
136
137unary_func!("Computes absolute value", abs, af_abs, AbsOutType);
138unary_func!("Computes phase value", arg, af_arg, ArgOutType);
139
140unary_func!(
141 "Truncate the values in an Array",
142 trunc,
143 af_trunc,
144 AbsOutType
145);
146unary_func!(
147 "Computes the sign of input Array values",
148 sign,
149 af_sign,
150 AbsOutType
151);
152unary_func!("Round the values in an Array", round, af_round, AbsOutType);
153unary_func!("Floor the values in an Array", floor, af_floor, AbsOutType);
154unary_func!("Ceil the values in an Array", ceil, af_ceil, AbsOutType);
155
156unary_func!("Compute sigmoid function", sigmoid, af_sigmoid, AbsOutType);
157unary_func!(
158 "Compute e raised to the power of value -1",
159 expm1,
160 af_expm1,
161 AbsOutType
162);
163unary_func!("Compute error function value", erf, af_erf, AbsOutType);
164unary_func!(
165 "Compute the complementary error function value",
166 erfc,
167 af_erfc,
168 AbsOutType
169);
170
171unary_func!("Compute logarithm base 10", log10, af_log10, AbsOutType);
172unary_func!(
173 "Compute the logarithm of input Array + 1",
174 log1p,
175 af_log1p,
176 AbsOutType
177);
178unary_func!("Compute logarithm base 2", log2, af_log2, AbsOutType);
179
180unary_func!("Compute the cube root", cbrt, af_cbrt, AbsOutType);
181unary_func!("Compute gamma function", tgamma, af_tgamma, AbsOutType);
182unary_func!(
183 "Compute the logarithm of absolute values of gamma function",
184 lgamma,
185 af_lgamma,
186 AbsOutType
187);
188
189unary_func!("Compute acosh", acosh, af_acosh, UnaryOutType);
190unary_func!("Compute acos", acos, af_acos, UnaryOutType);
191unary_func!("Compute asin", asin, af_asin, UnaryOutType);
192unary_func!("Compute asinh", asinh, af_asinh, UnaryOutType);
193unary_func!("Compute atan", atan, af_atan, UnaryOutType);
194unary_func!("Compute atanh", atanh, af_atanh, UnaryOutType);
195unary_func!("Compute cos", cos, af_cos, UnaryOutType);
196unary_func!("Compute cosh", cosh, af_cosh, UnaryOutType);
197unary_func!(
198 "Compute e raised to the power of value",
199 exp,
200 af_exp,
201 UnaryOutType
202);
203unary_func!("Compute the natural logarithm", log, af_log, UnaryOutType);
204unary_func!("Compute sin", sin, af_sin, UnaryOutType);
205unary_func!("Compute sinh", sinh, af_sinh, UnaryOutType);
206unary_func!("Compute the square root", sqrt, af_sqrt, UnaryOutType);
207unary_func!(
208 "Compute the reciprocal square root",
209 rsqrt,
210 af_rsqrt,
211 UnaryOutType
212);
213unary_func!("Compute tan", tan, af_tan, UnaryOutType);
214unary_func!("Compute tanh", tanh, af_tanh, UnaryOutType);
215
216unary_func!(
217 "Extract real values from a complex Array",
218 real,
219 af_real,
220 AbsOutType
221);
222unary_func!(
223 "Extract imaginary values from a complex Array",
224 imag,
225 af_imag,
226 AbsOutType
227);
228unary_func!(
229 "Create a complex Array from real Array",
230 cplx,
231 af_cplx,
232 ComplexOutType
233);
234unary_func!(
235 "Compute the complex conjugate",
236 conjg,
237 af_conjg,
238 ComplexOutType
239);
240unary_func!(
241 "Compute two raised to the power of value",
242 pow2,
243 af_pow2,
244 UnaryOutType
245);
246unary_func!(
247 "Compute the factorial",
248 factorial,
249 af_factorial,
250 UnaryOutType
251);
252
253macro_rules! unary_boolean_func {
254 [$doc_str: expr, $fn_name: ident, $ffi_fn: ident] => (
255 #[doc=$doc_str]
256 pub fn $fn_name<T: HasAfEnum>(input: &Array<T>) -> Array<bool> {
259 unsafe {
260 let mut temp: af_array = std::ptr::null_mut();
261 let err_val = $ffi_fn(&mut temp as *mut af_array, input.get());
262 HANDLE_ERROR(AfError::from(err_val));
263 temp.into()
264 }
265 }
266 )
267}
268
269unary_boolean_func!("Check if values are zero", iszero, af_iszero);
270unary_boolean_func!("Check if values are infinity", isinf, af_isinf);
271unary_boolean_func!("Check if values are NaN", isnan, af_isnan);
272
273macro_rules! binary_func {
274 ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => {
275 #[doc=$doc_str]
276 pub fn $fn_name<A, B>(lhs: &Array<A>, rhs: &Array<B>, batch: bool) -> Array<A::Output>
290 where
291 A: ImplicitPromote<B>,
292 B: ImplicitPromote<A>,
293 {
294 unsafe {
295 let mut temp: af_array = std::ptr::null_mut();
296 let err_val = $ffi_fn(
297 &mut temp as *mut af_array, lhs.get(), rhs.get(), batch,
298 );
299 HANDLE_ERROR(AfError::from(err_val));
300 Into::<Array<A::Output>>::into(temp)
301 }
302 }
303 };
304}
305
306binary_func!(
307 "Elementwise AND(bit) operation of two Arrays",
308 bitand,
309 af_bitand
310);
311binary_func!(
312 "Elementwise OR(bit) operation of two Arrays",
313 bitor,
314 af_bitor
315);
316binary_func!(
317 "Elementwise XOR(bit) operation of two Arrays",
318 bitxor,
319 af_bitxor
320);
321binary_func!(
322 "Elementwise not equals comparison of two Arrays",
323 neq,
324 af_neq
325);
326binary_func!(
327 "Elementwise logical and operation of two Arrays",
328 and,
329 af_and
330);
331binary_func!("Elementwise logical or operation of two Arrays", or, af_or);
332binary_func!(
333 "Elementwise minimum operation of two Arrays",
334 minof,
335 af_minof
336);
337binary_func!(
338 "Elementwise maximum operation of two Arrays",
339 maxof,
340 af_maxof
341);
342binary_func!(
343 "Compute length of hypotenuse of two Arrays",
344 hypot,
345 af_hypot
346);
347
348pub trait Convertable {
370 type OutType: HasAfEnum;
374
375 fn convert(&self) -> Array<Self::OutType>;
377}
378
379impl<T> Convertable for T
380where
381 T: Clone + ConstGenerator<OutType = T>,
382{
383 type OutType = T;
384
385 fn convert(&self) -> Array<Self::OutType> {
386 constant(self.clone(), Dim4::new(&[1, 1, 1, 1]))
387 }
388}
389
390impl<T: HasAfEnum> Convertable for Array<T> {
391 type OutType = T;
392
393 fn convert(&self) -> Array<Self::OutType> {
394 self.clone()
395 }
396}
397
398macro_rules! overloaded_binary_func {
399 ($doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => {
400 fn $help_name<A, B>(lhs: &Array<A>, rhs: &Array<B>, batch: bool) -> Array<A::Output>
401 where
402 A: ImplicitPromote<B>,
403 B: ImplicitPromote<A>,
404 {
405 unsafe {
406 let mut temp: af_array = std::ptr::null_mut();
407 let err_val = $ffi_name(
408 &mut temp as *mut af_array, lhs.get(), rhs.get(), batch,
409 );
410 HANDLE_ERROR(AfError::from(err_val));
411 temp.into()
412 }
413 }
414
415 #[doc=$doc_str]
416 pub fn $fn_name<T, U>(
445 arg1: &T,
446 arg2: &U,
447 batch: bool,
448 ) -> Array<
449 <<T as Convertable>::OutType as ImplicitPromote<<U as Convertable>::OutType>>::Output,
450 >
451 where
452 T: Convertable,
453 U: Convertable,
454 <T as Convertable>::OutType: ImplicitPromote<<U as Convertable>::OutType>,
455 <U as Convertable>::OutType: ImplicitPromote<<T as Convertable>::OutType>,
456 {
457 let lhs = arg1.convert(); let rhs = arg2.convert(); match (lhs.is_scalar(), rhs.is_scalar()) {
460 (true, false) => {
461 let l = tile(&lhs, rhs.dims());
462 $help_name(&l, &rhs, batch)
463 }
464 (false, true) => {
465 let r = tile(&rhs, lhs.dims());
466 $help_name(&lhs, &r, batch)
467 }
468 _ => $help_name(&lhs, &rhs, batch),
469 }
470 }
471 };
472}
473
474overloaded_binary_func!("Addition of two Arrays", add, add_helper, af_add);
475overloaded_binary_func!("Subtraction of two Arrays", sub, sub_helper, af_sub);
476overloaded_binary_func!("Multiplication of two Arrays", mul, mul_helper, af_mul);
477overloaded_binary_func!("Division of two Arrays", div, div_helper, af_div);
478overloaded_binary_func!("Compute remainder from two Arrays", rem, rem_helper, af_rem);
479overloaded_binary_func!("Compute left shift", shiftl, shiftl_helper, af_bitshiftl);
480overloaded_binary_func!("Compute right shift", shiftr, shiftr_helper, af_bitshiftr);
481overloaded_binary_func!(
482 "Compute modulo of two Arrays",
483 modulo,
484 modulo_helper,
485 af_mod
486);
487overloaded_binary_func!(
488 "Calculate atan2 of two Arrays",
489 atan2,
490 atan2_helper,
491 af_atan2
492);
493overloaded_binary_func!(
494 "Create complex array from two Arrays",
495 cplx2,
496 cplx2_helper,
497 af_cplx2
498);
499overloaded_binary_func!("Compute root", root, root_helper, af_root);
500overloaded_binary_func!("Computer power", pow, pow_helper, af_pow);
501
502macro_rules! overloaded_compare_func {
503 ($doc_str: expr, $fn_name: ident, $help_name: ident, $ffi_name: ident) => {
504 fn $help_name<A, B>(lhs: &Array<A>, rhs: &Array<B>, batch: bool) -> Array<bool>
505 where
506 A: ImplicitPromote<B>,
507 B: ImplicitPromote<A>,
508 {
509 unsafe {
510 let mut temp: af_array = std::ptr::null_mut();
511 let err_val = $ffi_name(
512 &mut temp as *mut af_array, lhs.get(), rhs.get(), batch,
513 );
514 HANDLE_ERROR(AfError::from(err_val));
515 temp.into()
516 }
517 }
518
519 #[doc=$doc_str]
520 pub fn $fn_name<T, U>(
549 arg1: &T,
550 arg2: &U,
551 batch: bool,
552 ) -> Array<bool>
553 where
554 T: Convertable,
555 U: Convertable,
556 <T as Convertable>::OutType: ImplicitPromote<<U as Convertable>::OutType>,
557 <U as Convertable>::OutType: ImplicitPromote<<T as Convertable>::OutType>,
558 {
559 let lhs = arg1.convert(); let rhs = arg2.convert(); match (lhs.is_scalar(), rhs.is_scalar()) {
562 (true, false) => {
563 let l = tile(&lhs, rhs.dims());
564 $help_name(&l, &rhs, batch)
565 }
566 (false, true) => {
567 let r = tile(&rhs, lhs.dims());
568 $help_name(&lhs, &r, batch)
569 }
570 _ => $help_name(&lhs, &rhs, batch),
571 }
572 }
573 };
574}
575
576overloaded_compare_func!(
577 "Perform `less than` comparison operation",
578 lt,
579 lt_helper,
580 af_lt
581);
582overloaded_compare_func!(
583 "Perform `greater than` comparison operation",
584 gt,
585 gt_helper,
586 af_gt
587);
588overloaded_compare_func!(
589 "Perform `less than equals` comparison operation",
590 le,
591 le_helper,
592 af_le
593);
594overloaded_compare_func!(
595 "Perform `greater than equals` comparison operation",
596 ge,
597 ge_helper,
598 af_ge
599);
600overloaded_compare_func!(
601 "Perform `equals` comparison operation",
602 eq,
603 eq_helper,
604 af_eq
605);
606
607fn clamp_helper<X, Y>(
608 inp: &Array<X>,
609 lo: &Array<Y>,
610 hi: &Array<Y>,
611 batch: bool,
612) -> Array<<X as ImplicitPromote<Y>>::Output>
613where
614 X: ImplicitPromote<Y>,
615 Y: ImplicitPromote<X>,
616{
617 unsafe {
618 let mut temp: af_array = std::ptr::null_mut();
619 let err_val = af_clamp(
620 &mut temp as *mut af_array,
621 inp.get(),
622 lo.get(),
623 hi.get(),
624 batch,
625 );
626 HANDLE_ERROR(AfError::from(err_val));
627 temp.into()
628 }
629}
630
631pub fn clamp<T, C>(
659 input: &Array<T>,
660 arg1: &C,
661 arg2: &C,
662 batch: bool,
663) -> Array<<T as ImplicitPromote<<C as Convertable>::OutType>>::Output>
664where
665 T: ImplicitPromote<<C as Convertable>::OutType>,
666 C: Convertable,
667 <C as Convertable>::OutType: ImplicitPromote<T>,
668{
669 let lo = arg1.convert(); let hi = arg2.convert(); match (lo.is_scalar(), hi.is_scalar()) {
672 (true, false) => {
673 let l = tile(&lo, hi.dims());
674 clamp_helper(&input, &l, &hi, batch)
675 }
676 (false, true) => {
677 let r = tile(&hi, lo.dims());
678 clamp_helper(&input, &lo, &r, batch)
679 }
680 (true, true) => {
681 let l = tile(&lo, input.dims());
682 let r = tile(&hi, input.dims());
683 clamp_helper(&input, &l, &r, batch)
684 }
685 _ => clamp_helper(&input, &lo, &hi, batch),
686 }
687}
688
689macro_rules! arith_rhs_scalar_func {
690 ($op_name:ident, $fn_name: ident) => {
691 impl<'f, T, U> $op_name<U> for &'f Array<T>
693 where
694 T: ImplicitPromote<U>,
695 U: ImplicitPromote<T> + Clone + ConstGenerator<OutType = U>,
696 {
697 type Output = Array<<T as ImplicitPromote<U>>::Output>;
698
699 fn $fn_name(self, rhs: U) -> Self::Output {
700 let temp = rhs.clone();
701 $fn_name(self, &temp, false)
702 }
703 }
704
705 impl<T, U> $op_name<U> for Array<T>
707 where
708 T: ImplicitPromote<U>,
709 U: ImplicitPromote<T> + Clone + ConstGenerator<OutType = U>,
710 {
711 type Output = Array<<T as ImplicitPromote<U>>::Output>;
712
713 fn $fn_name(self, rhs: U) -> Self::Output {
714 let temp = rhs.clone();
715 $fn_name(&self, &temp, false)
716 }
717 }
718 };
719}
720
721macro_rules! arith_lhs_scalar_func {
722 ($rust_type: ty, $op_name: ident, $fn_name: ident) => {
723 impl<'f, T> $op_name<&'f Array<T>> for $rust_type
725 where
726 T: ImplicitPromote<$rust_type>,
727 $rust_type: ImplicitPromote<T>,
728 {
729 type Output = Array<<$rust_type as ImplicitPromote<T>>::Output>;
730
731 fn $fn_name(self, rhs: &'f Array<T>) -> Self::Output {
732 $fn_name(&self, rhs, false)
733 }
734 }
735
736 impl<T> $op_name<Array<T>> for $rust_type
738 where
739 T: ImplicitPromote<$rust_type>,
740 $rust_type: ImplicitPromote<T>,
741 {
742 type Output = Array<<$rust_type as ImplicitPromote<T>>::Output>;
743
744 fn $fn_name(self, rhs: Array<T>) -> Self::Output {
745 $fn_name(&self, &rhs, false)
746 }
747 }
748 };
749}
750
751arith_rhs_scalar_func!(Add, add);
752arith_rhs_scalar_func!(Sub, sub);
753arith_rhs_scalar_func!(Mul, mul);
754arith_rhs_scalar_func!(Div, div);
755
756macro_rules! arith_scalar_spec {
757 ($ty_name:ty) => {
758 arith_lhs_scalar_func!($ty_name, Add, add);
759 arith_lhs_scalar_func!($ty_name, Sub, sub);
760 arith_lhs_scalar_func!($ty_name, Mul, mul);
761 arith_lhs_scalar_func!($ty_name, Div, div);
762 };
763}
764
765arith_scalar_spec!(Complex<f64>);
766arith_scalar_spec!(Complex<f32>);
767arith_scalar_spec!(f64);
768arith_scalar_spec!(f32);
769arith_scalar_spec!(u64);
770arith_scalar_spec!(i64);
771arith_scalar_spec!(u32);
772arith_scalar_spec!(i32);
773arith_scalar_spec!(u8);
774
775macro_rules! arith_func {
776 ($op_name:ident, $fn_name:ident, $delegate:ident) => {
777 impl<A, B> $op_name<Array<B>> for Array<A>
778 where
779 A: ImplicitPromote<B>,
780 B: ImplicitPromote<A>,
781 {
782 type Output = Array<<A as ImplicitPromote<B>>::Output>;
783
784 fn $fn_name(self, rhs: Array<B>) -> Self::Output {
785 $delegate(&self, &rhs, false)
786 }
787 }
788
789 impl<'a, A, B> $op_name<&'a Array<B>> for Array<A>
790 where
791 A: ImplicitPromote<B>,
792 B: ImplicitPromote<A>,
793 {
794 type Output = Array<<A as ImplicitPromote<B>>::Output>;
795
796 fn $fn_name(self, rhs: &'a Array<B>) -> Self::Output {
797 $delegate(&self, rhs, false)
798 }
799 }
800
801 impl<'a, A, B> $op_name<Array<B>> for &'a Array<A>
802 where
803 A: ImplicitPromote<B>,
804 B: ImplicitPromote<A>,
805 {
806 type Output = Array<<A as ImplicitPromote<B>>::Output>;
807
808 fn $fn_name(self, rhs: Array<B>) -> Self::Output {
809 $delegate(self, &rhs, false)
810 }
811 }
812
813 impl<'a, 'b, A, B> $op_name<&'a Array<B>> for &'b Array<A>
814 where
815 A: ImplicitPromote<B>,
816 B: ImplicitPromote<A>,
817 {
818 type Output = Array<<A as ImplicitPromote<B>>::Output>;
819
820 fn $fn_name(self, rhs: &'a Array<B>) -> Self::Output {
821 $delegate(self, rhs, false)
822 }
823 }
824 };
825}
826
827arith_func!(Add, add, add);
828arith_func!(Sub, sub, sub);
829arith_func!(Mul, mul, mul);
830arith_func!(Div, div, div);
831arith_func!(Rem, rem, rem);
832arith_func!(Shl, shl, shiftl);
833arith_func!(Shr, shr, shiftr);
834arith_func!(BitAnd, bitand, bitand);
835arith_func!(BitOr, bitor, bitor);
836arith_func!(BitXor, bitxor, bitxor);
837
838macro_rules! bitshift_scalar_func {
839 ($rust_type: ty, $trait_name: ident, $op_name: ident) => {
840 impl<T> $trait_name<$rust_type> for Array<T>
841 where
842 T: ImplicitPromote<$rust_type>,
843 $rust_type: ImplicitPromote<T>,
844 {
845 type Output = Array<<T as ImplicitPromote<$rust_type>>::Output>;
846
847 fn $op_name(self, rhs: $rust_type) -> Self::Output {
848 let op2 = constant(rhs, self.dims());
849 self.$op_name(op2)
850 }
851 }
852 impl<'f, T> $trait_name<$rust_type> for &'f Array<T>
853 where
854 T: ImplicitPromote<$rust_type>,
855 $rust_type: ImplicitPromote<T>,
856 {
857 type Output = Array<<T as ImplicitPromote<$rust_type>>::Output>;
858
859 fn $op_name(self, rhs: $rust_type) -> Self::Output {
860 let op2 = constant(rhs, self.dims());
861 self.$op_name(op2)
862 }
863 }
864 };
865}
866
867macro_rules! shift_spec {
868 ($trait_name: ident, $op_name: ident) => {
869 bitshift_scalar_func!(u64, $trait_name, $op_name);
870 bitshift_scalar_func!(u32, $trait_name, $op_name);
871 bitshift_scalar_func!(u16, $trait_name, $op_name);
872 bitshift_scalar_func!(u8, $trait_name, $op_name);
873 };
874}
875
876shift_spec!(Shl, shl);
877shift_spec!(Shr, shr);
878
879#[cfg(op_assign)]
880mod op_assign {
881
882 use super::*;
883 use crate::core::{assign_gen, Array, Indexer, Seq};
884 use std::ops::{AddAssign, DivAssign, MulAssign, RemAssign, SubAssign};
885 use std::ops::{BitAndAssign, BitOrAssign, BitXorAssign, ShlAssign, ShrAssign};
886
887 macro_rules! arith_assign_func {
888 ($op_name:ident, $fn_name:ident, $func: ident) => {
889 impl<A, B> $op_name<Array<B>> for Array<A>
890 where
891 A: ImplicitPromote<B>,
892 B: ImplicitPromote<A>,
893 {
894 fn $fn_name(&mut self, rhs: Array<B>) {
895 let tmp_seq = Seq::<f32>::default();
896 let mut idxrs = Indexer::default();
897 for n in 0..self.numdims() {
898 idxrs.set_index(&tmp_seq, n, Some(false));
899 }
900 let opres = $func(self as &Array<A>, &rhs, false).cast::<A>();
901 assign_gen(self, &idxrs, &opres);
902 }
903 }
904 };
905 }
906
907 arith_assign_func!(AddAssign, add_assign, add);
908 arith_assign_func!(SubAssign, sub_assign, sub);
909 arith_assign_func!(MulAssign, mul_assign, mul);
910 arith_assign_func!(DivAssign, div_assign, div);
911 arith_assign_func!(RemAssign, rem_assign, rem);
912 arith_assign_func!(ShlAssign, shl_assign, shiftl);
913 arith_assign_func!(ShrAssign, shr_assign, shiftr);
914
915 macro_rules! shift_assign_func {
916 ($rust_type:ty, $trait_name:ident, $op_name:ident, $func:ident) => {
917 impl<T> $trait_name<$rust_type> for Array<T>
918 where
919 $rust_type: ImplicitPromote<T>,
920 T: ImplicitPromote<$rust_type, Output = T>,
921 {
922 fn $op_name(&mut self, rhs: $rust_type) {
923 let mut temp = $func(self, &rhs, false);
924 mem::swap(self, &mut temp);
925 }
926 }
927 };
928 }
929
930 macro_rules! shift_assign_spec {
931 ($trait_name: ident, $op_name: ident, $func:ident) => {
932 shift_assign_func!(u64, $trait_name, $op_name, $func);
933 shift_assign_func!(u32, $trait_name, $op_name, $func);
934 shift_assign_func!(u16, $trait_name, $op_name, $func);
935 shift_assign_func!(u8, $trait_name, $op_name, $func);
936 };
937 }
938
939 shift_assign_spec!(ShlAssign, shl_assign, shiftl);
940 shift_assign_spec!(ShrAssign, shr_assign, shiftr);
941
942 macro_rules! bit_assign_func {
943 ($op_name:ident, $fn_name:ident, $func: ident) => {
944 impl<A, B> $op_name<Array<B>> for Array<A>
945 where
946 A: ImplicitPromote<B>,
947 B: ImplicitPromote<A>,
948 {
949 fn $fn_name(&mut self, rhs: Array<B>) {
950 let tmp_seq = Seq::<f32>::default();
951 let mut idxrs = Indexer::default();
952 for n in 0..self.numdims() {
953 idxrs.set_index(&tmp_seq, n, Some(false));
954 }
955 let opres = $func(self as &Array<A>, &rhs, false).cast::<A>();
956 assign_gen(self, &idxrs, &opres);
957 }
958 }
959 };
960 }
961
962 bit_assign_func!(BitAndAssign, bitand_assign, bitand);
963 bit_assign_func!(BitOrAssign, bitor_assign, bitor);
964 bit_assign_func!(BitXorAssign, bitxor_assign, bitxor);
965}
966
967impl<T> Neg for Array<T>
969where
970 T: Zero + ConstGenerator<OutType = T>,
971{
972 type Output = Array<T>;
973
974 fn neg(self) -> Self::Output {
975 let cnst = constant(T::zero(), self.dims());
976 sub(&cnst, &self, true)
977 }
978}
979
980pub fn bitnot<T: HasAfEnum>(input: &Array<T>) -> Array<T>
982where
983 T: HasAfEnum + IntegralType,
984{
985 unsafe {
986 let mut temp: af_array = std::ptr::null_mut();
987 let err_val = af_bitnot(&mut temp as *mut af_array, input.get());
988 HANDLE_ERROR(AfError::from(err_val));
989 temp.into()
990 }
991}