arrayfire/core/
array.rs

1use super::defines::{AfError, Backend, DType};
2use super::dim4::Dim4;
3use super::error::HANDLE_ERROR;
4use super::util::{af_array, dim_t, void_ptr, HasAfEnum};
5
6use libc::{c_char, c_int, c_longlong, c_uint, c_void};
7use std::ffi::CString;
8use std::marker::PhantomData;
9
10// Some unused functions from array.h in C-API of ArrayFire
11// af_copy_array
12// af_write_array
13// af_get_data_ref_count
14
15extern "C" {
16    fn af_create_array(
17        out: *mut af_array,
18        data: *const c_void,
19        ndims: c_uint,
20        dims: *const dim_t,
21        aftype: c_uint,
22    ) -> c_int;
23
24    fn af_create_handle(
25        out: *mut af_array,
26        ndims: c_uint,
27        dims: *const dim_t,
28        aftype: c_uint,
29    ) -> c_int;
30
31    fn af_device_array(
32        out: *mut af_array,
33        data: *mut c_void,
34        ndims: c_uint,
35        dims: *const dim_t,
36        aftype: c_uint,
37    ) -> c_int;
38
39    fn af_get_elements(out: *mut dim_t, arr: af_array) -> c_int;
40
41    fn af_get_type(out: *mut c_uint, arr: af_array) -> c_int;
42
43    fn af_get_dims(
44        dim0: *mut c_longlong,
45        dim1: *mut c_longlong,
46        dim2: *mut c_longlong,
47        dim3: *mut c_longlong,
48        arr: af_array,
49    ) -> c_int;
50
51    fn af_get_numdims(result: *mut c_uint, arr: af_array) -> c_int;
52
53    fn af_is_empty(result: *mut bool, arr: af_array) -> c_int;
54
55    fn af_is_scalar(result: *mut bool, arr: af_array) -> c_int;
56
57    fn af_is_row(result: *mut bool, arr: af_array) -> c_int;
58
59    fn af_is_column(result: *mut bool, arr: af_array) -> c_int;
60
61    fn af_is_vector(result: *mut bool, arr: af_array) -> c_int;
62
63    fn af_is_complex(result: *mut bool, arr: af_array) -> c_int;
64
65    fn af_is_real(result: *mut bool, arr: af_array) -> c_int;
66
67    fn af_is_double(result: *mut bool, arr: af_array) -> c_int;
68
69    fn af_is_single(result: *mut bool, arr: af_array) -> c_int;
70
71    fn af_is_half(result: *mut bool, arr: af_array) -> c_int;
72
73    fn af_is_integer(result: *mut bool, arr: af_array) -> c_int;
74
75    fn af_is_bool(result: *mut bool, arr: af_array) -> c_int;
76
77    fn af_is_realfloating(result: *mut bool, arr: af_array) -> c_int;
78
79    fn af_is_floating(result: *mut bool, arr: af_array) -> c_int;
80
81    fn af_is_linear(result: *mut bool, arr: af_array) -> c_int;
82
83    fn af_is_owner(result: *mut bool, arr: af_array) -> c_int;
84
85    fn af_is_sparse(result: *mut bool, arr: af_array) -> c_int;
86
87    fn af_get_data_ptr(data: *mut c_void, arr: af_array) -> c_int;
88
89    fn af_eval(arr: af_array) -> c_int;
90
91    fn af_eval_multiple(num: c_int, arrays: *const af_array) -> c_int;
92
93    fn af_set_manual_eval_flag(flag: c_int) -> c_int;
94
95    fn af_get_manual_eval_flag(flag: *mut c_int) -> c_int;
96
97    fn af_retain_array(out: *mut af_array, arr: af_array) -> c_int;
98
99    fn af_copy_array(out: *mut af_array, arr: af_array) -> c_int;
100
101    fn af_release_array(arr: af_array) -> c_int;
102
103    //fn af_print_array(arr: af_array) -> c_int;
104
105    fn af_print_array_gen(exp: *const c_char, arr: af_array, precision: c_int) -> c_int;
106
107    fn af_cast(out: *mut af_array, arr: af_array, aftype: c_uint) -> c_int;
108
109    fn af_get_backend_id(backend: *mut c_uint, input: af_array) -> c_int;
110
111    fn af_get_device_id(device: *mut c_int, input: af_array) -> c_int;
112
113    fn af_create_strided_array(
114        arr: *mut af_array,
115        data: *const c_void,
116        offset: dim_t,
117        ndims: c_uint,
118        dims: *const dim_t,
119        strides: *const dim_t,
120        aftype: c_uint,
121        stype: c_uint,
122    ) -> c_int;
123
124    fn af_get_strides(
125        s0: *mut dim_t,
126        s1: *mut dim_t,
127        s2: *mut dim_t,
128        s3: *mut dim_t,
129        arr: af_array,
130    ) -> c_int;
131
132    fn af_get_offset(offset: *mut dim_t, arr: af_array) -> c_int;
133
134    fn af_lock_array(arr: af_array) -> c_int;
135
136    fn af_unlock_array(arr: af_array) -> c_int;
137
138    fn af_get_device_ptr(ptr: *mut void_ptr, arr: af_array) -> c_int;
139
140    fn af_get_allocated_bytes(result: *mut usize, arr: af_array) -> c_int;
141}
142
143/// A multidimensional data container
144///
145/// Currently, Array objects can store only data until four dimensions
146///
147/// ## Sharing Across Threads
148///
149/// While sharing an Array with other threads, there is no need to wrap
150/// this in an Arc object unless only one such object is required to exist.
151/// The reason being that ArrayFire's internal Array is appropriately reference
152/// counted in thread safe manner. However, if you need to modify Array object,
153/// then please do wrap the object using a Mutex or Read-Write lock.
154///
155/// Examples on how to share Array across threads is illustrated in our
156/// [book](http://arrayfire.org/arrayfire-rust/book/multi-threading.html)
157///
158/// ### NOTE
159///
160/// All operators(traits) from std::ops module implemented for Array object
161/// carry out element wise operations. For example, `*` does multiplication of
162/// elements at corresponding locations in two different Arrays.
163pub struct Array<T: HasAfEnum> {
164    handle: af_array,
165    /// The phantom marker denotes the
166    /// allocation of data on compute device
167    _marker: PhantomData<T>,
168}
169
170/// Enable safely moving Array objects across threads
171unsafe impl<T: HasAfEnum> Send for Array<T> {}
172
173unsafe impl<T: HasAfEnum> Sync for Array<T> {}
174
175macro_rules! is_func {
176    ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => (
177        #[doc=$doc_str]
178        pub fn $fn_name(&self) -> bool {
179            unsafe {
180                let mut ret_val: bool = false;
181                let err_val = $ffi_fn(&mut ret_val as *mut bool, self.handle);
182                HANDLE_ERROR(AfError::from(err_val));
183                ret_val
184            }
185        }
186    )
187}
188
189impl<T> Array<T>
190where
191    T: HasAfEnum,
192{
193    /// Constructs a new Array object
194    ///
195    /// # Examples
196    ///
197    /// An example of creating an Array from f32 array
198    ///
199    /// ```rust
200    /// use arrayfire::{Array, Dim4, print};
201    /// let values: [f32; 3] = [1.0, 2.0, 3.0];
202    /// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
203    /// print(&indices);
204    /// ```
205    /// An example of creating an Array from half::f16 array
206    ///
207    /// ```rust
208    /// use arrayfire::{Array, Dim4, is_half_available, print};
209    /// use half::f16;
210    ///
211    /// let values: [f32; 3] = [1.0, 2.0, 3.0];
212    ///
213    /// if is_half_available(0) { // Default device is 0, hence the argument
214    ///     let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
215    ///
216    ///     let hvals = Array::new(&half_values, Dim4::new(&[3, 1, 1, 1]));
217    ///
218    ///     print(&hvals);
219    /// } else {
220    ///     println!("Half support isn't available on this device");
221    /// }
222    /// ```
223    ///
224    pub fn new(slice: &[T], dims: Dim4) -> Self {
225        let aftype = T::get_af_dtype();
226        unsafe {
227            let mut temp: af_array = std::ptr::null_mut();
228            let err_val = af_create_array(
229                &mut temp as *mut af_array,
230                slice.as_ptr() as *const c_void,
231                dims.ndims() as c_uint,
232                dims.get().as_ptr() as *const c_longlong,
233                aftype as c_uint,
234            );
235            HANDLE_ERROR(AfError::from(err_val));
236            temp.into()
237        }
238    }
239
240    /// Constructs a new Array object from strided data
241    ///
242    /// The data pointed by the slice passed to this function can possibily be offseted using an additional `offset` parameter.
243    pub fn new_strided(slice: &[T], offset: i64, dims: Dim4, strides: Dim4) -> Self {
244        let aftype = T::get_af_dtype();
245        unsafe {
246            let mut temp: af_array = std::ptr::null_mut();
247            let err_val = af_create_strided_array(
248                &mut temp as *mut af_array,
249                slice.as_ptr() as *const c_void,
250                offset as dim_t,
251                dims.ndims() as c_uint,
252                dims.get().as_ptr() as *const c_longlong,
253                strides.get().as_ptr() as *const c_longlong,
254                aftype as c_uint,
255                1 as c_uint,
256            );
257            HANDLE_ERROR(AfError::from(err_val));
258            temp.into()
259        }
260    }
261
262    /// Constructs a new Array object of specified dimensions and type
263    ///
264    /// # Examples
265    ///
266    /// ```rust
267    /// use arrayfire::{Array, Dim4};
268    /// let garbage_vals = Array::<f32>::new_empty(Dim4::new(&[3, 1, 1, 1]));
269    /// ```
270    pub fn new_empty(dims: Dim4) -> Self {
271        let aftype = T::get_af_dtype();
272        unsafe {
273            let mut temp: af_array = std::ptr::null_mut();
274            let err_val = af_create_handle(
275                &mut temp as *mut af_array,
276                dims.ndims() as c_uint,
277                dims.get().as_ptr() as *const c_longlong,
278                aftype as c_uint,
279            );
280            HANDLE_ERROR(AfError::from(err_val));
281            temp.into()
282        }
283    }
284
285    /// Constructs a new Array object from device pointer
286    ///
287    /// The example show cases the usage using CUDA API, but usage of this function will
288    /// be similar in CPU and OpenCL backends also. In the case of OpenCL backend, the pointer
289    /// would be cl_mem. A short example of how to create an Array from device pointer is
290    /// shown below but for detailed set of examples, please check out the tutorial book
291    /// pages:
292    ///  - [Interoperability with CUDA][1]
293    ///  - [Interoperability with OpenCL][2]
294    ///
295    ///  [1]: http://arrayfire.org/arrayfire-rust/book/cuda-interop.html
296    ///  [2]: http://arrayfire.org/arrayfire-rust/book/opencl-interop.html
297    ///
298    /// # Examples
299    ///
300    /// An example of creating an Array device pointer using
301    /// [rustacuda](https://github.com/bheisler/RustaCUDA) crate. The
302    /// example has to be copied to a `bin` crate with following contents in Cargo.toml
303    /// to run successfully. Note that, all required setup for rustacuda and arrayfire crate
304    /// have to completed first.
305    /// ```text
306    /// [package]
307    /// ....
308    /// [dependencies]
309    /// rustacuda = "0.1"
310    /// rustacuda_derive = "0.1"
311    /// rustacuda_core = "0.1"
312    /// arrayfire = "3.7.*"
313    /// ```
314    ///
315    /// ```rust,ignore
316    ///use arrayfire::*;
317    ///use rustacuda::*;
318    ///use rustacuda::prelude::*;
319    ///
320    ///fn main() {
321    ///    let v: Vec<_> = (0u8 .. 100).map(f32::from).collect();
322    ///
323    ///    rustacuda::init(CudaFlags::empty());
324    ///    let device = Device::get_device(0).unwrap();
325    ///    let context = Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO,
326    ///                                           device).unwrap();
327    ///    // Approach 1
328    ///    {
329    ///        let mut buffer = memory::DeviceBuffer::from_slice(&v).unwrap();
330    ///
331    ///        let array_dptr = Array::new_from_device_ptr(
332    ///            buffer.as_device_ptr().as_raw_mut(), dim4!(10, 10));
333    ///
334    ///        af_print!("array_dptr", &array_dptr);
335    ///
336    ///        array_dptr.lock(); // Needed to avoid free as arrayfire takes ownership
337    ///    }
338    ///
339    ///    // Approach 2
340    ///    {
341    ///        let mut dptr: *mut f32 = std::ptr::null_mut();
342    ///        unsafe {
343    ///            dptr = memory::cuda_malloc::<f32>(10*10).unwrap().as_raw_mut();
344    ///        }
345    ///        let array_dptr = Array::new_from_device_ptr(dptr, dim4!(10, 10));
346    ///        // note that values might be garbage in the memory pointed out by dptr
347    ///        // in this example as it is allocated but not initialized prior to passing
348    ///        // along to arrayfire::Array::new*
349    ///
350    ///        // After ArrayFire takes over ownership of the pointer, you can use other
351    ///        // arrayfire functions as usual.
352    ///        af_print!("array_dptr", &array_dptr);
353    ///    }
354    ///}
355    /// ```
356    pub fn new_from_device_ptr(dev_ptr: *mut T, dims: Dim4) -> Self {
357        let aftype = T::get_af_dtype();
358        unsafe {
359            let mut temp: af_array = std::ptr::null_mut();
360            let err_val = af_device_array(
361                &mut temp as *mut af_array,
362                dev_ptr as *mut c_void,
363                dims.ndims() as c_uint,
364                dims.get().as_ptr() as *const dim_t,
365                aftype as c_uint,
366            );
367            HANDLE_ERROR(AfError::from(err_val));
368            temp.into()
369        }
370    }
371
372    /// Returns the backend of the Array
373    ///
374    /// # Return Values
375    ///
376    /// Returns an value of type `Backend` which indicates which backend
377    /// was active when Array was created.
378    pub fn get_backend(&self) -> Backend {
379        unsafe {
380            let mut ret_val: u32 = 0;
381            let err_val = af_get_backend_id(&mut ret_val as *mut c_uint, self.handle);
382            HANDLE_ERROR(AfError::from(err_val));
383            match (err_val, ret_val) {
384                (0, 1) => Backend::CPU,
385                (0, 2) => Backend::CUDA,
386                (0, 3) => Backend::OPENCL,
387                _ => Backend::DEFAULT,
388            }
389        }
390    }
391
392    /// Returns the device identifier(integer) on which the Array was created
393    ///
394    /// # Return Values
395    ///
396    /// Return the device id on which Array was created.
397    pub fn get_device_id(&self) -> i32 {
398        unsafe {
399            let mut ret_val: i32 = 0;
400            let err_val = af_get_device_id(&mut ret_val as *mut c_int, self.handle);
401            HANDLE_ERROR(AfError::from(err_val));
402            ret_val
403        }
404    }
405
406    /// Returns the number of elements in the Array
407    pub fn elements(&self) -> usize {
408        unsafe {
409            let mut ret_val: dim_t = 0;
410            let err_val = af_get_elements(&mut ret_val as *mut dim_t, self.handle);
411            HANDLE_ERROR(AfError::from(err_val));
412            ret_val as usize
413        }
414    }
415
416    /// Returns the Array data type
417    pub fn get_type(&self) -> DType {
418        unsafe {
419            let mut ret_val: u32 = 0;
420            let err_val = af_get_type(&mut ret_val as *mut c_uint, self.handle);
421            HANDLE_ERROR(AfError::from(err_val));
422            DType::from(ret_val)
423        }
424    }
425
426    /// Returns the dimensions of the Array
427    pub fn dims(&self) -> Dim4 {
428        unsafe {
429            let mut ret0: i64 = 0;
430            let mut ret1: i64 = 0;
431            let mut ret2: i64 = 0;
432            let mut ret3: i64 = 0;
433            let err_val = af_get_dims(
434                &mut ret0 as *mut dim_t,
435                &mut ret1 as *mut dim_t,
436                &mut ret2 as *mut dim_t,
437                &mut ret3 as *mut dim_t,
438                self.handle,
439            );
440            HANDLE_ERROR(AfError::from(err_val));
441            Dim4::new(&[ret0 as u64, ret1 as u64, ret2 as u64, ret3 as u64])
442        }
443    }
444
445    /// Returns the strides of the Array
446    pub fn strides(&self) -> Dim4 {
447        unsafe {
448            let mut ret0: i64 = 0;
449            let mut ret1: i64 = 0;
450            let mut ret2: i64 = 0;
451            let mut ret3: i64 = 0;
452            let err_val = af_get_strides(
453                &mut ret0 as *mut dim_t,
454                &mut ret1 as *mut dim_t,
455                &mut ret2 as *mut dim_t,
456                &mut ret3 as *mut dim_t,
457                self.handle,
458            );
459            HANDLE_ERROR(AfError::from(err_val));
460            Dim4::new(&[ret0 as u64, ret1 as u64, ret2 as u64, ret3 as u64])
461        }
462    }
463
464    /// Returns the number of dimensions of the Array
465    pub fn numdims(&self) -> u32 {
466        unsafe {
467            let mut ret_val: u32 = 0;
468            let err_val = af_get_numdims(&mut ret_val as *mut c_uint, self.handle);
469            HANDLE_ERROR(AfError::from(err_val));
470            ret_val
471        }
472    }
473
474    /// Returns the offset to the pointer from where data begins
475    pub fn offset(&self) -> i64 {
476        unsafe {
477            let mut ret_val: i64 = 0;
478            let err_val = af_get_offset(&mut ret_val as *mut dim_t, self.handle);
479            HANDLE_ERROR(AfError::from(err_val));
480            ret_val
481        }
482    }
483
484    /// Returns the native FFI handle for Rust object `Array`
485    pub unsafe fn get(&self) -> af_array {
486        self.handle
487    }
488
489    /// Set the native FFI handle for Rust object `Array`
490    pub fn set(&mut self, handle: af_array) {
491        self.handle = handle;
492    }
493
494    /// Copies the data from the Array to the mutable slice `data`
495    ///
496    /// # Examples
497    ///
498    /// Basic case
499    /// ```
500    /// # use arrayfire::{Array,Dim4,HasAfEnum};
501    /// let a:Vec<u8> = vec![0,1,2,3,4,5,6,7,8];
502    /// let b = Array::<u8>::new(&a,Dim4::new(&[3,3,1,1]));
503    /// let mut c = vec!(u8::default();b.elements());
504    /// b.host(&mut c);
505    /// assert_eq!(c,a);
506    /// ```
507    /// Generic case
508    /// ```
509    /// # use arrayfire::{Array,Dim4,HasAfEnum};
510    /// fn to_vec<T:HasAfEnum+Default+Clone>(array:&Array<T>) -> Vec<T> {
511    ///     let mut vec = vec!(T::default();array.elements());
512    ///     array.host(&mut vec);
513    ///     return vec;
514    /// }
515    ///
516    /// let a = Array::<u8>::new(&[0,1,2,3,4,5,6,7,8],Dim4::new(&[3,3,1,1]));
517    /// let b:Vec<u8> = vec![0,1,2,3,4,5,6,7,8];
518    /// assert_eq!(to_vec(&a),b);
519    /// ```
520    pub fn host<O: HasAfEnum>(&self, data: &mut [O]) {
521        if data.len() != self.elements() {
522            HANDLE_ERROR(AfError::ERR_SIZE);
523        }
524        unsafe {
525            let err_val = af_get_data_ptr(data.as_mut_ptr() as *mut c_void, self.handle);
526            HANDLE_ERROR(AfError::from(err_val));
527        }
528    }
529
530    /// Evaluates any pending lazy expressions that represent the data in the Array object
531    pub fn eval(&self) {
532        unsafe {
533            let err_val = af_eval(self.handle);
534            HANDLE_ERROR(AfError::from(err_val));
535        }
536    }
537
538    /// Makes an copy of the Array
539    ///
540    /// This does a deep copy of the data into a new Array
541    pub fn copy(&self) -> Self {
542        unsafe {
543            let mut temp: af_array = std::ptr::null_mut();
544            let err_val = af_copy_array(&mut temp as *mut af_array, self.handle);
545            HANDLE_ERROR(AfError::from(err_val));
546            temp.into()
547        }
548    }
549
550    is_func!("Check if Array is empty", is_empty, af_is_empty);
551    is_func!("Check if Array is scalar", is_scalar, af_is_scalar);
552    is_func!("Check if Array is a row", is_row, af_is_row);
553    is_func!("Check if Array is a column", is_column, af_is_column);
554    is_func!("Check if Array is a vector", is_vector, af_is_vector);
555
556    is_func!(
557        "Check if Array is of real (not complex) type",
558        is_real,
559        af_is_real
560    );
561    is_func!(
562        "Check if Array is of complex type",
563        is_complex,
564        af_is_complex
565    );
566
567    is_func!(
568        "Check if Array's numerical type is of double precision",
569        is_double,
570        af_is_double
571    );
572    is_func!(
573        "Check if Array's numerical type is of single precision",
574        is_single,
575        af_is_single
576    );
577    is_func!(
578        "Check if Array's numerical type is of half precision",
579        is_half,
580        af_is_half
581    );
582    is_func!(
583        "Check if Array is of integral type",
584        is_integer,
585        af_is_integer
586    );
587    is_func!("Check if Array is of boolean type", is_bool, af_is_bool);
588
589    is_func!(
590        "Check if Array is floating point real(not complex) data type",
591        is_realfloating,
592        af_is_realfloating
593    );
594    is_func!(
595        "Check if Array is floating point type, either real or complex data",
596        is_floating,
597        af_is_floating
598    );
599
600    is_func!(
601        "Check if Array's memory layout is continuous and one dimensional",
602        is_linear,
603        af_is_linear
604    );
605    is_func!("Check if Array is a sparse matrix", is_sparse, af_is_sparse);
606    is_func!(
607        "Check if Array's memory is owned by it and not a view of another Array",
608        is_owner,
609        af_is_owner
610    );
611
612    /// Cast the Array data type to `target_type`
613    pub fn cast<O: HasAfEnum>(&self) -> Array<O> {
614        let trgt_type = O::get_af_dtype();
615        unsafe {
616            let mut temp: af_array = std::ptr::null_mut();
617            let err_val = af_cast(&mut temp as *mut af_array, self.handle, trgt_type as c_uint);
618            HANDLE_ERROR(AfError::from(err_val));
619            temp.into()
620        }
621    }
622
623    /// Lock the device buffer in the memory manager
624    ///
625    /// Locked buffers are not freed by memory manager until unlock is called.
626    pub fn lock(&self) {
627        unsafe {
628            let err_val = af_lock_array(self.handle);
629            HANDLE_ERROR(AfError::from(err_val));
630        }
631    }
632
633    /// Unlock the device buffer in the memory manager
634    ///
635    /// This function will give back the control over the device pointer to the
636    /// memory manager.
637    pub fn unlock(&self) {
638        unsafe {
639            let err_val = af_unlock_array(self.handle);
640            HANDLE_ERROR(AfError::from(err_val));
641        }
642    }
643
644    /// Get the device pointer and lock the buffer in memory manager
645    ///
646    /// The device pointer is not freed by memory manager until unlock is called.
647    pub unsafe fn device_ptr(&self) -> void_ptr {
648        let mut temp: void_ptr = std::ptr::null_mut();
649        let err_val = af_get_device_ptr(&mut temp as *mut void_ptr, self.handle);
650        HANDLE_ERROR(AfError::from(err_val));
651        temp
652    }
653
654    /// Get the size of physical allocated bytes.
655    ///
656    /// This function will return the size of the parent/owner if the current Array object is an
657    /// indexed Array.
658    pub fn get_allocated_bytes(&self) -> usize {
659        unsafe {
660            let mut temp: usize = 0;
661            let err_val = af_get_allocated_bytes(&mut temp as *mut usize, self.handle);
662            HANDLE_ERROR(AfError::from(err_val));
663            temp
664        }
665    }
666}
667
668/// Used for creating Array object from native
669/// resource id, an 64 bit integer
670impl<T: HasAfEnum> Into<Array<T>> for af_array {
671    fn into(self) -> Array<T> {
672        Array {
673            handle: self,
674            _marker: PhantomData,
675        }
676    }
677}
678
679/// Returns a new Array object after incrementing the reference count of native resource
680///
681/// Cloning an Array does not do a deep copy of the underlying array data. It increments the
682/// reference count of native resource and returns you the new reference in the form a new Array
683/// object.
684///
685/// To create a deep copy use
686/// [copy()](./struct.Array.html#method.copy)
687impl<T> Clone for Array<T>
688where
689    T: HasAfEnum,
690{
691    fn clone(&self) -> Self {
692        unsafe {
693            let mut temp: af_array = std::ptr::null_mut();
694            let ret_val = af_retain_array(&mut temp as *mut af_array, self.handle);
695            match ret_val {
696                0 => temp.into(),
697                _ => panic!("Weak copy of Array failed with error code: {}", ret_val),
698            }
699        }
700    }
701}
702
703/// To free resources when Array goes out of scope
704impl<T> Drop for Array<T>
705where
706    T: HasAfEnum,
707{
708    fn drop(&mut self) {
709        unsafe {
710            let ret_val = af_release_array(self.handle);
711            match ret_val {
712                0 => (),
713                _ => panic!("Array<T> drop failed with error code: {}", ret_val),
714            }
715        }
716    }
717}
718
719/// Print data in the Array
720///
721/// # Parameters
722///
723/// - `input` is the Array to be printed
724///
725/// # Examples
726///
727/// ```rust
728/// use arrayfire::{Dim4, print, randu};
729/// println!("Create a 5-by-3 matrix of random floats on the GPU");
730/// let dims = Dim4::new(&[5, 3, 1, 1]);
731/// let a = randu::<f32>(dims);
732/// print(&a);
733/// ```
734///
735/// The sample output will look like below:
736///
737/// ```text
738/// [5 3 1 1]
739///     0.7402     0.4464     0.7762
740///     0.9210     0.6673     0.2948
741///     0.0390     0.1099     0.7140
742///     0.9690     0.4702     0.3585
743///     0.9251     0.5132     0.6814
744/// ```
745pub fn print<T: HasAfEnum>(input: &Array<T>) {
746    let emptystring = CString::new("").unwrap();
747    unsafe {
748        let err_val = af_print_array_gen(
749            emptystring.to_bytes_with_nul().as_ptr() as *const c_char,
750            input.get(),
751            4,
752        );
753        HANDLE_ERROR(AfError::from(err_val));
754    }
755}
756
757/// Generalized Array print function
758///
759/// Use this function to print Array data with arbitrary preicsion
760///
761/// # Parameters
762///
763/// - `msg` is message to be printed before printing the Array data
764/// - `input` is the Array to be printed
765/// - `precision` is data precision with which Array has to be printed
766///
767/// # Examples
768///
769/// ```rust
770/// use arrayfire::{Dim4, print_gen, randu};
771/// println!("Create a 5-by-3 matrix of random floats on the GPU");
772/// let dims = Dim4::new(&[5, 3, 1, 1]);
773/// let a = randu::<f32>(dims);
774/// print_gen(String::from("Random Array"), &a, Some(6));
775/// ```
776///
777/// The sample output will look like below:
778///
779/// ```text
780/// Random Array
781///
782/// [5 3 1 1]
783///     0.740276     0.446440     0.776202
784///     0.921094     0.667321     0.294810
785///     0.039014     0.109939     0.714090
786///     0.969058     0.470269     0.358590
787///     0.925181     0.513225     0.681451
788/// ```
789pub fn print_gen<T: HasAfEnum>(msg: String, input: &Array<T>, precision: Option<i32>) {
790    let emptystring = CString::new(msg.as_bytes()).unwrap();
791    unsafe {
792        let err_val = af_print_array_gen(
793            emptystring.to_bytes_with_nul().as_ptr() as *const c_char,
794            input.get(),
795            match precision {
796                Some(p) => p,
797                None => 4,
798            } as c_int,
799        );
800        HANDLE_ERROR(AfError::from(err_val));
801    }
802}
803
804/// evaluate multiple arrays
805///
806/// Use this function to evaluate multiple arrays in single call
807///
808/// # Parameters
809///
810/// - `inputs` are the list of arrays to be evaluated
811pub fn eval_multiple<T: HasAfEnum>(inputs: Vec<&Array<T>>) {
812    unsafe {
813        let mut v = Vec::new();
814        for i in inputs {
815            v.push(i.get());
816        }
817
818        let err_val = af_eval_multiple(v.len() as c_int, v.as_ptr() as *const af_array);
819        HANDLE_ERROR(AfError::from(err_val));
820    }
821}
822
823/// Set eval flag value
824///
825/// This function can be used to toggle on/off the manual evaluation of arrays.
826///
827/// # Parameters
828///
829/// - `flag` is a boolean value indicating manual evaluation setting
830pub fn set_manual_eval(flag: bool) {
831    unsafe {
832        let err_val = af_set_manual_eval_flag(flag as c_int);
833        HANDLE_ERROR(AfError::from(err_val));
834    }
835}
836
837/// Get eval flag value
838///
839/// This function can be used to find out if manual evaluation of arrays is
840/// turned on or off.
841///
842/// # Return Values
843///
844/// A boolean indicating manual evaluation setting.
845pub fn is_eval_manual() -> bool {
846    unsafe {
847        let mut ret_val: i32 = 0;
848        let err_val = af_get_manual_eval_flag(&mut ret_val as *mut c_int);
849        HANDLE_ERROR(AfError::from(err_val));
850        ret_val > 0
851    }
852}
853
854#[cfg(feature = "afserde")]
855mod afserde {
856    // Reimport required from super scope
857    use super::{Array, DType, Dim4, HasAfEnum};
858
859    use serde::de::{Deserializer, Error, Unexpected};
860    use serde::ser::Serializer;
861    use serde::{Deserialize, Serialize};
862
863    #[derive(Debug, Serialize, Deserialize)]
864    struct ArrayOnHost<T: HasAfEnum + std::fmt::Debug> {
865        dtype: DType,
866        shape: Dim4,
867        data: Vec<T>,
868    }
869
870    /// Serialize Implementation of Array
871    impl<T> Serialize for Array<T>
872    where
873        T: std::default::Default + std::clone::Clone + Serialize + HasAfEnum + std::fmt::Debug,
874    {
875        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
876        where
877            S: Serializer,
878        {
879            let mut vec = vec![T::default(); self.elements()];
880            self.host(&mut vec);
881            let arr_on_host = ArrayOnHost {
882                dtype: self.get_type(),
883                shape: self.dims().clone(),
884                data: vec,
885            };
886            arr_on_host.serialize(serializer)
887        }
888    }
889
890    /// Deserialize Implementation of Array
891    impl<'de, T> Deserialize<'de> for Array<T>
892    where
893        T: Deserialize<'de> + HasAfEnum + std::fmt::Debug,
894    {
895        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
896        where
897            D: Deserializer<'de>,
898        {
899            match ArrayOnHost::<T>::deserialize(deserializer) {
900                Ok(arr_on_host) => {
901                    let read_dtype = arr_on_host.dtype;
902                    let expected_dtype = T::get_af_dtype();
903                    if expected_dtype != read_dtype {
904                        let error_msg = format!(
905                            "data type is {:?}, deserialized type is {:?}",
906                            expected_dtype, read_dtype
907                        );
908                        return Err(Error::invalid_value(Unexpected::Enum, &error_msg.as_str()));
909                    }
910                    Ok(Array::<T>::new(
911                        &arr_on_host.data,
912                        arr_on_host.shape.clone(),
913                    ))
914                }
915                Err(err) => Err(err),
916            }
917        }
918    }
919}
920
921#[cfg(test)]
922mod tests {
923    use super::super::array::print;
924    use super::super::data::constant;
925    use super::super::device::{info, set_device, sync};
926    use crate::dim4;
927    use std::sync::{mpsc, Arc, RwLock};
928    use std::thread;
929
930    #[test]
931    fn thread_move_array() {
932        // ANCHOR: move_array_to_thread
933        set_device(0);
934        info();
935        let mut a = constant(1, dim4!(3, 3));
936
937        let handle = thread::spawn(move || {
938            //set_device to appropriate device id is required in each thread
939            set_device(0);
940
941            println!("\nFrom thread {:?}", thread::current().id());
942
943            a += constant(2, dim4!(3, 3));
944            print(&a);
945        });
946
947        //Need to join other threads as main thread holds arrayfire context
948        handle.join().unwrap();
949        // ANCHOR_END: move_array_to_thread
950    }
951
952    #[test]
953    fn thread_borrow_array() {
954        set_device(0);
955        info();
956        let a = constant(1i32, dim4!(3, 3));
957
958        let handle = thread::spawn(move || {
959            set_device(0); //set_device to appropriate device id is required in each thread
960            println!("\nFrom thread {:?}", thread::current().id());
961            print(&a);
962        });
963        //Need to join other threads as main thread holds arrayfire context
964        handle.join().unwrap();
965    }
966
967    // ANCHOR: multiple_threads_enum_def
968    #[derive(Debug, Copy, Clone)]
969    enum Op {
970        Add,
971        Sub,
972        Div,
973        Mul,
974    }
975    // ANCHOR_END: multiple_threads_enum_def
976
977    #[test]
978    fn read_from_multiple_threads() {
979        // ANCHOR: read_from_multiple_threads
980        let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
981
982        // Set active GPU/device on main thread on which
983        // subsequent Array objects are created
984        set_device(0);
985
986        // ArrayFire Array's are internally maintained via atomic reference counting
987        // Thus, they need no Arc wrapping while moving to another thread.
988        // Just call clone method on the object and share the resulting clone object
989        let a = constant(1.0f32, dim4!(3, 3));
990        let b = constant(2.0f32, dim4!(3, 3));
991
992        let threads: Vec<_> = ops
993            .into_iter()
994            .map(|op| {
995                let x = a.clone();
996                let y = b.clone();
997                thread::spawn(move || {
998                    set_device(0); //Both of objects are created on device 0 earlier
999                    match op {
1000                        Op::Add => {
1001                            let _c = x + y;
1002                        }
1003                        Op::Sub => {
1004                            let _c = x - y;
1005                        }
1006                        Op::Div => {
1007                            let _c = x / y;
1008                        }
1009                        Op::Mul => {
1010                            let _c = x * y;
1011                        }
1012                    }
1013                    sync(0);
1014                    thread::sleep(std::time::Duration::new(1, 0));
1015                })
1016            })
1017            .collect();
1018        for child in threads {
1019            let _ = child.join();
1020        }
1021        // ANCHOR_END: read_from_multiple_threads
1022    }
1023
1024    #[test]
1025    fn access_using_rwlock() {
1026        // ANCHOR: access_using_rwlock
1027        let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
1028
1029        // Set active GPU/device on main thread on which
1030        // subsequent Array objects are created
1031        set_device(0);
1032
1033        let c = constant(0.0f32, dim4!(3, 3));
1034        let a = constant(1.0f32, dim4!(3, 3));
1035        let b = constant(2.0f32, dim4!(3, 3));
1036
1037        // Move ownership to RwLock and wrap in Arc since same object is to be modified
1038        let c_lock = Arc::new(RwLock::new(c));
1039
1040        // a and b are internally reference counted by ArrayFire. Unless there
1041        // is prior known need that they may be modified, you can simply clone
1042        // the objects pass them to threads
1043
1044        let threads: Vec<_> = ops
1045            .into_iter()
1046            .map(|op| {
1047                let x = a.clone();
1048                let y = b.clone();
1049
1050                let wlock = c_lock.clone();
1051                thread::spawn(move || {
1052                    //Both of objects are created on device 0 in main thread
1053                    //Every thread needs to set the device that it is going to
1054                    //work on. Note that all Array objects must have been created
1055                    //on same device as of date this is written on.
1056                    set_device(0);
1057                    if let Ok(mut c_guard) = wlock.write() {
1058                        match op {
1059                            Op::Add => {
1060                                *c_guard += x + y;
1061                            }
1062                            Op::Sub => {
1063                                *c_guard += x - y;
1064                            }
1065                            Op::Div => {
1066                                *c_guard += x / y;
1067                            }
1068                            Op::Mul => {
1069                                *c_guard += x * y;
1070                            }
1071                        }
1072                    }
1073                })
1074            })
1075            .collect();
1076
1077        for child in threads {
1078            let _ = child.join();
1079        }
1080
1081        //let read_guard = c_lock.read().unwrap();
1082        //af_print!("C after threads joined", *read_guard);
1083        //C after threads joined
1084        //[3 3 1 1]
1085        //    8.0000     8.0000     8.0000
1086        //    8.0000     8.0000     8.0000
1087        //    8.0000     8.0000     8.0000
1088        // ANCHOR_END: access_using_rwlock
1089    }
1090
1091    #[test]
1092    fn accum_using_channel() {
1093        // ANCHOR: accum_using_channel
1094        let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
1095        let ops_len: usize = ops.len();
1096
1097        // Set active GPU/device on main thread on which
1098        // subsequent Array objects are created
1099        set_device(0);
1100
1101        let mut c = constant(0.0f32, dim4!(3, 3));
1102        let a = constant(1.0f32, dim4!(3, 3));
1103        let b = constant(2.0f32, dim4!(3, 3));
1104
1105        let (tx, rx) = mpsc::channel();
1106
1107        let threads: Vec<_> = ops
1108            .into_iter()
1109            .map(|op| {
1110                // a and b are internally reference counted by ArrayFire. Unless there
1111                // is prior known need that they may be modified, you can simply clone
1112                // the objects pass them to threads
1113                let x = a.clone();
1114                let y = b.clone();
1115
1116                let tx_clone = tx.clone();
1117
1118                thread::spawn(move || {
1119                    //Both of objects are created on device 0 in main thread
1120                    //Every thread needs to set the device that it is going to
1121                    //work on. Note that all Array objects must have been created
1122                    //on same device as of date this is written on.
1123                    set_device(0);
1124
1125                    let c = match op {
1126                        Op::Add => x + y,
1127                        Op::Sub => x - y,
1128                        Op::Div => x / y,
1129                        Op::Mul => x * y,
1130                    };
1131                    tx_clone.send(c).unwrap();
1132                })
1133            })
1134            .collect();
1135
1136        for _i in 0..ops_len {
1137            c += rx.recv().unwrap();
1138        }
1139
1140        //Need to join other threads as main thread holds arrayfire context
1141        for child in threads {
1142            let _ = child.join();
1143        }
1144
1145        //af_print!("C after accumulating results", &c);
1146        //[3 3 1 1]
1147        //    8.0000     8.0000     8.0000
1148        //    8.0000     8.0000     8.0000
1149        //    8.0000     8.0000     8.0000
1150        // ANCHOR_END: accum_using_channel
1151    }
1152
1153    #[cfg(feature = "afserde")]
1154    mod serde_tests {
1155        use super::super::Array;
1156        use crate::algorithm::sum_all;
1157        use crate::randu;
1158
1159        #[test]
1160        fn array_serde_json() {
1161            let input = randu!(u8; 2, 2);
1162            let serd = match serde_json::to_string(&input) {
1163                Ok(serialized_str) => serialized_str,
1164                Err(e) => e.to_string(),
1165            };
1166
1167            let deserd: Array<u8> = serde_json::from_str(&serd).unwrap();
1168
1169            assert_eq!(sum_all(&(input - deserd)), (0u32, 0u32));
1170        }
1171
1172        #[test]
1173        fn array_serde_bincode() {
1174            let input = randu!(u8; 2, 2);
1175            let encoded = match bincode::serialize(&input) {
1176                Ok(encoded) => encoded,
1177                Err(_) => vec![],
1178            };
1179
1180            let decoded: Array<u8> = bincode::deserialize(&encoded).unwrap();
1181
1182            assert_eq!(sum_all(&(input - decoded)), (0u32, 0u32));
1183        }
1184    }
1185}