arrayfire/core/data.rs
1use super::array::Array;
2use super::defines::{AfError, BorderType};
3use super::dim4::Dim4;
4use super::error::HANDLE_ERROR;
5use super::util::{af_array, c32, c64, dim_t, u64_t, HasAfEnum};
6
7use libc::{c_double, c_int, c_uint};
8use std::option::Option;
9use std::vec::Vec;
10
11extern "C" {
12 fn af_constant(
13 out: *mut af_array,
14 val: c_double,
15 ndims: c_uint,
16 dims: *const dim_t,
17 afdtype: c_uint,
18 ) -> c_int;
19
20 fn af_constant_complex(
21 out: *mut af_array,
22 real: c_double,
23 imag: c_double,
24 ndims: c_uint,
25 dims: *const dim_t,
26 afdtype: c_uint,
27 ) -> c_int;
28
29 fn af_constant_long(out: *mut af_array, val: dim_t, ndims: c_uint, dims: *const dim_t)
30 -> c_int;
31
32 fn af_constant_ulong(
33 out: *mut af_array,
34 val: u64_t,
35 ndims: c_uint,
36 dims: *const dim_t,
37 ) -> c_int;
38
39 fn af_range(
40 out: *mut af_array,
41 ndims: c_uint,
42 dims: *const dim_t,
43 seq_dim: c_int,
44 afdtype: c_uint,
45 ) -> c_int;
46
47 fn af_iota(
48 out: *mut af_array,
49 ndims: c_uint,
50 dims: *const dim_t,
51 t_ndims: c_uint,
52 tdims: *const dim_t,
53 afdtype: c_uint,
54 ) -> c_int;
55
56 fn af_identity(out: *mut af_array, ndims: c_uint, dims: *const dim_t, afdtype: c_uint)
57 -> c_int;
58 fn af_diag_create(out: *mut af_array, arr: af_array, num: c_int) -> c_int;
59 fn af_diag_extract(out: *mut af_array, arr: af_array, num: c_int) -> c_int;
60 fn af_join(out: *mut af_array, dim: c_int, first: af_array, second: af_array) -> c_int;
61 fn af_join_many(
62 out: *mut af_array,
63 dim: c_int,
64 n_arrays: c_uint,
65 inpts: *const af_array,
66 ) -> c_int;
67
68 fn af_tile(
69 out: *mut af_array,
70 arr: af_array,
71 x: c_uint,
72 y: c_uint,
73 z: c_uint,
74 w: c_uint,
75 ) -> c_int;
76 fn af_reorder(
77 o: *mut af_array,
78 a: af_array,
79 x: c_uint,
80 y: c_uint,
81 z: c_uint,
82 w: c_uint,
83 ) -> c_int;
84 fn af_shift(o: *mut af_array, a: af_array, x: c_int, y: c_int, z: c_int, w: c_int) -> c_int;
85 fn af_moddims(out: *mut af_array, arr: af_array, ndims: c_uint, dims: *const dim_t) -> c_int;
86
87 fn af_flat(out: *mut af_array, arr: af_array) -> c_int;
88 fn af_flip(out: *mut af_array, arr: af_array, dim: c_uint) -> c_int;
89 fn af_lower(out: *mut af_array, arr: af_array, is_unit_diag: bool) -> c_int;
90 fn af_upper(out: *mut af_array, arr: af_array, is_unit_diag: bool) -> c_int;
91
92 fn af_select(out: *mut af_array, cond: af_array, a: af_array, b: af_array) -> c_int;
93 fn af_select_scalar_l(out: *mut af_array, cond: af_array, a: c_double, b: af_array) -> c_int;
94 fn af_select_scalar_r(out: *mut af_array, cond: af_array, a: af_array, b: c_double) -> c_int;
95
96 fn af_replace(a: *mut af_array, cond: af_array, b: af_array) -> c_int;
97 fn af_replace_scalar(a: *mut af_array, cond: af_array, b: c_double) -> c_int;
98
99 fn af_pad(
100 out: *mut af_array,
101 input: af_array,
102 begin_ndims: c_uint,
103 begin_dims: *const dim_t,
104 end_ndims: c_uint,
105 end_dims: *const dim_t,
106 pad_fill_type: c_uint,
107 ) -> c_int;
108}
109
110/// Type Trait to generate a constant [Array](./struct.Array.html) of given size
111///
112/// Internally, ConstGenerator trait is implemented by following types.
113///
114/// - f32
115/// - f64
116/// - num::Complex\<f32\>
117/// - num::Complex\<f64\>
118/// - bool
119/// - i32
120/// - u32
121/// - u8
122/// - i64
123/// - u64
124/// - i16
125/// - u16
126///
127pub trait ConstGenerator: HasAfEnum {
128 /// The type of Array<T> object returned by generate function
129 type OutType: HasAfEnum;
130
131 /// Create an Array of `dims` size from scalar value `self`.
132 ///
133 /// # Parameters
134 ///
135 /// - `dims` are the dimensions of the output constant [Array](./struct.Array.html)
136 fn generate(&self, dims: Dim4) -> Array<Self::OutType>;
137}
138
139impl ConstGenerator for i64 {
140 type OutType = i64;
141
142 fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
143 unsafe {
144 let mut temp: af_array = std::ptr::null_mut();
145 let err_val = af_constant_long(
146 &mut temp as *mut af_array,
147 *self,
148 dims.ndims() as c_uint,
149 dims.get().as_ptr() as *const dim_t,
150 );
151 HANDLE_ERROR(AfError::from(err_val));
152 temp.into()
153 }
154 }
155}
156
157impl ConstGenerator for u64 {
158 type OutType = u64;
159
160 fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
161 unsafe {
162 let mut temp: af_array = std::ptr::null_mut();
163 let err_val = af_constant_ulong(
164 &mut temp as *mut af_array,
165 *self,
166 dims.ndims() as c_uint,
167 dims.get().as_ptr() as *const dim_t,
168 );
169 HANDLE_ERROR(AfError::from(err_val));
170 temp.into()
171 }
172 }
173}
174
175impl ConstGenerator for c32 {
176 type OutType = c32;
177
178 fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
179 unsafe {
180 let mut temp: af_array = std::ptr::null_mut();
181 let err_val = af_constant_complex(
182 &mut temp as *mut af_array,
183 (*self).re as c_double,
184 (*self).im as c_double,
185 dims.ndims() as c_uint,
186 dims.get().as_ptr() as *const dim_t,
187 1,
188 );
189 HANDLE_ERROR(AfError::from(err_val));
190 temp.into()
191 }
192 }
193}
194
195impl ConstGenerator for c64 {
196 type OutType = c64;
197
198 fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
199 unsafe {
200 let mut temp: af_array = std::ptr::null_mut();
201 let err_val = af_constant_complex(
202 &mut temp as *mut af_array,
203 (*self).re as c_double,
204 (*self).im as c_double,
205 dims.ndims() as c_uint,
206 dims.get().as_ptr() as *const dim_t,
207 3,
208 );
209 HANDLE_ERROR(AfError::from(err_val));
210 temp.into()
211 }
212 }
213}
214
215impl ConstGenerator for bool {
216 type OutType = bool;
217
218 fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
219 unsafe {
220 let mut temp: af_array = std::ptr::null_mut();
221 let err_val = af_constant(
222 &mut temp as *mut af_array,
223 *self as c_int as c_double,
224 dims.ndims() as c_uint,
225 dims.get().as_ptr() as *const dim_t,
226 4,
227 );
228 HANDLE_ERROR(AfError::from(err_val));
229 temp.into()
230 }
231 }
232}
233
234macro_rules! cnst {
235 ($rust_type:ty, $ffi_type:expr) => {
236 impl ConstGenerator for $rust_type {
237 type OutType = $rust_type;
238
239 fn generate(&self, dims: Dim4) -> Array<Self::OutType> {
240 unsafe {
241 let mut temp: af_array = std::ptr::null_mut();
242 let err_val = af_constant(
243 &mut temp as *mut af_array,
244 *self as c_double,
245 dims.ndims() as c_uint,
246 dims.get().as_ptr() as *const dim_t,
247 $ffi_type,
248 );
249 HANDLE_ERROR(AfError::from(err_val));
250 temp.into()
251 }
252 }
253 }
254 };
255}
256
257cnst!(f32, 0);
258cnst!(f64, 2);
259cnst!(i32, 5);
260cnst!(u32, 6);
261cnst!(u8, 7);
262cnst!(i16, 10);
263cnst!(u16, 11);
264
265/// Create an Array with constant value
266///
267/// The trait ConstGenerator has been defined internally for the following types:
268///
269/// - i64
270/// - u64
271/// - num::Complex\<f32\> a.k.a c32
272/// - num::Complex\<f64\> a.k.a c64
273/// - f32
274/// - f64
275/// - i32
276/// - u32
277/// - u8
278/// - i16
279/// - u16
280///
281/// # Parameters
282///
283/// - `cnst` is the constant value to be filled in the Array
284/// - `dims` is the size of the constant Array
285///
286/// # Return Values
287///
288/// An Array of given dimensions with constant value
289pub fn constant<T>(cnst: T, dims: Dim4) -> Array<T>
290where
291 T: ConstGenerator<OutType = T>,
292{
293 cnst.generate(dims)
294}
295
296/// Create a Range of values
297///
298/// Creates an array with [0, n] values along the `seq_dim` which is tiled across other dimensions.
299///
300/// # Parameters
301///
302/// - `dims` is the size of Array
303/// - `seq_dim` is the dimension along which range values are populated, all values along other
304/// dimensions are just repeated
305///
306/// # Return Values
307/// Array
308pub fn range<T: HasAfEnum>(dims: Dim4, seq_dim: i32) -> Array<T> {
309 let aftype = T::get_af_dtype();
310 unsafe {
311 let mut temp: af_array = std::ptr::null_mut();
312 let err_val = af_range(
313 &mut temp as *mut af_array,
314 dims.ndims() as c_uint,
315 dims.get().as_ptr() as *const dim_t,
316 seq_dim as c_int,
317 aftype as c_uint,
318 );
319 HANDLE_ERROR(AfError::from(err_val));
320 temp.into()
321 }
322}
323
324/// Create a range of values
325///
326/// Create an sequence [0, dims.elements() - 1] and modify to specified dimensions dims and then tile it according to tile_dims.
327///
328/// # Parameters
329///
330/// - `dims` is the dimensions of the sequence to be generated
331/// - `tdims` is the number of repitions of the unit dimensions
332///
333/// # Return Values
334///
335/// Array
336pub fn iota<T: HasAfEnum>(dims: Dim4, tdims: Dim4) -> Array<T> {
337 let aftype = T::get_af_dtype();
338 unsafe {
339 let mut temp: af_array = std::ptr::null_mut();
340 let err_val = af_iota(
341 &mut temp as *mut af_array,
342 dims.ndims() as c_uint,
343 dims.get().as_ptr() as *const dim_t,
344 tdims.ndims() as c_uint,
345 tdims.get().as_ptr() as *const dim_t,
346 aftype as c_uint,
347 );
348 HANDLE_ERROR(AfError::from(err_val));
349 temp.into()
350 }
351}
352
353/// Create an identity array with 1's in diagonal
354///
355/// # Parameters
356///
357/// - `dims` is the output Array dimensions
358///
359/// # Return Values
360///
361/// Identity matrix
362pub fn identity<T: HasAfEnum>(dims: Dim4) -> Array<T> {
363 let aftype = T::get_af_dtype();
364 unsafe {
365 let mut temp: af_array = std::ptr::null_mut();
366 let err_val = af_identity(
367 &mut temp as *mut af_array,
368 dims.ndims() as c_uint,
369 dims.get().as_ptr() as *const dim_t,
370 aftype as c_uint,
371 );
372 HANDLE_ERROR(AfError::from(err_val));
373 temp.into()
374 }
375}
376
377/// Create a diagonal matrix
378///
379/// # Parameters
380///
381/// - `input` is the input Array
382/// - `dim` is the diagonal index relative to principal diagonal where values from input Array are
383/// to be placed
384///
385/// # Return Values
386///
387/// An Array with values as a diagonal Matrix
388pub fn diag_create<T>(input: &Array<T>, dim: i32) -> Array<T>
389where
390 T: HasAfEnum,
391{
392 unsafe {
393 let mut temp: af_array = std::ptr::null_mut();
394 let err_val = af_diag_create(&mut temp as *mut af_array, input.get(), dim);
395 HANDLE_ERROR(AfError::from(err_val));
396 temp.into()
397 }
398}
399
400/// Extract diagonal from a given Matrix
401///
402/// # Parameters
403///
404/// - `input` is the input Matrix
405/// - `dim` is the index of the diagonal that has to be extracted from the input Matrix
406///
407/// # Return Values
408///
409/// An Array with values of the diagonal from input Array
410pub fn diag_extract<T>(input: &Array<T>, dim: i32) -> Array<T>
411where
412 T: HasAfEnum,
413{
414 unsafe {
415 let mut temp: af_array = std::ptr::null_mut();
416 let err_val = af_diag_extract(&mut temp as *mut af_array, input.get(), dim);
417 HANDLE_ERROR(AfError::from(err_val));
418 temp.into()
419 }
420}
421
422/// Join two arrays
423///
424/// # Parameters
425///
426/// - `dim` is the dimension along which the concatenation has to be done
427/// - `first` is the Array that comes first in the concatenation
428/// - `second` is the Array that comes last in the concatenation
429///
430/// # Return Values
431///
432/// Concatenated Array
433pub fn join<T>(dim: i32, first: &Array<T>, second: &Array<T>) -> Array<T>
434where
435 T: HasAfEnum,
436{
437 unsafe {
438 let mut temp: af_array = std::ptr::null_mut();
439 let err_val = af_join(&mut temp as *mut af_array, dim, first.get(), second.get());
440 HANDLE_ERROR(AfError::from(err_val));
441 temp.into()
442 }
443}
444
445/// Join multiple arrays
446///
447/// # Parameters
448///
449/// - `dim` is the dimension along which the concatenation has to be done
450/// - `inputs` is the vector of Arrays that has to be concatenated
451///
452/// # Return Values
453///
454/// Concatenated Array
455pub fn join_many<T>(dim: i32, inputs: Vec<&Array<T>>) -> Array<T>
456where
457 T: HasAfEnum,
458{
459 unsafe {
460 let mut v = Vec::new();
461 for i in inputs {
462 v.push(i.get());
463 }
464 let mut temp: af_array = std::ptr::null_mut();
465 let err_val = af_join_many(
466 &mut temp as *mut af_array,
467 dim,
468 v.len() as u32,
469 v.as_ptr() as *const af_array,
470 );
471 HANDLE_ERROR(AfError::from(err_val));
472 temp.into()
473 }
474}
475
476/// Tile the input array along specified dimension
477///
478/// Tile essentially creates copies of data along each dimension.
479/// The number of copies created is provided by the user on per
480/// axis basis using [Dim4](./struct.dim4.html)
481///
482///# Parameters
483///
484/// - `input` is the input Array
485/// - `dims` is the target(output) dimensions
486///
487///# Return Values
488///
489/// Tiled input array as per the tiling dimensions provided
490pub fn tile<T>(input: &Array<T>, dims: Dim4) -> Array<T>
491where
492 T: HasAfEnum,
493{
494 unsafe {
495 let mut temp: af_array = std::ptr::null_mut();
496 let err_val = af_tile(
497 &mut temp as *mut af_array,
498 input.get() as af_array,
499 dims[0] as c_uint,
500 dims[1] as c_uint,
501 dims[2] as c_uint,
502 dims[3] as c_uint,
503 );
504 HANDLE_ERROR(AfError::from(err_val));
505 temp.into()
506 }
507}
508
509/// Reorder the array according to the new specified axes
510///
511/// Exchanges data within an array such that the requested change in axes is
512/// satisfied. The linear ordering of data within the array is preserved.
513///
514/// The default order of axes in ArrayFire is [0 1 2 3] i.e. axis with smallest
515/// distance between adjacent elements followed by next smallest distance axis and
516/// so on. See [examples](#examples) to have a basic idea of how data is re-ordered.
517///
518///# Parameters
519///
520/// - `input` is the input Array
521/// - `new_axis0` is the new first axis for output
522/// - `new_axis1` is the new second axis for output
523/// - `next_axes` is the new axes order for output
524///
525///# Return Values
526///
527/// Array with data reordered as per the new axes order
528///
529///# Examples
530///
531/// ```rust
532/// use arrayfire::{Array, Dim4, print, randu, reorder_v2};
533/// let a = randu::<f32>(Dim4::new(&[5, 3, 1, 1]));
534/// let b = reorder_v2(&a, 1, 0, None);
535/// print(&a);
536///
537/// // [5 3 1 1]
538/// // 0.8104 0.2990 0.3014
539/// // 0.6913 0.2802 0.6938
540/// // 0.7821 0.1480 0.3513
541/// // 0.3054 0.1330 0.7176
542/// // 0.1673 0.4696 0.1181
543///
544/// print(&b);
545/// // [3 5 1 1]
546/// // 0.8104 0.6913 0.7821 0.3054 0.1673
547/// // 0.2990 0.2802 0.1480 0.1330 0.4696
548/// // 0.3014 0.6938 0.3513 0.7176 0.1181
549///
550/// let c = reorder_v2(&a, 2, 0, Some(vec![1]));
551/// print(&c);
552///
553/// // [1 5 3 1]
554/// // 0.8104 0.6913 0.7821 0.3054 0.1673
555/// //
556/// // 0.2990 0.2802 0.1480 0.1330 0.4696
557/// //
558/// // 0.3014 0.6938 0.3513 0.7176 0.1181
559/// ```
560pub fn reorder_v2<T>(
561 input: &Array<T>,
562 new_axis0: u64,
563 new_axis1: u64,
564 next_axes: Option<Vec<u64>>,
565) -> Array<T>
566where
567 T: HasAfEnum,
568{
569 let mut new_axes = [0, 1, 2, 3];
570 new_axes[0] = new_axis0;
571 new_axes[1] = new_axis1;
572 match next_axes {
573 Some(left_over_new_axes) => {
574 // At the moment of writing this comment, ArrayFire could
575 // handle only a maximum of 4 dimensions. Hence, excluding
576 // the two explicit axes arguments to this function, a maximum
577 // of only two more axes can be provided. Hence the below condition.
578 assert!(left_over_new_axes.len() <= 2);
579
580 for a_idx in 0..left_over_new_axes.len() {
581 new_axes[2 + a_idx] = left_over_new_axes[a_idx];
582 }
583 }
584 None => {
585 for a_idx in 2..4 {
586 new_axes[a_idx] = a_idx as u64;
587 }
588 }
589 };
590
591 unsafe {
592 let mut temp: af_array = std::ptr::null_mut();
593 let err_val = af_reorder(
594 &mut temp as *mut af_array,
595 input.get() as af_array,
596 new_axes[0] as c_uint,
597 new_axes[1] as c_uint,
598 new_axes[2] as c_uint,
599 new_axes[3] as c_uint,
600 );
601 HANDLE_ERROR(AfError::from(err_val));
602 temp.into()
603 }
604}
605
606/// Reorder the array in specified order
607///
608/// The default order of axes in ArrayFire is axis with smallest distance
609/// between adjacent elements towards an axis with highest distance between
610/// adjacent elements.
611///
612///# Parameters
613///
614/// - `input` is the input Array
615/// - `dims` is the target(output) dimensions
616///
617///# Return Values
618///
619/// Array with data reordered as per the new axes order
620#[deprecated(since = "3.6.3", note = "Please use new reorder API")]
621pub fn reorder<T>(input: &Array<T>, dims: Dim4) -> Array<T>
622where
623 T: HasAfEnum,
624{
625 reorder_v2(input, dims[0], dims[1], Some(vec![dims[2], dims[3]]))
626}
627
628///"Circular shift of values along specified dimension
629///
630///# Parameters
631///
632/// - `input` is the input Array
633/// - `offsets` is 4-value tuple that specifies the shift along respective dimension
634///
635///# Return Values
636///
637/// An Array with shifted data.
638///
639///# Examples
640///
641/// ```rust
642/// use arrayfire::{Array, Dim4, print, randu, shift};
643/// let a = randu::<f32>(Dim4::new(&[5, 1, 1, 1]));
644/// let _a = shift(&a, &[-1i32, 1 , 1, 1]); //shift data one step backward
645/// let a_ = shift(&a, &[ 1i32, 1 , 1, 1]); //shift data one step forward
646/// print(& a);
647/// print(&_a);
648/// print(&a_);
649/// ```
650pub fn shift<T>(input: &Array<T>, offsets: &[i32; 4]) -> Array<T>
651where
652 T: HasAfEnum,
653{
654 unsafe {
655 let mut temp: af_array = std::ptr::null_mut();
656 let err_val = af_shift(
657 &mut temp as *mut af_array,
658 input.get(),
659 offsets[0],
660 offsets[1],
661 offsets[2],
662 offsets[3],
663 );
664 HANDLE_ERROR(AfError::from(err_val));
665 temp.into()
666 }
667}
668
669/// Change the shape of the Array
670///
671/// # Parameters
672///
673/// - `input` is the input Array
674/// - `dims` is the new dimensions to which the input Array is reshaped to
675///
676/// # Return Values
677/// Reshaped Array
678pub fn moddims<T>(input: &Array<T>, dims: Dim4) -> Array<T>
679where
680 T: HasAfEnum,
681{
682 unsafe {
683 let mut temp: af_array = std::ptr::null_mut();
684 let err_val = af_moddims(
685 &mut temp as *mut af_array,
686 input.get(),
687 dims.ndims() as c_uint,
688 dims.get().as_ptr() as *const dim_t,
689 );
690 HANDLE_ERROR(AfError::from(err_val));
691 temp.into()
692 }
693}
694
695/// Flatten the multidimensional Array to an 1D Array
696pub fn flat<T>(input: &Array<T>) -> Array<T>
697where
698 T: HasAfEnum,
699{
700 unsafe {
701 let mut temp: af_array = std::ptr::null_mut();
702 let err_val = af_flat(&mut temp as *mut af_array, input.get());
703 HANDLE_ERROR(AfError::from(err_val));
704 temp.into()
705 }
706}
707
708/// Flip the Array
709///
710/// # Parameters
711///
712/// - `input` is the Array to be flipped
713/// - `dim` is the dimension along which the flip has to happen
714///
715/// # Return Values
716///
717/// Flipped Array
718pub fn flip<T>(input: &Array<T>, dim: u32) -> Array<T>
719where
720 T: HasAfEnum,
721{
722 unsafe {
723 let mut temp: af_array = std::ptr::null_mut();
724 let err_val = af_flip(&mut temp as *mut af_array, input.get(), dim);
725 HANDLE_ERROR(AfError::from(err_val));
726 temp.into()
727 }
728}
729
730/// Create lower triangular matrix
731///
732/// # Parameters
733///
734/// - `input` is the input Array
735/// - `is_unit_diag` dictates if the output Array should have 1's along diagonal
736///
737/// # Return Values
738/// Array
739pub fn lower<T>(input: &Array<T>, is_unit_diag: bool) -> Array<T>
740where
741 T: HasAfEnum,
742{
743 unsafe {
744 let mut temp: af_array = std::ptr::null_mut();
745 let err_val = af_lower(&mut temp as *mut af_array, input.get(), is_unit_diag);
746 HANDLE_ERROR(AfError::from(err_val));
747 temp.into()
748 }
749}
750
751/// Create upper triangular matrix
752///
753/// # Parameters
754///
755/// - `input` is the input Array
756/// - `is_unit_diag` dictates if the output Array should have 1's along diagonal
757///
758/// # Return Values
759/// Array
760pub fn upper<T>(input: &Array<T>, is_unit_diag: bool) -> Array<T>
761where
762 T: HasAfEnum,
763{
764 unsafe {
765 let mut temp: af_array = std::ptr::null_mut();
766 let err_val = af_upper(&mut temp as *mut af_array, input.get(), is_unit_diag);
767 HANDLE_ERROR(AfError::from(err_val));
768 temp.into()
769 }
770}
771
772/// Element wise conditional operator for Arrays
773///
774/// This function does the C-equivalent of the following statement, except that the operation
775/// happens on a GPU for all elements simultaneously.
776///
777/// ```text
778/// c = cond ? a : b; /// where cond, a & b are all objects of type Array
779/// ```
780///
781/// # Parameters
782///
783/// - `a` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
784/// `True`
785/// - `cond` is the Array with boolean values
786/// - `b` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
787/// `False`
788///
789/// # Return Values
790///
791/// An Array
792pub fn select<T>(a: &Array<T>, cond: &Array<bool>, b: &Array<T>) -> Array<T>
793where
794 T: HasAfEnum,
795{
796 unsafe {
797 let mut temp: af_array = std::ptr::null_mut();
798 let err_val = af_select(&mut temp as *mut af_array, cond.get(), a.get(), b.get());
799 HANDLE_ERROR(AfError::from(err_val));
800 temp.into()
801 }
802}
803
804/// Element wise conditional operator for Arrays
805///
806/// This function does the C-equivalent of the following statement, except that the operation
807/// happens on a GPU for all elements simultaneously.
808///
809/// ```text
810/// c = cond ? a : b; /// where a is a scalar(f64) and b is Array
811/// ```
812///
813/// # Parameters
814///
815/// - `a` is the scalar that is assigned to output if corresponding element in `cond` Array is
816/// `True`
817/// - `cond` is the Array with conditional values
818/// - `b` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
819/// `False`
820///
821/// # Return Values
822///
823/// An Array
824pub fn selectl<T>(a: f64, cond: &Array<bool>, b: &Array<T>) -> Array<T>
825where
826 T: HasAfEnum,
827{
828 unsafe {
829 let mut temp: af_array = std::ptr::null_mut();
830 let err_val = af_select_scalar_l(&mut temp as *mut af_array, cond.get(), a, b.get());
831 HANDLE_ERROR(AfError::from(err_val));
832 temp.into()
833 }
834}
835
836/// Element wise conditional operator for Arrays
837///
838/// This function does the C-equivalent of the following statement, except that the operation
839/// happens on a GPU for all elements simultaneously.
840///
841/// ```text
842/// c = cond ? a : b; /// where a is Array and b is a scalar(f64)
843/// ```
844///
845/// # Parameters
846///
847/// - `a` is the Array whose element will be assigned to output if corresponding element in `cond` Array is
848/// `True`
849/// - `cond` is the Array with conditional values
850/// - `b` is the scalar that is assigned to output if corresponding element in `cond` Array is
851/// `False`
852///
853/// # Return Values
854///
855/// An Array
856pub fn selectr<T>(a: &Array<T>, cond: &Array<bool>, b: f64) -> Array<T>
857where
858 T: HasAfEnum,
859{
860 unsafe {
861 let mut temp: af_array = std::ptr::null_mut();
862 let err_val = af_select_scalar_r(&mut temp as *mut af_array, cond.get(), a.get(), b);
863 HANDLE_ERROR(AfError::from(err_val));
864 temp.into()
865 }
866}
867
868/// Inplace replace in Array based on a condition
869///
870/// This function does the C-equivalent of the following statement, except that the operation
871/// happens on a GPU for all elements simultaneously.
872///
873/// ```text
874/// a = cond ? a : b; /// where cond, a & b are all objects of type Array
875/// ```
876///
877/// # Parameters
878///
879/// - `a` is the Array whose element will be replaced with element from `b` if corresponding element in `cond` Array is `True`
880/// - `cond` is the Array with conditional values
881/// - `b` is the Array whose element will replace the element in output if corresponding element in `cond` Array is
882/// `False`
883///
884/// # Return Values
885///
886/// None
887pub fn replace<T>(a: &mut Array<T>, cond: &Array<bool>, b: &Array<T>)
888where
889 T: HasAfEnum,
890{
891 unsafe {
892 let err_val = af_replace(a.get() as *mut af_array, cond.get(), b.get());
893 HANDLE_ERROR(AfError::from(err_val));
894 }
895}
896
897/// Inplace replace in Array based on a condition
898///
899/// This function does the C-equivalent of the following statement, except that the operation
900/// happens on a GPU for all elements simultaneously.
901///
902/// ```text
903/// a = cond ? a : b; /// where cond, a are Arrays and b is scalar(f64)
904/// ```
905///
906/// # Parameters
907///
908/// - `a` is the Array whose element will be replaced with element from `b` if corresponding element in `cond` Array is `True`
909/// - `cond` is the Array with conditional values
910/// - `b` is the scalar that will replace the element in output if corresponding element in `cond` Array is
911/// `False`
912///
913/// # Return Values
914///
915/// None
916pub fn replace_scalar<T>(a: &mut Array<T>, cond: &Array<bool>, b: f64)
917where
918 T: HasAfEnum,
919{
920 unsafe {
921 let err_val = af_replace_scalar(a.get() as *mut af_array, cond.get(), b);
922 HANDLE_ERROR(AfError::from(err_val));
923 }
924}
925
926/// Pad input Array along borders
927///
928/// # Parameters
929///
930/// - `input` is the input array to be padded
931/// - `begin` is padding size before first element along a given dimension
932/// - `end` is padding size after the last element along a given dimension
933/// - `fill_type` indicates what values should be used to fill padded regions
934///
935/// # Return Values
936///
937/// Padded Array
938pub fn pad<T: HasAfEnum>(
939 input: &Array<T>,
940 begin: Dim4,
941 end: Dim4,
942 fill_type: BorderType,
943) -> Array<T> {
944 unsafe {
945 let mut temp: af_array = std::ptr::null_mut();
946 let err_val = af_pad(
947 &mut temp as *mut af_array,
948 input.get(),
949 4,
950 begin.get().as_ptr() as *const dim_t,
951 4,
952 end.get().as_ptr() as *const dim_t,
953 fill_type as c_uint,
954 );
955 HANDLE_ERROR(AfError::from(err_val));
956 temp.into()
957 }
958}
959
960#[cfg(test)]
961mod tests {
962 use super::reorder_v2;
963
964 use super::super::defines::BorderType;
965 use super::super::device::set_device;
966 use super::super::random::randu;
967 use super::pad;
968
969 use crate::dim4;
970
971 #[test]
972 fn check_reorder_api() {
973 set_device(0);
974 let a = randu::<f32>(dim4!(4, 5, 2, 3));
975
976 let _transposed = reorder_v2(&a, 1, 0, None);
977 let _swap_0_2 = reorder_v2(&a, 2, 1, Some(vec![0]));
978 let _swap_1_2 = reorder_v2(&a, 0, 2, Some(vec![1]));
979 let _swap_0_3 = reorder_v2(&a, 3, 1, Some(vec![2, 0]));
980 }
981
982 #[test]
983 fn check_pad_api() {
984 set_device(0);
985 let a = randu::<f32>(dim4![3, 3]);
986 let begin_dims = dim4!(0, 0, 0, 0);
987 let end_dims = dim4!(2, 2, 0, 0);
988 let _padded = pad(&a, begin_dims, end_dims, BorderType::ZERO);
989 }
990}