//! Trait definitions
use crate::data::DataOrSuffStat;
use rand::Rng;

/// Random variable
/// Contains the minimal functionality that a random object must have to be
/// useful: a function defining the un-normalized density/mass at a point,
/// and functions to draw samples from the distribution.
pub trait Rv<X> {
    /// Probability function
    /// # Example
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::Rv;
    /// let g = Gaussian::standard();
    /// assert!(g.f(&0.0_f64) > g.f(&0.1_f64));
    /// assert!(g.f(&0.0_f64) > g.f(&-0.1_f64));
    /// ```
    fn f(&self, x: &X) -> f64 {

    /// Probability function
    /// # Example
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::Rv;
    /// let g = Gaussian::standard();
    /// assert!(g.ln_f(&0.0_f64) > g.ln_f(&0.1_f64));
    /// assert!(g.ln_f(&0.0_f64) > g.ln_f(&-0.1_f64));
    /// ```
    fn ln_f(&self, x: &X) -> f64;

    /// Single draw from the `Rv`
    /// # Example
    /// Flip a coin
    /// ```
    /// use rv::dist::Bernoulli;
    /// use rv::traits::Rv;
    /// let b = Bernoulli::uniform();
    /// let mut rng = rand::thread_rng();
    /// let x: bool = b.draw(&mut rng); // could be true, could be false.
    /// ```
    fn draw<R: Rng>(&self, rng: &mut R) -> X;

    /// Multiple draws of the `Rv`
    /// # Example
    /// Flip a lot of coins
    /// ```
    /// use rv::dist::Bernoulli;
    /// use rv::traits::Rv;
    /// let b = Bernoulli::uniform();
    /// let mut rng = rand::thread_rng();
    /// let xs: Vec<bool> = b.sample(22, &mut rng);
    /// assert_eq!(xs.len(), 22);
    /// ```
    /// Estimate Gaussian mean
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::Rv;
    /// let gauss = Gaussian::standard();
    /// let mut rng = rand::thread_rng();
    /// let xs: Vec<f64> = gauss.sample(100_000, &mut rng);
    /// assert::close(xs.iter().sum::<f64>()/100_000.0, 0.0, 1e-2);
    /// ```
    fn sample<R: Rng>(&self, n: usize, mut rng: &mut R) -> Vec<X> {
        (0..n).map(|_| self.draw(&mut rng)).collect()

    /// Create a never-ending iterator of samples
    /// # Example
    /// Estimate the mean of a Gamma distribution
    /// ```
    /// use rv::traits::Rv;
    /// use rv::dist::Gamma;
    /// let mut rng = rand::thread_rng();
    /// let gamma = Gamma::new(2.0, 1.0).unwrap();
    /// let n = 1_000_000_usize;
    /// let mean = <Gamma as Rv<f64>>::sample_stream(&gamma, &mut rng)
    ///     .take(n)
    ///     .sum::<f64>() / n as f64;;
    /// assert::close(mean, 2.0, 1e-2);
    /// ```
    fn sample_stream<'r, R: Rng>(
        &'r self,
        mut rng: &'r mut R,
    ) -> Box<dyn Iterator<Item = X> + 'r> {
        Box::new(std::iter::repeat_with(move || self.draw(&mut rng)))

// Auto impl for deref types
impl<Fx, X> Rv<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: Rv<X>,
    fn ln_f(&self, x: &X) -> f64 {

    fn f(&self, x: &X) -> f64 {

    fn draw<R: Rng>(&self, mut rng: &mut R) -> X {
        self.deref().draw(&mut rng)

    fn sample<R: Rng>(&self, n: usize, mut rng: &mut R) -> Vec<X> {
        self.deref().sample(n, &mut rng)

/// Identifies the support of the Rv
pub trait Support<X> {
    /// Returns `true` if `x` is in the support of the `Rv`
    /// # Example
    /// ```
    /// use rv::dist::Uniform;
    /// use rv::traits::Support;
    /// // Create uniform with support on the interval [0, 1]
    /// let u = Uniform::new(0.0, 1.0).unwrap();
    /// assert!(u.supports(&0.5_f64));
    /// assert!(!u.supports(&-0.1_f64));
    /// assert!(!u.supports(&1.1_f64));
    /// ```
    fn supports(&self, x: &X) -> bool;

impl<Fx, X> Support<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: Support<X>,
    fn supports(&self, x: &X) -> bool {

/// Is a continuous probability distributions
/// This trait uses the `Rv<X>` and `Support<X>` implementations to implement
/// itself.
pub trait ContinuousDistr<X>: Rv<X> + Support<X> {
    /// The value of the Probability Density Function (PDF) at `x`
    /// # Panics
    /// If `x` is not in the support.
    /// # Example
    /// Compute the Gaussian PDF, f(x)
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::ContinuousDistr;
    /// let g = Gaussian::standard();
    /// let f_mean = g.pdf(&0.0_f64);
    /// let f_low = g.pdf(&-1.0_f64);
    /// let f_high = g.pdf(&1.0_f64);
    /// assert!(f_mean > f_low);
    /// assert!(f_mean > f_high);
    /// assert!((f_low - f_high).abs() < 1E-12);
    /// ```
    fn pdf(&self, x: &X) -> f64 {

    /// The value of the log Probability Density Function (PDF) at `x`
    /// # Panics
    /// If `x` is not in the support.
    /// # Example
    /// Compute the natural logarithm of the Gaussian PDF, ln(f(x))
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::ContinuousDistr;
    /// let g = Gaussian::standard();
    /// let lnf_mean = g.ln_pdf(&0.0_f64);
    /// let lnf_low = g.ln_pdf(&-1.0_f64);
    /// let lnf_high = g.ln_pdf(&1.0_f64);
    /// assert!(lnf_mean > lnf_low);
    /// assert!(lnf_mean > lnf_high);
    /// assert!((lnf_low - lnf_high).abs() < 1E-12);
    /// ```
    fn ln_pdf(&self, x: &X) -> f64 {
        if !self.supports(&x) {
            panic!("x not in support");

impl<Fx, X> ContinuousDistr<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: ContinuousDistr<X>,

/// Has a cumulative distribution function (CDF)
pub trait Cdf<X>: Rv<X> {
    /// The value of the Cumulative Density Function at `x`
    /// # Example
    /// The proportion of probability in (-∞, μ) in N(μ, σ) is 50%
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::Cdf;
    /// let g = Gaussian::new(1.0, 1.5).unwrap();
    /// assert!((g.cdf(&1.0_f64) - 0.5).abs() < 1E-12);
    /// ```
    fn cdf(&self, x: &X) -> f64;

    /// Survival function, `1 - CDF(x)`
    fn sf(&self, x: &X) -> f64 {
        1.0 - self.cdf(x)

impl<Fx, X> Cdf<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: Cdf<X>,
    fn cdf(&self, x: &X) -> f64 {

    fn sf(&self, x: &X) -> f64 {

/// Has an inverse-CDF / quantile function
pub trait InverseCdf<X>: Rv<X> + Support<X> {
    /// The value of the `x` at the given probability in the CDF
    /// # Example
    /// The CDF identity: p = CDF(x) => x = CDF<sup>-1</sup>(p)
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::Cdf;
    /// use rv::traits::InverseCdf;
    /// let g = Gaussian::standard();
    /// let x: f64 = 1.2;
    /// let p: f64 = g.cdf(&x);
    /// let y: f64 = g.invcdf(p);
    /// // x and y should be about the same
    /// assert!((x - y).abs() < 1E-12);
    /// ```
    fn invcdf(&self, p: f64) -> X;

    /// Alias for `invcdf`
    fn quantile(&self, p: f64) -> X {

    /// Interval containing `p` proportion for the probability
    /// # Example
    /// Confidence interval
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::InverseCdf;
    /// let g = Gaussian::new(100.0, 15.0).unwrap();
    /// let ci: (f64, f64) = g.interval(0.68268949213708585);  // one stddev
    /// assert!( (ci.0 - 85.0).abs() < 1E-12);
    /// assert!( (ci.1 - 115.0).abs() < 1E-12);
    /// ```
    fn interval(&self, p: f64) -> (X, X) {
        let pt = (1.0 - p) / 2.0;
        (self.quantile(pt), self.quantile(p + pt))

impl<Fx, X> InverseCdf<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: InverseCdf<X>,
    fn invcdf(&self, p: f64) -> X {

    fn quantile(&self, p: f64) -> X {

    fn interval(&self, p: f64) -> (X, X) {

/// Is a discrete probability distribution
pub trait DiscreteDistr<X>: Rv<X> + Support<X> {
    /// Probability mass function (PMF) at `x`
    /// # Panics
    /// If `x` is not supported
    /// # Example
    /// The probability of a fair coin coming up heads in 0.5
    /// ```
    /// use rv::dist::Bernoulli;
    /// use rv::traits::DiscreteDistr;
    /// // Fair coin (p = 0.5)
    /// let b = Bernoulli::uniform();
    /// assert!( (b.pmf(&true) - 0.5).abs() < 1E-12);
    /// ```
    fn pmf(&self, x: &X) -> f64 {

    /// Natural logarithm of the probability mass function (PMF)
    /// # Panics
    /// If `x` is not supported
    /// # Example
    /// The probability of a fair coin coming up heads in 0.5
    /// ```
    /// use rv::dist::Bernoulli;
    /// use rv::traits::DiscreteDistr;
    /// // Fair coin (p = 0.5)
    /// let b = Bernoulli::uniform();
    /// assert!( (b.ln_pmf(&true) - 0.5_f64.ln()).abs() < 1E-12);
    /// ```
    fn ln_pmf(&self, x: &X) -> f64 {
        if !self.supports(&x) {
            panic!("x not in support");

impl<Fx, X> DiscreteDistr<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: DiscreteDistr<X>,

/// Defines the distribution mean
pub trait Mean<X> {
    /// Returns `None` if the mean is undefined
    fn mean(&self) -> Option<X>;

impl<Fx, X> Mean<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: Mean<X>,
    fn mean(&self) -> Option<X> {

/// Defines the distribution median
pub trait Median<X> {
    /// Returns `None` if the median is undefined
    fn median(&self) -> Option<X>;

impl<Fx, X> Median<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: Median<X>,
    fn median(&self) -> Option<X> {

/// Defines the distribution mode
pub trait Mode<X> {
    /// Returns `None` if the mode is undefined or is not a single value
    fn mode(&self) -> Option<X>;

impl<Fx, X> Mode<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: Mode<X>,
    fn mode(&self) -> Option<X> {

/// Defines the distribution variance
pub trait Variance<X> {
    /// Returns `None` if the variance is undefined
    fn variance(&self) -> Option<X>;

impl<Fx, X> Variance<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: Variance<X>,
    fn variance(&self) -> Option<X> {

/// Defines the distribution entropy
pub trait Entropy {
    /// The entropy, *H(X)*
    fn entropy(&self) -> f64;

impl<Fx> Entropy for Fx
    Fx: std::ops::Deref,
    Fx::Target: Entropy,
    fn entropy(&self) -> f64 {

pub trait Skewness {
    fn skewness(&self) -> Option<f64>;

impl<Fx> Skewness for Fx
    Fx: std::ops::Deref,
    Fx::Target: Skewness,
    fn skewness(&self) -> Option<f64> {

pub trait Kurtosis {
    fn kurtosis(&self) -> Option<f64>;

impl<Fx> Kurtosis for Fx
    Fx: std::ops::Deref,
    Fx::Target: Kurtosis,
    fn kurtosis(&self) -> Option<f64> {

/// KL divergences
pub trait KlDivergence {
    /// The KL divergence, KL(P|Q) between this distribution, P, and another, Q
    /// # Example
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::KlDivergence;
    /// let g1 = Gaussian::new(1.0, 1.0).unwrap();
    /// let g2 = Gaussian::new(-1.0, 2.0).unwrap();
    /// let kl_self = g1.kl(&g1);
    /// let kl_other = g1.kl(&g2);
    /// // KL(P|P) = 0
    /// assert!( kl_self < 1E-12 );
    /// // KL(P|Q) > 0 if P ≠ Q
    /// assert!( kl_self < kl_other );
    /// ```
    fn kl(&self, other: &Self) -> f64;

    /// Symmetrised divergence, KL(P|Q) + KL(Q|P)
    /// # Example
    /// ```
    /// use rv::dist::Gaussian;
    /// use rv::traits::KlDivergence;
    /// let g1 = Gaussian::new(1.0, 1.0).unwrap();
    /// let g2 = Gaussian::new(-1.0, 2.0).unwrap();
    /// let kl_12 = g1.kl(&g2);
    /// let kl_21 = g2.kl(&g1);
    /// let kl_sym = g1.kl_sym(&g2);
    /// assert!( (kl_12 + kl_21 - kl_sym).abs() < 1E-10 );
    /// ```
    fn kl_sym(&self, other: &Self) -> f64 {
        self.kl(&other) + other.kl(&self)

impl<Fx> KlDivergence for Fx
    Fx: std::ops::Deref,
    Fx::Target: KlDivergence,
    fn kl(&self, other: &Self) -> f64 {

    fn kl_sym(&self, other: &Self) -> f64 {

/// The data for this distribution can be summarized by a statistic
pub trait HasSuffStat<X> {
    type Stat: SuffStat<X>;
    fn empty_suffstat(&self) -> Self::Stat;

impl<Fx, X> HasSuffStat<X> for Fx
    Fx: std::ops::Deref,
    Fx::Target: HasSuffStat<X>,
    type Stat = <<Fx as std::ops::Deref>::Target as HasSuffStat<X>>::Stat;

    fn empty_suffstat(&self) -> Self::Stat {

/// Is a [sufficient statistic]( for a
/// distribution.
/// # Examples
/// Basic suffstat useage.
/// ```
/// use rv::data::BernoulliSuffStat;
/// use rv::traits::SuffStat;
/// // Bernoulli sufficient statistics are the number of observations, n, and
/// // the number of successes, k.
/// let mut stat = BernoulliSuffStat::new();
/// assert!(stat.n() == 0 && stat.k() == 0);
/// stat.observe(&true);  // observe `true`
/// assert!(stat.n() == 1 && stat.k() == 1);
/// stat.observe(&false);  // observe `false`
/// assert!(stat.n() == 2 && stat.k() == 1);
/// stat.forget_many(&vec![false, true]);  // forget `true` and `false`
/// assert!(stat.n() == 0 && stat.k() == 0);
/// ```
/// Conjugate analysis of coin flips using Bernoulli with a Beta prior on the
/// success probability.
/// ```
/// use rv::traits::SuffStat;
/// use rv::traits::ConjugatePrior;
/// use rv::data::BernoulliSuffStat;
/// use rv::dist::{Bernoulli, Beta};
/// let flips = vec![true, false, false];
/// // Pack the data into a sufficient statistic that holds the number of
/// // trials and the number of successes
/// let mut stat = BernoulliSuffStat::new();
/// stat.observe_many(&flips);
/// let prior = Beta::jeffreys();
/// // If we observe more false than true, the posterior predictive
/// // probability of true decreases.
/// let pp_no_obs = prior.pp(&true, &(&BernoulliSuffStat::new()).into());
/// let pp_obs = prior.pp(&true, &(&flips).into());
/// assert!(pp_obs < pp_no_obs);
/// ```
pub trait SuffStat<X> {
    /// Returns the number of observations
    fn n(&self) -> usize;

    /// Assimilate the datum `x` into the statistic
    fn observe(&mut self, x: &X);

    /// Remove the datum `x` from the statistic
    fn forget(&mut self, x: &X);

    /// Assimilate several observations
    fn observe_many(&mut self, xs: &[X]) {
        xs.iter().for_each(|x| self.observe(x));

    /// Forget several observations
    fn forget_many(&mut self, xs: &[X]) {
        xs.iter().for_each(|x| self.forget(x));

impl<S, X> SuffStat<X> for S
    S: std::ops::DerefMut,
    S::Target: SuffStat<X>,
    fn n(&self) -> usize {

    fn observe(&mut self, x: &X) {

    fn forget(&mut self, x: &X) {

    fn observe_many(&mut self, xs: &[X]) {

    fn forget_many(&mut self, xs: &[X]) {

/// A prior on `Fx` that induces a posterior that is the same form as the prior
/// # Example
/// Conjugate analysis of coin flips using Bernoulli with a Beta prior on the
/// success probability.
/// ```
/// use rv::traits::ConjugatePrior;
/// use rv::dist::{Bernoulli, Beta};
/// let flips = vec![true, false, false];
/// let prior = Beta::jeffreys();
/// // If we observe more false than true, the posterior predictive
/// // probability of true decreases.
/// let pp_no_obs = prior.pp(&true, &(&vec![]).into());
/// let pp_obs = prior.pp(&true, &(&flips).into());
/// assert!(pp_obs < pp_no_obs);
/// ```
pub trait ConjugatePrior<X, Fx>: Rv<Fx>
    Fx: Rv<X> + HasSuffStat<X>,
    type Posterior: Rv<Fx>;

    /// Computes the posterior distribution from the data
    fn posterior(&self, x: &DataOrSuffStat<X, Fx>) -> Self::Posterior;

    /// Log marginal likelihood
    fn ln_m(&self, x: &DataOrSuffStat<X, Fx>) -> f64;

    /// Log posterior predictive of y given x
    fn ln_pp(&self, y: &X, x: &DataOrSuffStat<X, Fx>) -> f64;

    /// Marginal likelihood of x
    fn m(&self, x: &DataOrSuffStat<X, Fx>) -> f64 {

    /// Posterior Predictive distribution
    fn pp(&self, y: &X, x: &DataOrSuffStat<X, Fx>) -> f64 {
        self.ln_pp(&y, x).exp()

/// Get the quad bounds of a univariate real distribution
pub trait QuadBounds {
    fn quad_bounds(&self) -> (f64, f64);