use Integer;
use inner::{Inner, InnerMut};
use gmp_mpfr_sys::gmp::{self, randstate_t};
use std::marker::PhantomData;
use std::mem;
use std::os::raw::{c_ulong, c_void};
use std::panic::{self, AssertUnwindSafe};
use std::process;
pub struct RandState<'a> {
inner: randstate_t,
phantom: PhantomData<&'a RandGen>,
}
impl<'a> Default for RandState<'a> {
#[inline]
fn default() -> RandState<'a> {
RandState::new()
}
}
impl<'a> Clone for RandState<'a> {
#[inline]
fn clone(&self) -> RandState<'a> {
unsafe {
let mut inner = mem::uninitialized();
gmp::randinit_set(&mut inner, self.inner());
RandState {
inner,
phantom: PhantomData,
}
}
}
}
impl<'a> Drop for RandState<'a> {
#[inline]
fn drop(&mut self) {
unsafe {
gmp::randclear(self.inner_mut());
}
}
}
unsafe impl<'a> Send for RandState<'a> {}
unsafe impl<'a> Sync for RandState<'a> {}
impl<'a> RandState<'a> {
#[inline]
pub fn new() -> RandState<'a> {
unsafe {
let mut inner = mem::uninitialized();
gmp::randinit_default(&mut inner);
RandState {
inner,
phantom: PhantomData,
}
}
}
pub fn new_mersenne_twister() -> RandState<'a> {
unsafe {
let mut inner = mem::uninitialized();
gmp::randinit_mt(&mut inner);
RandState {
inner,
phantom: PhantomData,
}
}
}
pub fn new_linear_congruential(
a: &Integer,
c: u32,
bits: u32,
) -> RandState<'a> {
unsafe {
let mut inner = mem::uninitialized();
gmp::randinit_lc_2exp(&mut inner, a.inner(), c.into(), bits.into());
RandState {
inner,
phantom: PhantomData,
}
}
}
pub fn new_linear_congruential_size(size: u32) -> Option<RandState<'a>> {
unsafe {
let mut inner = mem::uninitialized();
if gmp::randinit_lc_2exp_size(&mut inner, size.into()) != 0 {
Some(RandState {
inner,
phantom: PhantomData,
})
} else {
None
}
}
}
pub fn new_custom<'c, T>(custom: &mut T) -> RandState<'c>
where
T: 'c + RandGen,
{
let b = Box::new(custom as &mut RandGen);
let r_ptr = Box::into_raw(b);
let inner = MpRandState {
seed: gmp::mpz_t {
alloc: 0,
size: 0,
d: r_ptr as *mut gmp::limb_t,
},
_alg: RandAlg::_DEFAULT,
_algdata: &CUSTOM_FUNCS as *const _ as *mut _,
};
RandState {
inner: unsafe { mem::transmute(inner) },
phantom: PhantomData,
}
}
#[inline]
pub fn seed(&mut self, seed: &Integer) {
unsafe {
gmp::randseed(self.inner_mut(), seed.inner());
}
}
#[inline]
pub fn bits(&mut self, bits: u32) -> u32 {
assert!(bits <= 32, "bits out of range");
unsafe { gmp::urandomb_ui(self.inner_mut(), bits.into()) as u32 }
}
#[inline]
pub fn below(&mut self, bound: u32) -> u32 {
assert_ne!(bound, 0, "cannot be below zero");
unsafe { gmp::urandomm_ui(self.inner_mut(), bound.into()) as u32 }
}
}
pub trait RandGen: Send + Sync {
fn gen(&mut self) -> u32;
#[inline]
fn seed(&mut self, _seed: &Integer) {}
}
#[repr(C)]
enum RandAlg {
_DEFAULT = 0,
}
#[repr(C)]
struct MpRandState {
seed: gmp::mpz_t,
_alg: RandAlg,
_algdata: *mut c_void,
}
#[repr(C)]
struct Funcs {
_seed: Option<unsafe extern "C" fn(*mut randstate_t, *const gmp::mpz_t)>,
_get:
Option<
unsafe extern "C" fn(*mut randstate_t, *mut gmp::limb_t, c_ulong),
>,
_clear: Option<unsafe extern "C" fn(*mut randstate_t)>,
_iset: Option<unsafe extern "C" fn(*mut randstate_t, *const randstate_t)>,
}
macro_rules! c_callback {
{ $(fn $func:ident($($param:tt)*) $body:block)* } => {
$(
unsafe extern "C" fn $func($($param)*) {
panic::catch_unwind(AssertUnwindSafe(|| $body))
.unwrap_or_else(|_| process::abort())
}
)*
}
}
c_callback! {
fn custom_seed(s: *mut randstate_t, seed: *const gmp::mpz_t) {
let s_ptr = s as *mut MpRandState;
let r_ptr = (*s_ptr).seed.d as *mut &mut RandGen;
(*r_ptr).seed(&(*(seed as *const Integer)));
}
fn custom_get(
s: *mut randstate_t,
limb: *mut gmp::limb_t,
bits: c_ulong,
) {
let s_ptr = s as *mut MpRandState;
let r_ptr = (*s_ptr).seed.d as *mut &mut RandGen;
let gen = || (*r_ptr).gen();
#[cfg(gmp_limb_bits_64)]
{
let (limbs, rest) = (bits / 64, bits % 64);
assert_eq!((limbs + 1) as isize as c_ulong, limbs + 1, "overflow");
let limbs = limbs as isize;
for i in 0..limbs {
let n = u64::from(gen()) | u64::from(gen()) << 32;
*(limb.offset(i)) = n as gmp::limb_t;
}
if rest >= 32 {
let mut n = u64::from(gen());
if rest > 32 {
let mask = !(!0 << (rest - 32));
n |= u64::from(gen() & mask) << 32;
}
*(limb.offset(limbs)) = n as gmp::limb_t;
} else if rest > 0 {
let mask = !(!0 << rest);
let n = u64::from(gen() & mask);
*(limb.offset(limbs)) = n as gmp::limb_t;
}
}
#[cfg(gmp_limb_bits_32)]
{
let (limbs, rest) = (bits / 32, bits % 32);
assert_eq!((limbs + 1) as isize as c_ulong, limbs + 1, "overflow");
let limbs = limbs as isize;
for i in 0..limbs {
*(limb.offset(i)) = gen() as gmp::limb_t;
}
if rest > 0 {
let mask = !(!0 << rest);
*(limb.offset(limbs)) = (gen() & mask) as gmp::limb_t;
}
}
}
fn custom_clear(s: *mut randstate_t) {
let s_ptr = s as *mut MpRandState;
let r_ptr = (*s_ptr).seed.d as *mut &mut RandGen;
drop(Box::from_raw(r_ptr));
}
fn custom_iset(_s: *mut randstate_t, _src: *const randstate_t) {
panic!("cannot clone custom Rand");
}
}
const CUSTOM_FUNCS: Funcs = Funcs {
_seed: Some(custom_seed),
_get: Some(custom_get),
_clear: Some(custom_clear),
_iset: Some(custom_iset),
};
impl<'a> Inner for RandState<'a> {
type Output = randstate_t;
#[inline]
fn inner(&self) -> &randstate_t {
&self.inner
}
}
impl<'a> InnerMut for RandState<'a> {
#[inline]
unsafe fn inner_mut(&mut self) -> &mut randstate_t {
&mut self.inner
}
}