use super::DistributionClass;
use crate::distributions::Distribution;
use num::Complex;
use RustQuant_error::RustQuantError;
pub struct Uniform {
        a: f64,
        b: f64,
        class: DistributionClass,
}
impl Uniform {
                                                            #[must_use]
    pub fn new(a: f64, b: f64, class: DistributionClass) -> Self {
        assert!(a <= b);
        match class {
            DistributionClass::Discrete => Self {
                a: a.round(),
                b: b.round(),
                class,
            },
            DistributionClass::Continuous => Self { a, b, class },
        }
    }
}
impl Distribution for Uniform {
                                                    fn cf(&self, t: f64) -> Complex<f64> {
        let i: Complex<f64> = Complex::i();
        match self.class {
            DistributionClass::Discrete => {
                ((i * t * self.a).exp() - (i * t * (self.b + 1.0)).exp())
                    / ((1.0 - (i * t).exp()) * (self.b - self.a + 1.0))
            }
            DistributionClass::Continuous => {
                ((i * t * self.b).exp() - (i * t * self.a).exp()) / (i * t * (self.b - self.a))
            }
        }
    }
                                            fn pdf(&self, x: f64) -> f64 {
        match self.class {
            DistributionClass::Discrete => {
                if x >= self.a && x <= self.b {
                    (self.b - self.a + 1.0).recip()
                } else {
                    0.0
                }
            }
            DistributionClass::Continuous => {
                if x >= self.a && x <= self.b {
                    (self.b - self.a).recip()
                } else {
                    0.0
                }
            }
        }
    }
                                            fn pmf(&self, x: f64) -> f64 {
        match self.class {
            DistributionClass::Discrete => {
                if x >= self.a && x <= self.b {
                    (self.b - self.a + 1.0).recip()
                } else {
                    0.0
                }
            }
            DistributionClass::Continuous => {
                if x >= self.a && x <= self.b {
                    (self.b - self.a).recip()
                } else {
                    0.0
                }
            }
        }
    }
                                            fn cdf(&self, x: f64) -> f64 {
        match self.class {
            DistributionClass::Discrete => {
                if x < self.a {
                    0.0
                } else if x >= self.a && x <= self.b {
                    (x.floor() - self.a + 1.0) / (self.b - self.a + 1.0)
                } else {
                    1.0
                }
            }
            DistributionClass::Continuous => {
                if x < self.a {
                    0.0
                } else if x >= self.a && x <= self.b {
                    (x - self.a) / (self.b - self.a)
                } else {
                    1.0
                }
            }
        }
    }
                                                fn inv_cdf(&self, p: f64) -> f64 {
        assert!((0.0..=1.0).contains(&p));
        match self.class {
            DistributionClass::Discrete => todo!(),
            DistributionClass::Continuous => self.a + p * (self.b - self.a),
        }
    }
                                                fn mean(&self) -> f64 {
        0.5 * (self.a + self.b)
    }
                                                fn median(&self) -> f64 {
        0.5 * (self.a + self.b)
    }
                                            fn mode(&self) -> f64 {
        match self.class {
            DistributionClass::Discrete => todo!(),
            DistributionClass::Continuous => (self.a + self.b) * 0.5,
        }
    }
                                            fn variance(&self) -> f64 {
        match self.class {
            DistributionClass::Discrete => (self.b - self.a + 1.0).powi(2) / 12.0,
            DistributionClass::Continuous => (self.b - self.a).powi(2) / 12.0,
        }
    }
                                            fn skewness(&self) -> f64 {
        0.0
    }
                                            fn kurtosis(&self) -> f64 {
        let n = self.b - self.a + 1.0;
        match self.class {
            DistributionClass::Discrete => -(6. * (n * n + 1.)) / (5. * (n * n - 1.)),
            DistributionClass::Continuous => -6.0 / 5.0,
        }
    }
                                            fn entropy(&self) -> f64 {
        match self.class {
            DistributionClass::Discrete => (self.b - self.a + 1.0).ln(),
            DistributionClass::Continuous => (self.b - self.a).ln(),
        }
    }
                                            fn mgf(&self, t: f64) -> f64 {
        let n = self.b - self.a + 1.0;
        match self.class {
            DistributionClass::Discrete => {
                ((t * self.a).exp() - (t * (self.b + 1.0)).exp()) / (n * (1.0 - (t).exp()))
            }
            DistributionClass::Continuous => {
                ((t * self.b).exp() - (t * self.a).exp()) / (t * (self.b - self.a))
            }
        }
    }
                                                    fn sample(&self, n: usize) -> Result<Vec<f64>, RustQuantError> {
                        use rand::thread_rng;
        use rand_distr::{Distribution, Uniform};
        assert!(n > 0);
        let mut rng = thread_rng();
        let dist = Uniform::new(self.a, self.b);
        let mut variates: Vec<f64> = Vec::with_capacity(n);
        match self.class {
            DistributionClass::Discrete => {
                for _ in 0..variates.capacity() {
                    variates.push(dist.sample(&mut rng) as usize as f64);
                }
            }
            DistributionClass::Continuous => {
                for _ in 0..variates.capacity() {
                    variates.push(dist.sample(&mut rng));
                }
            }
        }
        Ok(variates)
    }
}
#[cfg(test)]
mod tests {
    use super::*;
    use RustQuant_utils::{assert_approx_equal, RUSTQUANT_EPSILON as EPS};
    #[test]
    fn test_uniform_distribution_continuous() {
        let dist: Uniform = Uniform::new(0.0, 1.0, DistributionClass::Continuous);
                let cf = dist.cf(1.0);
        assert_approx_equal!(cf.re, 0.841_470_984_807_896_5, EPS);
        assert_approx_equal!(cf.im, 0.459_697_694_131_860_23, EPS);
                let pmf = dist.pmf(0.5);
        assert_approx_equal!(pmf, 1.0, EPS);
                let cdf = dist.cdf(0.5);
        assert_approx_equal!(cdf, 0.5, EPS);
    }
    #[test]
    fn test_uniform_distribution_discrete() {
        let dist: Uniform = Uniform::new(0.0, 1.0, DistributionClass::Discrete);
                let cf = dist.cf(1.0);
        assert_approx_equal!(cf.re, 0.770_151_152_934_069_9, EPS);
        assert_approx_equal!(cf.im, 0.420_735_492_403_948_36, EPS);
                let pmf = dist.pmf(0.5);
        assert_approx_equal!(pmf, 0.5, EPS);
                let cdf = dist.cdf(0.5);
        assert_approx_equal!(cdf, 0.5, EPS);
    }
}