use crate::error::{StatsError, StatsResult};
use scirs2_core::ndarray::{Array1, Array2, ArrayView1};
use scirs2_core::numeric::{Float, NumCast, Zero};
use scirs2_core::random::prelude::*;
use scirs2_core::random::uniform::SampleUniform;
use scirs2_core::random::{rngs::StdRng, SeedableRng};
use scirs2_core::random::{Distribution, StandardNormal};
#[allow(dead_code)]
pub fn random_sample<T, D>(
size: usize,
distribution: &D,
seed: Option<u64>,
) -> StatsResult<Array1<T>>
where
T: Copy + Zero,
D: Distribution<T>,
{
if size == 0 {
return Err(StatsError::InvalidArgument(
"Size must be positive".to_string(),
));
}
let mut rng: StdRng = match seed {
Some(seed_value) => {
let mut seed_bytes = [0u8; 32];
seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
SeedableRng::from_seed(seed_bytes)
}
None => {
{
let mut system_rng = scirs2_core::random::thread_rng();
let seed: [u8; 32] = system_rng.random();
SeedableRng::from_seed(seed)
}
}
};
let mut result = Array1::zeros(size);
for i in 0..size {
result[i] = distribution.sample(&mut rng);
}
Ok(result)
}
#[allow(dead_code)]
pub fn uniform<F>(low: F, high: F, size: usize, seed: Option<u64>) -> StatsResult<Array1<F>>
where
F: Float + NumCast + Zero + SampleUniform + std::fmt::Display,
{
if size == 0 {
return Err(StatsError::InvalidArgument(
"Size must be positive".to_string(),
));
}
if low >= high {
return Err(StatsError::InvalidArgument(
"Upper bound must be greater than lower bound".to_string(),
));
}
let distribution = scirs2_core::random::Uniform::new(low, high).map_err(|e| {
StatsError::ComputationError(format!("Failed to create uniform distribution: {}", e))
})?;
random_sample(size, &distribution, seed)
}
#[allow(dead_code)]
pub fn randint(low: i64, high: i64, size: usize, seed: Option<u64>) -> StatsResult<Array1<i64>> {
if size == 0 {
return Err(StatsError::InvalidArgument(
"Size must be positive".to_string(),
));
}
if low >= high {
return Err(StatsError::InvalidArgument(
"Upper bound must be greater than lower bound".to_string(),
));
}
let distribution = scirs2_core::random::Uniform::new_inclusive(low, high - 1).map_err(|e| {
StatsError::ComputationError(format!("Failed to create uniform distribution: {}", e))
})?;
random_sample(size, &distribution, seed)
}
#[allow(dead_code)]
pub fn randn(size: usize, seed: Option<u64>) -> StatsResult<Array1<f64>> {
if size == 0 {
return Err(StatsError::InvalidArgument(
"Size must be positive".to_string(),
));
}
let mut rng: StdRng = match seed {
Some(seed_value) => {
let mut seed_bytes = [0u8; 32];
seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
SeedableRng::from_seed(seed_bytes)
}
None => {
{
let mut system_rng = scirs2_core::random::thread_rng();
let seed: [u8; 32] = system_rng.random();
SeedableRng::from_seed(seed)
}
}
};
let distribution = StandardNormal;
let mut result = Array1::zeros(size);
for i in 0..size {
result[i] = distribution.sample(&mut rng);
}
Ok(result)
}
#[allow(dead_code)]
pub fn choice<T>(
a: &ArrayView1<T>,
size: usize,
replace: bool,
p: Option<&ArrayView1<f64>>,
seed: Option<u64>,
) -> StatsResult<Array1<T>>
where
T: Copy,
{
let n = a.len();
if n == 0 {
return Err(StatsError::InvalidArgument(
"Input array cannot be empty".to_string(),
));
}
if size == 0 {
return Err(StatsError::InvalidArgument(
"Size must be positive".to_string(),
));
}
if !replace && size > n {
return Err(StatsError::InvalidArgument(
"Cannot take a larger sample than population when 'replace=false'".to_string(),
));
}
let mut rng: StdRng = match seed {
Some(seed_value) => {
let mut seed_bytes = [0u8; 32];
seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
SeedableRng::from_seed(seed_bytes)
}
None => {
{
let mut system_rng = scirs2_core::random::thread_rng();
let seed: [u8; 32] = system_rng.random();
SeedableRng::from_seed(seed)
}
}
};
let mut result = Vec::with_capacity(size);
if let Some(weights) = p {
if weights.len() != n {
return Err(StatsError::DimensionMismatch(
"Length of weights must match length of array".to_string(),
));
}
let sum: f64 = weights.iter().sum();
if (sum - 1.0).abs() > 1e-10 {
return Err(StatsError::InvalidArgument(
"Weights must sum to 1.0".to_string(),
));
}
let mut cumulative = Vec::with_capacity(n);
let mut cum_sum = 0.0;
for &w in weights.iter() {
if w < 0.0 {
return Err(StatsError::InvalidArgument(
"Weights must be non-negative".to_string(),
));
}
cum_sum += w;
cumulative.push(cum_sum);
}
if replace {
for _ in 0..size {
let r: f64 = rng.random();
let mut low = 0;
let mut high = n - 1;
while low < high {
let mid = (low + high) / 2;
if r > cumulative[mid] {
low = mid + 1;
} else {
high = mid;
}
}
result.push(a[low]);
}
} else {
let mut indices: Vec<usize> = (0..n).collect();
for i in 0..size {
let mut remaining_weights = vec![0.0; n - i];
let mut total_weight = 0.0;
for j in 0..n - i {
remaining_weights[j] = weights[indices[j]];
total_weight += remaining_weights[j];
}
for w in remaining_weights.iter_mut() {
*w /= total_weight;
}
let mut cum_weights = vec![0.0; n - i];
let mut cum_sum = 0.0;
for j in 0..n - i {
cum_sum += remaining_weights[j];
cum_weights[j] = cum_sum;
}
let r: f64 = rng.random();
let mut selected = 0;
for (j, &weight) in cum_weights.iter().enumerate().take(n - i) {
if r <= weight {
selected = j;
break;
}
}
result.push(a[indices[selected]]);
indices.swap(selected, n - i - 1);
}
}
} else {
if replace {
let uniform = scirs2_core::random::Uniform::new(0, n).expect("Operation failed");
for _ in 0..size {
let idx = uniform.sample(&mut rng);
result.push(a[idx]);
}
} else {
let mut indices: Vec<usize> = (0..n).collect();
for i in 0..size {
let j = rng.random_range(i..n);
indices.swap(i, j);
result.push(a[indices[i]]);
}
}
}
Ok(Array1::from(result))
}
#[allow(dead_code)]
pub fn permutation<T>(x: &ArrayView1<T>, seed: Option<u64>) -> StatsResult<Array1<T>>
where
T: Copy,
{
let n = x.len();
if n == 0 {
return Err(StatsError::InvalidArgument(
"Input array cannot be empty".to_string(),
));
}
let mut rng: StdRng = match seed {
Some(seed_value) => {
let mut seed_bytes = [0u8; 32];
seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
SeedableRng::from_seed(seed_bytes)
}
None => {
{
let mut system_rng = scirs2_core::random::thread_rng();
let seed: [u8; 32] = system_rng.random();
SeedableRng::from_seed(seed)
}
}
};
let mut result = Array1::from_iter(x.iter().cloned());
for i in (1..n).rev() {
let j = rng.random_range(0..i);
result.swap(i, j);
}
Ok(result)
}
#[allow(dead_code)]
pub fn permutation_int(n: usize, seed: Option<u64>) -> StatsResult<Array1<usize>> {
if n == 0 {
return Err(StatsError::InvalidArgument(
"Length must be positive".to_string(),
));
}
let mut rng: StdRng = match seed {
Some(seed_value) => {
let mut seed_bytes = [0u8; 32];
seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
SeedableRng::from_seed(seed_bytes)
}
None => {
{
let mut system_rng = scirs2_core::random::thread_rng();
let seed: [u8; 32] = system_rng.random();
SeedableRng::from_seed(seed)
}
}
};
let mut result = Array1::from_iter(0..n);
for i in (1..n).rev() {
let j = rng.random_range(0..i);
result.swap(i, j);
}
Ok(result)
}
#[allow(dead_code)]
pub fn random_binary_matrix(
n_rows: usize,
n_cols: usize,
density: f64,
seed: Option<u64>,
) -> StatsResult<Array2<i32>> {
if n_rows == 0 || n_cols == 0 {
return Err(StatsError::InvalidArgument(
"Dimensions must be positive".to_string(),
));
}
if !(0.0..=1.0).contains(&density) {
return Err(StatsError::InvalidArgument(
"Density must be between 0 and 1".to_string(),
));
}
let mut rng: StdRng = match seed {
Some(seed_value) => {
let mut seed_bytes = [0u8; 32];
seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
SeedableRng::from_seed(seed_bytes)
}
None => {
{
let mut system_rng = scirs2_core::random::thread_rng();
let seed: [u8; 32] = system_rng.random();
SeedableRng::from_seed(seed)
}
}
};
let mut result = Array2::zeros((n_rows, n_cols));
for i in 0..n_rows {
for j in 0..n_cols {
if rng.random::<f64>() < density {
result[[i, j]] = 1;
}
}
}
Ok(result)
}
#[allow(dead_code)]
pub fn bootstrap_sample<T>(
x: &ArrayView1<T>,
n_samples_: usize,
seed: Option<u64>,
) -> StatsResult<Array2<T>>
where
T: Copy + scirs2_core::numeric::Zero,
{
let n = x.len();
if n == 0 {
return Err(StatsError::InvalidArgument(
"Input array cannot be empty".to_string(),
));
}
if n_samples_ == 0 {
return Err(StatsError::InvalidArgument(
"Number of _samples must be positive".to_string(),
));
}
let mut rng: StdRng = match seed {
Some(seed_value) => {
let mut seed_bytes = [0u8; 32];
seed_bytes[..8].copy_from_slice(&seed_value.to_le_bytes());
SeedableRng::from_seed(seed_bytes)
}
None => {
{
let mut system_rng = scirs2_core::random::thread_rng();
let seed: [u8; 32] = system_rng.random();
SeedableRng::from_seed(seed)
}
}
};
let uniform = scirs2_core::random::Uniform::new(0, n).expect("Operation failed");
let mut result = Array2::zeros((n_samples_, n));
for i in 0..n_samples_ {
for j in 0..n {
let idx = uniform.sample(&mut rng);
result[[i, j]] = x[idx];
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use scirs2_core::ndarray::array;
#[test]
fn test_random_sample() {
let uniform_dist = scirs2_core::random::Uniform::new(0.0, 1.0).expect("Operation failed");
let samples = random_sample(100, &uniform_dist, Some(42)).expect("Operation failed");
assert_eq!(samples.len(), 100);
for &s in samples.iter() {
assert!((0.0..1.0).contains(&s));
}
assert!(
random_sample::<f64, scirs2_core::random::Uniform<f64>>(0, &uniform_dist, None)
.is_err()
);
}
#[test]
fn test_uniform() {
let samples = uniform(10.0, 20.0, 50, Some(42)).expect("Operation failed");
assert_eq!(samples.len(), 50);
for &s in samples.iter() {
assert!((10.0..20.0).contains(&s));
}
assert!(uniform(5.0, 5.0, 10, None).is_err());
assert!(uniform(10.0, 5.0, 10, None).is_err());
assert!(uniform(0.0, 1.0, 0, None).is_err());
}
#[test]
fn test_randint() {
let samples = randint(1, 101, 100, Some(42)).expect("Operation failed");
assert_eq!(samples.len(), 100);
for &s in samples.iter() {
assert!((1..=100).contains(&s));
}
assert!(randint(5, 5, 10, None).is_err());
assert!(randint(10, 5, 10, None).is_err());
assert!(randint(0, 10, 0, None).is_err());
}
#[test]
fn test_randn() {
let samples = randn(1000, Some(42)).expect("Operation failed");
assert_eq!(samples.len(), 1000);
let sum: f64 = samples.iter().sum();
let mean = sum / 1000.0;
let sum_sq: f64 = samples.iter().map(|&x| (x - mean).powi(2)).sum();
let variance = sum_sq / 1000.0;
assert!(mean.abs() < 0.1);
assert_relative_eq!(variance, 1.0, epsilon = 0.2);
assert!(randn(0, None).is_err());
}
#[test]
fn test_choice() {
let options = array![10, 20, 30, 40, 50];
let choices = choice(&options.view(), 10, true, None, Some(42)).expect("Operation failed");
assert_eq!(choices.len(), 10);
for &c in choices.iter() {
assert!(options.iter().any(|&x| x == c));
}
let choices_no_replace =
choice(&options.view(), 3, false, None, Some(123)).expect("Operation failed");
assert_eq!(choices_no_replace.len(), 3);
for i in 0..choices_no_replace.len() {
for j in i + 1..choices_no_replace.len() {
assert_ne!(choices_no_replace[i], choices_no_replace[j]);
}
}
let weights = array![0.1, 0.2, 0.3, 0.2, 0.2];
let weighted_choices = choice(&options.view(), 5, true, Some(&weights.view()), Some(42))
.expect("Operation failed");
assert_eq!(weighted_choices.len(), 5);
assert!(choice(&options.view(), 0, true, None, None).is_err());
assert!(choice(&options.view(), 10, false, None, None).is_err());
let wrong_weights = array![0.5, 0.5];
assert!(choice(&options.view(), 2, true, Some(&wrong_weights.view()), None).is_err());
let neg_weights = array![-0.1, 0.2, 0.3, 0.3, 0.3];
assert!(choice(&options.view(), 2, true, Some(&neg_weights.view()), None).is_err());
let empty: Array1<i32> = array![];
assert!(choice(&empty.view(), 1, true, None, None).is_err());
}
#[test]
fn test_permutation() {
let arr = array![1, 2, 3, 4, 5];
let perm = permutation(&arr.view(), Some(42)).expect("Operation failed");
assert_eq!(perm.len(), arr.len());
for &val in arr.iter() {
assert!(perm.iter().any(|&x| x == val));
}
let empty: Array1<i32> = array![];
assert!(permutation(&empty.view(), None).is_err());
}
#[test]
fn test_permutation_int() {
let perm = permutation_int(10, Some(42)).expect("Operation failed");
assert_eq!(perm.len(), 10);
for i in 0..10 {
assert!(perm.iter().any(|&x| x == i));
}
assert!(permutation_int(0, None).is_err());
}
#[test]
fn test_random_binary_matrix() {
let matrix = random_binary_matrix(5, 5, 0.5, Some(42)).expect("Operation failed");
assert_eq!(matrix.shape(), &[5, 5]);
for &val in matrix.iter() {
assert!(val == 0 || val == 1);
}
let ones_count = matrix.iter().filter(|&&x| x == 1).count();
let density = ones_count as f64 / 25.0;
assert!(density > 0.2 && density < 0.8);
assert!(random_binary_matrix(0, 5, 0.5, None).is_err());
assert!(random_binary_matrix(5, 0, 0.5, None).is_err());
assert!(random_binary_matrix(5, 5, -0.1, None).is_err());
assert!(random_binary_matrix(5, 5, 1.1, None).is_err());
}
#[test]
fn test_bootstrap_sample() {
let data = array![1.0, 2.0, 3.0, 4.0, 5.0];
let samples = bootstrap_sample(&data.view(), 10, Some(42)).expect("Operation failed");
assert_eq!(samples.shape(), &[10, 5]);
assert!(bootstrap_sample(&data.view(), 0, None).is_err());
let empty: Array1<f64> = array![];
assert!(bootstrap_sample(&empty.view(), 10, None).is_err());
}
}