use std::convert::TryFrom;
use std::fmt::Debug;
use crate::feature::Component;
use lace_data::Datum;
use lace_data::SparseContainer;
use lace_stats::prior::csd::CsdHyper;
use lace_stats::prior::nix::NixHyper;
use lace_stats::prior::pg::PgHyper;
use lace_stats::rv::data::{
    BernoulliSuffStat, CategoricalDatum, CategoricalSuffStat, GaussianSuffStat,
    PoissonSuffStat,
};
use lace_stats::rv::dist::{
    Bernoulli, Beta, Categorical, Gamma, Gaussian, NormalInvChiSquared,
    Poisson, SymmetricDirichlet,
};
use lace_stats::rv::traits::{ConjugatePrior, HasSuffStat, Mode, Rv};
use lace_stats::UpdatePrior;
use serde::de::DeserializeOwned;
use serde::Serialize;
pub trait AccumScore<X: Clone + Default>: Rv<X> + Sync {
    fn accum_score(&self, scores: &mut [f64], container: &SparseContainer<X>) {
        use lace_data::AccumScore;
        container.accum_score(scores, &|x| self.ln_f(x))
    }
}
impl<X: CategoricalDatum + Default> AccumScore<X> for Categorical {}
impl AccumScore<u32> for Poisson {}
impl AccumScore<f64> for Gaussian {}
impl AccumScore<bool> for Bernoulli {}
pub trait LaceDatum:
    Sync + Serialize + DeserializeOwned + TryFrom<Datum> + Default + Clone + Debug
{
}
impl<X> LaceDatum for X where
    X: Sync
        + Serialize
        + DeserializeOwned
        + TryFrom<Datum>
        + Default
        + Clone
        + Debug
{
}
pub trait LaceStat:
    Sync + Serialize + DeserializeOwned + Debug + Clone + PartialEq
{
}
impl<X> LaceStat for X where
    X: Sync + Serialize + DeserializeOwned + Debug + Clone + PartialEq
{
}
pub trait LaceLikelihood<X: LaceDatum>:
    Rv<X>
    + Mode<X>
    + AccumScore<X>
    + HasSuffStat<X>
    + Serialize
    + DeserializeOwned
    + Sync
    + Into<Component>
    + Clone
    + Debug
    + PartialEq
{
    fn ln_f_max(&self) -> Option<f64> {
        self.mode().map(|x| self.ln_f(&x))
    }
}
impl<X, Fx> LaceLikelihood<X> for Fx
where
    X: LaceDatum,
    Fx: Rv<X>
        + Mode<X>
        + AccumScore<X>
        + HasSuffStat<X>
        + Serialize
        + DeserializeOwned
        + Sync
        + Into<Component>
        + Clone
        + Debug
        + PartialEq,
    Fx::Stat: Sync + Serialize + DeserializeOwned + Clone + Debug,
{
}
pub trait LacePrior<X: LaceDatum, Fx: LaceLikelihood<X>, H>:
    ConjugatePrior<X, Fx>
    + UpdatePrior<X, Fx, H>
    + Serialize
    + DeserializeOwned
    + Sync
    + Clone
    + Debug
{
    fn empty_suffstat(&self) -> Fx::Stat;
    fn invalid_temp_component(&self) -> Fx;
    fn score_column<I: Iterator<Item = Fx::Stat>>(&self, stats: I) -> f64;
}
impl LacePrior<u8, Categorical, CsdHyper> for SymmetricDirichlet {
    fn empty_suffstat(&self) -> CategoricalSuffStat {
        CategoricalSuffStat::new(self.k())
    }
    fn invalid_temp_component(&self) -> Categorical {
        Categorical::new_unchecked(vec![0.0; self.k()])
    }
    fn score_column<I: Iterator<Item = CategoricalSuffStat>>(
        &self,
        stats: I,
    ) -> f64 {
        use special::Gamma;
        let sum_alpha = self.alpha() * self.k() as f64;
        let a = sum_alpha.ln_gamma().0;
        let d = self.alpha().ln_gamma().0 * self.k() as f64;
        stats
            .map(|stat| {
                let b = (sum_alpha + stat.n() as f64).ln_gamma().0;
                let c = stat.counts().iter().fold(0.0, |acc, &ct| {
                    acc + (self.alpha() + ct).ln_gamma().0
                });
                a - b + c - d
            })
            .sum::<f64>()
    }
}
#[inline]
fn poisson_zn(shape: f64, rate: f64, stat: &PoissonSuffStat) -> f64 {
    use special::Gamma;
    let shape_n = shape + stat.sum();
    let rate_n = rate + stat.n() as f64;
    let ln_gamma_shape = shape_n.ln_gamma().0;
    let ln_rate = rate_n.ln();
    shape_n.mul_add(-ln_rate, ln_gamma_shape)
}
impl LacePrior<u32, Poisson, PgHyper> for Gamma {
    fn empty_suffstat(&self) -> PoissonSuffStat {
        PoissonSuffStat::new()
    }
    fn invalid_temp_component(&self) -> Poisson {
        Poisson::new_unchecked(1.0)
    }
    fn score_column<I: Iterator<Item = PoissonSuffStat>>(
        &self,
        stats: I,
    ) -> f64 {
        use special::Gamma as _;
        let shape = self.shape();
        let rate = self.rate();
        let z0 = {
            let ln_gamma_shape = shape.ln_gamma().0;
            let ln_rate = rate.ln();
            shape.mul_add(-ln_rate, ln_gamma_shape)
        };
        stats
            .map(|stat| {
                let zn = poisson_zn(shape, rate, &stat);
                zn - z0 - stat.sum_ln_fact()
            })
            .sum::<f64>()
    }
}
impl LacePrior<bool, Bernoulli, ()> for Beta {
    fn empty_suffstat(&self) -> BernoulliSuffStat {
        BernoulliSuffStat::new()
    }
    fn invalid_temp_component(&self) -> Bernoulli {
        Bernoulli::uniform()
    }
    fn score_column<I: Iterator<Item = BernoulliSuffStat>>(
        &self,
        stats: I,
    ) -> f64 {
        use lace_stats::rv::data::DataOrSuffStat;
        let cache = <Beta as ConjugatePrior<bool, Bernoulli>>::ln_m_cache(self);
        stats
            .map(|stat| {
                let x = DataOrSuffStat::SuffStat::<bool, Bernoulli>(&stat);
                self.ln_m_with_cache(&cache, &x)
            })
            .sum::<f64>()
    }
}
impl LacePrior<f64, Gaussian, NixHyper> for NormalInvChiSquared {
    fn empty_suffstat(&self) -> GaussianSuffStat {
        GaussianSuffStat::new()
    }
    fn invalid_temp_component(&self) -> Gaussian {
        Gaussian::standard()
    }
    fn score_column<I: Iterator<Item = GaussianSuffStat>>(
        &self,
        stats: I,
    ) -> f64 {
        use lace_stats::rv::data::DataOrSuffStat;
        let cache = self.ln_m_cache();
        stats
            .map(|stat| {
                let x = DataOrSuffStat::SuffStat(&stat);
                self.ln_m_with_cache(&cache, &x)
            })
            .sum::<f64>()
    }
}