use std::f64;
use rand::Rng;
use rand::distributions::{Sample, IndependentSample};
use function::{beta, gamma};
use statistics::*;
use distribution::{Univariate, Continuous, Distribution};
use result::Result;
use error::StatsError;
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct StudentsT {
location: f64,
scale: f64,
freedom: f64,
}
impl StudentsT {
pub fn new(location: f64, scale: f64, freedom: f64) -> Result<StudentsT> {
let is_nan = location.is_nan() || scale.is_nan() || freedom.is_nan();
if is_nan || scale <= 0.0 || freedom <= 0.0 {
Err(StatsError::BadParams)
} else {
Ok(StudentsT {
location: location,
scale: scale,
freedom: freedom,
})
}
}
pub fn location(&self) -> f64 {
self.location
}
pub fn scale(&self) -> f64 {
self.scale
}
pub fn freedom(&self) -> f64 {
self.freedom
}
}
impl Sample<f64> for StudentsT {
fn sample<R: Rng>(&mut self, r: &mut R) -> f64 {
super::Distribution::sample(self, r)
}
}
impl IndependentSample<f64> for StudentsT {
fn ind_sample<R: Rng>(&self, r: &mut R) -> f64 {
super::Distribution::sample(self, r)
}
}
impl Distribution<f64> for StudentsT {
fn sample<R: Rng>(&self, r: &mut R) -> f64 {
let gamma = super::gamma::sample_unchecked(r, 0.5 * self.freedom, 0.5);
super::normal::sample_unchecked(r,
self.location,
self.scale * (self.freedom / gamma).sqrt())
}
}
impl Univariate<f64, f64> for StudentsT {
fn cdf(&self, x: f64) -> f64 {
if self.freedom == f64::INFINITY {
super::normal::cdf_unchecked(x, self.location, self.scale)
} else {
let k = (x - self.location) / self.scale;
let h = self.freedom / (self.freedom + k * k);
let ib = 0.5 * beta::beta_reg(self.freedom / 2.0, 0.5, h);
if x <= self.location { ib } else { 1.0 - ib }
}
}
}
impl Min<f64> for StudentsT {
fn min(&self) -> f64 {
f64::NEG_INFINITY
}
}
impl Max<f64> for StudentsT {
fn max(&self) -> f64 {
f64::INFINITY
}
}
impl Mean<f64> for StudentsT {
fn mean(&self) -> f64 {
assert!(self.freedom > 1.0,
format!("{}", StatsError::ArgGt("freedom", 1.0)));
self.location
}
}
impl Variance<f64> for StudentsT {
fn variance(&self) -> f64 {
assert!(self.freedom > 1.0,
format!("{}", StatsError::ArgGt("freedom", 1.0)));
if self.freedom == f64::INFINITY {
self.scale * self.scale
} else if self.freedom > 2.0 {
self.freedom * self.scale * self.scale / (self.freedom - 2.0)
} else {
f64::INFINITY
}
}
fn std_dev(&self) -> f64 {
self.variance().sqrt()
}
}
impl Entropy<f64> for StudentsT {
fn entropy(&self) -> f64 {
assert!(self.location == 0.0 && self.scale == 1.0,
"Cannot calculate entropy for StudentsT distribution where location is not 0 and \
scale is not 1");
(self.freedom + 1.0) / 2.0 *
(gamma::digamma((self.freedom + 1.0) / 2.0) - gamma::digamma(self.freedom / 2.0)) +
(self.freedom.sqrt() * beta::beta(self.freedom / 2.0, 0.5)).ln()
}
}
impl Skewness<f64> for StudentsT {
fn skewness(&self) -> f64 {
assert!(self.freedom > 3.0,
format!("{}", StatsError::ArgGt("freedom", 3.0)));
0.0
}
}
impl Median<f64> for StudentsT {
fn median(&self) -> f64 {
self.location
}
}
impl Mode<f64> for StudentsT {
fn mode(&self) -> f64 {
self.location
}
}
impl Continuous<f64, f64> for StudentsT {
fn pdf(&self, x: f64) -> f64 {
if self.freedom >= 1e8 {
super::normal::pdf_unchecked(x, self.location, self.scale)
} else {
let d = (x - self.location) / self.scale;
(gamma::ln_gamma((self.freedom + 1.0) / 2.0) - gamma::ln_gamma(self.freedom / 2.0))
.exp() *
(1.0 + d * d / self.freedom).powf(-0.5 * (self.freedom + 1.0)) /
(self.freedom * f64::consts::PI).sqrt() / self.scale
}
}
fn ln_pdf(&self, x: f64) -> f64 {
if self.freedom >= 1e8 {
super::normal::ln_pdf_unchecked(x, self.location, self.scale)
} else {
let d = (x - self.location) / self.scale;
gamma::ln_gamma((self.freedom + 1.0) / 2.0) -
0.5 * ((self.freedom + 1.0) * (1.0 + d * d / self.freedom).ln()) -
gamma::ln_gamma(self.freedom / 2.0) -
0.5 * (self.freedom * f64::consts::PI).ln() - self.scale.ln()
}
}
}
#[cfg_attr(rustfmt, rustfmt_skip)]
#[cfg(test)]
mod test {
use std::f64;
use std::panic;
use statistics::*;
use distribution::{Univariate, Continuous, StudentsT};
fn try_create(location: f64, scale: f64, freedom: f64) -> StudentsT {
let n = StudentsT::new(location, scale, freedom);
assert!(n.is_ok());
n.unwrap()
}
fn create_case(location: f64, scale: f64, freedom: f64) {
let n = try_create(location, scale, freedom);
assert_eq!(n.location(), location);
assert_eq!(n.scale(), scale);
assert_eq!(n.freedom(), freedom);
}
fn bad_create_case(location: f64, scale: f64, freedom: f64) {
let n = StudentsT::new(location, scale, freedom);
assert!(n.is_err());
}
fn get_value<F>(location: f64, scale: f64, freedom: f64, eval: F) -> f64
where F: Fn(StudentsT) -> f64
{
let n = try_create(location, scale, freedom);
eval(n)
}
fn test_case<F>(location: f64, scale: f64, freedom: f64, expected: f64, eval: F)
where F: Fn(StudentsT) -> f64
{
let x = get_value(location, scale, freedom, eval);
assert_eq!(expected, x);
}
fn test_almost<F>(location: f64, scale: f64, freedom: f64, expected: f64, acc: f64, eval: F)
where F: Fn(StudentsT) -> f64
{
let x = get_value(location, scale, freedom, eval);
assert_almost_eq!(expected, x, acc);
}
fn test_panic<F>(location: f64, scale: f64, freedom: f64, eval: F)
where F : Fn(StudentsT) -> f64,
F : panic::UnwindSafe
{
let result = panic::catch_unwind(|| {
get_value(location, scale, freedom, eval)
});
assert!(result.is_err());
}
#[test]
fn test_create() {
create_case(0.0, 0.1, 1.0);
create_case(0.0, 1.0, 1.0);
create_case(-5.0, 1.0, 3.0);
create_case(10.0, 10.0, f64::INFINITY);
}
#[test]
fn test_bad_create() {
bad_create_case(f64::NAN, 1.0, 1.0);
bad_create_case(0.0, f64::NAN, 1.0);
bad_create_case(0.0, 1.0, f64::NAN);
bad_create_case(0.0, -10.0, 1.0);
bad_create_case(0.0, 10.0, -1.0);
}
#[test]
fn test_mean() {
test_panic(0.0, 1.0, 1.0, |x| x.mean());
test_panic(0.0, 0.1, 1.0, |x| x.mean());
test_case(0.0, 1.0, 3.0, 0.0, |x| x.mean());
test_panic(0.0, 10.0, 1.0, |x| x.mean());
test_case(0.0, 10.0, 2.0, 0.0, |x| x.mean());
test_case(0.0, 10.0, f64::INFINITY, 0.0, |x| x.mean());
test_panic(10.0, 1.0, 1.0, |x| x.mean());
test_case(-5.0, 100.0, 1.5, -5.0, |x| x.mean());
test_panic(0.0, f64::INFINITY, 1.0, |x| x.mean());
}
#[test]
fn test_variance() {
test_panic(0.0, 1.0, 1.0, |x| x.variance());
test_panic(0.0, 0.1, 1.0, |x| x.variance());
test_case(0.0, 1.0, 3.0, 3.0, |x| x.variance());
test_panic(0.0, 10.0, 1.0, |x| x.variance());
test_case(0.0, 10.0, 2.0, f64::INFINITY, |x| x.variance());
test_case(0.0, 10.0, 2.5, 500.0, |x| x.variance());
test_panic(10.0, 1.0, 1.0, |x| x.variance());
test_case(10.0, 1.0, 2.5, 5.0, |x| x.variance());
test_case(-5.0, 100.0, 1.5, f64::INFINITY, |x| x.variance());
test_panic(0.0, f64::INFINITY, 1.0, |x| x.variance());
}
#[test]
fn test_std_dev() {
test_panic(0.0, 1.0, 1.0, |x| x.std_dev());
test_panic(0.0, 0.1, 1.0, |x| x.std_dev());
test_case(0.0, 1.0, 3.0, 1.7320508075688772935274463415059, |x| x.std_dev());
test_panic(0.0, 10.0, 1.0, |x| x.std_dev());
test_case(0.0, 10.0, 2.0, f64::INFINITY, |x| x.std_dev());
test_case(0.0, 10.0, 2.5, 22.360679774997896964091736687313, |x| x.std_dev());
test_case(0.0, 10.0, f64::INFINITY, 10.0, |x| x.std_dev());
test_panic(10.0, 1.0, 1.0, |x| x.std_dev());
test_case(10.0, 1.0, 2.5, 2.2360679774997896964091736687313, |x| x.std_dev());
test_case(-5.0, 100.0, 1.5, f64::INFINITY, |x| x.std_dev());
test_panic(0.0, f64::INFINITY, 1.0, |x| x.std_dev());
}
#[test]
fn test_mode() {
test_case(0.0, 1.0, 1.0, 0.0, |x| x.mode());
test_case(0.0, 0.1, 1.0, 0.0, |x| x.mode());
test_case(0.0, 1.0, 3.0, 0.0, |x| x.mode());
test_case(0.0, 10.0, 1.0, 0.0, |x| x.mode());
test_case(0.0, 10.0, 2.0, 0.0, |x| x.mode());
test_case(0.0, 10.0, 2.5, 0.0, |x| x.mode());
test_case(0.0, 10.0, f64::INFINITY, 0.0, |x| x.mode());
test_case(10.0, 1.0, 1.0, 10.0, |x| x.mode());
test_case(10.0, 1.0, 2.5, 10.0, |x| x.mode());
test_case(-5.0, 100.0, 1.5, -5.0, |x| x.mode());
test_case(0.0, f64::INFINITY, 1.0, 0.0, |x| x.mode());
}
#[test]
fn test_median() {
test_case(0.0, 1.0, 1.0, 0.0, |x| x.median());
test_case(0.0, 0.1, 1.0, 0.0, |x| x.median());
test_case(0.0, 1.0, 3.0, 0.0, |x| x.median());
test_case(0.0, 10.0, 1.0, 0.0, |x| x.median());
test_case(0.0, 10.0, 2.0, 0.0, |x| x.median());
test_case(0.0, 10.0, 2.5, 0.0, |x| x.median());
test_case(0.0, 10.0, f64::INFINITY, 0.0, |x| x.median());
test_case(10.0, 1.0, 1.0, 10.0, |x| x.median());
test_case(10.0, 1.0, 2.5, 10.0, |x| x.median());
test_case(-5.0, 100.0, 1.5, -5.0, |x| x.median());
test_case(0.0, f64::INFINITY, 1.0, 0.0, |x| x.median());
}
#[test]
fn test_min_max() {
test_case(0.0, 1.0, 1.0, f64::NEG_INFINITY, |x| x.min());
test_case(2.5, 100.0, 1.5, f64::NEG_INFINITY, |x| x.min());
test_case(10.0, f64::INFINITY, 3.5, f64::NEG_INFINITY, |x| x.min());
test_case(0.0, 1.0, 1.0, f64::INFINITY, |x| x.max());
test_case(2.5, 100.0, 1.5, f64::INFINITY, |x| x.max());
test_case(10.0, f64::INFINITY, 5.5, f64::INFINITY, |x| x.max());
}
#[test]
fn test_pdf() {
test_almost(0.0, 1.0, 1.0, 0.318309886183791, 1e-15, |x| x.pdf(0.0));
test_almost(0.0, 1.0, 1.0, 0.159154943091895, 1e-15, |x| x.pdf(1.0));
test_almost(0.0, 1.0, 1.0, 0.159154943091895, 1e-15, |x| x.pdf(-1.0));
test_almost(0.0, 1.0, 1.0, 0.063661977236758, 1e-15, |x| x.pdf(2.0));
test_almost(0.0, 1.0, 1.0, 0.063661977236758, 1e-15, |x| x.pdf(-2.0));
test_almost(0.0, 1.0, 2.0, 0.353553390593274, 1e-15, |x| x.pdf(0.0));
test_almost(0.0, 1.0, 2.0, 0.192450089729875, 1e-15, |x| x.pdf(1.0));
test_almost(0.0, 1.0, 2.0, 0.192450089729875, 1e-15, |x| x.pdf(-1.0));
test_almost(0.0, 1.0, 2.0, 0.068041381743977, 1e-15, |x| x.pdf(2.0));
test_almost(0.0, 1.0, 2.0, 0.068041381743977, 1e-15, |x| x.pdf(-2.0));
test_almost(0.0, 1.0, f64::INFINITY, 0.398942280401433, 1e-15, |x| x.pdf(0.0));
test_almost(0.0, 1.0, f64::INFINITY, 0.241970724519143, 1e-15, |x| x.pdf(1.0));
test_almost(0.0, 1.0, f64::INFINITY, 0.053990966513188, 1e-15, |x| x.pdf(2.0));
}
#[test]
fn test_ln_pdf() {
test_almost(0.0, 1.0, 1.0, -1.144729885849399, 1e-14, |x| x.ln_pdf(0.0));
test_almost(0.0, 1.0, 1.0, -1.837877066409348, 1e-14, |x| x.ln_pdf(1.0));
test_almost(0.0, 1.0, 1.0, -1.837877066409348, 1e-14, |x| x.ln_pdf(-1.0));
test_almost(0.0, 1.0, 1.0, -2.754167798283503, 1e-14, |x| x.ln_pdf(2.0));
test_almost(0.0, 1.0, 1.0, -2.754167798283503, 1e-14, |x| x.ln_pdf(-2.0));
test_almost(0.0, 1.0, 2.0, -1.039720770839917, 1e-14, |x| x.ln_pdf(0.0));
test_almost(0.0, 1.0, 2.0, -1.647918433002166, 1e-14, |x| x.ln_pdf(1.0));
test_almost(0.0, 1.0, 2.0, -1.647918433002166, 1e-14, |x| x.ln_pdf(-1.0));
test_almost(0.0, 1.0, 2.0, -2.687639203842085, 1e-14, |x| x.ln_pdf(2.0));
test_almost(0.0, 1.0, 2.0, -2.687639203842085, 1e-14, |x| x.ln_pdf(-2.0));
test_almost(0.0, 1.0, f64::INFINITY, -0.918938533204672, 1e-14, |x| x.ln_pdf(0.0));
test_almost(0.0, 1.0, f64::INFINITY, -1.418938533204674, 1e-14, |x| x.ln_pdf(1.0));
test_almost(0.0, 1.0, f64::INFINITY, -2.918938533204674, 1e-14, |x| x.ln_pdf(2.0));
}
#[test]
fn test_cdf() {
test_case(0.0, 1.0, 1.0, 0.5, |x| x.cdf(0.0));
test_almost(0.0, 1.0, 1.0, 0.75, 1e-15, |x| x.cdf(1.0));
test_almost(0.0, 1.0, 1.0, 0.25, 1e-15, |x| x.cdf(-1.0));
test_almost(0.0, 1.0, 1.0, 0.852416382349567, 1e-15, |x| x.cdf(2.0));
test_almost(0.0, 1.0, 1.0, 0.147583617650433, 1e-15, |x| x.cdf(-2.0));
test_case(0.0, 1.0, 2.0, 0.5, |x| x.cdf(0.0));
test_almost(0.0, 1.0, 2.0, 0.788675134594813, 1e-15, |x| x.cdf(1.0));
test_almost(0.0, 1.0, 2.0, 0.211324865405187, 1e-15, |x| x.cdf(-1.0));
test_almost(0.0, 1.0, 2.0, 0.908248290463863, 1e-15, |x| x.cdf(2.0));
test_almost(0.0, 1.0, 2.0, 0.091751709536137, 1e-15, |x| x.cdf(-2.0));
test_case(0.0, 1.0, f64::INFINITY, 0.5, |x| x.cdf(0.0));
test_almost(0.0, 1.0, f64::INFINITY, 0.841344746068543, 1e-10, |x| x.cdf(1.0));
test_almost(0.0, 1.0, f64::INFINITY, 0.977249868051821, 1e-11, |x| x.cdf(2.0));
}
}