#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::traits::SuffStat;
#[derive(Debug, Clone, PartialEq, Copy)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct VonMisesSuffStat {
n: usize,
sum_sin: f64,
sum_cos: f64,
}
impl VonMisesSuffStat {
#[inline]
#[must_use]
pub fn new() -> Self {
VonMisesSuffStat {
n: 0,
sum_sin: 0.0,
sum_cos: 0.0,
}
}
#[inline]
#[must_use]
pub fn from_parts_unchecked(n: usize, sum_cos: f64, sum_sin: f64) -> Self {
VonMisesSuffStat {
n,
sum_sin,
sum_cos,
}
}
#[must_use]
pub fn from_data(xs: &[f64]) -> Self {
let mut stat = VonMisesSuffStat::new();
for x in xs {
stat.observe(x);
}
stat
}
#[inline]
#[must_use]
pub fn n(&self) -> usize {
self.n
}
#[inline]
#[must_use]
pub fn sum_cos(&self) -> f64 {
self.sum_cos
}
#[inline]
#[must_use]
pub fn sum_sin(&self) -> f64 {
self.sum_sin
}
}
impl Default for VonMisesSuffStat {
fn default() -> Self {
Self::new()
}
}
impl From<&Vec<f64>> for VonMisesSuffStat {
fn from(xs: &Vec<f64>) -> Self {
Self::from_data(xs)
}
}
impl From<&[f64]> for VonMisesSuffStat {
fn from(xs: &[f64]) -> Self {
Self::from_data(xs)
}
}
impl<const N: usize> From<&[f64; N]> for VonMisesSuffStat {
fn from(xs: &[f64; N]) -> Self {
Self::from_data(xs)
}
}
impl SuffStat<f64> for VonMisesSuffStat {
fn n(&self) -> usize {
self.n
}
fn observe(&mut self, x: &f64) {
let (sin_x, cos_x) = x.sin_cos();
self.sum_sin += sin_x;
self.sum_cos += cos_x;
self.n += 1;
}
fn forget(&mut self, x: &f64) {
let (sin_x, cos_x) = x.sin_cos();
self.sum_sin -= sin_x;
self.sum_cos -= cos_x;
self.n -= 1;
}
fn merge(&mut self, other: Self) {
if other.n == 0 {
return;
}
self.n += other.n;
self.sum_sin += other.sum_sin;
self.sum_cos += other.sum_cos;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_suffstat_has_zero_n() {
let stat = VonMisesSuffStat::new();
assert_eq!(stat.n(), 0);
}
#[test]
fn observe_increments_n() {
let mut stat = VonMisesSuffStat::new();
stat.observe(&1.0);
assert_eq!(stat.n(), 1);
}
#[test]
fn forget_decrements_n() {
let mut stat = VonMisesSuffStat::new();
stat.observe(&1.0);
stat.forget(&1.0);
assert_eq!(stat.n(), 0);
}
#[test]
fn merge_adds_n() {
let mut stat1 = VonMisesSuffStat::new();
let mut stat2 = VonMisesSuffStat::new();
stat1.observe(&1.0);
stat2.observe(&2.0);
stat1.merge(stat2);
assert_eq!(stat1.n(), 2);
}
#[test]
fn merge_empty_stat_does_nothing() {
let mut stat1 = VonMisesSuffStat::new();
let stat2 = VonMisesSuffStat::new();
stat1.observe(&1.0);
stat1.merge(stat2);
assert_eq!(stat1.n(), 1);
}
#[test]
fn from_data_empty_vec() {
let data: Vec<f64> = vec![];
let stat = VonMisesSuffStat::from_data(&data);
assert_eq!(stat.n(), 0);
}
#[test]
fn from_empty_vec() {
let data: Vec<f64> = vec![];
let stat = VonMisesSuffStat::from(&data);
assert_eq!(stat.n(), 0);
}
#[test]
fn from_empty_slice() {
let data: &[f64] = &[];
let stat = VonMisesSuffStat::from(data);
assert_eq!(stat.n(), 0);
}
#[test]
fn from_vec() {
let data = vec![0.0, std::f64::consts::PI / 2.0, std::f64::consts::PI];
let stat = VonMisesSuffStat::from(&data);
assert_eq!(stat.n(), 3);
assert::close(stat.sum_cos(), 0.0, 1e-14); assert::close(stat.sum_sin(), 1.0, 1e-14); }
#[test]
fn from_slice() {
let data = [0.0, std::f64::consts::PI / 2.0, std::f64::consts::PI];
let stat = VonMisesSuffStat::from(data.as_slice());
assert_eq!(stat.n(), 3);
assert::close(stat.sum_cos(), 0.0, 1e-14); assert::close(stat.sum_sin(), 1.0, 1e-14); }
}