use crate::DType;
use crate::stats::error::{StatsError, StatsResult};
use crate::stats::{DiscreteDistribution, Distribution};
use numr::algorithm::special::SpecialFunctions;
use numr::error::Result;
use numr::ops::{ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct DiscreteUniform {
low: i64,
high: i64,
n: u64,
}
impl DiscreteUniform {
pub fn new(low: i64, high: i64) -> StatsResult<Self> {
if high < low {
return Err(StatsError::InvalidParameter {
name: "high".to_string(),
value: high as f64,
reason: format!("upper bound must be >= lower bound ({})", low),
});
}
let n = (high - low + 1) as u64;
Ok(Self { low, high, n })
}
pub fn randint(n: u64) -> StatsResult<Self> {
if n == 0 {
return Err(StatsError::InvalidParameter {
name: "n".to_string(),
value: 0.0,
reason: "n must be positive".to_string(),
});
}
Self::new(0, (n - 1) as i64)
}
pub fn low(&self) -> i64 {
self.low
}
pub fn high(&self) -> i64 {
self.high
}
pub fn n(&self) -> u64 {
self.n
}
fn in_support(&self, k: i64) -> bool {
k >= self.low && k <= self.high
}
}
impl Distribution for DiscreteUniform {
fn mean(&self) -> f64 {
(self.low + self.high) as f64 / 2.0
}
fn var(&self) -> f64 {
let n = self.n as f64;
(n * n - 1.0) / 12.0
}
fn entropy(&self) -> f64 {
(self.n as f64).ln()
}
fn median(&self) -> f64 {
(self.low + self.high) as f64 / 2.0
}
fn mode(&self) -> f64 {
self.mean()
}
fn skewness(&self) -> f64 {
0.0
}
fn kurtosis(&self) -> f64 {
let n = self.n as f64;
let n2 = n * n;
-6.0 * (n2 + 1.0) / (5.0 * (n2 - 1.0))
}
}
impl DiscreteDistribution for DiscreteUniform {
fn pmf(&self, k: u64) -> f64 {
let k_signed = k as i64;
if self.in_support(k_signed) {
1.0 / self.n as f64
} else {
0.0
}
}
fn log_pmf(&self, k: u64) -> f64 {
let k_signed = k as i64;
if self.in_support(k_signed) {
-(self.n as f64).ln()
} else {
f64::NEG_INFINITY
}
}
fn cdf(&self, k: u64) -> f64 {
let k_signed = k as i64;
if k_signed < self.low {
0.0
} else if k_signed >= self.high {
1.0
} else {
(k_signed - self.low + 1) as f64 / self.n as f64
}
}
fn sf(&self, k: u64) -> f64 {
1.0 - self.cdf(k)
}
fn ppf(&self, p: f64) -> StatsResult<u64> {
if !(0.0..=1.0).contains(&p) {
return Err(StatsError::InvalidParameter {
name: "p".to_string(),
value: p,
reason: "probability must be in [0, 1]".to_string(),
});
}
if p == 0.0 {
return Ok(self.low as u64);
}
if p == 1.0 {
return Ok(self.high as u64);
}
let k = self.low + (self.n as f64 * p).ceil() as i64 - 1;
let k = k.max(self.low).min(self.high);
Ok(k as u64)
}
fn pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let pmf_const = 1.0 / (self.n as f64);
let shape = k.shape();
client.fill(shape, pmf_const, k.dtype())
}
fn log_pmf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let log_pmf_const = -(self.n as f64).ln();
let shape = k.shape();
client.fill(shape, log_pmf_const, k.dtype())
}
fn cdf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let low_f = self.low as f64;
let n_f = self.n as f64;
let k_floor = client.floor(k)?;
let adjusted = client.sub_scalar(&k_floor, low_f - 1.0)?;
let cdf_val = client.div_scalar(&adjusted, n_f)?;
client.clamp(&cdf_val, 0.0, 1.0)
}
fn sf_tensor<R: Runtime<DType = DType>, C>(
&self,
k: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let cdf = self.cdf_tensor(k, client)?;
client.sub_scalar(&client.mul_scalar(&cdf, -1.0)?, -1.0)
}
fn ppf_tensor<R: Runtime<DType = DType>, C>(
&self,
p: &Tensor<R>,
client: &C,
) -> Result<Tensor<R>>
where
C: TensorOps<R> + ScalarOps<R> + SpecialFunctions<R> + RuntimeClient<R>,
{
let low_f = self.low as f64;
let high_f = self.high as f64;
let n_f = self.n as f64;
let n_times_p = client.mul_scalar(p, n_f)?;
let ceiled = client.ceil(&n_times_p)?;
let ppf_val = client.add_scalar(&ceiled, low_f - 1.0)?;
client.clamp(&ppf_val, low_f, high_f)
}
}
impl DiscreteUniform {
pub fn coin() -> Self {
Self::new(0, 1).unwrap()
}
pub fn die(n: u64) -> StatsResult<Self> {
if n == 0 {
return Err(StatsError::InvalidParameter {
name: "n".to_string(),
value: 0.0,
reason: "number of sides must be positive".to_string(),
});
}
Self::new(1, n as i64)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_discrete_uniform_creation() {
assert!(DiscreteUniform::new(1, 6).is_ok());
assert!(DiscreteUniform::new(5, 5).is_ok()); assert!(DiscreteUniform::new(6, 1).is_err()); }
#[test]
fn test_discrete_uniform_pmf() {
let d = DiscreteUniform::new(1, 6).unwrap();
for k in 1..=6 {
assert!((d.pmf(k) - 1.0 / 6.0).abs() < 1e-10);
}
assert!((d.pmf(0) - 0.0).abs() < 1e-10);
assert!((d.pmf(7) - 0.0).abs() < 1e-10);
}
#[test]
fn test_discrete_uniform_cdf() {
let d = DiscreteUniform::new(1, 6).unwrap();
assert!((d.cdf(1) - 1.0 / 6.0).abs() < 1e-10);
assert!((d.cdf(2) - 2.0 / 6.0).abs() < 1e-10);
assert!((d.cdf(3) - 3.0 / 6.0).abs() < 1e-10);
assert!((d.cdf(6) - 1.0).abs() < 1e-10);
assert!((d.cdf(0) - 0.0).abs() < 1e-10);
}
#[test]
fn test_discrete_uniform_mean() {
let d = DiscreteUniform::new(1, 6).unwrap();
assert!((d.mean() - 3.5).abs() < 1e-10);
let d = DiscreteUniform::new(0, 10).unwrap();
assert!((d.mean() - 5.0).abs() < 1e-10);
}
#[test]
fn test_discrete_uniform_variance() {
let d = DiscreteUniform::new(1, 6).unwrap();
assert!((d.var() - 35.0 / 12.0).abs() < 1e-10);
}
#[test]
fn test_discrete_uniform_ppf() {
let d = DiscreteUniform::new(1, 6).unwrap();
assert_eq!(d.ppf(0.0).unwrap(), 1);
assert_eq!(d.ppf(1.0).unwrap(), 6);
for k in 1..=6_u64 {
let p = d.cdf(k);
let recovered = d.ppf(p).unwrap();
assert!(recovered == k || recovered == k + 1);
}
}
#[test]
fn test_discrete_uniform_entropy() {
let d = DiscreteUniform::new(1, 6).unwrap();
assert!((d.entropy() - 6.0_f64.ln()).abs() < 1e-10);
}
#[test]
fn test_discrete_uniform_skewness() {
let d = DiscreteUniform::new(1, 6).unwrap();
assert!((d.skewness() - 0.0).abs() < 1e-10);
}
#[test]
fn test_discrete_uniform_die() {
let d6 = DiscreteUniform::die(6).unwrap();
assert_eq!(d6.low(), 1);
assert_eq!(d6.high(), 6);
assert_eq!(d6.n(), 6);
}
#[test]
fn test_discrete_uniform_coin() {
let coin = DiscreteUniform::coin();
assert_eq!(coin.low(), 0);
assert_eq!(coin.high(), 1);
assert!((coin.pmf(0) - 0.5).abs() < 1e-10);
assert!((coin.pmf(1) - 0.5).abs() < 1e-10);
}
#[test]
fn test_discrete_uniform_randint() {
let d = DiscreteUniform::randint(10).unwrap();
assert_eq!(d.low(), 0);
assert_eq!(d.high(), 9);
assert_eq!(d.n(), 10);
}
#[test]
fn test_discrete_uniform_single_value() {
let d = DiscreteUniform::new(5, 5).unwrap();
assert!((d.pmf(5) - 1.0).abs() < 1e-10);
assert!((d.mean() - 5.0).abs() < 1e-10);
assert!((d.var() - 0.0).abs() < 1e-10);
}
}