pub mod bernoulli;
pub mod beta;
pub mod categorical;
pub mod distribution;
pub mod exponential;
pub mod gamma;
pub mod normal;
pub mod uniform;
use crate::error::{RusTorchError, RusTorchResult};
pub use bernoulli::Bernoulli;
pub use beta::Beta;
pub use categorical::Categorical;
pub use distribution::Distribution;
pub use exponential::Exponential;
pub use gamma::Gamma;
pub use normal::Normal;
pub use uniform::Uniform;
use crate::tensor::Tensor;
use num_traits::Float;
pub trait DistributionTrait<T: Float + 'static> {
fn sample(&self, shape: Option<&[usize]>) -> RusTorchResult<Tensor<T>>;
fn sample_n(&self, n: usize) -> RusTorchResult<Tensor<T>> {
self.sample(Some(&[n]))
}
fn log_prob(&self, value: &Tensor<T>) -> RusTorchResult<Tensor<T>>;
fn prob(&self, value: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
let log_p = self.log_prob(value)?;
let data = log_p.data.as_slice().unwrap();
let prob_data: Vec<T> = data.iter().map(|&x| x.exp()).collect();
Ok(Tensor::from_vec(prob_data, log_p.shape().to_vec()))
}
fn cdf(&self, value: &Tensor<T>) -> RusTorchResult<Tensor<T>>;
fn icdf(&self, value: &Tensor<T>) -> RusTorchResult<Tensor<T>>;
fn mean(&self) -> RusTorchResult<Tensor<T>>;
fn variance(&self) -> RusTorchResult<Tensor<T>>;
fn stddev(&self) -> RusTorchResult<Tensor<T>> {
let var = self.variance()?;
let data = var.data.as_slice().unwrap();
let stddev_data: Vec<T> = data.iter().map(|&x| x.sqrt()).collect();
Ok(Tensor::from_vec(stddev_data, var.shape().to_vec()))
}
fn entropy(&self) -> RusTorchResult<Tensor<T>>;
}
pub struct DistributionUtils;
impl DistributionUtils {
pub fn validate_probability<T: Float + std::fmt::Display>(p: &Tensor<T>) -> RusTorchResult<()> {
let data = p.data.as_slice().unwrap();
for &val in data {
if val < T::zero() || val > T::one() {
return Err(RusTorchError::invalid_parameter(&format!(
"Probability must be in [0, 1], got {}",
val
)));
}
}
Ok(())
}
pub fn validate_positive<T: Float + std::fmt::Display>(
param: &Tensor<T>,
name: &str,
) -> RusTorchResult<()> {
let data = param.data.as_slice().unwrap();
for &val in data {
if val <= T::zero() {
return Err(RusTorchError::invalid_parameter(&format!(
"{} must be positive, got {}",
name, val
)));
}
}
Ok(())
}
pub fn validate_non_negative<T: Float + std::fmt::Display>(
param: &Tensor<T>,
name: &str,
) -> RusTorchResult<()> {
let data = param.data.as_slice().unwrap();
for &val in data {
if val < T::zero() {
return Err(RusTorchError::invalid_parameter(&format!(
"{} must be non-negative, got {}",
name, val
)));
}
}
Ok(())
}
pub fn random_uniform<T>(shape: &[usize]) -> Tensor<T>
where
T: Float + 'static + rand::distributions::uniform::SampleUniform,
{
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
let size: usize = shape.iter().product();
let data: Vec<T> = (0..size)
.map(|_| rng.gen_range(T::zero()..T::one()))
.collect();
Tensor::from_vec(data, shape.to_vec())
}
pub fn random_normal<T>(shape: &[usize]) -> Tensor<T>
where
T: Float + 'static + rand::distributions::uniform::SampleUniform,
{
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
let size: usize = shape.iter().product();
let mut data = Vec::with_capacity(size);
let pi = T::from(std::f64::consts::PI).unwrap();
let two = T::from(2.0).unwrap();
for _ in 0..size.div_ceil(2) {
let u1: T = rng.gen_range(T::from(1e-10).unwrap()..T::one());
let u2: T = rng.gen_range(T::zero()..T::one());
let z0 = (-two * u1.ln()).sqrt() * (two * pi * u2).cos();
let z1 = (-two * u1.ln()).sqrt() * (two * pi * u2).sin();
data.push(z0);
if data.len() < size {
data.push(z1);
}
}
data.truncate(size);
Tensor::from_vec(data, shape.to_vec())
}
pub fn random_uniform_scalar<T>() -> T
where
T: Float + 'static + rand::distributions::uniform::SampleUniform,
{
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
rng.gen_range(T::zero()..T::one())
}
pub fn random_normal_scalar<T>() -> T
where
T: Float + 'static + rand::distributions::uniform::SampleUniform,
{
use rand::{thread_rng, Rng};
let mut rng = thread_rng();
let pi = T::from(std::f64::consts::PI).unwrap();
let two = T::from(2.0).unwrap();
let u1: T = rng.gen_range(T::from(1e-10).unwrap()..T::one());
let u2: T = rng.gen_range(T::zero()..T::one());
(-two * u1.ln()).sqrt() * (two * pi * u2).cos()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validation_utils() {
let valid_prob = Tensor::from_vec([0.0f32, 0.5, 1.0].to_vec(), vec![3]);
assert!(DistributionUtils::validate_probability(&valid_prob).is_ok());
let invalid_prob = Tensor::from_vec([0.0f32, 0.5, 1.5].to_vec(), vec![3]);
assert!(DistributionUtils::validate_probability(&invalid_prob).is_err());
let positive = Tensor::from_vec([0.1f32, 1.0, 2.0].to_vec(), vec![3]);
assert!(DistributionUtils::validate_positive(&positive, "param").is_ok());
let non_positive = Tensor::from_vec([0.0f32, 1.0, 2.0].to_vec(), vec![3]);
assert!(DistributionUtils::validate_positive(&non_positive, "param").is_err());
}
#[test]
fn test_random_generators() {
let uniform = DistributionUtils::random_uniform::<f32>(&[10, 5]);
assert_eq!(uniform.shape(), &[10, 5]);
let data = uniform.data.as_slice().unwrap();
for &val in data {
assert!(val >= 0.0 && val <= 1.0);
}
let normal = DistributionUtils::random_normal::<f32>(&[100]);
assert_eq!(normal.shape(), &[100]);
let data = normal.data.as_slice().unwrap();
let mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert!(mean.abs() < 0.5); }
}