arrayfire/core/array.rs
1use super::defines::{AfError, Backend, DType};
2use super::dim4::Dim4;
3use super::error::HANDLE_ERROR;
4use super::util::{af_array, dim_t, void_ptr, HasAfEnum};
5
6use libc::{c_char, c_int, c_longlong, c_uint, c_void};
7use std::ffi::CString;
8use std::marker::PhantomData;
9
10// Some unused functions from array.h in C-API of ArrayFire
11// af_copy_array
12// af_write_array
13// af_get_data_ref_count
14
15extern "C" {
16 fn af_create_array(
17 out: *mut af_array,
18 data: *const c_void,
19 ndims: c_uint,
20 dims: *const dim_t,
21 aftype: c_uint,
22 ) -> c_int;
23
24 fn af_create_handle(
25 out: *mut af_array,
26 ndims: c_uint,
27 dims: *const dim_t,
28 aftype: c_uint,
29 ) -> c_int;
30
31 fn af_device_array(
32 out: *mut af_array,
33 data: *mut c_void,
34 ndims: c_uint,
35 dims: *const dim_t,
36 aftype: c_uint,
37 ) -> c_int;
38
39 fn af_get_elements(out: *mut dim_t, arr: af_array) -> c_int;
40
41 fn af_get_type(out: *mut c_uint, arr: af_array) -> c_int;
42
43 fn af_get_dims(
44 dim0: *mut c_longlong,
45 dim1: *mut c_longlong,
46 dim2: *mut c_longlong,
47 dim3: *mut c_longlong,
48 arr: af_array,
49 ) -> c_int;
50
51 fn af_get_numdims(result: *mut c_uint, arr: af_array) -> c_int;
52
53 fn af_is_empty(result: *mut bool, arr: af_array) -> c_int;
54
55 fn af_is_scalar(result: *mut bool, arr: af_array) -> c_int;
56
57 fn af_is_row(result: *mut bool, arr: af_array) -> c_int;
58
59 fn af_is_column(result: *mut bool, arr: af_array) -> c_int;
60
61 fn af_is_vector(result: *mut bool, arr: af_array) -> c_int;
62
63 fn af_is_complex(result: *mut bool, arr: af_array) -> c_int;
64
65 fn af_is_real(result: *mut bool, arr: af_array) -> c_int;
66
67 fn af_is_double(result: *mut bool, arr: af_array) -> c_int;
68
69 fn af_is_single(result: *mut bool, arr: af_array) -> c_int;
70
71 fn af_is_half(result: *mut bool, arr: af_array) -> c_int;
72
73 fn af_is_integer(result: *mut bool, arr: af_array) -> c_int;
74
75 fn af_is_bool(result: *mut bool, arr: af_array) -> c_int;
76
77 fn af_is_realfloating(result: *mut bool, arr: af_array) -> c_int;
78
79 fn af_is_floating(result: *mut bool, arr: af_array) -> c_int;
80
81 fn af_is_linear(result: *mut bool, arr: af_array) -> c_int;
82
83 fn af_is_owner(result: *mut bool, arr: af_array) -> c_int;
84
85 fn af_is_sparse(result: *mut bool, arr: af_array) -> c_int;
86
87 fn af_get_data_ptr(data: *mut c_void, arr: af_array) -> c_int;
88
89 fn af_eval(arr: af_array) -> c_int;
90
91 fn af_eval_multiple(num: c_int, arrays: *const af_array) -> c_int;
92
93 fn af_set_manual_eval_flag(flag: c_int) -> c_int;
94
95 fn af_get_manual_eval_flag(flag: *mut c_int) -> c_int;
96
97 fn af_retain_array(out: *mut af_array, arr: af_array) -> c_int;
98
99 fn af_copy_array(out: *mut af_array, arr: af_array) -> c_int;
100
101 fn af_release_array(arr: af_array) -> c_int;
102
103 //fn af_print_array(arr: af_array) -> c_int;
104
105 fn af_print_array_gen(exp: *const c_char, arr: af_array, precision: c_int) -> c_int;
106
107 fn af_cast(out: *mut af_array, arr: af_array, aftype: c_uint) -> c_int;
108
109 fn af_get_backend_id(backend: *mut c_uint, input: af_array) -> c_int;
110
111 fn af_get_device_id(device: *mut c_int, input: af_array) -> c_int;
112
113 fn af_create_strided_array(
114 arr: *mut af_array,
115 data: *const c_void,
116 offset: dim_t,
117 ndims: c_uint,
118 dims: *const dim_t,
119 strides: *const dim_t,
120 aftype: c_uint,
121 stype: c_uint,
122 ) -> c_int;
123
124 fn af_get_strides(
125 s0: *mut dim_t,
126 s1: *mut dim_t,
127 s2: *mut dim_t,
128 s3: *mut dim_t,
129 arr: af_array,
130 ) -> c_int;
131
132 fn af_get_offset(offset: *mut dim_t, arr: af_array) -> c_int;
133
134 fn af_lock_array(arr: af_array) -> c_int;
135
136 fn af_unlock_array(arr: af_array) -> c_int;
137
138 fn af_get_device_ptr(ptr: *mut void_ptr, arr: af_array) -> c_int;
139
140 fn af_get_allocated_bytes(result: *mut usize, arr: af_array) -> c_int;
141}
142
143/// A multidimensional data container
144///
145/// Currently, Array objects can store only data until four dimensions
146///
147/// ## Sharing Across Threads
148///
149/// While sharing an Array with other threads, there is no need to wrap
150/// this in an Arc object unless only one such object is required to exist.
151/// The reason being that ArrayFire's internal Array is appropriately reference
152/// counted in thread safe manner. However, if you need to modify Array object,
153/// then please do wrap the object using a Mutex or Read-Write lock.
154///
155/// Examples on how to share Array across threads is illustrated in our
156/// [book](http://arrayfire.org/arrayfire-rust/book/multi-threading.html)
157///
158/// ### NOTE
159///
160/// All operators(traits) from std::ops module implemented for Array object
161/// carry out element wise operations. For example, `*` does multiplication of
162/// elements at corresponding locations in two different Arrays.
163pub struct Array<T: HasAfEnum> {
164 handle: af_array,
165 /// The phantom marker denotes the
166 /// allocation of data on compute device
167 _marker: PhantomData<T>,
168}
169
170/// Enable safely moving Array objects across threads
171unsafe impl<T: HasAfEnum> Send for Array<T> {}
172
173unsafe impl<T: HasAfEnum> Sync for Array<T> {}
174
175macro_rules! is_func {
176 ($doc_str: expr, $fn_name: ident, $ffi_fn: ident) => (
177 #[doc=$doc_str]
178 pub fn $fn_name(&self) -> bool {
179 unsafe {
180 let mut ret_val: bool = false;
181 let err_val = $ffi_fn(&mut ret_val as *mut bool, self.handle);
182 HANDLE_ERROR(AfError::from(err_val));
183 ret_val
184 }
185 }
186 )
187}
188
189impl<T> Array<T>
190where
191 T: HasAfEnum,
192{
193 /// Constructs a new Array object
194 ///
195 /// # Examples
196 ///
197 /// An example of creating an Array from f32 array
198 ///
199 /// ```rust
200 /// use arrayfire::{Array, Dim4, print};
201 /// let values: [f32; 3] = [1.0, 2.0, 3.0];
202 /// let indices = Array::new(&values, Dim4::new(&[3, 1, 1, 1]));
203 /// print(&indices);
204 /// ```
205 /// An example of creating an Array from half::f16 array
206 ///
207 /// ```rust
208 /// use arrayfire::{Array, Dim4, is_half_available, print};
209 /// use half::f16;
210 ///
211 /// let values: [f32; 3] = [1.0, 2.0, 3.0];
212 ///
213 /// if is_half_available(0) { // Default device is 0, hence the argument
214 /// let half_values = values.iter().map(|&x| f16::from_f32(x)).collect::<Vec<_>>();
215 ///
216 /// let hvals = Array::new(&half_values, Dim4::new(&[3, 1, 1, 1]));
217 ///
218 /// print(&hvals);
219 /// } else {
220 /// println!("Half support isn't available on this device");
221 /// }
222 /// ```
223 ///
224 pub fn new(slice: &[T], dims: Dim4) -> Self {
225 let aftype = T::get_af_dtype();
226 unsafe {
227 let mut temp: af_array = std::ptr::null_mut();
228 let err_val = af_create_array(
229 &mut temp as *mut af_array,
230 slice.as_ptr() as *const c_void,
231 dims.ndims() as c_uint,
232 dims.get().as_ptr() as *const c_longlong,
233 aftype as c_uint,
234 );
235 HANDLE_ERROR(AfError::from(err_val));
236 temp.into()
237 }
238 }
239
240 /// Constructs a new Array object from strided data
241 ///
242 /// The data pointed by the slice passed to this function can possibily be offseted using an additional `offset` parameter.
243 pub fn new_strided(slice: &[T], offset: i64, dims: Dim4, strides: Dim4) -> Self {
244 let aftype = T::get_af_dtype();
245 unsafe {
246 let mut temp: af_array = std::ptr::null_mut();
247 let err_val = af_create_strided_array(
248 &mut temp as *mut af_array,
249 slice.as_ptr() as *const c_void,
250 offset as dim_t,
251 dims.ndims() as c_uint,
252 dims.get().as_ptr() as *const c_longlong,
253 strides.get().as_ptr() as *const c_longlong,
254 aftype as c_uint,
255 1 as c_uint,
256 );
257 HANDLE_ERROR(AfError::from(err_val));
258 temp.into()
259 }
260 }
261
262 /// Constructs a new Array object of specified dimensions and type
263 ///
264 /// # Examples
265 ///
266 /// ```rust
267 /// use arrayfire::{Array, Dim4};
268 /// let garbage_vals = Array::<f32>::new_empty(Dim4::new(&[3, 1, 1, 1]));
269 /// ```
270 pub fn new_empty(dims: Dim4) -> Self {
271 let aftype = T::get_af_dtype();
272 unsafe {
273 let mut temp: af_array = std::ptr::null_mut();
274 let err_val = af_create_handle(
275 &mut temp as *mut af_array,
276 dims.ndims() as c_uint,
277 dims.get().as_ptr() as *const c_longlong,
278 aftype as c_uint,
279 );
280 HANDLE_ERROR(AfError::from(err_val));
281 temp.into()
282 }
283 }
284
285 /// Constructs a new Array object from device pointer
286 ///
287 /// The example show cases the usage using CUDA API, but usage of this function will
288 /// be similar in CPU and OpenCL backends also. In the case of OpenCL backend, the pointer
289 /// would be cl_mem. A short example of how to create an Array from device pointer is
290 /// shown below but for detailed set of examples, please check out the tutorial book
291 /// pages:
292 /// - [Interoperability with CUDA][1]
293 /// - [Interoperability with OpenCL][2]
294 ///
295 /// [1]: http://arrayfire.org/arrayfire-rust/book/cuda-interop.html
296 /// [2]: http://arrayfire.org/arrayfire-rust/book/opencl-interop.html
297 ///
298 /// # Examples
299 ///
300 /// An example of creating an Array device pointer using
301 /// [rustacuda](https://github.com/bheisler/RustaCUDA) crate. The
302 /// example has to be copied to a `bin` crate with following contents in Cargo.toml
303 /// to run successfully. Note that, all required setup for rustacuda and arrayfire crate
304 /// have to completed first.
305 /// ```text
306 /// [package]
307 /// ....
308 /// [dependencies]
309 /// rustacuda = "0.1"
310 /// rustacuda_derive = "0.1"
311 /// rustacuda_core = "0.1"
312 /// arrayfire = "3.7.*"
313 /// ```
314 ///
315 /// ```rust,ignore
316 ///use arrayfire::*;
317 ///use rustacuda::*;
318 ///use rustacuda::prelude::*;
319 ///
320 ///fn main() {
321 /// let v: Vec<_> = (0u8 .. 100).map(f32::from).collect();
322 ///
323 /// rustacuda::init(CudaFlags::empty());
324 /// let device = Device::get_device(0).unwrap();
325 /// let context = Context::create_and_push(ContextFlags::MAP_HOST | ContextFlags::SCHED_AUTO,
326 /// device).unwrap();
327 /// // Approach 1
328 /// {
329 /// let mut buffer = memory::DeviceBuffer::from_slice(&v).unwrap();
330 ///
331 /// let array_dptr = Array::new_from_device_ptr(
332 /// buffer.as_device_ptr().as_raw_mut(), dim4!(10, 10));
333 ///
334 /// af_print!("array_dptr", &array_dptr);
335 ///
336 /// array_dptr.lock(); // Needed to avoid free as arrayfire takes ownership
337 /// }
338 ///
339 /// // Approach 2
340 /// {
341 /// let mut dptr: *mut f32 = std::ptr::null_mut();
342 /// unsafe {
343 /// dptr = memory::cuda_malloc::<f32>(10*10).unwrap().as_raw_mut();
344 /// }
345 /// let array_dptr = Array::new_from_device_ptr(dptr, dim4!(10, 10));
346 /// // note that values might be garbage in the memory pointed out by dptr
347 /// // in this example as it is allocated but not initialized prior to passing
348 /// // along to arrayfire::Array::new*
349 ///
350 /// // After ArrayFire takes over ownership of the pointer, you can use other
351 /// // arrayfire functions as usual.
352 /// af_print!("array_dptr", &array_dptr);
353 /// }
354 ///}
355 /// ```
356 pub fn new_from_device_ptr(dev_ptr: *mut T, dims: Dim4) -> Self {
357 let aftype = T::get_af_dtype();
358 unsafe {
359 let mut temp: af_array = std::ptr::null_mut();
360 let err_val = af_device_array(
361 &mut temp as *mut af_array,
362 dev_ptr as *mut c_void,
363 dims.ndims() as c_uint,
364 dims.get().as_ptr() as *const dim_t,
365 aftype as c_uint,
366 );
367 HANDLE_ERROR(AfError::from(err_val));
368 temp.into()
369 }
370 }
371
372 /// Returns the backend of the Array
373 ///
374 /// # Return Values
375 ///
376 /// Returns an value of type `Backend` which indicates which backend
377 /// was active when Array was created.
378 pub fn get_backend(&self) -> Backend {
379 unsafe {
380 let mut ret_val: u32 = 0;
381 let err_val = af_get_backend_id(&mut ret_val as *mut c_uint, self.handle);
382 HANDLE_ERROR(AfError::from(err_val));
383 match (err_val, ret_val) {
384 (0, 1) => Backend::CPU,
385 (0, 2) => Backend::CUDA,
386 (0, 3) => Backend::OPENCL,
387 _ => Backend::DEFAULT,
388 }
389 }
390 }
391
392 /// Returns the device identifier(integer) on which the Array was created
393 ///
394 /// # Return Values
395 ///
396 /// Return the device id on which Array was created.
397 pub fn get_device_id(&self) -> i32 {
398 unsafe {
399 let mut ret_val: i32 = 0;
400 let err_val = af_get_device_id(&mut ret_val as *mut c_int, self.handle);
401 HANDLE_ERROR(AfError::from(err_val));
402 ret_val
403 }
404 }
405
406 /// Returns the number of elements in the Array
407 pub fn elements(&self) -> usize {
408 unsafe {
409 let mut ret_val: dim_t = 0;
410 let err_val = af_get_elements(&mut ret_val as *mut dim_t, self.handle);
411 HANDLE_ERROR(AfError::from(err_val));
412 ret_val as usize
413 }
414 }
415
416 /// Returns the Array data type
417 pub fn get_type(&self) -> DType {
418 unsafe {
419 let mut ret_val: u32 = 0;
420 let err_val = af_get_type(&mut ret_val as *mut c_uint, self.handle);
421 HANDLE_ERROR(AfError::from(err_val));
422 DType::from(ret_val)
423 }
424 }
425
426 /// Returns the dimensions of the Array
427 pub fn dims(&self) -> Dim4 {
428 unsafe {
429 let mut ret0: i64 = 0;
430 let mut ret1: i64 = 0;
431 let mut ret2: i64 = 0;
432 let mut ret3: i64 = 0;
433 let err_val = af_get_dims(
434 &mut ret0 as *mut dim_t,
435 &mut ret1 as *mut dim_t,
436 &mut ret2 as *mut dim_t,
437 &mut ret3 as *mut dim_t,
438 self.handle,
439 );
440 HANDLE_ERROR(AfError::from(err_val));
441 Dim4::new(&[ret0 as u64, ret1 as u64, ret2 as u64, ret3 as u64])
442 }
443 }
444
445 /// Returns the strides of the Array
446 pub fn strides(&self) -> Dim4 {
447 unsafe {
448 let mut ret0: i64 = 0;
449 let mut ret1: i64 = 0;
450 let mut ret2: i64 = 0;
451 let mut ret3: i64 = 0;
452 let err_val = af_get_strides(
453 &mut ret0 as *mut dim_t,
454 &mut ret1 as *mut dim_t,
455 &mut ret2 as *mut dim_t,
456 &mut ret3 as *mut dim_t,
457 self.handle,
458 );
459 HANDLE_ERROR(AfError::from(err_val));
460 Dim4::new(&[ret0 as u64, ret1 as u64, ret2 as u64, ret3 as u64])
461 }
462 }
463
464 /// Returns the number of dimensions of the Array
465 pub fn numdims(&self) -> u32 {
466 unsafe {
467 let mut ret_val: u32 = 0;
468 let err_val = af_get_numdims(&mut ret_val as *mut c_uint, self.handle);
469 HANDLE_ERROR(AfError::from(err_val));
470 ret_val
471 }
472 }
473
474 /// Returns the offset to the pointer from where data begins
475 pub fn offset(&self) -> i64 {
476 unsafe {
477 let mut ret_val: i64 = 0;
478 let err_val = af_get_offset(&mut ret_val as *mut dim_t, self.handle);
479 HANDLE_ERROR(AfError::from(err_val));
480 ret_val
481 }
482 }
483
484 /// Returns the native FFI handle for Rust object `Array`
485 pub unsafe fn get(&self) -> af_array {
486 self.handle
487 }
488
489 /// Set the native FFI handle for Rust object `Array`
490 pub fn set(&mut self, handle: af_array) {
491 self.handle = handle;
492 }
493
494 /// Copies the data from the Array to the mutable slice `data`
495 ///
496 /// # Examples
497 ///
498 /// Basic case
499 /// ```
500 /// # use arrayfire::{Array,Dim4,HasAfEnum};
501 /// let a:Vec<u8> = vec![0,1,2,3,4,5,6,7,8];
502 /// let b = Array::<u8>::new(&a,Dim4::new(&[3,3,1,1]));
503 /// let mut c = vec!(u8::default();b.elements());
504 /// b.host(&mut c);
505 /// assert_eq!(c,a);
506 /// ```
507 /// Generic case
508 /// ```
509 /// # use arrayfire::{Array,Dim4,HasAfEnum};
510 /// fn to_vec<T:HasAfEnum+Default+Clone>(array:&Array<T>) -> Vec<T> {
511 /// let mut vec = vec!(T::default();array.elements());
512 /// array.host(&mut vec);
513 /// return vec;
514 /// }
515 ///
516 /// let a = Array::<u8>::new(&[0,1,2,3,4,5,6,7,8],Dim4::new(&[3,3,1,1]));
517 /// let b:Vec<u8> = vec![0,1,2,3,4,5,6,7,8];
518 /// assert_eq!(to_vec(&a),b);
519 /// ```
520 pub fn host<O: HasAfEnum>(&self, data: &mut [O]) {
521 if data.len() != self.elements() {
522 HANDLE_ERROR(AfError::ERR_SIZE);
523 }
524 unsafe {
525 let err_val = af_get_data_ptr(data.as_mut_ptr() as *mut c_void, self.handle);
526 HANDLE_ERROR(AfError::from(err_val));
527 }
528 }
529
530 /// Evaluates any pending lazy expressions that represent the data in the Array object
531 pub fn eval(&self) {
532 unsafe {
533 let err_val = af_eval(self.handle);
534 HANDLE_ERROR(AfError::from(err_val));
535 }
536 }
537
538 /// Makes an copy of the Array
539 ///
540 /// This does a deep copy of the data into a new Array
541 pub fn copy(&self) -> Self {
542 unsafe {
543 let mut temp: af_array = std::ptr::null_mut();
544 let err_val = af_copy_array(&mut temp as *mut af_array, self.handle);
545 HANDLE_ERROR(AfError::from(err_val));
546 temp.into()
547 }
548 }
549
550 is_func!("Check if Array is empty", is_empty, af_is_empty);
551 is_func!("Check if Array is scalar", is_scalar, af_is_scalar);
552 is_func!("Check if Array is a row", is_row, af_is_row);
553 is_func!("Check if Array is a column", is_column, af_is_column);
554 is_func!("Check if Array is a vector", is_vector, af_is_vector);
555
556 is_func!(
557 "Check if Array is of real (not complex) type",
558 is_real,
559 af_is_real
560 );
561 is_func!(
562 "Check if Array is of complex type",
563 is_complex,
564 af_is_complex
565 );
566
567 is_func!(
568 "Check if Array's numerical type is of double precision",
569 is_double,
570 af_is_double
571 );
572 is_func!(
573 "Check if Array's numerical type is of single precision",
574 is_single,
575 af_is_single
576 );
577 is_func!(
578 "Check if Array's numerical type is of half precision",
579 is_half,
580 af_is_half
581 );
582 is_func!(
583 "Check if Array is of integral type",
584 is_integer,
585 af_is_integer
586 );
587 is_func!("Check if Array is of boolean type", is_bool, af_is_bool);
588
589 is_func!(
590 "Check if Array is floating point real(not complex) data type",
591 is_realfloating,
592 af_is_realfloating
593 );
594 is_func!(
595 "Check if Array is floating point type, either real or complex data",
596 is_floating,
597 af_is_floating
598 );
599
600 is_func!(
601 "Check if Array's memory layout is continuous and one dimensional",
602 is_linear,
603 af_is_linear
604 );
605 is_func!("Check if Array is a sparse matrix", is_sparse, af_is_sparse);
606 is_func!(
607 "Check if Array's memory is owned by it and not a view of another Array",
608 is_owner,
609 af_is_owner
610 );
611
612 /// Cast the Array data type to `target_type`
613 pub fn cast<O: HasAfEnum>(&self) -> Array<O> {
614 let trgt_type = O::get_af_dtype();
615 unsafe {
616 let mut temp: af_array = std::ptr::null_mut();
617 let err_val = af_cast(&mut temp as *mut af_array, self.handle, trgt_type as c_uint);
618 HANDLE_ERROR(AfError::from(err_val));
619 temp.into()
620 }
621 }
622
623 /// Lock the device buffer in the memory manager
624 ///
625 /// Locked buffers are not freed by memory manager until unlock is called.
626 pub fn lock(&self) {
627 unsafe {
628 let err_val = af_lock_array(self.handle);
629 HANDLE_ERROR(AfError::from(err_val));
630 }
631 }
632
633 /// Unlock the device buffer in the memory manager
634 ///
635 /// This function will give back the control over the device pointer to the
636 /// memory manager.
637 pub fn unlock(&self) {
638 unsafe {
639 let err_val = af_unlock_array(self.handle);
640 HANDLE_ERROR(AfError::from(err_val));
641 }
642 }
643
644 /// Get the device pointer and lock the buffer in memory manager
645 ///
646 /// The device pointer is not freed by memory manager until unlock is called.
647 pub unsafe fn device_ptr(&self) -> void_ptr {
648 let mut temp: void_ptr = std::ptr::null_mut();
649 let err_val = af_get_device_ptr(&mut temp as *mut void_ptr, self.handle);
650 HANDLE_ERROR(AfError::from(err_val));
651 temp
652 }
653
654 /// Get the size of physical allocated bytes.
655 ///
656 /// This function will return the size of the parent/owner if the current Array object is an
657 /// indexed Array.
658 pub fn get_allocated_bytes(&self) -> usize {
659 unsafe {
660 let mut temp: usize = 0;
661 let err_val = af_get_allocated_bytes(&mut temp as *mut usize, self.handle);
662 HANDLE_ERROR(AfError::from(err_val));
663 temp
664 }
665 }
666}
667
668/// Used for creating Array object from native
669/// resource id, an 64 bit integer
670impl<T: HasAfEnum> Into<Array<T>> for af_array {
671 fn into(self) -> Array<T> {
672 Array {
673 handle: self,
674 _marker: PhantomData,
675 }
676 }
677}
678
679/// Returns a new Array object after incrementing the reference count of native resource
680///
681/// Cloning an Array does not do a deep copy of the underlying array data. It increments the
682/// reference count of native resource and returns you the new reference in the form a new Array
683/// object.
684///
685/// To create a deep copy use
686/// [copy()](./struct.Array.html#method.copy)
687impl<T> Clone for Array<T>
688where
689 T: HasAfEnum,
690{
691 fn clone(&self) -> Self {
692 unsafe {
693 let mut temp: af_array = std::ptr::null_mut();
694 let ret_val = af_retain_array(&mut temp as *mut af_array, self.handle);
695 match ret_val {
696 0 => temp.into(),
697 _ => panic!("Weak copy of Array failed with error code: {}", ret_val),
698 }
699 }
700 }
701}
702
703/// To free resources when Array goes out of scope
704impl<T> Drop for Array<T>
705where
706 T: HasAfEnum,
707{
708 fn drop(&mut self) {
709 unsafe {
710 let ret_val = af_release_array(self.handle);
711 match ret_val {
712 0 => (),
713 _ => panic!("Array<T> drop failed with error code: {}", ret_val),
714 }
715 }
716 }
717}
718
719/// Print data in the Array
720///
721/// # Parameters
722///
723/// - `input` is the Array to be printed
724///
725/// # Examples
726///
727/// ```rust
728/// use arrayfire::{Dim4, print, randu};
729/// println!("Create a 5-by-3 matrix of random floats on the GPU");
730/// let dims = Dim4::new(&[5, 3, 1, 1]);
731/// let a = randu::<f32>(dims);
732/// print(&a);
733/// ```
734///
735/// The sample output will look like below:
736///
737/// ```text
738/// [5 3 1 1]
739/// 0.7402 0.4464 0.7762
740/// 0.9210 0.6673 0.2948
741/// 0.0390 0.1099 0.7140
742/// 0.9690 0.4702 0.3585
743/// 0.9251 0.5132 0.6814
744/// ```
745pub fn print<T: HasAfEnum>(input: &Array<T>) {
746 let emptystring = CString::new("").unwrap();
747 unsafe {
748 let err_val = af_print_array_gen(
749 emptystring.to_bytes_with_nul().as_ptr() as *const c_char,
750 input.get(),
751 4,
752 );
753 HANDLE_ERROR(AfError::from(err_val));
754 }
755}
756
757/// Generalized Array print function
758///
759/// Use this function to print Array data with arbitrary preicsion
760///
761/// # Parameters
762///
763/// - `msg` is message to be printed before printing the Array data
764/// - `input` is the Array to be printed
765/// - `precision` is data precision with which Array has to be printed
766///
767/// # Examples
768///
769/// ```rust
770/// use arrayfire::{Dim4, print_gen, randu};
771/// println!("Create a 5-by-3 matrix of random floats on the GPU");
772/// let dims = Dim4::new(&[5, 3, 1, 1]);
773/// let a = randu::<f32>(dims);
774/// print_gen(String::from("Random Array"), &a, Some(6));
775/// ```
776///
777/// The sample output will look like below:
778///
779/// ```text
780/// Random Array
781///
782/// [5 3 1 1]
783/// 0.740276 0.446440 0.776202
784/// 0.921094 0.667321 0.294810
785/// 0.039014 0.109939 0.714090
786/// 0.969058 0.470269 0.358590
787/// 0.925181 0.513225 0.681451
788/// ```
789pub fn print_gen<T: HasAfEnum>(msg: String, input: &Array<T>, precision: Option<i32>) {
790 let emptystring = CString::new(msg.as_bytes()).unwrap();
791 unsafe {
792 let err_val = af_print_array_gen(
793 emptystring.to_bytes_with_nul().as_ptr() as *const c_char,
794 input.get(),
795 match precision {
796 Some(p) => p,
797 None => 4,
798 } as c_int,
799 );
800 HANDLE_ERROR(AfError::from(err_val));
801 }
802}
803
804/// evaluate multiple arrays
805///
806/// Use this function to evaluate multiple arrays in single call
807///
808/// # Parameters
809///
810/// - `inputs` are the list of arrays to be evaluated
811pub fn eval_multiple<T: HasAfEnum>(inputs: Vec<&Array<T>>) {
812 unsafe {
813 let mut v = Vec::new();
814 for i in inputs {
815 v.push(i.get());
816 }
817
818 let err_val = af_eval_multiple(v.len() as c_int, v.as_ptr() as *const af_array);
819 HANDLE_ERROR(AfError::from(err_val));
820 }
821}
822
823/// Set eval flag value
824///
825/// This function can be used to toggle on/off the manual evaluation of arrays.
826///
827/// # Parameters
828///
829/// - `flag` is a boolean value indicating manual evaluation setting
830pub fn set_manual_eval(flag: bool) {
831 unsafe {
832 let err_val = af_set_manual_eval_flag(flag as c_int);
833 HANDLE_ERROR(AfError::from(err_val));
834 }
835}
836
837/// Get eval flag value
838///
839/// This function can be used to find out if manual evaluation of arrays is
840/// turned on or off.
841///
842/// # Return Values
843///
844/// A boolean indicating manual evaluation setting.
845pub fn is_eval_manual() -> bool {
846 unsafe {
847 let mut ret_val: i32 = 0;
848 let err_val = af_get_manual_eval_flag(&mut ret_val as *mut c_int);
849 HANDLE_ERROR(AfError::from(err_val));
850 ret_val > 0
851 }
852}
853
854#[cfg(feature = "afserde")]
855mod afserde {
856 // Reimport required from super scope
857 use super::{Array, DType, Dim4, HasAfEnum};
858
859 use serde::de::{Deserializer, Error, Unexpected};
860 use serde::ser::Serializer;
861 use serde::{Deserialize, Serialize};
862
863 #[derive(Debug, Serialize, Deserialize)]
864 struct ArrayOnHost<T: HasAfEnum + std::fmt::Debug> {
865 dtype: DType,
866 shape: Dim4,
867 data: Vec<T>,
868 }
869
870 /// Serialize Implementation of Array
871 impl<T> Serialize for Array<T>
872 where
873 T: std::default::Default + std::clone::Clone + Serialize + HasAfEnum + std::fmt::Debug,
874 {
875 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
876 where
877 S: Serializer,
878 {
879 let mut vec = vec![T::default(); self.elements()];
880 self.host(&mut vec);
881 let arr_on_host = ArrayOnHost {
882 dtype: self.get_type(),
883 shape: self.dims().clone(),
884 data: vec,
885 };
886 arr_on_host.serialize(serializer)
887 }
888 }
889
890 /// Deserialize Implementation of Array
891 impl<'de, T> Deserialize<'de> for Array<T>
892 where
893 T: Deserialize<'de> + HasAfEnum + std::fmt::Debug,
894 {
895 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
896 where
897 D: Deserializer<'de>,
898 {
899 match ArrayOnHost::<T>::deserialize(deserializer) {
900 Ok(arr_on_host) => {
901 let read_dtype = arr_on_host.dtype;
902 let expected_dtype = T::get_af_dtype();
903 if expected_dtype != read_dtype {
904 let error_msg = format!(
905 "data type is {:?}, deserialized type is {:?}",
906 expected_dtype, read_dtype
907 );
908 return Err(Error::invalid_value(Unexpected::Enum, &error_msg.as_str()));
909 }
910 Ok(Array::<T>::new(
911 &arr_on_host.data,
912 arr_on_host.shape.clone(),
913 ))
914 }
915 Err(err) => Err(err),
916 }
917 }
918 }
919}
920
921#[cfg(test)]
922mod tests {
923 use super::super::array::print;
924 use super::super::data::constant;
925 use super::super::device::{info, set_device, sync};
926 use crate::dim4;
927 use std::sync::{mpsc, Arc, RwLock};
928 use std::thread;
929
930 #[test]
931 fn thread_move_array() {
932 // ANCHOR: move_array_to_thread
933 set_device(0);
934 info();
935 let mut a = constant(1, dim4!(3, 3));
936
937 let handle = thread::spawn(move || {
938 //set_device to appropriate device id is required in each thread
939 set_device(0);
940
941 println!("\nFrom thread {:?}", thread::current().id());
942
943 a += constant(2, dim4!(3, 3));
944 print(&a);
945 });
946
947 //Need to join other threads as main thread holds arrayfire context
948 handle.join().unwrap();
949 // ANCHOR_END: move_array_to_thread
950 }
951
952 #[test]
953 fn thread_borrow_array() {
954 set_device(0);
955 info();
956 let a = constant(1i32, dim4!(3, 3));
957
958 let handle = thread::spawn(move || {
959 set_device(0); //set_device to appropriate device id is required in each thread
960 println!("\nFrom thread {:?}", thread::current().id());
961 print(&a);
962 });
963 //Need to join other threads as main thread holds arrayfire context
964 handle.join().unwrap();
965 }
966
967 // ANCHOR: multiple_threads_enum_def
968 #[derive(Debug, Copy, Clone)]
969 enum Op {
970 Add,
971 Sub,
972 Div,
973 Mul,
974 }
975 // ANCHOR_END: multiple_threads_enum_def
976
977 #[test]
978 fn read_from_multiple_threads() {
979 // ANCHOR: read_from_multiple_threads
980 let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
981
982 // Set active GPU/device on main thread on which
983 // subsequent Array objects are created
984 set_device(0);
985
986 // ArrayFire Array's are internally maintained via atomic reference counting
987 // Thus, they need no Arc wrapping while moving to another thread.
988 // Just call clone method on the object and share the resulting clone object
989 let a = constant(1.0f32, dim4!(3, 3));
990 let b = constant(2.0f32, dim4!(3, 3));
991
992 let threads: Vec<_> = ops
993 .into_iter()
994 .map(|op| {
995 let x = a.clone();
996 let y = b.clone();
997 thread::spawn(move || {
998 set_device(0); //Both of objects are created on device 0 earlier
999 match op {
1000 Op::Add => {
1001 let _c = x + y;
1002 }
1003 Op::Sub => {
1004 let _c = x - y;
1005 }
1006 Op::Div => {
1007 let _c = x / y;
1008 }
1009 Op::Mul => {
1010 let _c = x * y;
1011 }
1012 }
1013 sync(0);
1014 thread::sleep(std::time::Duration::new(1, 0));
1015 })
1016 })
1017 .collect();
1018 for child in threads {
1019 let _ = child.join();
1020 }
1021 // ANCHOR_END: read_from_multiple_threads
1022 }
1023
1024 #[test]
1025 fn access_using_rwlock() {
1026 // ANCHOR: access_using_rwlock
1027 let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
1028
1029 // Set active GPU/device on main thread on which
1030 // subsequent Array objects are created
1031 set_device(0);
1032
1033 let c = constant(0.0f32, dim4!(3, 3));
1034 let a = constant(1.0f32, dim4!(3, 3));
1035 let b = constant(2.0f32, dim4!(3, 3));
1036
1037 // Move ownership to RwLock and wrap in Arc since same object is to be modified
1038 let c_lock = Arc::new(RwLock::new(c));
1039
1040 // a and b are internally reference counted by ArrayFire. Unless there
1041 // is prior known need that they may be modified, you can simply clone
1042 // the objects pass them to threads
1043
1044 let threads: Vec<_> = ops
1045 .into_iter()
1046 .map(|op| {
1047 let x = a.clone();
1048 let y = b.clone();
1049
1050 let wlock = c_lock.clone();
1051 thread::spawn(move || {
1052 //Both of objects are created on device 0 in main thread
1053 //Every thread needs to set the device that it is going to
1054 //work on. Note that all Array objects must have been created
1055 //on same device as of date this is written on.
1056 set_device(0);
1057 if let Ok(mut c_guard) = wlock.write() {
1058 match op {
1059 Op::Add => {
1060 *c_guard += x + y;
1061 }
1062 Op::Sub => {
1063 *c_guard += x - y;
1064 }
1065 Op::Div => {
1066 *c_guard += x / y;
1067 }
1068 Op::Mul => {
1069 *c_guard += x * y;
1070 }
1071 }
1072 }
1073 })
1074 })
1075 .collect();
1076
1077 for child in threads {
1078 let _ = child.join();
1079 }
1080
1081 //let read_guard = c_lock.read().unwrap();
1082 //af_print!("C after threads joined", *read_guard);
1083 //C after threads joined
1084 //[3 3 1 1]
1085 // 8.0000 8.0000 8.0000
1086 // 8.0000 8.0000 8.0000
1087 // 8.0000 8.0000 8.0000
1088 // ANCHOR_END: access_using_rwlock
1089 }
1090
1091 #[test]
1092 fn accum_using_channel() {
1093 // ANCHOR: accum_using_channel
1094 let ops: Vec<_> = vec![Op::Add, Op::Sub, Op::Div, Op::Mul, Op::Add, Op::Div];
1095 let ops_len: usize = ops.len();
1096
1097 // Set active GPU/device on main thread on which
1098 // subsequent Array objects are created
1099 set_device(0);
1100
1101 let mut c = constant(0.0f32, dim4!(3, 3));
1102 let a = constant(1.0f32, dim4!(3, 3));
1103 let b = constant(2.0f32, dim4!(3, 3));
1104
1105 let (tx, rx) = mpsc::channel();
1106
1107 let threads: Vec<_> = ops
1108 .into_iter()
1109 .map(|op| {
1110 // a and b are internally reference counted by ArrayFire. Unless there
1111 // is prior known need that they may be modified, you can simply clone
1112 // the objects pass them to threads
1113 let x = a.clone();
1114 let y = b.clone();
1115
1116 let tx_clone = tx.clone();
1117
1118 thread::spawn(move || {
1119 //Both of objects are created on device 0 in main thread
1120 //Every thread needs to set the device that it is going to
1121 //work on. Note that all Array objects must have been created
1122 //on same device as of date this is written on.
1123 set_device(0);
1124
1125 let c = match op {
1126 Op::Add => x + y,
1127 Op::Sub => x - y,
1128 Op::Div => x / y,
1129 Op::Mul => x * y,
1130 };
1131 tx_clone.send(c).unwrap();
1132 })
1133 })
1134 .collect();
1135
1136 for _i in 0..ops_len {
1137 c += rx.recv().unwrap();
1138 }
1139
1140 //Need to join other threads as main thread holds arrayfire context
1141 for child in threads {
1142 let _ = child.join();
1143 }
1144
1145 //af_print!("C after accumulating results", &c);
1146 //[3 3 1 1]
1147 // 8.0000 8.0000 8.0000
1148 // 8.0000 8.0000 8.0000
1149 // 8.0000 8.0000 8.0000
1150 // ANCHOR_END: accum_using_channel
1151 }
1152
1153 #[cfg(feature = "afserde")]
1154 mod serde_tests {
1155 use super::super::Array;
1156 use crate::algorithm::sum_all;
1157 use crate::randu;
1158
1159 #[test]
1160 fn array_serde_json() {
1161 let input = randu!(u8; 2, 2);
1162 let serd = match serde_json::to_string(&input) {
1163 Ok(serialized_str) => serialized_str,
1164 Err(e) => e.to_string(),
1165 };
1166
1167 let deserd: Array<u8> = serde_json::from_str(&serd).unwrap();
1168
1169 assert_eq!(sum_all(&(input - deserd)), (0u32, 0u32));
1170 }
1171
1172 #[test]
1173 fn array_serde_bincode() {
1174 let input = randu!(u8; 2, 2);
1175 let encoded = match bincode::serialize(&input) {
1176 Ok(encoded) => encoded,
1177 Err(_) => vec![],
1178 };
1179
1180 let decoded: Array<u8> = bincode::deserialize(&encoded).unwrap();
1181
1182 assert_eq!(sum_all(&(input - decoded)), (0u32, 0u32));
1183 }
1184 }
1185}