pub mod basic;
pub mod continuous;
pub mod continuous_simplified;
pub mod discrete;
pub use basic::{normal_, rand, randint, randint_, randn, randperm, uniform_};
pub use discrete::{bernoulli, bernoulli_, multinomial};
pub use continuous::{
beta, cauchy, chi_squared, exponential_, f_distribution, gamma, log_normal, student_t, weibull,
};
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::creation::ones;
#[test]
fn test_basic_random_operations() -> torsh_core::Result<()> {
let tensor = rand(&[3, 4], None, None, Some(42))?;
assert_eq!(tensor.shape().dims(), &[3, 4]);
let data = tensor.data()?;
for &val in data.iter() {
assert!(val >= 0.0 && val < 1.0, "Values should be in [0, 1)");
}
let uniform = uniform_(&[10], -2.0, 2.0, Some(42))?;
let uniform_data = uniform.data()?;
for &val in uniform_data.iter() {
assert!(val >= -2.0 && val < 2.0, "Values should be in [-2, 2)");
}
let normal = randn(&[100], Some(0.0), Some(1.0), Some(42))?;
assert_eq!(normal.shape().dims(), &[100]);
Ok(())
}
#[test]
fn test_discrete_distributions() -> torsh_core::Result<()> {
let probs = ones::<f32>(&[3])?; let samples = multinomial(&probs, 5, true, Some(42))?;
assert_eq!(samples.shape().dims(), &[5]);
let sample_data = samples.data()?;
for &val in sample_data.iter() {
assert!(val >= 0.0 && val < 3.0, "Samples should be valid indices");
}
let bernoulli_tensor = bernoulli_(&[20], 0.5, Some(42))?;
let bernoulli_data = bernoulli_tensor.data()?;
for &val in bernoulli_data.iter() {
assert!(
val == 0.0 || val == 1.0,
"Bernoulli values should be 0 or 1"
);
}
Ok(())
}
#[test]
fn test_continuous_distributions() -> torsh_core::Result<()> {
let exp_tensor = exponential_(&[50], 2.0, Some(42))?;
let exp_data = exp_tensor.data()?;
for &val in exp_data.iter() {
assert!(val > 0.0, "Exponential values should be positive");
}
let gamma_tensor = gamma(&[30], 2.0, Some(1.0), Some(42))?;
let gamma_data = gamma_tensor.data()?;
for &val in gamma_data.iter() {
assert!(val > 0.0, "Gamma values should be positive");
}
let beta_tensor = beta(&[25], 2.0, 2.0, Some(42))?;
let beta_data = beta_tensor.data()?;
for &val in beta_data.iter() {
assert!(val >= 0.0 && val <= 1.0, "Beta values should be in [0,1]");
}
Ok(())
}
#[test]
fn test_random_integers() -> torsh_core::Result<()> {
let int_tensor = randint(&[20], 10, Some(42))?;
let int_data = int_tensor.data()?;
for &val in int_data.iter() {
assert!(
val >= 0.0 && val < 10.0,
"Random integers should be in [0, 10)"
);
}
let int_range = randint_(&[15], -5, 5, Some(42))?;
let range_data = int_range.data()?;
for &val in range_data.iter() {
assert!(
val >= -5.0 && val < 5.0,
"Random integers should be in [-5, 5)"
);
}
Ok(())
}
#[test]
fn test_random_permutation() -> torsh_core::Result<()> {
let perm = randperm(10, Some(42))?;
assert_eq!(perm.shape().dims(), &[10]);
let perm_data = perm.data()?;
let mut values: Vec<usize> = perm_data.iter().map(|&x| x as usize).collect();
values.sort();
for (i, &val) in values.iter().enumerate() {
assert_eq!(
val, i,
"Permutation should contain all values 0..n-1 exactly once"
);
}
Ok(())
}
#[test]
fn test_seeded_reproducibility() -> torsh_core::Result<()> {
let tensor1 = rand(&[5], None, None, Some(123))?;
let tensor2 = rand(&[5], None, None, Some(123))?;
let data1 = tensor1.data()?;
let data2 = tensor2.data()?;
for (val1, val2) in data1.iter().zip(data2.iter()) {
assert_eq!(val1, val2, "Same seed should produce identical results");
}
let tensor3 = rand(&[100], None, None, Some(456))?;
let data3 = tensor3.data()?;
let mut same_count = 0;
for (val1, val3) in data1.iter().zip(data3.iter()) {
if (val1 - val3).abs() < 1e-6 {
same_count += 1;
}
}
assert!(
same_count < 3,
"Different seeds should produce different results"
);
Ok(())
}
#[test]
fn test_parameter_validation() -> torsh_core::Result<()> {
assert!(uniform_(&[5], 2.0, 1.0, None).is_err()); assert!(normal_(&[5], 0.0, -1.0, None).is_err()); assert!(randint_(&[5], 5, 5, None).is_err()); assert!(exponential_(&[5], -1.0, None).is_err()); assert!(gamma(&[5], -1.0, None, None).is_err()); assert!(beta(&[5], -1.0, 1.0, None).is_err()); assert!(chi_squared(&[5], -1.0, None).is_err());
Ok(())
}
#[test]
fn test_statistical_properties() -> torsh_core::Result<()> {
let n = 10000;
let uniform = uniform_(&[n], 0.0, 1.0, Some(42))?;
let uniform_data = uniform.data()?;
let mean: f32 = uniform_data.iter().sum::<f32>() / n as f32;
assert!(
(mean - 0.5).abs() < 0.05,
"Uniform mean should be approximately 0.5"
);
let normal = normal_(&[n], 0.0, 1.0, Some(42))?;
let normal_data = normal.data()?;
let normal_mean: f32 = normal_data.iter().sum::<f32>() / n as f32;
assert!(
normal_mean.abs() < 0.1,
"Normal mean should be approximately 0"
);
let exp = exponential_(&[n], 2.0, Some(42))?;
let exp_data = exp.data()?;
let exp_mean: f32 = exp_data.iter().sum::<f32>() / n as f32;
let expected_mean = 1.0 / 2.0; assert!(
(exp_mean - expected_mean).abs() < 0.1,
"Exponential mean should be approximately 1/lambda"
);
Ok(())
}
}