arrayfire/core/
random.rs

1use super::array::Array;
2use super::defines::{AfError, RandomEngineType};
3use super::dim4::Dim4;
4use super::error::HANDLE_ERROR;
5use super::util::{af_array, af_random_engine, dim_t, u64_t, FloatingPoint, HasAfEnum};
6
7use libc::{c_int, c_uint};
8
9extern "C" {
10    fn af_set_seed(seed: u64_t) -> c_int;
11    fn af_get_seed(seed: *mut u64_t) -> c_int;
12
13    fn af_randu(out: *mut af_array, ndims: c_uint, dims: *const dim_t, afdtype: c_uint) -> c_int;
14    fn af_randn(out: *mut af_array, ndims: c_uint, dims: *const dim_t, afdtype: c_uint) -> c_int;
15
16    fn af_create_random_engine(engine: *mut af_random_engine, rtype: c_uint, seed: u64_t) -> c_int;
17    fn af_retain_random_engine(
18        engine: *mut af_random_engine,
19        inputEngine: af_random_engine,
20    ) -> c_int;
21    fn af_random_engine_set_type(engine: *mut af_random_engine, rtpye: c_uint) -> c_int;
22    fn af_random_engine_get_type(rtype: *mut c_uint, engine: af_random_engine) -> c_int;
23    fn af_random_engine_set_seed(engine: *mut af_random_engine, seed: u64_t) -> c_int;
24    fn af_random_engine_get_seed(seed: *mut u64_t, engine: af_random_engine) -> c_int;
25    fn af_release_random_engine(engine: af_random_engine) -> c_int;
26
27    fn af_get_default_random_engine(engine: *mut af_random_engine) -> c_int;
28    fn af_set_default_random_engine_type(rtype: c_uint) -> c_int;
29
30    fn af_random_uniform(
31        out: *mut af_array,
32        ndims: c_uint,
33        dims: *const dim_t,
34        aftype: c_uint,
35        engine: af_random_engine,
36    ) -> c_int;
37    fn af_random_normal(
38        out: *mut af_array,
39        ndims: c_uint,
40        dims: *const dim_t,
41        aftype: c_uint,
42        engine: af_random_engine,
43    ) -> c_int;
44}
45
46/// Set seed for random number generation
47pub fn set_seed(seed: u64) {
48    unsafe {
49        let err_val = af_set_seed(seed as u64_t);
50        HANDLE_ERROR(AfError::from(err_val));
51    }
52}
53
54/// Get the seed of random number generator
55pub fn get_seed() -> u64 {
56    let mut ret_val: u64 = 0;
57    unsafe {
58        let err_val = af_get_seed(&mut ret_val as *mut u64_t);
59        HANDLE_ERROR(AfError::from(err_val));
60    }
61    ret_val
62}
63
64macro_rules! data_gen_def {
65    [$doc_str: expr, $fn_name:ident, $ffi_name: ident, $($type_trait: ident),+] => (
66        #[doc=$doc_str]
67        ///
68        ///# Parameters
69        ///
70        /// - `dims` is the output dimensions
71        ///
72        ///# Return Values
73        ///
74        /// An Array with random values.
75        pub fn $fn_name<T>(dims: Dim4) -> Array<T>
76        where $( T: $type_trait, )* {
77            let aftype = T::get_af_dtype();
78            unsafe {
79                let mut temp: af_array = std::ptr::null_mut();
80                let err_val = $ffi_name(&mut temp as *mut af_array,
81                                        dims.ndims() as c_uint, dims.get().as_ptr() as *const dim_t,
82                                        aftype as c_uint);
83                HANDLE_ERROR(AfError::from(err_val));
84                temp.into()
85            }
86        }
87    )
88}
89
90data_gen_def!(
91    "Create random numbers from uniform distribution",
92    randu,
93    af_randu,
94    HasAfEnum
95);
96data_gen_def!(
97    "Create random numbers from normal distribution",
98    randn,
99    af_randn,
100    HasAfEnum,
101    FloatingPoint
102);
103
104/// Random number generator engine
105///
106/// This is a wrapper for ArrayFire's native random number generator engine.
107///
108/// ## Sharing Across Threads
109///
110/// While sharing this object with other threads, there is no need to wrap
111/// this in an Arc object unless only one such object is required to exist.
112/// The reason being that ArrayFire's internal details that are pointed to
113/// by the RandoMEngine handle are appropriately reference counted in thread safe
114/// manner. However, if you need to modify RandomEngine object, then please do wrap
115/// the object using a Mutex or Read-Write lock.
116pub struct RandomEngine {
117    handle: af_random_engine,
118}
119
120unsafe impl Send for RandomEngine {}
121
122/// Used for creating RandomEngine object from native resource id
123impl From<af_random_engine> for RandomEngine {
124    fn from(t: af_random_engine) -> Self {
125        Self { handle: t }
126    }
127}
128
129impl RandomEngine {
130    /// Create a new random engine object
131    ///
132    /// # Parameters
133    ///
134    /// - `rengine` can be value of [RandomEngineType](./enum.RandomEngineType.html) enum.
135    /// - `seed` is the initial seed value
136    ///
137    /// # Return Values
138    ///
139    /// A object of type RandomEngine
140    pub fn new(rengine: RandomEngineType, seed: Option<u64>) -> Self {
141        unsafe {
142            let mut temp: af_random_engine = std::ptr::null_mut();
143            let err_val = af_create_random_engine(
144                &mut temp as *mut af_random_engine,
145                rengine as c_uint,
146                match seed {
147                    Some(s) => s,
148                    None => 0,
149                } as u64_t,
150            );
151            HANDLE_ERROR(AfError::from(err_val));
152            RandomEngine { handle: temp }
153        }
154    }
155
156    /// Get random engine type
157    pub fn get_type(&self) -> RandomEngineType {
158        let mut temp: u32 = 0;
159        unsafe {
160            let err_val = af_random_engine_get_type(&mut temp as *mut c_uint, self.handle);
161            HANDLE_ERROR(AfError::from(err_val));
162        }
163        RandomEngineType::from(temp)
164    }
165
166    /// Get random engine type
167    pub fn set_type(&mut self, engine_type: RandomEngineType) {
168        unsafe {
169            let err_val = af_random_engine_set_type(
170                &mut self.handle as *mut af_random_engine,
171                engine_type as c_uint,
172            );
173            HANDLE_ERROR(AfError::from(err_val));
174        }
175    }
176
177    /// Set seed for random engine
178    pub fn set_seed(&mut self, seed: u64) {
179        unsafe {
180            let err_val =
181                af_random_engine_set_seed(&mut self.handle as *mut af_random_engine, seed as u64_t);
182            HANDLE_ERROR(AfError::from(err_val));
183        }
184    }
185
186    /// Get seed of the random engine
187    pub fn get_seed(&self) -> u64 {
188        let mut seed: u64 = 0;
189        unsafe {
190            let err_val = af_random_engine_get_seed(&mut seed as *mut u64_t, self.handle);
191            HANDLE_ERROR(AfError::from(err_val));
192        }
193        seed
194    }
195
196    /// Returns the native FFI handle for Rust object `RandomEngine`
197    pub unsafe fn get(&self) -> af_random_engine {
198        self.handle
199    }
200}
201
202/// Increment reference count of RandomEngine's native resource
203impl Clone for RandomEngine {
204    fn clone(&self) -> Self {
205        unsafe {
206            let mut temp: af_random_engine = std::ptr::null_mut();
207            let err_val = af_retain_random_engine(&mut temp as *mut af_random_engine, self.handle);
208            HANDLE_ERROR(AfError::from(err_val));
209            RandomEngine::from(temp)
210        }
211    }
212}
213
214/// Free RandomEngine's native resource
215impl Drop for RandomEngine {
216    fn drop(&mut self) {
217        unsafe {
218            let err_val = af_release_random_engine(self.handle);
219            HANDLE_ERROR(AfError::from(err_val));
220        }
221    }
222}
223
224#[cfg(feature = "afserde")]
225mod afserde {
226    // Reimport required from super scope
227    use super::{RandomEngine, RandomEngineType};
228
229    use serde::de::Deserializer;
230    use serde::ser::Serializer;
231    use serde::{Deserialize, Serialize};
232
233    #[derive(Debug, Serialize, Deserialize)]
234    struct RandEngine {
235        engine_type: RandomEngineType,
236        seed: u64,
237    }
238
239    /// Serialize Implementation of Array
240    impl Serialize for RandomEngine {
241        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
242        where
243            S: Serializer,
244        {
245            let r = RandEngine {
246                engine_type: self.get_type(),
247                seed: self.get_seed(),
248            };
249            r.serialize(serializer)
250        }
251    }
252
253    /// Deserialize Implementation of Array
254    #[cfg(feature = "afserde")]
255    impl<'de> Deserialize<'de> for RandomEngine {
256        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
257        where
258            D: Deserializer<'de>,
259        {
260            match RandEngine::deserialize(deserializer) {
261                Ok(r) => Ok(RandomEngine::new(r.engine_type, Some(r.seed))),
262                Err(err) => Err(err),
263            }
264        }
265    }
266}
267
268/// Get default random engine
269pub fn get_default_random_engine() -> RandomEngine {
270    unsafe {
271        let mut temp: af_random_engine = std::ptr::null_mut();
272        let mut err_val = af_get_default_random_engine(&mut temp as *mut af_random_engine);
273        HANDLE_ERROR(AfError::from(err_val));
274        let mut handle: af_random_engine = std::ptr::null_mut();
275        err_val = af_retain_random_engine(&mut handle as *mut af_random_engine, temp);
276        HANDLE_ERROR(AfError::from(err_val));
277        RandomEngine { handle: handle }
278    }
279}
280
281/// Set the random engine type for default random number generator
282///
283/// # Parameters
284///
285/// - `rtype` can take one of the values of enum [RandomEngineType](./enum.RandomEngineType.html)
286pub fn set_default_random_engine_type(rtype: RandomEngineType) {
287    unsafe {
288        let err_val = af_set_default_random_engine_type(rtype as c_uint);
289        HANDLE_ERROR(AfError::from(err_val));
290    }
291}
292
293/// Generate array of uniform numbers using a random engine
294///
295/// # Parameters
296///
297/// - `dims` is output array dimensions
298/// - `engine` is an object of type [RandomEngine](./struct.RandomEngine.html)
299///
300/// # Return Values
301///
302/// An Array with uniform numbers generated using random engine
303pub fn random_uniform<T>(dims: Dim4, engine: &RandomEngine) -> Array<T>
304where
305    T: HasAfEnum,
306{
307    let aftype = T::get_af_dtype();
308    unsafe {
309        let mut temp: af_array = std::ptr::null_mut();
310        let err_val = af_random_uniform(
311            &mut temp as *mut af_array,
312            dims.ndims() as c_uint,
313            dims.get().as_ptr() as *const dim_t,
314            aftype as c_uint,
315            engine.get(),
316        );
317        HANDLE_ERROR(AfError::from(err_val));
318        temp.into()
319    }
320}
321
322/// Generate array of normal numbers using a random engine
323///
324/// # Parameters
325///
326/// - `dims` is output array dimensions
327/// - `engine` is an object of type [RandomEngine](./struct.RandomEngine.html)
328///
329/// # Return Values
330///
331/// An Array with normal numbers generated using random engine
332pub fn random_normal<T>(dims: Dim4, engine: &RandomEngine) -> Array<T>
333where
334    T: HasAfEnum + FloatingPoint,
335{
336    let aftype = T::get_af_dtype();
337    unsafe {
338        let mut temp: af_array = std::ptr::null_mut();
339        let err_val = af_random_normal(
340            &mut temp as *mut af_array,
341            dims.ndims() as c_uint,
342            dims.get().as_ptr() as *const dim_t,
343            aftype as c_uint,
344            engine.get(),
345        );
346        HANDLE_ERROR(AfError::from(err_val));
347        temp.into()
348    }
349}
350
351#[cfg(test)]
352mod tests {
353    #[cfg(feature = "afserde")]
354    mod serde_tests {
355        use super::super::RandomEngine;
356        use crate::core::defines::RandomEngineType;
357
358        #[test]
359        #[cfg(feature = "afserde")]
360        fn random_engine_serde_bincode() {
361            let input = RandomEngine::new(RandomEngineType::THREEFRY_2X32_16, Some(2047));
362            let encoded = match bincode::serialize(&input) {
363                Ok(encoded) => encoded,
364                Err(_) => vec![],
365            };
366
367            let decoded: RandomEngine = bincode::deserialize(&encoded).unwrap();
368
369            assert_eq!(input.get_seed(), decoded.get_seed());
370            assert_eq!(input.get_type(), decoded.get_type());
371        }
372    }
373}