pub(crate) use _random::module_def;
#[pymodule]
mod _random {
use crate::common::lock::PyMutex;
use crate::vm::{
PyObjectRef, PyPayload, PyRef, PyResult, VirtualMachine,
builtins::{PyInt, PyTupleRef},
convert::ToPyException,
function::OptionalOption,
types::{Constructor, Initializer},
};
use itertools::Itertools;
use malachite_bigint::{BigInt, BigUint, Sign};
use mt19937::MT19937;
use num_traits::{Signed, Zero};
use rand_core::{RngCore, SeedableRng};
use rustpython_vm::types::DefaultConstructor;
#[pyattr]
#[pyclass(name = "Random")]
#[derive(Debug, PyPayload, Default)]
struct PyRandom {
rng: PyMutex<MT19937>,
}
impl DefaultConstructor for PyRandom {}
impl Initializer for PyRandom {
type Args = OptionalOption;
fn init(zelf: PyRef<Self>, x: Self::Args, vm: &VirtualMachine) -> PyResult<()> {
zelf.seed(x, vm)
}
}
#[pyclass(flags(BASETYPE), with(Constructor, Initializer))]
impl PyRandom {
#[pymethod]
fn random(&self) -> f64 {
let mut rng = self.rng.lock();
mt19937::gen_res53(&mut *rng)
}
#[pymethod]
fn seed(&self, n: OptionalOption<PyObjectRef>, vm: &VirtualMachine) -> PyResult<()> {
*self.rng.lock() = match n.flatten() {
Some(n) => {
let (_, mut key) = match n.downcast::<PyInt>() {
Ok(n) => n.as_bigint().abs(),
Err(obj) => BigInt::from(obj.hash(vm)?).abs(),
}
.to_u32_digits();
if cfg!(target_endian = "big") {
key.reverse();
}
let key = if key.is_empty() { &[0] } else { key.as_slice() };
MT19937::new_with_slice_seed(key)
}
None => MT19937::try_from_os_rng()
.map_err(|e| std::io::Error::from(e).to_pyexception(vm))?,
};
Ok(())
}
#[pymethod]
fn getrandbits(&self, k: PyObjectRef, vm: &VirtualMachine) -> PyResult<BigInt> {
let k_int = k.try_index(vm)?;
let k_bigint = k_int.as_bigint();
if k_bigint.is_negative() {
return Err(vm.new_value_error("number of bits must be non-negative"));
}
let k: isize = k_int
.try_to_primitive(vm)
.map_err(|_| vm.new_overflow_error("getrandbits: number of bits too large"))?;
match k {
0 => Ok(BigInt::zero()),
mut k => {
let mut rng = self.rng.lock();
let mut gen_u32 = |k| {
let r = rng.next_u32();
if k < 32 { r >> (32 - k) } else { r }
};
let words = (k - 1) / 32 + 1;
let word_array = (0..words)
.map(|_| {
let word = gen_u32(k);
k = k.wrapping_sub(32);
word
})
.collect::<Vec<_>>();
let uint = BigUint::new(word_array);
let sign = if uint.is_zero() {
Sign::NoSign
} else {
Sign::Plus
};
Ok(BigInt::from_biguint(sign, uint))
}
}
}
#[pymethod]
fn getstate(&self, vm: &VirtualMachine) -> PyTupleRef {
let rng = self.rng.lock();
vm.new_tuple(
rng.get_state()
.iter()
.copied()
.chain([rng.get_index() as u32])
.map(|i| vm.ctx.new_int(i).into())
.collect::<Vec<PyObjectRef>>(),
)
}
#[pymethod]
fn setstate(&self, state: PyTupleRef, vm: &VirtualMachine) -> PyResult<()> {
let state: &[_; mt19937::N + 1] = state
.as_slice()
.try_into()
.map_err(|_| vm.new_value_error("state vector is the wrong size"))?;
let (index, state) = state.split_last().unwrap();
let index: usize = index.try_to_value(vm)?;
if index > mt19937::N {
return Err(vm.new_value_error("invalid state"));
}
let state: [u32; mt19937::N] = state
.iter()
.map(|i| i.try_to_value(vm))
.process_results(|it| it.collect_array())?
.unwrap();
let mut rng = self.rng.lock();
rng.set_state(&state);
rng.set_index(index);
Ok(())
}
}
}