use crate::array::Array;
use crate::error::{NumRs2Error, Result};
use num_traits::{Float, NumCast};
use scirs2_core::ndarray::distributions::uniform::SampleUniform;
use scirs2_core::random::prelude::*;
use scirs2_core::SliceRandomExt;
use scirs2_stats::{
distributions::{
lognormal::Lognormal as LogNormal, Bernoulli, Exponential, Gamma, Normal, Uniform,
},
Distribution,
};
use std::sync::{Arc, Mutex};
use std::time::{Duration, SystemTime, UNIX_EPOCH};
pub struct Generator {
rng: Arc<Mutex<StdRng>>,
}
impl Default for Generator {
fn default() -> Self {
Self::new()
}
}
impl Generator {
pub fn new() -> Self {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_else(|_| Duration::from_secs(1));
Self {
rng: Arc::new(Mutex::new(StdRng::seed_from_u64(now.as_secs()))),
}
}
pub fn with_seed(seed: u64) -> Self {
Self {
rng: Arc::new(Mutex::new(StdRng::seed_from_u64(seed))),
}
}
pub fn random<T>(&self, shape: &[usize]) -> Result<Array<T>>
where
T: Clone + SampleUniform + NumCast,
{
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
let rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
for _ in 0..size {
let uniform_dist = scirs2_stats::distributions::Uniform::new(0.0f64, 1.0f64)
.expect("random: uniform distribution [0, 1) should always be valid");
let val_f64 = uniform_dist.rvs(1).expect("uniform sampling failed")[0];
let val = NumCast::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation(
"Failed to convert uniform sample to target type".to_string(),
)
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn normal<T: Float + NumCast + Clone + std::fmt::Debug + std::fmt::Display>(
&self,
mean: T,
std: T,
shape: &[usize],
) -> Result<Array<T>> {
if std <= T::zero() {
return Err(NumRs2Error::InvalidOperation(format!(
"Standard deviation must be positive, got {}",
std
)));
}
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
let mean_f64 = mean.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert mean to f64".to_string())
})?;
let std_f64 = std.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert std to f64".to_string())
})?;
let dist = Normal::new(mean_f64, std_f64).map_err(|e| {
NumRs2Error::InvalidOperation(format!("Failed to create normal distribution: {}", e))
})?;
let rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
for _ in 0..size {
let val_f64 = dist.rvs(1).expect("distribution sampling failed")[0];
let val = T::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation(
"Failed to convert normal sample to target type".to_string(),
)
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn lognormal<T: Float + NumCast + Clone + std::fmt::Debug + std::fmt::Display>(
&self,
mean: T,
sigma: T,
shape: &[usize],
) -> Result<Array<T>> {
if sigma <= T::zero() {
return Err(NumRs2Error::InvalidOperation(format!(
"Sigma must be positive, got {}",
sigma
)));
}
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
let mean_f64 = mean.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert mean to f64".to_string())
})?;
let sigma_f64 = sigma.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert sigma to f64".to_string())
})?;
let dist = LogNormal::new(mean_f64, sigma_f64, 0.0).map_err(|e| {
NumRs2Error::InvalidOperation(format!(
"Failed to create log-normal distribution: {}",
e
))
})?;
let rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
for _ in 0..size {
let val_f64 = dist.rvs(1).expect("distribution sampling failed")[0];
let val = T::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation(
"Failed to convert lognormal sample to target type".to_string(),
)
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn uniform<T: Clone + PartialOrd + SampleUniform + Float + NumCast + std::fmt::Display>(
&self,
low: T,
high: T,
shape: &[usize],
) -> Result<Array<T>> {
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
let dist = Uniform::new(low, high).map_err(|e| {
NumRs2Error::InvalidOperation(format!("Failed to create uniform distribution: {}", e))
})?;
let rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
for _ in 0..size {
vec.push(dist.rvs(1).expect("distribution sampling failed")[0]);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn bernoulli<T: Float + NumCast + Clone + std::fmt::Debug + std::fmt::Display>(
&self,
p: T,
shape: &[usize],
) -> Result<Array<T>> {
if p < T::zero() || p > T::one() {
return Err(NumRs2Error::InvalidOperation(format!(
"Probability must be in [0, 1], got {}",
p
)));
}
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
let p_f64 = p.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert probability to f64".to_string())
})?;
let dist = Bernoulli::new(p_f64).map_err(|e| {
NumRs2Error::InvalidOperation(format!("Failed to create Bernoulli distribution: {}", e))
})?;
let rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
for _ in 0..size {
let val_f64 = dist.rvs(1).expect("distribution sampling failed")[0];
let val = if val_f64 > 0.5 { T::one() } else { T::zero() };
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn gamma<T: Float + NumCast + Clone + std::fmt::Debug + std::fmt::Display>(
&self,
shape_param: T,
scale: T,
size_shape: &[usize],
) -> Result<Array<T>> {
if shape_param <= T::zero() || scale <= T::zero() {
return Err(NumRs2Error::InvalidOperation(format!(
"Shape and scale parameters must be positive, got shape={}, scale={}",
shape_param, scale
)));
}
let arr_size: usize = size_shape.iter().product();
let mut vec = Vec::with_capacity(arr_size);
let shape_f64 = shape_param.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert shape to f64".to_string())
})?;
let scale_f64 = scale.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert scale to f64".to_string())
})?;
let corrected_scale = 1.0 / scale_f64;
let dist = Gamma::new(shape_f64, corrected_scale, 0.0).map_err(|e| {
NumRs2Error::InvalidOperation(format!("Failed to create gamma distribution: {}", e))
})?;
let rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
for _ in 0..arr_size {
let val_f64 = dist.rvs(1).expect("distribution sampling failed")[0];
let val = T::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation(
"Failed to convert gamma sample to target type".to_string(),
)
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(size_shape))
}
pub fn exponential<T: Float + NumCast + Clone + std::fmt::Debug + std::fmt::Display>(
&self,
scale: T,
shape: &[usize],
) -> Result<Array<T>> {
if scale <= T::zero() {
return Err(NumRs2Error::InvalidOperation(format!(
"Scale parameter must be positive, got {}",
scale
)));
}
let size: usize = shape.iter().product();
let mut vec = Vec::with_capacity(size);
let scale_f64 = scale.to_f64().ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert scale to f64".to_string())
})?;
let rate = 1.0 / scale_f64;
let dist = Exponential::new(rate, 0.0).map_err(|e| {
NumRs2Error::InvalidOperation(format!(
"Failed to create exponential distribution: {}",
e
))
})?;
let rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
for _ in 0..size {
let val_f64 = dist.rvs(1).expect("distribution sampling failed")[0];
let val = T::from(val_f64).ok_or_else(|| {
NumRs2Error::InvalidOperation(
"Failed to convert exponential sample to target type".to_string(),
)
})?;
vec.push(val);
}
Ok(Array::from_vec(vec).reshape(shape))
}
pub fn shuffle<T: Clone>(&self, array: &mut Array<T>) -> Result<()> {
let rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
let mut data = array.to_vec();
data.shuffle(&mut thread_rng());
let shape = array.shape();
*array = Array::from_vec(data).reshape(&shape);
Ok(())
}
pub fn choice<T: Clone>(
&self,
array: &Array<T>,
size: Option<usize>,
replace: Option<bool>,
) -> Result<Array<T>> {
let data = array.to_vec();
if data.is_empty() {
return Err(NumRs2Error::InvalidOperation(
"Cannot choose from an empty array".to_string(),
));
}
let choose_size = size.unwrap_or(1);
let with_replacement = replace.unwrap_or(true);
if !with_replacement && choose_size > data.len() {
return Err(NumRs2Error::InvalidOperation(format!(
"Cannot choose {} items without replacement from array of size {}",
choose_size,
data.len()
)));
}
let mut rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
let mut result = Vec::with_capacity(choose_size);
if with_replacement {
for _ in 0..choose_size {
let idx = rng.random_range(0..data.len());
result.push(data[idx].clone());
}
} else {
let mut indices: Vec<usize> = (0..data.len()).collect();
indices.shuffle(&mut thread_rng());
for i in 0..choose_size {
result.push(data[indices[i]].clone());
}
}
if size.is_none() {
Ok(Array::from_vec(result))
} else {
Ok(Array::from_vec(result))
}
}
pub fn permutation<T: NumCast + Clone>(&self, n: usize) -> Result<Array<T>> {
let rng = self
.rng
.lock()
.map_err(|_| NumRs2Error::InvalidOperation("Failed to acquire RNG lock".to_string()))?;
let mut indices: Vec<usize> = (0..n).collect();
indices.shuffle(&mut thread_rng());
let mut result = Vec::with_capacity(n);
for idx in indices {
let val = T::from(idx).ok_or_else(|| {
NumRs2Error::InvalidOperation("Failed to convert index to target type".to_string())
})?;
result.push(val);
}
Ok(Array::from_vec(result))
}
}
lazy_static::lazy_static! {
static ref GLOBAL_GENERATOR: Mutex<Generator> = Mutex::new(Generator::new());
}
pub fn seed(seed: u64) {
if let Ok(mut generator) = GLOBAL_GENERATOR.lock() {
*generator = Generator::with_seed(seed);
}
}
pub fn rand<T>(shape: &[usize]) -> Result<Array<T>>
where
T: Clone + SampleUniform + NumCast,
{
let generator = GLOBAL_GENERATOR.lock().map_err(|_| {
NumRs2Error::InvalidOperation("Failed to acquire global generator lock".to_string())
})?;
generator.random(shape)
}
pub fn randn<T: Float + NumCast + Clone + std::fmt::Debug + std::fmt::Display>(
shape: &[usize],
) -> Result<Array<T>> {
let generator = GLOBAL_GENERATOR.lock().map_err(|_| {
NumRs2Error::InvalidOperation("Failed to acquire global generator lock".to_string())
})?;
generator.normal(T::zero(), T::one(), shape)
}
pub fn uniform<T: Clone + PartialOrd + SampleUniform + Float + NumCast + std::fmt::Display>(
low: T,
high: T,
shape: &[usize],
) -> Result<Array<T>> {
let generator = GLOBAL_GENERATOR.lock().map_err(|_| {
NumRs2Error::InvalidOperation("Failed to acquire global generator lock".to_string())
})?;
generator.uniform(low, high, shape)
}
pub fn shuffle<T: Clone>(array: &mut Array<T>) -> Result<()> {
let generator = GLOBAL_GENERATOR.lock().map_err(|_| {
NumRs2Error::InvalidOperation("Failed to acquire global generator lock".to_string())
})?;
generator.shuffle(array)
}
pub fn choice<T: Clone>(
array: &Array<T>,
size: Option<usize>,
replace: Option<bool>,
) -> Result<Array<T>> {
let generator = GLOBAL_GENERATOR.lock().map_err(|_| {
NumRs2Error::InvalidOperation("Failed to acquire global generator lock".to_string())
})?;
generator.choice(array, size, replace)
}