use scirs2_core::random::prelude::*;
use scirs2_core::random::{FisherF, RandBeta};
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn exponential_(shape: &[usize], lambd: f32, _generator: Option<u64>) -> TorshResult<Tensor> {
if lambd <= 0.0 {
return Err(TorshError::InvalidArgument(
"exponential: lambda must be greater than 0".to_string(),
));
}
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
let mut rng = thread_rng();
let exp_dist = Exponential::new(lambd).map_err(|e| {
TorshError::InvalidArgument(format!("exponential: failed to create distribution: {}", e))
})?;
for _ in 0..size {
let sample = rng.sample(&exp_dist);
values.push(sample);
}
Tensor::from_vec(values, shape)
}
pub fn gamma(
shape: &[usize],
concentration: f32,
rate: Option<f32>,
_generator: Option<u64>,
) -> TorshResult<Tensor> {
if concentration <= 0.0 {
return Err(TorshError::InvalidArgument(
"gamma: concentration must be greater than 0".to_string(),
));
}
let actual_rate = rate.unwrap_or(1.0);
if actual_rate <= 0.0 {
return Err(TorshError::InvalidArgument(
"gamma: rate must be greater than 0".to_string(),
));
}
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
let mut rng = thread_rng();
let gamma_dist = Gamma::new(concentration, actual_rate).map_err(|e| {
TorshError::InvalidArgument(format!("gamma: failed to create distribution: {}", e))
})?;
for _ in 0..size {
let sample = rng.sample(&gamma_dist);
values.push(sample);
}
Tensor::from_vec(values, shape)
}
pub fn beta(
shape: &[usize],
alpha: f32,
beta: f32,
_generator: Option<u64>,
) -> TorshResult<Tensor> {
if alpha <= 0.0 || beta <= 0.0 {
return Err(TorshError::InvalidArgument(
"beta: alpha and beta must be greater than 0".to_string(),
));
}
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
let mut rng = thread_rng();
let beta_dist = RandBeta::new(alpha, beta).map_err(|e| {
TorshError::InvalidArgument(format!("beta: failed to create distribution: {}", e))
})?;
for _ in 0..size {
let sample = rng.sample(&beta_dist);
values.push(sample);
}
Tensor::from_vec(values, shape)
}
pub fn chi_squared(shape: &[usize], df: f32, _generator: Option<u64>) -> TorshResult<Tensor> {
if df <= 0.0 {
return Err(TorshError::InvalidArgument(
"chi_squared: degrees of freedom must be greater than 0".to_string(),
));
}
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
let mut rng = thread_rng();
let chi2_dist = ChiSquared::new(df).map_err(|e| {
TorshError::InvalidArgument(format!("chi_squared: failed to create distribution: {}", e))
})?;
for _ in 0..size {
let sample = rng.sample(&chi2_dist);
values.push(sample);
}
Tensor::from_vec(values, shape)
}
pub fn student_t(shape: &[usize], df: f32, _generator: Option<u64>) -> TorshResult<Tensor> {
if df <= 0.0 {
return Err(TorshError::InvalidArgument(
"student_t: degrees of freedom must be greater than 0".to_string(),
));
}
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
let mut rng = thread_rng();
let t_dist = StudentT::new(df).map_err(|e| {
TorshError::InvalidArgument(format!("student_t: failed to create distribution: {}", e))
})?;
for _ in 0..size {
let sample = rng.sample(&t_dist);
values.push(sample);
}
Tensor::from_vec(values, shape)
}
pub fn f_distribution(
shape: &[usize],
dfnum: f32,
dfden: f32,
_generator: Option<u64>,
) -> TorshResult<Tensor> {
if dfnum <= 0.0 || dfden <= 0.0 {
return Err(TorshError::InvalidArgument(
"f_distribution: degrees of freedom must be greater than 0".to_string(),
));
}
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
let mut rng = thread_rng();
let f_dist = FisherF::new(dfnum, dfden).map_err(|e| {
TorshError::InvalidArgument(format!(
"f_distribution: failed to create distribution: {}",
e
))
})?;
for _ in 0..size {
let sample = rng.sample(&f_dist);
values.push(sample);
}
Tensor::from_vec(values, shape)
}
pub fn log_normal(
shape: &[usize],
loc: f32,
scale: f32,
_generator: Option<u64>,
) -> TorshResult<Tensor> {
if scale <= 0.0 {
return Err(TorshError::InvalidArgument(
"log_normal: scale must be greater than 0".to_string(),
));
}
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
let mut rng = thread_rng();
let lognorm_dist = LogNormal::new(loc, scale).map_err(|e| {
TorshError::InvalidArgument(format!("log_normal: failed to create distribution: {}", e))
})?;
for _ in 0..size {
let sample = rng.sample(&lognorm_dist);
values.push(sample);
}
Tensor::from_vec(values, shape)
}
pub fn weibull(
shape_param: f32,
shape: &[usize],
scale: f32,
_generator: Option<u64>,
) -> TorshResult<Tensor> {
if shape_param <= 0.0 || scale <= 0.0 {
return Err(TorshError::InvalidArgument(
"weibull: shape and scale must be greater than 0".to_string(),
));
}
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
let mut rng = thread_rng();
let weibull_dist = Weibull::new(scale, shape_param).map_err(|e| {
TorshError::InvalidArgument(format!("weibull: failed to create distribution: {}", e))
})?;
for _ in 0..size {
let sample = rng.sample(&weibull_dist);
values.push(sample);
}
Tensor::from_vec(values, shape)
}
pub fn cauchy(
shape: &[usize],
median: f32,
sigma: f32,
_generator: Option<u64>,
) -> TorshResult<Tensor> {
if sigma <= 0.0 {
return Err(TorshError::InvalidArgument(
"cauchy: sigma must be greater than 0".to_string(),
));
}
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
let mut rng = thread_rng();
let cauchy_dist = Cauchy::new(median, sigma).map_err(|e| {
TorshError::InvalidArgument(format!("cauchy: failed to create distribution: {}", e))
})?;
for _ in 0..size {
let sample = rng.sample(&cauchy_dist);
values.push(sample);
}
Tensor::from_vec(values, shape)
}
pub fn dirichlet(alpha: &[f32], num_samples: usize, generator: Option<u64>) -> TorshResult<Tensor> {
if alpha.is_empty() {
return Err(TorshError::InvalidArgument(
"dirichlet: alpha must have at least one element".to_string(),
));
}
for &a in alpha {
if a <= 0.0 {
return Err(TorshError::InvalidArgument(format!(
"dirichlet: all alpha values must be positive, got {}",
a
)));
}
}
if num_samples == 0 {
return Err(TorshError::InvalidArgument(
"dirichlet: num_samples must be greater than 0".to_string(),
));
}
let mut rng = thread_rng();
if let Some(seed) = generator {
let mut seeded = scirs2_core::random::seeded_rng(seed);
let k = alpha.len();
let mut samples = Vec::with_capacity(num_samples * k);
for _ in 0..num_samples {
let mut gamma_samples = Vec::with_capacity(k);
let mut sum = 0.0f32;
for &alpha_i in alpha {
let gamma_dist = Gamma::new(alpha_i, 1.0).map_err(|e| {
TorshError::InvalidArgument(format!(
"dirichlet: failed to create Gamma distribution: {}",
e
))
})?;
let sample = seeded.sample(&gamma_dist);
gamma_samples.push(sample);
sum += sample;
}
for gamma_sample in gamma_samples {
samples.push(gamma_sample / sum);
}
}
return Tensor::from_vec(samples, &[num_samples, k]);
}
let k = alpha.len();
let mut samples = Vec::with_capacity(num_samples * k);
for _ in 0..num_samples {
let mut gamma_samples = Vec::with_capacity(k);
let mut sum = 0.0f32;
for &alpha_i in alpha {
let gamma_dist = Gamma::new(alpha_i, 1.0).map_err(|e| {
TorshError::InvalidArgument(format!(
"dirichlet: failed to create Gamma distribution: {}",
e
))
})?;
let sample = rng.sample(&gamma_dist);
gamma_samples.push(sample);
sum += sample;
}
for gamma_sample in gamma_samples {
samples.push(gamma_sample / sum);
}
}
Tensor::from_vec(samples, &[num_samples, k])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dirichlet_basic() -> TorshResult<()> {
let alpha = vec![1.0, 1.0, 1.0]; let num_samples = 1000;
let samples = dirichlet(&alpha, num_samples, Some(42))?;
assert_eq!(samples.shape().dims(), &[num_samples, alpha.len()]);
let data = samples.to_vec()?;
for i in 0..num_samples {
let row_sum: f32 = (0..alpha.len()).map(|j| data[i * alpha.len() + j]).sum();
assert!(
(row_sum - 1.0).abs() < 1e-4,
"Row {} sum is {}, expected 1.0",
i,
row_sum
);
}
Ok(())
}
#[test]
fn test_dirichlet_validation() {
assert!(dirichlet(&[], 10, None).is_err());
assert!(dirichlet(&[-1.0, 1.0], 10, None).is_err());
assert!(dirichlet(&[0.0, 1.0], 10, None).is_err());
assert!(dirichlet(&[1.0, 1.0], 0, None).is_err());
}
#[test]
fn test_dirichlet_reproducibility() -> TorshResult<()> {
let alpha = vec![2.0, 3.0, 5.0];
let num_samples = 100;
let seed = Some(12345);
let samples1 = dirichlet(&alpha, num_samples, seed)?;
let samples2 = dirichlet(&alpha, num_samples, seed)?;
let data1 = samples1.to_vec()?;
let data2 = samples2.to_vec()?;
for (v1, v2) in data1.iter().zip(data2.iter()) {
assert!(
(v1 - v2).abs() < 1e-6,
"Reproducibility failed: {} vs {}",
v1,
v2
);
}
Ok(())
}
#[test]
fn test_dirichlet_concentration() -> TorshResult<()> {
let high_alpha = vec![100.0, 100.0, 100.0];
let samples = dirichlet(&high_alpha, 100, Some(42))?;
let data = samples.to_vec()?;
let k = high_alpha.len();
let expected_mean = 1.0 / k as f32;
let mean: f32 = (0..100).map(|i| data[i * k]).sum::<f32>() / 100.0;
assert!(
(mean - expected_mean).abs() < 0.05,
"Mean {} too far from expected {}",
mean,
expected_mean
);
Ok(())
}
#[test]
fn test_dirichlet_asymmetric() -> TorshResult<()> {
let alpha = vec![10.0, 1.0, 1.0]; let samples = dirichlet(&alpha, 100, Some(42))?;
let data = samples.to_vec()?;
let k = alpha.len();
let mean0: f32 = (0..100).map(|i| data[i * k]).sum::<f32>() / 100.0;
let mean1: f32 = (0..100).map(|i| data[i * k + 1]).sum::<f32>() / 100.0;
assert!(
mean0 > mean1,
"First dimension mean {} should be greater than second {}",
mean0,
mean1
);
Ok(())
}
}