use rand::Rng;
use rand_distr::Distribution as Distribution2;
use rand_distr::FisherF as FisherF2;
use crate::distributions::Distribution;
pub struct FisherF {
d1: f64,
d2: f64,
}
impl FisherF {
pub fn new(d1: f64, d2: f64) -> Result<FisherF, String> {
if d1 <= 0.0 {
Err(format!(
"FisherF: illegal d1 `{}` should be greater than 0",
d1
))
} else if d2 <= 0.0 {
Err(format!(
"FisherF: illegal d2 `{}` should be greater than 0",
d2
))
} else {
Ok(FisherF { d1, d2 })
}
}
}
impl<R: Rng + ?Sized> Distribution<R> for FisherF {
type Domain = f64;
fn sample(&self, rng: &mut R) -> f64 {
FisherF2::new(self.d1, self.d2).unwrap().sample(rng)
}
fn log_prob(&self, x: &f64) -> f64 {
if *x <= 0.0 {
return f64::NEG_INFINITY;
}
let d1 = self.d1;
let d2 = self.d2;
let log_beta =
libm::lgamma(d1 / 2.0) + libm::lgamma(d2 / 2.0) - libm::lgamma((d1 + d2) / 2.0);
(d1 / 2.0) * (d1 / d2).ln() + (d1 / 2.0 - 1.0) * x.ln()
- ((d1 + d2) / 2.0) * (1.0 + d1 / d2 * x).ln()
- log_beta
}
fn is_discrete(&self) -> bool {
false
}
}
impl std::fmt::Display for FisherF {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "FisherF {{ d1 = {}, d2 = {} }}", self.d1, self.d2)
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::ThreadRng;
use rand::thread_rng;
#[test]
fn fisher_f_sample() {
let mut rng = thread_rng();
let d1 = 10.0f64;
let d2 = 20.0f64;
let dist = FisherF::new(d1, d2).unwrap();
println!("dist = {}", dist);
let trials = 100_000;
let mut total = 0.0f64;
for _ in 0..trials {
total += dist.sample(&mut rng);
}
let empirical_mean = total / trials as f64;
let expected_mean = d2 / (d2 - 2.0);
let variance = 2.0 * d2 * d2 * (d1 + d2 - 2.0) / (d1 * (d2 - 2.0).powi(2) * (d2 - 4.0));
let std = variance.sqrt();
let err = 5.0 * std / (trials as f64).sqrt();
assert!((empirical_mean - expected_mean).abs() < err);
}
#[test]
fn fisher_f_log_prob() {
let dist = FisherF::new(2.0, 2.0).unwrap();
let lp = <FisherF as Distribution<ThreadRng>>::log_prob(&dist, &1.0);
assert!(lp.is_finite());
let lp_zero = <FisherF as Distribution<ThreadRng>>::log_prob(&dist, &0.0);
assert_eq!(lp_zero, f64::NEG_INFINITY);
assert!(!<FisherF as Distribution<ThreadRng>>::is_discrete(&dist));
}
#[test]
fn fisher_f_display() {
let dist = FisherF::new(5.0, 10.0).unwrap();
let s = format!("{}", dist);
assert!(s.contains("FisherF"), "missing type name: {}", s);
}
#[test]
#[should_panic]
fn fisher_f_zero_d1() {
FisherF::new(0.0, 1.0).unwrap();
}
#[test]
#[should_panic]
fn fisher_f_zero_d2() {
FisherF::new(1.0, 0.0).unwrap();
}
}