use crate::{Distribution, StandardNormal};
use core::fmt;
use num_traits::Float;
use rand::{Rng, RngExt};
#[derive(Clone, Copy, Debug, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub struct SkewNormal<F>
where
F: Float,
StandardNormal: Distribution<F>,
{
location: F,
scale: F,
shape: F,
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum Error {
ScaleTooSmall,
BadShape,
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Error::ScaleTooSmall => {
"scale parameter is either non-finite or it is less or equal to zero in skew normal distribution"
}
Error::BadShape => "shape parameter is non-finite in skew normal distribution",
})
}
}
#[cfg(feature = "std")]
impl std::error::Error for Error {}
impl<F> SkewNormal<F>
where
F: Float,
StandardNormal: Distribution<F>,
{
#[inline]
pub fn new(location: F, scale: F, shape: F) -> Result<SkewNormal<F>, Error> {
if !scale.is_finite() || !(scale > F::zero()) {
return Err(Error::ScaleTooSmall);
}
if !shape.is_finite() {
return Err(Error::BadShape);
}
Ok(SkewNormal {
location,
scale,
shape,
})
}
pub fn location(&self) -> F {
self.location
}
pub fn scale(&self) -> F {
self.scale
}
pub fn shape(&self) -> F {
self.shape
}
}
impl<F> Distribution<F> for SkewNormal<F>
where
F: Float,
StandardNormal: Distribution<F>,
{
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> F {
let linear_map = |x: F| -> F { x * self.scale + self.location };
let u_1: F = rng.sample(StandardNormal);
if self.shape == F::zero() {
linear_map(u_1)
} else {
let u_2 = rng.sample(StandardNormal);
let (u, v) = (u_1.max(u_2), u_1.min(u_2));
if self.shape == -F::one() {
linear_map(v)
} else if self.shape == F::one() {
linear_map(u)
} else {
let normalized = ((F::one() + self.shape) * u + (F::one() - self.shape) * v)
/ ((F::one() + self.shape * self.shape).sqrt()
* F::from(core::f64::consts::SQRT_2).unwrap());
linear_map(normalized)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_samples<F: Float + fmt::Debug, D: Distribution<F>>(distr: D, zero: F, expected: &[F]) {
let mut rng = crate::test::rng(213);
let mut buf = [zero; 4];
for x in &mut buf {
*x = rng.sample(&distr);
}
assert_eq!(buf, expected);
}
#[test]
#[should_panic]
fn invalid_scale_nan() {
SkewNormal::new(0.0, f64::NAN, 0.0).unwrap();
}
#[test]
#[should_panic]
fn invalid_scale_zero() {
SkewNormal::new(0.0, 0.0, 0.0).unwrap();
}
#[test]
#[should_panic]
fn invalid_scale_negative() {
SkewNormal::new(0.0, -1.0, 0.0).unwrap();
}
#[test]
#[should_panic]
fn invalid_scale_infinite() {
SkewNormal::new(0.0, f64::INFINITY, 0.0).unwrap();
}
#[test]
#[should_panic]
fn invalid_shape_nan() {
SkewNormal::new(0.0, 1.0, f64::NAN).unwrap();
}
#[test]
#[should_panic]
fn invalid_shape_infinite() {
SkewNormal::new(0.0, 1.0, f64::INFINITY).unwrap();
}
#[test]
fn valid_location_nan() {
SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap();
}
#[test]
fn skew_normal_value_stability() {
test_samples(
SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
0f32,
&[-0.11844189, 0.781378, 0.06563994, -1.1932899],
);
test_samples(
SkewNormal::new(0.0, 1.0, 0.0).unwrap(),
0f64,
&[
-0.11844188827977231,
0.7813779637772346,
0.06563993969580051,
-1.1932899004186373,
],
);
test_samples(
SkewNormal::new(f64::INFINITY, 1.0, 0.0).unwrap(),
0f64,
&[f64::INFINITY, f64::INFINITY, f64::INFINITY, f64::INFINITY],
);
test_samples(
SkewNormal::new(f64::NEG_INFINITY, 1.0, 0.0).unwrap(),
0f64,
&[
f64::NEG_INFINITY,
f64::NEG_INFINITY,
f64::NEG_INFINITY,
f64::NEG_INFINITY,
],
);
}
#[test]
fn skew_normal_value_location_nan() {
let skew_normal = SkewNormal::new(f64::NAN, 1.0, 0.0).unwrap();
let mut rng = crate::test::rng(213);
let mut buf = [0.0; 4];
for x in &mut buf {
*x = rng.sample(skew_normal);
}
for value in buf.iter() {
assert!(value.is_nan());
}
}
#[test]
fn skew_normal_distributions_can_be_compared() {
assert_eq!(
SkewNormal::new(1.0, 2.0, 3.0),
SkewNormal::new(1.0, 2.0, 3.0)
);
}
}