arrayfire/core/
data.rs

1use super::array::Array;
2use super::defines::{AfError, BorderType};
3use super::dim4::Dim4;
4use super::error::HANDLE_ERROR;
5use super::util::{af_array, c32, c64, dim_t, u64_t, HasAfEnum};
6
7use libc::{c_double, c_int, c_uint};
8use std::option::Option;
9use std::vec::Vec;
10
11extern "C" {
12    fn af_constant(
13        out: *mut af_array,
14        val: c_double,
15        ndims: c_uint,
16        dims: *const dim_t,
17        afdtype: c_uint,
18    ) -> c_int;
19
20    fn af_constant_complex(
21        out: *mut af_array,
22        real: c_double,
23        imag: c_double,
24        ndims: c_uint,
25        dims: *const dim_t,
26        afdtype: c_uint,
27    ) -> c_int;
28
29    fn af_constant_long(out: *mut af_array, val: dim_t, ndims: c_uint, dims: *const dim_t)
30        -> c_int;
31
32    fn af_constant_ulong(
33        out: *mut af_array,
34        val: u64_t,
35        ndims: c_uint,
36        dims: *const dim_t,
37    ) -> c_int;
38
39    fn af_range(
40        out: *mut af_array,
41        ndims: c_uint,
42        dims: *const dim_t,
43        seq_dim: c_int,
44        afdtype: c_uint,
45    ) -> c_int;
46
47    fn af_iota(
48        out: *mut af_array,
49        ndims: c_uint,
50        dims: *const dim_t,
51        t_ndims: c_uint,
52        tdims: *const dim_t,
53        afdtype: c_uint,
54    ) -> c_int;
55
56    fn af_identity(out: *mut af_array, ndims: c_uint, dims: *const dim_t, afdtype: c_uint)
57        -> c_int;
58    fn af_diag_create(out: *mut af_array, arr: af_array, num: c_int) -> c_int;
59    fn af_diag_extract(out: *mut af_array, arr: af_array, num: c_int) -> c_int;
60    fn af_join(out: *mut af_array, dim: c_int, first: af_array, second: af_array) -> c_int;
61    fn af_join_many(
62        out: *mut af_array,
63        dim: c_int,
64        n_arrays: c_uint,
65        inpts: *const af_array,
66    ) -> c_int;
67
68    fn af_tile(
69        out: *mut af_array,
70        arr: af_array,
71        x: c_uint,
72        y: c_uint,
73        z: c_uint,
74        w: c_uint,
75    ) -> c_int;
76    fn af_reorder(
77        o: *mut af_array,
78        a: af_array,
79        x: c_uint,
80        y: c_uint,
81        z: c_uint,
82        w: c_uint,
83    ) -> c_int;
84    fn af_shift(o: *mut af_array, a: af_array, x: c_int, y: c_int, z: c_int, w: c_int) -> c_int;
85    fn af_moddims(out: *mut af_array, arr: af_array, ndims: c_uint, dims: *const dim_t) -> c_int;
86
87    fn af_flat(out: *mut af_array, arr: af_array) -> c_int;
88    fn af_flip(out: *mut af_array, arr: af_array, dim: c_uint) -> c_int;
89    fn af_lower(out: *mut af_array, arr: af_array, is_unit_diag: bool) -> c_int;
90    fn af_upper(out: *mut af_array, arr: af_array, is_unit_diag: bool) -> c_int;
91
92    fn af_select(out: *mut af_array, cond: af_array, a: af_array, b: af_array) -> c_int;
93    fn af_select_scalar_l(out: *mut af_array, cond: af_array, a: c_double, b: af_array) -> c_int;
94    fn af_select_scalar_r(out: *mut af_array, cond: af_array, a: af_array, b: c_double) -> c_int;
95
96    fn af_replace(a: *mut af_array, cond: af_array, b: af_array) -> c_int;
97    fn af_replace_scalar(a: *mut af_array, cond: af_array, b: c_double) -> c_int;
98
99    fn af_pad(
100        out: *mut af_array,
101        input: af_array,
102        begin_ndims: c_uint,
103        begin_dims: *const dim_t,
104        end_ndims: c_uint,
105        end_dims: *const dim_t,
106        pad_fill_type: c_uint,
107    ) -> c_int;
108}
109
110/// Type Trait to generate a constant [Array](./struct.Array.html) of given size
111///
112/// Internally, ConstGenerator trait is implemented by following types.
113///
114/// - f32
115/// - f64
116/// - num::Complex\<f32\>
117/// - num::Complex\<f64\>
118/// - bool
119/// - i32
120/// - u32
121/// - u8
122/// - i64
123/// - u64
124/// - i16
125/// - u16
126///
127pub trait ConstGenerator: HasAfEnum {
128    /// The type of Array<T> object returned by generate function
129    type OutType: HasAfEnum;
130
131    /// Create an Array of `dims` size from scalar value `self`.
132    ///
133    /// # Parameters
134    ///
135    /// - `dims` are the dimensions of the output constant [Array](./struct.Array.html)
136    fn generate(&self, dims: Dim4) -> Array<Self::OutType>;
137}
138
139impl ConstGenerator for i64 {
140    type OutType = i64;
141
142    fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
143        unsafe {
144            let mut temp: af_array = std::ptr::null_mut();
145            let err_val = af_constant_long(
146                &mut temp as *mut af_array,
147                *self,
148                dims.ndims() as c_uint,
149                dims.get().as_ptr() as *const dim_t,
150            );
151            HANDLE_ERROR(AfError::from(err_val));
152            temp.into()
153        }
154    }
155}
156
157impl ConstGenerator for u64 {
158    type OutType = u64;
159
160    fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
161        unsafe {
162            let mut temp: af_array = std::ptr::null_mut();
163            let err_val = af_constant_ulong(
164                &mut temp as *mut af_array,
165                *self,
166                dims.ndims() as c_uint,
167                dims.get().as_ptr() as *const dim_t,
168            );
169            HANDLE_ERROR(AfError::from(err_val));
170            temp.into()
171        }
172    }
173}
174
175impl ConstGenerator for c32 {
176    type OutType = c32;
177
178    fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
179        unsafe {
180            let mut temp: af_array = std::ptr::null_mut();
181            let err_val = af_constant_complex(
182                &mut temp as *mut af_array,
183                (*self).re as c_double,
184                (*self).im as c_double,
185                dims.ndims() as c_uint,
186                dims.get().as_ptr() as *const dim_t,
187                1,
188            );
189            HANDLE_ERROR(AfError::from(err_val));
190            temp.into()
191        }
192    }
193}
194
195impl ConstGenerator for c64 {
196    type OutType = c64;
197
198    fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
199        unsafe {
200            let mut temp: af_array = std::ptr::null_mut();
201            let err_val = af_constant_complex(
202                &mut temp as *mut af_array,
203                (*self).re as c_double,
204                (*self).im as c_double,
205                dims.ndims() as c_uint,
206                dims.get().as_ptr() as *const dim_t,
207                3,
208            );
209            HANDLE_ERROR(AfError::from(err_val));
210            temp.into()
211        }
212    }
213}
214
215impl ConstGenerator for bool {
216    type OutType = bool;
217
218    fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
219        unsafe {
220            let mut temp: af_array = std::ptr::null_mut();
221            let err_val = af_constant(
222                &mut temp as *mut af_array,
223                *self as c_int as c_double,
224                dims.ndims() as c_uint,
225                dims.get().as_ptr() as *const dim_t,
226                4,
227            );
228            HANDLE_ERROR(AfError::from(err_val));
229            temp.into()
230        }
231    }
232}
233
234macro_rules! cnst {
235    ($rust_type:ty, $ffi_type:expr) => {
236        impl ConstGenerator for $rust_type {
237            type OutType = $rust_type;
238
239            fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
240                unsafe {
241                    let mut temp: af_array = std::ptr::null_mut();
242                    let err_val = af_constant(
243                        &mut temp as *mut af_array,
244                        *self as c_double,
245                        dims.ndims() as c_uint,
246                        dims.get().as_ptr() as *const dim_t,
247                        $ffi_type,
248                    );
249                    HANDLE_ERROR(AfError::from(err_val));
250                    temp.into()
251                }
252            }
253        }
254    };
255}
256
257cnst!(f32, 0);
258cnst!(f64, 2);
259cnst!(i32, 5);
260cnst!(u32, 6);
261cnst!(u8, 7);
262cnst!(i16, 10);
263cnst!(u16, 11);
264
265/// Create an Array with constant value
266///
267/// The trait ConstGenerator has been defined internally for the following types:
268///
269/// - i64
270/// - u64
271/// - num::Complex\<f32\> a.k.a c32
272/// - num::Complex\<f64\> a.k.a c64
273/// - f32
274/// - f64
275/// - i32
276/// - u32
277/// - u8
278/// - i16
279/// - u16
280///
281/// # Parameters
282///
283/// - `cnst` is the constant value to be filled in the Array
284/// - `dims` is the size of the constant Array
285///
286/// # Return Values
287///
288/// An Array of given dimensions with constant value
289pub fn constant<T>(cnst: T, dims: Dim4) -> Array<T>
290where
291    T: ConstGenerator<OutType = T>,
292{
293    cnst.generate(dims)
294}
295
296/// Create a Range of values
297///
298/// Creates an array with [0, n] values along the `seq_dim` which is tiled across other dimensions.
299///
300/// # Parameters
301///
302/// - `dims` is the size of Array
303/// - `seq_dim` is the dimension along which range values are populated, all values along other
304/// dimensions are just repeated
305///
306/// # Return Values
307/// Array
308pub fn range<T: HasAfEnum>(dims: Dim4, seq_dim: i32) -> Array<T> {
309    let aftype = T::get_af_dtype();
310    unsafe {
311        let mut temp: af_array = std::ptr::null_mut();
312        let err_val = af_range(
313            &mut temp as *mut af_array,
314            dims.ndims() as c_uint,
315            dims.get().as_ptr() as *const dim_t,
316            seq_dim as c_int,
317            aftype as c_uint,
318        );
319        HANDLE_ERROR(AfError::from(err_val));
320        temp.into()
321    }
322}
323
324/// Create a range of values
325///
326/// Create an sequence [0, dims.elements() - 1] and modify to specified dimensions dims and then tile it according to tile_dims.
327///
328/// # Parameters
329///
330/// - `dims` is the dimensions of the sequence to be generated
331/// - `tdims` is the number of repitions of the unit dimensions
332///
333/// # Return Values
334///
335/// Array
336pub fn iota<T: HasAfEnum>(dims: Dim4, tdims: Dim4) -> Array<T> {
337    let aftype = T::get_af_dtype();
338    unsafe {
339        let mut temp: af_array = std::ptr::null_mut();
340        let err_val = af_iota(
341            &mut temp as *mut af_array,
342            dims.ndims() as c_uint,
343            dims.get().as_ptr() as *const dim_t,
344            tdims.ndims() as c_uint,
345            tdims.get().as_ptr() as *const dim_t,
346            aftype as c_uint,
347        );
348        HANDLE_ERROR(AfError::from(err_val));
349        temp.into()
350    }
351}
352
353/// Create an identity array with 1's in diagonal
354///
355/// # Parameters
356///
357/// - `dims` is the output Array dimensions
358///
359/// # Return Values
360///
361/// Identity matrix
362pub fn identity<T: HasAfEnum>(dims: Dim4) -> Array<T> {
363    let aftype = T::get_af_dtype();
364    unsafe {
365        let mut temp: af_array = std::ptr::null_mut();
366        let err_val = af_identity(
367            &mut temp as *mut af_array,
368            dims.ndims() as c_uint,
369            dims.get().as_ptr() as *const dim_t,
370            aftype as c_uint,
371        );
372        HANDLE_ERROR(AfError::from(err_val));
373        temp.into()
374    }
375}
376
377/// Create a diagonal matrix
378///
379/// # Parameters
380///
381/// - `input` is the input Array
382/// - `dim` is the diagonal index relative to principal diagonal where values from input Array are
383/// to be placed
384///
385/// # Return Values
386///
387/// An Array with values as a diagonal Matrix
388pub fn diag_create<T>(input: &Array<T>, dim: i32) -> Array<T>
389where
390    T: HasAfEnum,
391{
392    unsafe {
393        let mut temp: af_array = std::ptr::null_mut();
394        let err_val = af_diag_create(&mut temp as *mut af_array, input.get(), dim);
395        HANDLE_ERROR(AfError::from(err_val));
396        temp.into()
397    }
398}
399
400/// Extract diagonal from a given Matrix
401///
402/// # Parameters
403///
404/// - `input` is the input Matrix
405/// - `dim` is the index of the diagonal that has to be extracted from the input Matrix
406///
407/// # Return Values
408///
409/// An Array with values of the diagonal from input Array
410pub fn diag_extract<T>(input: &Array<T>, dim: i32) -> Array<T>
411where
412    T: HasAfEnum,
413{
414    unsafe {
415        let mut temp: af_array = std::ptr::null_mut();
416        let err_val = af_diag_extract(&mut temp as *mut af_array, input.get(), dim);
417        HANDLE_ERROR(AfError::from(err_val));
418        temp.into()
419    }
420}
421
422/// Join two arrays
423///
424/// # Parameters
425///
426/// - `dim` is the dimension along which the concatenation has to be done
427/// - `first` is the Array that comes first in the concatenation
428/// - `second` is the Array that comes last in the concatenation
429///
430/// # Return Values
431///
432/// Concatenated Array
433pub fn join<T>(dim: i32, first: &Array<T>, second: &Array<T>) -> Array<T>
434where
435    T: HasAfEnum,
436{
437    unsafe {
438        let mut temp: af_array = std::ptr::null_mut();
439        let err_val = af_join(&mut temp as *mut af_array, dim, first.get(), second.get());
440        HANDLE_ERROR(AfError::from(err_val));
441        temp.into()
442    }
443}
444
445/// Join multiple arrays
446///
447/// # Parameters
448///
449/// - `dim` is the dimension along which the concatenation has to be done
450/// - `inputs` is the vector of Arrays that has to be concatenated
451///
452/// # Return Values
453///
454/// Concatenated Array
455pub fn join_many<T>(dim: i32, inputs: Vec<&Array<T>>) -> Array<T>
456where
457    T: HasAfEnum,
458{
459    unsafe {
460        let mut v = Vec::new();
461        for i in inputs {
462            v.push(i.get());
463        }
464        let mut temp: af_array = std::ptr::null_mut();
465        let err_val = af_join_many(
466            &mut temp as *mut af_array,
467            dim,
468            v.len() as u32,
469            v.as_ptr() as *const af_array,
470        );
471        HANDLE_ERROR(AfError::from(err_val));
472        temp.into()
473    }
474}
475
476/// Tile the input array along specified dimension
477///
478/// Tile essentially creates copies of data along each dimension.
479/// The number of copies created is provided by the user on per
480/// axis basis using [Dim4](./struct.dim4.html)
481///
482///# Parameters
483///
484/// - `input` is the input Array
485/// - `dims` is the target(output) dimensions
486///
487///# Return Values
488///
489/// Tiled input array as per the tiling dimensions provided
490pub fn tile<T>(input: &Array<T>, dims: Dim4) -> Array<T>
491where
492    T: HasAfEnum,
493{
494    unsafe {
495        let mut temp: af_array = std::ptr::null_mut();
496        let err_val = af_tile(
497            &mut temp as *mut af_array,
498            input.get() as af_array,
499            dims[0] as c_uint,
500            dims[1] as c_uint,
501            dims[2] as c_uint,
502            dims[3] as c_uint,
503        );
504        HANDLE_ERROR(AfError::from(err_val));
505        temp.into()
506    }
507}
508
509/// Reorder the array according to the new specified axes
510///
511/// Exchanges data within an array such that the requested change in axes is
512/// satisfied. The linear ordering of data within the array is preserved.
513///
514/// The default order of axes in ArrayFire is [0 1 2 3] i.e. axis with smallest
515/// distance between adjacent elements followed by next smallest distance axis and
516/// so on. See [examples](#examples) to have a basic idea of how data is re-ordered.
517///
518///# Parameters
519///
520/// - `input` is the input Array
521/// - `new_axis0` is the new first axis for output
522/// - `new_axis1` is the new second axis for output
523/// - `next_axes` is the new axes order for output
524///
525///# Return Values
526///
527/// Array with data reordered as per the new axes order
528///
529///# Examples
530///
531/// ```rust
532/// use arrayfire::{Array, Dim4, print, randu, reorder_v2};
533/// let a  = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
534/// let b  = reorder_v2(&a, 1, 0, None);
535/// print(&a);
536///
537/// // [5 3 1 1]
538/// //  0.8104     0.2990     0.3014
539/// //  0.6913     0.2802     0.6938
540/// //  0.7821     0.1480     0.3513
541/// //  0.3054     0.1330     0.7176
542/// //  0.1673     0.4696     0.1181
543///
544/// print(&b);
545/// // [3 5 1 1]
546/// //     0.8104     0.6913     0.7821     0.3054     0.1673
547/// //     0.2990     0.2802     0.1480     0.1330     0.4696
548/// //     0.3014     0.6938     0.3513     0.7176     0.1181
549///
550/// let c  = reorder_v2(&a, 2, 0, Some(vec![1]));
551/// print(&c);
552///
553/// // [1 5 3 1]
554/// //  0.8104     0.6913     0.7821     0.3054     0.1673
555/// //
556/// //  0.2990     0.2802     0.1480     0.1330     0.4696
557/// //
558/// //  0.3014     0.6938     0.3513     0.7176     0.1181
559/// ```
560pub fn reorder_v2<T>(
561    input: &Array<T>,
562    new_axis0: u64,
563    new_axis1: u64,
564    next_axes: Option<Vec<u64>>,
565) -> Array<T>
566where
567    T: HasAfEnum,
568{
569    let mut new_axes = [0, 1, 2, 3];
570    new_axes[0] = new_axis0;
571    new_axes[1] = new_axis1;
572    match next_axes {
573        Some(left_over_new_axes) => {
574            // At the moment of writing this comment, ArrayFire could
575            // handle only a maximum of 4 dimensions. Hence, excluding
576            // the two explicit axes arguments to this function, a maximum
577            // of only two more axes can be provided. Hence the below condition.
578            assert!(left_over_new_axes.len() <= 2);
579
580            for a_idx in 0..left_over_new_axes.len() {
581                new_axes[2 + a_idx] = left_over_new_axes[a_idx];
582            }
583        }
584        None => {
585            for a_idx in 2..4 {
586                new_axes[a_idx] = a_idx as u64;
587            }
588        }
589    };
590
591    unsafe {
592        let mut temp: af_array = std::ptr::null_mut();
593        let err_val = af_reorder(
594            &mut temp as *mut af_array,
595            input.get() as af_array,
596            new_axes[0] as c_uint,
597            new_axes[1] as c_uint,
598            new_axes[2] as c_uint,
599            new_axes[3] as c_uint,
600        );
601        HANDLE_ERROR(AfError::from(err_val));
602        temp.into()
603    }
604}
605
606/// Reorder the array in specified order
607///
608/// The default order of axes in ArrayFire is axis with smallest distance
609/// between adjacent elements towards an axis with highest distance between
610/// adjacent elements.
611///
612///# Parameters
613///
614/// - `input` is the input Array
615/// - `dims` is the target(output) dimensions
616///
617///# Return Values
618///
619/// Array with data reordered as per the new axes order
620#[deprecated(since = "3.6.3", note = "Please use new reorder API")]
621pub fn reorder<T>(input: &Array<T>, dims: Dim4) -> Array<T>
622where
623    T: HasAfEnum,
624{
625    reorder_v2(input, dims[0], dims[1], Some(vec![dims[2], dims[3]]))
626}
627
628///"Circular shift of values along specified dimension
629///
630///# Parameters
631///
632/// - `input` is the input Array
633/// - `offsets` is 4-value tuple that specifies the shift along respective dimension
634///
635///# Return Values
636///
637/// An Array with shifted data.
638///
639///# Examples
640///
641/// ```rust
642/// use arrayfire::{Array, Dim4, print, randu, shift};
643/// let a  = randu::<f32>(Dim4::new(&[5, 1, 1, 1]));
644/// let _a = shift(&a, &[-1i32, 1 , 1, 1]); //shift data one step backward
645/// let a_ = shift(&a, &[ 1i32, 1 , 1, 1]); //shift data one step forward
646/// print(& a);
647/// print(&_a);
648/// print(&a_);
649/// ```
650pub fn shift<T>(input: &Array<T>, offsets: &[i32; 4]) -> Array<T>
651where
652    T: HasAfEnum,
653{
654    unsafe {
655        let mut temp: af_array = std::ptr::null_mut();
656        let err_val = af_shift(
657            &mut temp as *mut af_array,
658            input.get(),
659            offsets[0],
660            offsets[1],
661            offsets[2],
662            offsets[3],
663        );
664        HANDLE_ERROR(AfError::from(err_val));
665        temp.into()
666    }
667}
668
669/// Change the shape of the Array
670///
671/// # Parameters
672///
673/// - `input` is the input Array
674/// - `dims` is the new dimensions to which the input Array is reshaped to
675///
676/// # Return Values
677/// Reshaped Array
678pub fn moddims<T>(input: &Array<T>, dims: Dim4) -> Array<T>
679where
680    T: HasAfEnum,
681{
682    unsafe {
683        let mut temp: af_array = std::ptr::null_mut();
684        let err_val = af_moddims(
685            &mut temp as *mut af_array,
686            input.get(),
687            dims.ndims() as c_uint,
688            dims.get().as_ptr() as *const dim_t,
689        );
690        HANDLE_ERROR(AfError::from(err_val));
691        temp.into()
692    }
693}
694
695/// Flatten the multidimensional Array to an 1D Array
696pub fn flat<T>(input: &Array<T>) -> Array<T>
697where
698    T: HasAfEnum,
699{
700    unsafe {
701        let mut temp: af_array = std::ptr::null_mut();
702        let err_val = af_flat(&mut temp as *mut af_array, input.get());
703        HANDLE_ERROR(AfError::from(err_val));
704        temp.into()
705    }
706}
707
708/// Flip the Array
709///
710/// # Parameters
711///
712/// - `input` is the Array to be flipped
713/// - `dim` is the dimension along which the flip has to happen
714///
715/// # Return Values
716///
717/// Flipped Array
718pub fn flip<T>(input: &Array<T>, dim: u32) -> Array<T>
719where
720    T: HasAfEnum,
721{
722    unsafe {
723        let mut temp: af_array = std::ptr::null_mut();
724        let err_val = af_flip(&mut temp as *mut af_array, input.get(), dim);
725        HANDLE_ERROR(AfError::from(err_val));
726        temp.into()
727    }
728}
729
730/// Create lower triangular matrix
731///
732/// # Parameters
733///
734/// - `input` is the input Array
735/// - `is_unit_diag` dictates if the output Array should have 1's along diagonal
736///
737/// # Return Values
738/// Array
739pub fn lower<T>(input: &Array<T>, is_unit_diag: bool) -> Array<T>
740where
741    T: HasAfEnum,
742{
743    unsafe {
744        let mut temp: af_array = std::ptr::null_mut();
745        let err_val = af_lower(&mut temp as *mut af_array, input.get(), is_unit_diag);
746        HANDLE_ERROR(AfError::from(err_val));
747        temp.into()
748    }
749}
750
751/// Create upper triangular matrix
752///
753/// # Parameters
754///
755/// - `input` is the input Array
756/// - `is_unit_diag` dictates if the output Array should have 1's along diagonal
757///
758/// # Return Values
759/// Array
760pub fn upper<T>(input: &Array<T>, is_unit_diag: bool) -> Array<T>
761where
762    T: HasAfEnum,
763{
764    unsafe {
765        let mut temp: af_array = std::ptr::null_mut();
766        let err_val = af_upper(&mut temp as *mut af_array, input.get(), is_unit_diag);
767        HANDLE_ERROR(AfError::from(err_val));
768        temp.into()
769    }
770}
771
772/// Element wise conditional operator for Arrays
773///
774/// This function does the C-equivalent of the following statement, except that the operation
775/// happens on a GPU for all elements simultaneously.
776///
777/// ```text
778/// c = cond ? a : b; /// where cond, a & b are all objects of type Array
779/// ```
780///
781/// # Parameters
782///
783/// - `a` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
784/// `True`
785/// - `cond` is the Array with boolean values
786/// - `b` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
787/// `False`
788///
789/// # Return Values
790///
791/// An Array
792pub fn select<T>(a: &Array<T>, cond: &Array<bool>, b: &Array<T>) -> Array<T>
793where
794    T: HasAfEnum,
795{
796    unsafe {
797        let mut temp: af_array = std::ptr::null_mut();
798        let err_val = af_select(&mut temp as *mut af_array, cond.get(), a.get(), b.get());
799        HANDLE_ERROR(AfError::from(err_val));
800        temp.into()
801    }
802}
803
804/// Element wise conditional operator for Arrays
805///
806/// This function does the C-equivalent of the following statement, except that the operation
807/// happens on a GPU for all elements simultaneously.
808///
809/// ```text
810/// c = cond ? a : b; /// where  a is a scalar(f64) and b is Array
811/// ```
812///
813/// # Parameters
814///
815/// - `a` is the scalar that is assigned to output if corresponding element in `cond` Array is
816/// `True`
817/// - `cond` is the Array with conditional values
818/// - `b` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
819/// `False`
820///
821/// # Return Values
822///
823/// An Array
824pub fn selectl<T>(a: f64, cond: &Array<bool>, b: &Array<T>) -> Array<T>
825where
826    T: HasAfEnum,
827{
828    unsafe {
829        let mut temp: af_array = std::ptr::null_mut();
830        let err_val = af_select_scalar_l(&mut temp as *mut af_array, cond.get(), a, b.get());
831        HANDLE_ERROR(AfError::from(err_val));
832        temp.into()
833    }
834}
835
836/// Element wise conditional operator for Arrays
837///
838/// This function does the C-equivalent of the following statement, except that the operation
839/// happens on a GPU for all elements simultaneously.
840///
841/// ```text
842/// c = cond ? a : b; /// where a is Array and b is a scalar(f64)
843/// ```
844///
845/// # Parameters
846///
847/// - `a` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
848/// `True`
849/// - `cond` is the Array with conditional values
850/// - `b` is the scalar that is assigned to output if corresponding element in `cond` Array is
851/// `False`
852///
853/// # Return Values
854///
855/// An Array
856pub fn selectr<T>(a: &Array<T>, cond: &Array<bool>, b: f64) -> Array<T>
857where
858    T: HasAfEnum,
859{
860    unsafe {
861        let mut temp: af_array = std::ptr::null_mut();
862        let err_val = af_select_scalar_r(&mut temp as *mut af_array, cond.get(), a.get(), b);
863        HANDLE_ERROR(AfError::from(err_val));
864        temp.into()
865    }
866}
867
868/// Inplace replace in Array based on a condition
869///
870/// This function does the C-equivalent of the following statement, except that the operation
871/// happens on a GPU for all elements simultaneously.
872///
873/// ```text
874/// a = cond ? a : b; /// where cond, a & b are all objects of type Array
875/// ```
876///
877/// # Parameters
878///
879/// - `a` is the Array whose element will be replaced with element from `b` if corresponding element in `cond` Array is `True`
880/// - `cond` is the Array with conditional values
881/// - `b` is the Array whose element will replace the element in output if corresponding element in `cond` Array is
882/// `False`
883///
884/// # Return Values
885///
886/// None
887pub fn replace<T>(a: &mut Array<T>, cond: &Array<bool>, b: &Array<T>)
888where
889    T: HasAfEnum,
890{
891    unsafe {
892        let err_val = af_replace(a.get() as *mut af_array, cond.get(), b.get());
893        HANDLE_ERROR(AfError::from(err_val));
894    }
895}
896
897/// Inplace replace in Array based on a condition
898///
899/// This function does the C-equivalent of the following statement, except that the operation
900/// happens on a GPU for all elements simultaneously.
901///
902/// ```text
903/// a = cond ? a : b; /// where cond, a are Arrays and b is scalar(f64)
904/// ```
905///
906/// # Parameters
907///
908/// - `a` is the Array whose element will be replaced with element from `b` if corresponding element in `cond` Array is `True`
909/// - `cond` is the Array with conditional values
910/// - `b` is the scalar that will replace the element in output if corresponding element in `cond` Array is
911/// `False`
912///
913/// # Return Values
914///
915/// None
916pub fn replace_scalar<T>(a: &mut Array<T>, cond: &Array<bool>, b: f64)
917where
918    T: HasAfEnum,
919{
920    unsafe {
921        let err_val = af_replace_scalar(a.get() as *mut af_array, cond.get(), b);
922        HANDLE_ERROR(AfError::from(err_val));
923    }
924}
925
926/// Pad input Array along borders
927///
928/// # Parameters
929///
930/// - `input` is the input array to be padded
931/// - `begin` is padding size before first element along a given dimension
932/// - `end` is padding size after the last element along a given dimension
933/// - `fill_type` indicates what values should be used to fill padded regions
934///
935/// # Return Values
936///
937/// Padded Array
938pub fn pad<T: HasAfEnum>(
939    input: &Array<T>,
940    begin: Dim4,
941    end: Dim4,
942    fill_type: BorderType,
943) -> Array<T> {
944    unsafe {
945        let mut temp: af_array = std::ptr::null_mut();
946        let err_val = af_pad(
947            &mut temp as *mut af_array,
948            input.get(),
949            4,
950            begin.get().as_ptr() as *const dim_t,
951            4,
952            end.get().as_ptr() as *const dim_t,
953            fill_type as c_uint,
954        );
955        HANDLE_ERROR(AfError::from(err_val));
956        temp.into()
957    }
958}
959
960#[cfg(test)]
961mod tests {
962    use super::reorder_v2;
963
964    use super::super::defines::BorderType;
965    use super::super::device::set_device;
966    use super::super::random::randu;
967    use super::pad;
968
969    use crate::dim4;
970
971    #[test]
972    fn check_reorder_api() {
973        set_device(0);
974        let a = randu::<f32>(dim4!(4, 5, 2, 3));
975
976        let _transposed = reorder_v2(&a, 1, 0, None);
977        let _swap_0_2 = reorder_v2(&a, 2, 1, Some(vec![0]));
978        let _swap_1_2 = reorder_v2(&a, 0, 2, Some(vec![1]));
979        let _swap_0_3 = reorder_v2(&a, 3, 1, Some(vec![2, 0]));
980    }
981
982    #[test]
983    fn check_pad_api() {
984        set_device(0);
985        let a = randu::<f32>(dim4![3, 3]);
986        let begin_dims = dim4!(0, 0, 0, 0);
987        let end_dims = dim4!(2, 2, 0, 0);
988        let _padded = pad(&a, begin_dims, end_dims, BorderType::ZERO);
989    }
990}