extern crate libc;
use array::Array;
use dim4::Dim4;
use defines::{AfError, RandomEngineType};
use error::HANDLE_ERROR;
use self::libc::{uint8_t, c_int, c_uint};
use util::{DimT, HasAfEnum, MutAfArray, MutRandEngine, RandEngine, Uintl};
#[allow(dead_code)]
extern {
fn af_set_seed(seed: Uintl) -> c_int;
fn af_get_seed(seed: *mut Uintl) -> c_int;
fn af_randu(out: MutAfArray, ndims: c_uint, dims: *const DimT, afdtype: uint8_t) -> c_int;
fn af_randn(out: MutAfArray, ndims: c_uint, dims: *const DimT, afdtype: uint8_t) -> c_int;
fn af_create_random_engine(engine: MutRandEngine, rtype: uint8_t, seed: Uintl) -> c_int;
fn af_retain_random_engine(engine: MutRandEngine, inputEngine: RandEngine) -> c_int;
fn af_random_engine_set_type(engine: MutRandEngine, rtpye: uint8_t) -> c_int;
fn af_random_engine_get_type(rtype: *mut uint8_t, engine: RandEngine) -> c_int;
fn af_random_engine_set_seed(engine: MutRandEngine, seed: Uintl) -> c_int;
fn af_random_engine_get_seed(seed: *mut Uintl, engine: RandEngine) -> c_int;
fn af_release_random_engine(engine: RandEngine) -> c_int;
fn af_get_default_random_engine(engine: MutRandEngine) -> c_int;
fn af_set_default_random_engine_type(rtype: uint8_t) -> c_int;
fn af_random_uniform(out: MutAfArray, ndims: c_uint, dims: *const DimT,
aftype: uint8_t, engine: RandEngine) -> c_int;
fn af_random_normal(out: MutAfArray, ndims: c_uint, dims: *const DimT,
aftype: uint8_t, engine: RandEngine) -> c_int;
}
pub fn set_seed(seed: u64) {
unsafe {
let err_val = af_set_seed(seed as Uintl);
HANDLE_ERROR(AfError::from(err_val));
}
}
#[allow(unused_mut)]
pub fn get_seed() -> u64 {
unsafe {
let mut temp: u64 = 0;
let err_val = af_get_seed(&mut temp as *mut Uintl);
HANDLE_ERROR(AfError::from(err_val));
temp
}
}
macro_rules! data_gen_def {
($doc_str: expr, $fn_name:ident, $ffi_name: ident) => (
#[doc=$doc_str]
#[allow(unused_mut)]
pub fn $fn_name<T: HasAfEnum>(dims: Dim4) -> Array {
unsafe {
let aftype = T::get_af_dtype();
let mut temp: i64 = 0;
let err_val = $ffi_name(&mut temp as MutAfArray,
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
aftype as uint8_t);
HANDLE_ERROR(AfError::from(err_val));
Array::from(temp)
}
}
)
}
data_gen_def!("Create random numbers from uniform distribution", randu, af_randu);
data_gen_def!("Create random numbers from normal distribution", randn, af_randn);
pub struct RandomEngine {
handle: i64,
}
impl From<i64> for RandomEngine {
fn from(t: i64) -> RandomEngine {
RandomEngine {handle: t}
}
}
impl RandomEngine {
pub fn new(rengine: RandomEngineType, seed: Option<u64>) -> RandomEngine {
unsafe {
let mut temp: i64 = 0;
let err_val = af_create_random_engine(&mut temp as MutRandEngine, rengine as uint8_t,
match seed {Some(s) => s, None => 0} as Uintl);
HANDLE_ERROR(AfError::from(err_val));
RandomEngine::from(temp)
}
}
pub fn get_type(&self) -> RandomEngineType {
unsafe {
let mut temp: u8 = 0;
let err_val = af_random_engine_get_type(&mut temp as *mut uint8_t,
self.handle as RandEngine);
HANDLE_ERROR(AfError::from(err_val));
RandomEngineType::from(temp as i32)
}
}
pub fn set_type(&mut self, engine_type: RandomEngineType) {
unsafe {
let err_val = af_random_engine_set_type(&mut self.handle as MutRandEngine,
engine_type as uint8_t);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn set_seed(&mut self, seed: u64) {
unsafe {
let err_val = af_random_engine_set_seed(&mut self.handle as MutRandEngine,
seed as Uintl);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn get_seed(&self) -> u64 {
unsafe {
let mut seed: u64 = 0;
let err_val = af_random_engine_get_seed(&mut seed as *mut Uintl, self.handle as RandEngine);
HANDLE_ERROR(AfError::from(err_val));
seed
}
}
pub fn get(&self) -> i64 {
self.handle
}
}
impl Clone for RandomEngine {
fn clone(&self) -> RandomEngine {
unsafe {
let mut temp: i64 = 0;
let err_val = af_retain_random_engine(&mut temp as MutRandEngine, self.handle as RandEngine);
HANDLE_ERROR(AfError::from(err_val));
RandomEngine::from(temp)
}
}
}
impl Drop for RandomEngine {
fn drop(&mut self) {
unsafe {
let err_val = af_release_random_engine(self.handle as RandEngine);
HANDLE_ERROR(AfError::from(err_val));
}
}
}
pub fn get_default_random_engine() -> RandomEngine {
unsafe {
let mut temp : i64 = 0;
let mut err_val = af_get_default_random_engine(&mut temp as MutRandEngine);
HANDLE_ERROR(AfError::from(err_val));
let mut handle : i64 = 0;
err_val = af_retain_random_engine(&mut handle as MutRandEngine, temp as RandEngine);
HANDLE_ERROR(AfError::from(err_val));
RandomEngine::from(handle)
}
}
pub fn set_default_random_engine_type(rtype: RandomEngineType) {
unsafe {
let err_val = af_set_default_random_engine_type(rtype as uint8_t);
HANDLE_ERROR(AfError::from(err_val));
}
}
pub fn random_uniform<T: HasAfEnum>(dims: Dim4, engine: RandomEngine) -> Array {
unsafe {
let aftype = T::get_af_dtype();
let mut temp : i64 = 0;
let err_val = af_random_uniform(&mut temp as MutAfArray,
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
aftype as uint8_t, engine.get() as RandEngine);
HANDLE_ERROR(AfError::from(err_val));
Array::from(temp)
}
}
pub fn random_normal<T: HasAfEnum>(dims: Dim4, engine: RandomEngine) -> Array {
unsafe {
let aftype = T::get_af_dtype();
let mut temp : i64 = 0;
let err_val = af_random_normal(&mut temp as MutAfArray,
dims.ndims() as c_uint, dims.get().as_ptr() as *const DimT,
aftype as uint8_t, engine.get() as RandEngine);
HANDLE_ERROR(AfError::from(err_val));
Array::from(temp)
}
}