1use super::array::Array;
2use super::defines::{AfError, RandomEngineType};
3use super::dim4::Dim4;
4use super::error::HANDLE_ERROR;
5use super::util::{af_array, af_random_engine, dim_t, u64_t, FloatingPoint, HasAfEnum};
6
7use libc::{c_int, c_uint};
8
9extern "C" {
10 fn af_set_seed(seed: u64_t) -> c_int;
11 fn af_get_seed(seed: *mut u64_t) -> c_int;
12
13 fn af_randu(out: *mut af_array, ndims: c_uint, dims: *const dim_t, afdtype: c_uint) -> c_int;
14 fn af_randn(out: *mut af_array, ndims: c_uint, dims: *const dim_t, afdtype: c_uint) -> c_int;
15
16 fn af_create_random_engine(engine: *mut af_random_engine, rtype: c_uint, seed: u64_t) -> c_int;
17 fn af_retain_random_engine(
18 engine: *mut af_random_engine,
19 inputEngine: af_random_engine,
20 ) -> c_int;
21 fn af_random_engine_set_type(engine: *mut af_random_engine, rtpye: c_uint) -> c_int;
22 fn af_random_engine_get_type(rtype: *mut c_uint, engine: af_random_engine) -> c_int;
23 fn af_random_engine_set_seed(engine: *mut af_random_engine, seed: u64_t) -> c_int;
24 fn af_random_engine_get_seed(seed: *mut u64_t, engine: af_random_engine) -> c_int;
25 fn af_release_random_engine(engine: af_random_engine) -> c_int;
26
27 fn af_get_default_random_engine(engine: *mut af_random_engine) -> c_int;
28 fn af_set_default_random_engine_type(rtype: c_uint) -> c_int;
29
30 fn af_random_uniform(
31 out: *mut af_array,
32 ndims: c_uint,
33 dims: *const dim_t,
34 aftype: c_uint,
35 engine: af_random_engine,
36 ) -> c_int;
37 fn af_random_normal(
38 out: *mut af_array,
39 ndims: c_uint,
40 dims: *const dim_t,
41 aftype: c_uint,
42 engine: af_random_engine,
43 ) -> c_int;
44}
45
46pub fn set_seed(seed: u64) {
48 unsafe {
49 let err_val = af_set_seed(seed as u64_t);
50 HANDLE_ERROR(AfError::from(err_val));
51 }
52}
53
54pub fn get_seed() -> u64 {
56 let mut ret_val: u64 = 0;
57 unsafe {
58 let err_val = af_get_seed(&mut ret_val as *mut u64_t);
59 HANDLE_ERROR(AfError::from(err_val));
60 }
61 ret_val
62}
63
64macro_rules! data_gen_def {
65 [$doc_str: expr, $fn_name:ident, $ffi_name: ident, $($type_trait: ident),+] => (
66 #[doc=$doc_str]
67 pub fn $fn_name<T>(dims: Dim4) -> Array<T>
76 where $( T: $type_trait, )* {
77 let aftype = T::get_af_dtype();
78 unsafe {
79 let mut temp: af_array = std::ptr::null_mut();
80 let err_val = $ffi_name(&mut temp as *mut af_array,
81 dims.ndims() as c_uint, dims.get().as_ptr() as *const dim_t,
82 aftype as c_uint);
83 HANDLE_ERROR(AfError::from(err_val));
84 temp.into()
85 }
86 }
87 )
88}
89
90data_gen_def!(
91 "Create random numbers from uniform distribution",
92 randu,
93 af_randu,
94 HasAfEnum
95);
96data_gen_def!(
97 "Create random numbers from normal distribution",
98 randn,
99 af_randn,
100 HasAfEnum,
101 FloatingPoint
102);
103
104pub struct RandomEngine {
117 handle: af_random_engine,
118}
119
120unsafe impl Send for RandomEngine {}
121
122impl From<af_random_engine> for RandomEngine {
124 fn from(t: af_random_engine) -> Self {
125 Self { handle: t }
126 }
127}
128
129impl RandomEngine {
130 pub fn new(rengine: RandomEngineType, seed: Option<u64>) -> Self {
141 unsafe {
142 let mut temp: af_random_engine = std::ptr::null_mut();
143 let err_val = af_create_random_engine(
144 &mut temp as *mut af_random_engine,
145 rengine as c_uint,
146 match seed {
147 Some(s) => s,
148 None => 0,
149 } as u64_t,
150 );
151 HANDLE_ERROR(AfError::from(err_val));
152 RandomEngine { handle: temp }
153 }
154 }
155
156 pub fn get_type(&self) -> RandomEngineType {
158 let mut temp: u32 = 0;
159 unsafe {
160 let err_val = af_random_engine_get_type(&mut temp as *mut c_uint, self.handle);
161 HANDLE_ERROR(AfError::from(err_val));
162 }
163 RandomEngineType::from(temp)
164 }
165
166 pub fn set_type(&mut self, engine_type: RandomEngineType) {
168 unsafe {
169 let err_val = af_random_engine_set_type(
170 &mut self.handle as *mut af_random_engine,
171 engine_type as c_uint,
172 );
173 HANDLE_ERROR(AfError::from(err_val));
174 }
175 }
176
177 pub fn set_seed(&mut self, seed: u64) {
179 unsafe {
180 let err_val =
181 af_random_engine_set_seed(&mut self.handle as *mut af_random_engine, seed as u64_t);
182 HANDLE_ERROR(AfError::from(err_val));
183 }
184 }
185
186 pub fn get_seed(&self) -> u64 {
188 let mut seed: u64 = 0;
189 unsafe {
190 let err_val = af_random_engine_get_seed(&mut seed as *mut u64_t, self.handle);
191 HANDLE_ERROR(AfError::from(err_val));
192 }
193 seed
194 }
195
196 pub unsafe fn get(&self) -> af_random_engine {
198 self.handle
199 }
200}
201
202impl Clone for RandomEngine {
204 fn clone(&self) -> Self {
205 unsafe {
206 let mut temp: af_random_engine = std::ptr::null_mut();
207 let err_val = af_retain_random_engine(&mut temp as *mut af_random_engine, self.handle);
208 HANDLE_ERROR(AfError::from(err_val));
209 RandomEngine::from(temp)
210 }
211 }
212}
213
214impl Drop for RandomEngine {
216 fn drop(&mut self) {
217 unsafe {
218 let err_val = af_release_random_engine(self.handle);
219 HANDLE_ERROR(AfError::from(err_val));
220 }
221 }
222}
223
224#[cfg(feature = "afserde")]
225mod afserde {
226 use super::{RandomEngine, RandomEngineType};
228
229 use serde::de::Deserializer;
230 use serde::ser::Serializer;
231 use serde::{Deserialize, Serialize};
232
233 #[derive(Debug, Serialize, Deserialize)]
234 struct RandEngine {
235 engine_type: RandomEngineType,
236 seed: u64,
237 }
238
239 impl Serialize for RandomEngine {
241 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
242 where
243 S: Serializer,
244 {
245 let r = RandEngine {
246 engine_type: self.get_type(),
247 seed: self.get_seed(),
248 };
249 r.serialize(serializer)
250 }
251 }
252
253 #[cfg(feature = "afserde")]
255 impl<'de> Deserialize<'de> for RandomEngine {
256 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
257 where
258 D: Deserializer<'de>,
259 {
260 match RandEngine::deserialize(deserializer) {
261 Ok(r) => Ok(RandomEngine::new(r.engine_type, Some(r.seed))),
262 Err(err) => Err(err),
263 }
264 }
265 }
266}
267
268pub fn get_default_random_engine() -> RandomEngine {
270 unsafe {
271 let mut temp: af_random_engine = std::ptr::null_mut();
272 let mut err_val = af_get_default_random_engine(&mut temp as *mut af_random_engine);
273 HANDLE_ERROR(AfError::from(err_val));
274 let mut handle: af_random_engine = std::ptr::null_mut();
275 err_val = af_retain_random_engine(&mut handle as *mut af_random_engine, temp);
276 HANDLE_ERROR(AfError::from(err_val));
277 RandomEngine { handle: handle }
278 }
279}
280
281pub fn set_default_random_engine_type(rtype: RandomEngineType) {
287 unsafe {
288 let err_val = af_set_default_random_engine_type(rtype as c_uint);
289 HANDLE_ERROR(AfError::from(err_val));
290 }
291}
292
293pub fn random_uniform<T>(dims: Dim4, engine: &RandomEngine) -> Array<T>
304where
305 T: HasAfEnum,
306{
307 let aftype = T::get_af_dtype();
308 unsafe {
309 let mut temp: af_array = std::ptr::null_mut();
310 let err_val = af_random_uniform(
311 &mut temp as *mut af_array,
312 dims.ndims() as c_uint,
313 dims.get().as_ptr() as *const dim_t,
314 aftype as c_uint,
315 engine.get(),
316 );
317 HANDLE_ERROR(AfError::from(err_val));
318 temp.into()
319 }
320}
321
322pub fn random_normal<T>(dims: Dim4, engine: &RandomEngine) -> Array<T>
333where
334 T: HasAfEnum + FloatingPoint,
335{
336 let aftype = T::get_af_dtype();
337 unsafe {
338 let mut temp: af_array = std::ptr::null_mut();
339 let err_val = af_random_normal(
340 &mut temp as *mut af_array,
341 dims.ndims() as c_uint,
342 dims.get().as_ptr() as *const dim_t,
343 aftype as c_uint,
344 engine.get(),
345 );
346 HANDLE_ERROR(AfError::from(err_val));
347 temp.into()
348 }
349}
350
351#[cfg(test)]
352mod tests {
353 #[cfg(feature = "afserde")]
354 mod serde_tests {
355 use super::super::RandomEngine;
356 use crate::core::defines::RandomEngineType;
357
358 #[test]
359 #[cfg(feature = "afserde")]
360 fn random_engine_serde_bincode() {
361 let input = RandomEngine::new(RandomEngineType::THREEFRY_2X32_16, Some(2047));
362 let encoded = match bincode::serialize(&input) {
363 Ok(encoded) => encoded,
364 Err(_) => vec![],
365 };
366
367 let decoded: RandomEngine = bincode::deserialize(&encoded).unwrap();
368
369 assert_eq!(input.get_seed(), decoded.get_seed());
370 assert_eq!(input.get_type(), decoded.get_type());
371 }
372 }
373}