use crate::distributions::{Distribution, DistributionTrait, DistributionUtils};
use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
use std::marker::PhantomData;
#[derive(Debug, Clone)]
pub struct Uniform<T: Float> {
pub low: Tensor<T>,
pub high: Tensor<T>,
pub base: Distribution,
_phantom: PhantomData<T>,
}
impl<T: Float + 'static> Uniform<T>
where
T: rand::distributions::uniform::SampleUniform + num_traits::FromPrimitive + std::fmt::Display,
{
pub fn new(low: Tensor<T>, high: Tensor<T>, validate_args: bool) -> RusTorchResult<Self> {
if validate_args {
let low_data = low.data.as_slice().unwrap();
let high_data = high.data.as_slice().unwrap();
for (i, (&l, &h)) in low_data.iter().zip(high_data.iter().cycle()).enumerate() {
if l >= h {
return Err(RusTorchError::invalid_parameter(&format!(
"low must be less than high, got low[{}] = {}, high[{}] = {}",
i, l, i, h
)));
}
}
}
let batch_shape = Distribution::broadcast_shapes(low.shape(), high.shape())?;
let event_shape = vec![];
Ok(Self {
low,
high,
base: Distribution::new(batch_shape, event_shape, validate_args),
_phantom: PhantomData,
})
}
pub fn from_scalars(low: T, high: T, validate_args: bool) -> RusTorchResult<Self> {
let low_tensor = Tensor::from_vec(vec![low], vec![]);
let high_tensor = Tensor::from_vec(vec![high], vec![]);
Self::new(low_tensor, high_tensor, validate_args)
}
pub fn standard(validate_args: bool) -> RusTorchResult<Self> {
Self::from_scalars(T::zero(), T::one(), validate_args)
}
pub fn symmetric(half_width: T, validate_args: bool) -> RusTorchResult<Self> {
let neg_width = T::zero() - half_width;
Self::from_scalars(neg_width, half_width, validate_args)
}
fn range(&self) -> Tensor<T> {
let low_data = self.low.data.as_slice().unwrap();
let high_data = self.high.data.as_slice().unwrap();
let range_data: Vec<T> = low_data
.iter()
.zip(high_data.iter().cycle())
.map(|(&l, &h)| h - l)
.collect();
Tensor::from_vec(range_data, self.low.shape().to_vec())
}
}
impl<T: Float + 'static> DistributionTrait<T> for Uniform<T>
where
T: rand::distributions::uniform::SampleUniform + num_traits::FromPrimitive + std::fmt::Display,
{
fn sample(&self, shape: Option<&[usize]>) -> RusTorchResult<Tensor<T>> {
let sample_shape = self.base.expand_shape(shape);
let uniform_01 = DistributionUtils::random_uniform::<T>(&sample_shape);
let uniform_data = uniform_01.data.as_slice().unwrap();
let low_data = self.low.data.as_slice().unwrap();
let high_data = self.high.data.as_slice().unwrap();
let result_data: Vec<T> = uniform_data
.iter()
.zip(low_data.iter().cycle())
.zip(high_data.iter().cycle())
.map(|((&u, &l), &h)| l + (h - l) * u)
.collect();
Ok(Tensor::from_vec(result_data, sample_shape))
}
fn log_prob(&self, value: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
let value_data = value.data.as_slice().unwrap();
let low_data = self.low.data.as_slice().unwrap();
let high_data = self.high.data.as_slice().unwrap();
let neg_inf = T::neg_infinity();
let result_data: Vec<T> = value_data
.iter()
.zip(low_data.iter().cycle())
.zip(high_data.iter().cycle())
.map(|((&v, &l), &h)| {
if v >= l && v < h {
-(h - l).ln()
} else {
neg_inf
}
})
.collect();
Ok(Tensor::from_vec(result_data, value.shape().to_vec()))
}
fn cdf(&self, value: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
let value_data = value.data.as_slice().unwrap();
let low_data = self.low.data.as_slice().unwrap();
let high_data = self.high.data.as_slice().unwrap();
let result_data: Vec<T> = value_data
.iter()
.zip(low_data.iter().cycle())
.zip(high_data.iter().cycle())
.map(|((&v, &l), &h)| {
if v < l {
T::zero()
} else if v >= h {
T::one()
} else {
(v - l) / (h - l)
}
})
.collect();
Ok(Tensor::from_vec(result_data, value.shape().to_vec()))
}
fn icdf(&self, value: &Tensor<T>) -> RusTorchResult<Tensor<T>> {
let value_data = value.data.as_slice().unwrap();
let low_data = self.low.data.as_slice().unwrap();
let high_data = self.high.data.as_slice().unwrap();
let result_data: Vec<T> = value_data
.iter()
.zip(low_data.iter().cycle())
.zip(high_data.iter().cycle())
.map(|((&p, &l), &h)| l + (h - l) * p)
.collect();
Ok(Tensor::from_vec(result_data, value.shape().to_vec()))
}
fn mean(&self) -> RusTorchResult<Tensor<T>> {
let low_data = self.low.data.as_slice().unwrap();
let high_data = self.high.data.as_slice().unwrap();
let half = T::from(0.5).unwrap();
let mean_data: Vec<T> = low_data
.iter()
.zip(high_data.iter().cycle())
.map(|(&l, &h)| half * (l + h))
.collect();
Ok(Tensor::from_vec(mean_data, self.low.shape().to_vec()))
}
fn variance(&self) -> RusTorchResult<Tensor<T>> {
let range = self.range();
let range_data = range.data.as_slice().unwrap();
let twelve = T::from(12.0).unwrap();
let var_data: Vec<T> = range_data.iter().map(|&r| (r * r) / twelve).collect();
Ok(Tensor::from_vec(var_data, range.shape().to_vec()))
}
fn entropy(&self) -> RusTorchResult<Tensor<T>> {
let range = self.range();
let range_data = range.data.as_slice().unwrap();
let entropy_data: Vec<T> = range_data.iter().map(|&r| r.ln()).collect();
Ok(Tensor::from_vec(entropy_data, range.shape().to_vec()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn test_uniform_creation() {
let low = Tensor::from_vec(vec![0.0f32], vec![1]);
let high = Tensor::from_vec(vec![1.0f32], vec![1]);
let uniform = Uniform::new(low, high, true).unwrap();
assert_eq!(uniform.base.batch_shape, vec![1]);
assert_eq!(uniform.base.event_shape, Vec::<usize>::new());
}
#[test]
fn test_standard_uniform() {
let uniform = Uniform::<f32>::standard(true).unwrap();
let mean = uniform.mean().unwrap();
let var = uniform.variance().unwrap();
assert_abs_diff_eq!(mean.data.as_slice().unwrap()[0], 0.5, epsilon = 1e-6);
assert_abs_diff_eq!(var.data.as_slice().unwrap()[0], 1.0 / 12.0, epsilon = 1e-6);
}
#[test]
fn test_uniform_sampling() {
let uniform = Uniform::<f32>::from_scalars(2.0, 5.0, true).unwrap();
let samples = uniform.sample(Some(&[1000])).unwrap();
assert_eq!(samples.shape(), &[1000]);
let data = samples.data.as_slice().unwrap();
for &x in data {
assert!(x >= 2.0 && x < 5.0);
}
let sample_mean: f32 = data.iter().sum::<f32>() / data.len() as f32;
assert_abs_diff_eq!(sample_mean, 3.5, epsilon = 0.15); }
#[test]
fn test_uniform_log_prob() {
let uniform = Uniform::<f32>::from_scalars(1.0, 3.0, true).unwrap();
let values = Tensor::from_vec(vec![0.5f32, 1.5, 2.0, 3.5], vec![4]);
let log_probs = uniform.log_prob(&values).unwrap();
let log_prob_data = log_probs.data.as_slice().unwrap();
assert_eq!(log_prob_data[0], f32::NEG_INFINITY); assert_abs_diff_eq!(log_prob_data[1], -(2.0f32).ln(), epsilon = 1e-6); assert_abs_diff_eq!(log_prob_data[2], -(2.0f32).ln(), epsilon = 1e-6); assert_eq!(log_prob_data[3], f32::NEG_INFINITY); }
#[test]
fn test_uniform_cdf() {
let uniform = Uniform::<f32>::from_scalars(2.0, 6.0, true).unwrap();
let values = Tensor::from_vec(vec![1.0f32, 2.0, 4.0, 6.0, 7.0], vec![5]);
let cdf_vals = uniform.cdf(&values).unwrap();
let cdf_data = cdf_vals.data.as_slice().unwrap();
assert_eq!(cdf_data[0], 0.0); assert_eq!(cdf_data[1], 0.0); assert_abs_diff_eq!(cdf_data[2], 0.5, epsilon = 1e-6); assert_eq!(cdf_data[3], 1.0); assert_eq!(cdf_data[4], 1.0); }
#[test]
fn test_uniform_icdf() {
let uniform = Uniform::<f32>::from_scalars(1.0, 5.0, true).unwrap();
let probs = Tensor::from_vec(vec![0.0f32, 0.25, 0.5, 0.75, 1.0], vec![5]);
let icdf_vals = uniform.icdf(&probs).unwrap();
let icdf_data = icdf_vals.data.as_slice().unwrap();
assert_abs_diff_eq!(icdf_data[0], 1.0, epsilon = 1e-6); assert_abs_diff_eq!(icdf_data[1], 2.0, epsilon = 1e-6); assert_abs_diff_eq!(icdf_data[2], 3.0, epsilon = 1e-6); assert_abs_diff_eq!(icdf_data[3], 4.0, epsilon = 1e-6); assert_abs_diff_eq!(icdf_data[4], 5.0, epsilon = 1e-6); }
#[test]
fn test_symmetric_uniform() {
let uniform = Uniform::<f32>::symmetric(2.5, true).unwrap();
let mean = uniform.mean().unwrap();
assert_abs_diff_eq!(mean.data.as_slice().unwrap()[0], 0.0, epsilon = 1e-6);
}
#[test]
fn test_invalid_parameters() {
assert!(Uniform::<f32>::from_scalars(3.0, 2.0, true).is_err());
assert!(Uniform::<f32>::from_scalars(1.0, 1.0, true).is_err());
}
}