use crate::error::{StatsError, StatsResult};
use num_traits::ToPrimitive;
use serde::{Deserialize, Serialize};
use std::cmp::PartialOrd;
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct UniformConfig<T>
where
T: ToPrimitive + PartialOrd,
{
pub a: T,
pub b: T,
}
impl<T> UniformConfig<T>
where
T: ToPrimitive + PartialOrd,
{
pub fn new(a: T, b: T) -> StatsResult<Self> {
if a < b {
Ok(Self { a, b })
} else {
Err(StatsError::InvalidInput {
message: "UniformConfig::new: a must be less than b".to_string(),
})
}
}
}
#[inline]
pub fn uniform_pdf<T>(x: T, a: T, b: T) -> StatsResult<f64>
where
T: ToPrimitive + PartialOrd,
{
let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_pdf: Failed to convert arg1 to f64"
.to_string(),
})?;
let a_64 = a.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_pdf: Failed to convert arg2 to f64"
.to_string(),
})?;
let b_64 = b.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_pdf: Failed to convert arg3 to f64"
.to_string(),
})?;
if a_64 >= b_64 {
return Err(StatsError::InvalidInput {
message: "distributions::uniform_distribution::uniform_pdf: a must be less than b"
.to_string(),
});
}
Ok(if x_64 < a_64 || x_64 > b_64 {
0.0
} else {
1.0 / (b_64 - a_64)
})
}
#[inline]
pub fn uniform_cdf<T>(x: T, a: T, b: T) -> StatsResult<f64>
where
T: ToPrimitive + PartialOrd,
{
let x_64 = x.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_cdf: Failed to convert arg1 to f64"
.to_string(),
})?;
let a_64 = a.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_cdf: Failed to convert arg2 to f64"
.to_string(),
})?;
let b_64 = b.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_cdf: Failed to convert arg3 to f64"
.to_string(),
})?;
if a_64 >= b_64 {
return Err(StatsError::InvalidInput {
message: "distributions::uniform_distribution::uniform_cdf: a must be less than b"
.to_string(),
});
}
Ok(if x_64 < a_64 {
0.0
} else if x_64 > b_64 {
1.0
} else {
(x_64 - a_64) / (b_64 - a_64)
})
}
#[inline]
pub fn uniform_inverse_cdf<T>(p: T, a: T, b: T) -> StatsResult<f64>
where
T: ToPrimitive + PartialOrd,
{
let p_64 = p.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_inverse_cdf: Failed to convert arg1 to f64".to_string(),
})?;
let a_64 = a.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_inverse_cdf: Failed to convert arg2 to f64".to_string(),
})?;
let b_64 = b.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_inverse_cdf: Failed to convert arg3 to f64".to_string(),
})?;
if a_64 >= b_64 {
return Err(StatsError::InvalidInput {
message:
"distributions::uniform_distribution::uniform_inverse_cdf: a must be less than b"
.to_string(),
});
}
if !(0.0..=1.0).contains(&p_64) {
return Err(StatsError::InvalidInput {
message: "distributions::uniform_distribution::uniform_inverse_cdf: p must be between 0 and 1".to_string(),
});
}
Ok(a_64 + (p_64 * (b_64 - a_64)))
}
#[inline]
pub fn uniform_mean<T>(a: T, b: T) -> StatsResult<f64>
where
T: ToPrimitive + PartialOrd,
{
let a_64 = a.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_mean: Failed to convert arg1 to f64"
.to_string(),
})?;
let b_64 = b.to_f64().ok_or_else(|| StatsError::ConversionError {
message: "distributions::uniform_distribution::uniform_mean: Failed to convert arg2 to f64"
.to_string(),
})?;
if a_64 >= b_64 {
return Err(StatsError::InvalidInput {
message: "distributions::uniform_distribution::uniform_mean: a must be less than b"
.to_string(),
});
}
Ok((a_64 + b_64) / 2.0)
}
#[derive(Debug, Clone, Copy)]
pub struct Uniform {
pub a: f64,
pub b: f64,
}
impl Uniform {
pub fn new(a: f64, b: f64) -> StatsResult<Self> {
if a >= b {
return Err(StatsError::InvalidInput {
message: "Uniform::new: a must be strictly less than b".to_string(),
});
}
Ok(Self { a, b })
}
pub fn fit(data: &[f64]) -> StatsResult<Self> {
if data.is_empty() {
return Err(StatsError::InvalidInput {
message: "Uniform::fit: data must not be empty".to_string(),
});
}
let a = data.iter().cloned().fold(f64::INFINITY, f64::min);
let b = data.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
Self::new(a, b)
}
}
impl crate::distributions::traits::Distribution for Uniform {
fn name(&self) -> &str {
"Uniform"
}
fn num_params(&self) -> usize {
2
}
fn pdf(&self, x: f64) -> StatsResult<f64> {
uniform_pdf(x, self.a, self.b)
}
fn logpdf(&self, x: f64) -> StatsResult<f64> {
if x < self.a || x > self.b {
Ok(f64::NEG_INFINITY)
} else {
Ok(-((self.b - self.a).ln()))
}
}
fn cdf(&self, x: f64) -> StatsResult<f64> {
uniform_cdf(x, self.a, self.b)
}
fn inverse_cdf(&self, p: f64) -> StatsResult<f64> {
uniform_inverse_cdf(p, self.a, self.b)
}
fn mean(&self) -> f64 {
(self.a + self.b) / 2.0
}
fn variance(&self) -> f64 {
(self.b - self.a).powi(2) / 12.0
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPSILON: f64 = 1e-10;
#[test]
fn test_uniform_pdf_inside_range() {
assert!((uniform_pdf(0.0, 0.0, 1.0).unwrap() - 1.0).abs() < EPSILON);
assert!((uniform_pdf(0.5, 0.0, 1.0).unwrap() - 1.0).abs() < EPSILON);
assert!((uniform_pdf(1.0, 0.0, 1.0).unwrap() - 1.0).abs() < EPSILON);
assert!((uniform_pdf(2.0, 2.0, 4.0).unwrap() - 0.5).abs() < EPSILON);
assert!((uniform_pdf(3.0, 2.0, 4.0).unwrap() - 0.5).abs() < EPSILON);
assert!((uniform_pdf(4.0, 2.0, 4.0).unwrap() - 0.5).abs() < EPSILON);
}
#[test]
fn test_uniform_pdf_outside_range() {
assert!((uniform_pdf(-1.0, 0.0, 1.0).unwrap() - 0.0).abs() < EPSILON);
assert!((uniform_pdf(2.0, 0.0, 1.0).unwrap() - 0.0).abs() < EPSILON);
}
#[test]
fn test_uniform_pdf_invalid_range() {
let result = uniform_pdf(0.5, 1.0, 0.0);
assert!(result.is_err(), "Should return error for a >= b");
}
#[test]
fn test_uniform_cdf_inside_range() {
assert!((uniform_cdf(0.0, 0.0, 1.0).unwrap() - 0.0).abs() < EPSILON);
assert!((uniform_cdf(0.25, 0.0, 1.0).unwrap() - 0.25).abs() < EPSILON);
assert!((uniform_cdf(0.5, 0.0, 1.0).unwrap() - 0.5).abs() < EPSILON);
assert!((uniform_cdf(0.75, 0.0, 1.0).unwrap() - 0.75).abs() < EPSILON);
assert!((uniform_cdf(1.0, 0.0, 1.0).unwrap() - 1.0).abs() < EPSILON);
assert!((uniform_cdf(2.0, 2.0, 4.0).unwrap() - 0.0).abs() < EPSILON);
assert!((uniform_cdf(3.0, 2.0, 4.0).unwrap() - 0.5).abs() < EPSILON);
assert!((uniform_cdf(4.0, 2.0, 4.0).unwrap() - 1.0).abs() < EPSILON);
}
#[test]
fn test_uniform_cdf_outside_range() {
assert!((uniform_cdf(-1.0, 0.0, 1.0).unwrap() - 0.0).abs() < EPSILON);
assert!((uniform_cdf(2.0, 0.0, 1.0).unwrap() - 1.0).abs() < EPSILON);
}
#[test]
fn test_uniform_cdf_inverse_cdf_relationship() {
let a = 2.0;
let b = 5.0;
for p in [0.0, 0.1, 0.25, 0.5, 0.75, 0.9, 1.0] {
let x = uniform_inverse_cdf(p, a, b).unwrap();
let p_result = uniform_cdf(x, a, b).unwrap();
assert!(
(p - p_result).abs() < EPSILON,
"CDF(inverse_CDF(p)) should equal p"
);
if p > 0.0 && p < 1.0 {
let x_within_range = a + p * (b - a);
let p_cdf = uniform_cdf(x_within_range, a, b).unwrap();
let x_result = uniform_inverse_cdf(p_cdf, a, b).unwrap();
assert!(
(x_within_range - x_result).abs() < EPSILON,
"inverse_CDF(CDF(x)) should equal x"
);
}
}
}
#[test]
fn test_uniform_config_new_a_less_than_b() {
let config = UniformConfig::new(0.0, 1.0);
assert!(config.is_ok());
}
#[test]
fn test_uniform_config_new_a_equal_b() {
let result = UniformConfig::new(1.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_uniform_config_new_a_greater_than_b() {
let result = UniformConfig::new(2.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_uniform_pdf_at_boundary_a() {
let result = uniform_pdf(0.0, 0.0, 1.0).unwrap();
assert!((result - 1.0).abs() < EPSILON);
}
#[test]
fn test_uniform_pdf_at_boundary_b() {
let result = uniform_pdf(1.0, 0.0, 1.0).unwrap();
assert!((result - 1.0).abs() < EPSILON);
}
#[test]
fn test_uniform_inverse_cdf_p_negative() {
let result = uniform_inverse_cdf(-0.1, 0.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_uniform_inverse_cdf_p_greater_than_one() {
let result = uniform_inverse_cdf(1.5, 0.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_uniform_inverse_cdf_p_zero() {
let result = uniform_inverse_cdf(0.0, 2.0, 5.0).unwrap();
assert!((result - 2.0).abs() < EPSILON);
}
#[test]
fn test_uniform_inverse_cdf_p_one() {
let result = uniform_inverse_cdf(1.0, 2.0, 5.0).unwrap();
assert!((result - 5.0).abs() < EPSILON);
}
#[test]
fn test_uniform_inverse_cdf_a_greater_than_b() {
let result = uniform_inverse_cdf(0.5, 2.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_uniform_mean_a_greater_than_b() {
let result = uniform_mean(2.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_uniform_mean_a_equal_b() {
let result = uniform_mean(1.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_uniform_cdf_a_greater_than_b() {
let result = uniform_cdf(0.5, 2.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_uniform_cdf_a_equal_b() {
let result = uniform_cdf(0.5, 1.0, 1.0);
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
StatsError::InvalidInput { .. }
));
}
#[test]
fn test_uniform_cdf_x_exactly_at_a() {
let result = uniform_cdf(0.0, 0.0, 1.0).unwrap();
assert!((result - 0.0).abs() < EPSILON);
}
#[test]
fn test_uniform_cdf_x_exactly_at_b() {
let result = uniform_cdf(1.0, 0.0, 1.0).unwrap();
assert!((result - 1.0).abs() < EPSILON);
}
#[test]
fn test_uniform_pdf_x_between_a_and_b() {
let result = uniform_pdf(0.5, 0.0, 1.0).unwrap();
assert!((result - 1.0).abs() < EPSILON);
}
}