Documentation
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};

use crate::traits::SuffStat;

/// `VonMises` sufficient statistic.
///
/// Holds the number of observations, their sum, and the sum of their squared
/// values.
#[derive(Debug, Clone, PartialEq, Copy)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct VonMisesSuffStat {
    /// Number of observations
    n: usize,
    /// ∑ⱼ sin(xⱼ)
    sum_sin: f64,
    /// ∑ⱼ cos(xⱼ)
    sum_cos: f64,
}

impl VonMisesSuffStat {
    #[inline]
    #[must_use]
    pub fn new() -> Self {
        VonMisesSuffStat {
            n: 0,
            sum_sin: 0.0,
            sum_cos: 0.0,
        }
    }

    /// Create a sufficient statistic from components without checking whether
    /// they are valid.
    #[inline]
    #[must_use]
    pub fn from_parts_unchecked(n: usize, sum_cos: f64, sum_sin: f64) -> Self {
        VonMisesSuffStat {
            n,
            sum_sin,
            sum_cos,
        }
    }

    /// Create a sufficient statistic from a slice of data
    #[must_use]
    pub fn from_data(xs: &[f64]) -> Self {
        let mut stat = VonMisesSuffStat::new();
        for x in xs {
            stat.observe(x);
        }
        stat
    }

    /// Get the number of observations
    #[inline]
    #[must_use]
    pub fn n(&self) -> usize {
        self.n
    }

    /// Get the sum of cosines
    #[inline]
    #[must_use]
    pub fn sum_cos(&self) -> f64 {
        self.sum_cos
    }

    /// Get the sum of sines
    #[inline]
    #[must_use]
    pub fn sum_sin(&self) -> f64 {
        self.sum_sin
    }
}

impl Default for VonMisesSuffStat {
    fn default() -> Self {
        Self::new()
    }
}

impl From<&Vec<f64>> for VonMisesSuffStat {
    fn from(xs: &Vec<f64>) -> Self {
        Self::from_data(xs)
    }
}

impl From<&[f64]> for VonMisesSuffStat {
    fn from(xs: &[f64]) -> Self {
        Self::from_data(xs)
    }
}

impl<const N: usize> From<&[f64; N]> for VonMisesSuffStat {
    fn from(xs: &[f64; N]) -> Self {
        Self::from_data(xs)
    }
}

impl SuffStat<f64> for VonMisesSuffStat {
    fn n(&self) -> usize {
        self.n
    }

    fn observe(&mut self, x: &f64) {
        let (sin_x, cos_x) = x.sin_cos();
        self.sum_sin += sin_x;
        self.sum_cos += cos_x;
        self.n += 1;
    }

    fn forget(&mut self, x: &f64) {
        let (sin_x, cos_x) = x.sin_cos();
        self.sum_sin -= sin_x;
        self.sum_cos -= cos_x;
        self.n -= 1;
    }

    fn merge(&mut self, other: Self) {
        if other.n == 0 {
            return;
        }
        self.n += other.n;
        self.sum_sin += other.sum_sin;
        self.sum_cos += other.sum_cos;
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn empty_suffstat_has_zero_n() {
        let stat = VonMisesSuffStat::new();
        assert_eq!(stat.n(), 0);
    }

    #[test]
    fn observe_increments_n() {
        let mut stat = VonMisesSuffStat::new();
        stat.observe(&1.0);
        assert_eq!(stat.n(), 1);
    }

    #[test]
    fn forget_decrements_n() {
        let mut stat = VonMisesSuffStat::new();
        stat.observe(&1.0);
        stat.forget(&1.0);
        assert_eq!(stat.n(), 0);
    }

    #[test]
    fn merge_adds_n() {
        let mut stat1 = VonMisesSuffStat::new();
        let mut stat2 = VonMisesSuffStat::new();
        stat1.observe(&1.0);
        stat2.observe(&2.0);
        stat1.merge(stat2);
        assert_eq!(stat1.n(), 2);
    }

    #[test]
    fn merge_empty_stat_does_nothing() {
        let mut stat1 = VonMisesSuffStat::new();
        let stat2 = VonMisesSuffStat::new();
        stat1.observe(&1.0);
        stat1.merge(stat2);
        assert_eq!(stat1.n(), 1);
    }

    #[test]
    fn from_data_empty_vec() {
        let data: Vec<f64> = vec![];
        let stat = VonMisesSuffStat::from_data(&data);
        assert_eq!(stat.n(), 0);
    }

    #[test]
    fn from_empty_vec() {
        let data: Vec<f64> = vec![];
        let stat = VonMisesSuffStat::from(&data);
        assert_eq!(stat.n(), 0);
    }

    #[test]
    fn from_empty_slice() {
        let data: &[f64] = &[];
        let stat = VonMisesSuffStat::from(data);
        assert_eq!(stat.n(), 0);
    }

    #[test]
    fn from_vec() {
        let data = vec![0.0, std::f64::consts::PI / 2.0, std::f64::consts::PI];
        let stat = VonMisesSuffStat::from(&data);
        assert_eq!(stat.n(), 3);
        assert::close(stat.sum_cos(), 0.0, 1e-14); // cos(0) + cos(π/2) + cos(π) = 1 + 0 + (-1) = 0
        assert::close(stat.sum_sin(), 1.0, 1e-14); // sin(0) + sin(π/2) + sin(π) = 0 + 1 + 0 = 1
    }

    #[test]
    fn from_slice() {
        let data = [0.0, std::f64::consts::PI / 2.0, std::f64::consts::PI];
        let stat = VonMisesSuffStat::from(data.as_slice());
        assert_eq!(stat.n(), 3);
        assert::close(stat.sum_cos(), 0.0, 1e-14); // cos(0) + cos(π/2) + cos(π) = 1 + 0 + (-1) = 0
        assert::close(stat.sum_sin(), 1.0, 1e-14); // sin(0) + sin(π/2) + sin(π) = 0 + 1 + 0 = 1
    }
}