#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
use crate::consts::TWO_PI;
use crate::data::CdvmSuffStat;
use crate::impl_display;
use crate::misc::func::LogSumExp;
use crate::misc::ln_pflip;
use crate::traits::{
HasDensity, HasSuffStat, Mean, Mode, Parameterized, Sampleable, Support,
};
use rand::Rng;
use std::fmt;
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub struct Cdvm {
modulus: usize,
mu: f64,
k: f64,
log_norm_const: f64,
twopi_over_m: f64,
}
#[derive(Debug, Clone, PartialEq)]
pub struct CdvmParameters {
pub modulus: usize,
pub mu: f64,
pub k: f64,
}
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
#[cfg_attr(feature = "serde1", serde(rename_all = "snake_case"))]
pub enum CdvmError {
MuNotFinite { mu: f64 },
KNotFinite { k: f64 },
KNegative { k: f64 },
InvalidCategories { modulus: usize },
}
impl Cdvm {
pub fn new(modulus: usize, mu: f64, k: f64) -> Result<Self, CdvmError> {
if !mu.is_finite() {
return Err(CdvmError::MuNotFinite { mu });
}
if !k.is_finite() {
return Err(CdvmError::KNotFinite { k });
}
if k < 0.0 {
return Err(CdvmError::KNegative { k });
}
if modulus < 2 {
return Err(CdvmError::InvalidCategories { modulus });
}
Ok(Cdvm::new_unchecked(modulus, mu, k))
}
#[cfg(test)]
fn is_consistent(&self) -> bool {
let other = Cdvm::new(self.modulus, self.mu, self.k).unwrap();
self.mu() == other.mu()
&& self.k() == other.k()
&& self.modulus() == other.modulus()
&& self.log_norm_const() == other.log_norm_const()
&& self.twopi_over_m() == other.twopi_over_m()
}
#[inline]
#[must_use]
pub fn new_unchecked(modulus: usize, mu: f64, k: f64) -> Self {
let log_norm_const = Cdvm::compute_log_norm_const(modulus, mu, k);
Cdvm {
modulus,
mu,
k,
log_norm_const,
twopi_over_m: TWO_PI / modulus as f64,
}
}
fn cdvm_kernel(two_pi_over_m: f64, mu: f64, k: f64, x: usize) -> f64 {
k * ((two_pi_over_m * (x as f64 - mu)).cos())
}
fn compute_log_norm_const(modulus: usize, mu: f64, k: f64) -> f64 {
let two_pi_over_m = TWO_PI / modulus as f64;
(0..modulus)
.map(|x| Cdvm::cdvm_kernel(two_pi_over_m, mu, k, x))
.logsumexp()
}
#[must_use]
pub fn modulus(&self) -> usize {
self.modulus
}
#[must_use]
pub fn mu(&self) -> f64 {
self.mu
}
#[must_use]
pub fn k(&self) -> f64 {
self.k
}
#[must_use]
pub fn twopi_over_m(&self) -> f64 {
self.twopi_over_m
}
fn log_norm_const(&self) -> f64 {
self.log_norm_const
}
pub fn set_mu(&mut self, mu: f64) -> Result<(), CdvmError> {
if !mu.is_finite() {
return Err(CdvmError::MuNotFinite { mu });
}
self.set_mu_unchecked(mu);
Ok(())
}
pub fn set_mu_unchecked(&mut self, mu: f64) {
self.mu = mu;
self.log_norm_const =
Cdvm::compute_log_norm_const(self.modulus, mu, self.k);
}
pub fn set_k(&mut self, k: f64) -> Result<(), CdvmError> {
if !k.is_finite() {
return Err(CdvmError::KNotFinite { k });
}
if k < 0.0 {
return Err(CdvmError::KNegative { k });
}
self.set_k_unchecked(k);
Ok(())
}
pub fn set_k_unchecked(&mut self, k: f64) {
self.k = k;
self.log_norm_const =
Cdvm::compute_log_norm_const(self.modulus, self.mu, k);
}
}
impl Parameterized for Cdvm {
type Parameters = CdvmParameters;
fn emit_params(&self) -> Self::Parameters {
CdvmParameters {
modulus: self.modulus,
mu: self.mu,
k: self.k,
}
}
fn from_params(params: Self::Parameters) -> Self {
Self::new(params.modulus, params.mu, params.k).unwrap()
}
}
impl PartialEq for Cdvm {
fn eq(&self, other: &Cdvm) -> bool {
self.modulus == other.modulus
&& self.mu == other.mu
&& self.k == other.k
}
}
impl From<&Cdvm> for String {
fn from(cdvm: &Cdvm) -> String {
format!(
"CDVM(modulus: {}, μ: {}, κ: {})",
cdvm.modulus, cdvm.mu, cdvm.k
)
}
}
impl Mean<f64> for Cdvm {
fn mean(&self) -> Option<f64> {
Some(self.mu)
}
}
impl Mode<usize> for Cdvm {
fn mode(&self) -> Option<usize> {
Some(self.mu.round() as usize)
}
}
impl_display!(Cdvm);
impl std::error::Error for CdvmError {}
#[cfg_attr(coverage_nightly, coverage(off))]
impl fmt::Display for CdvmError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::MuNotFinite { mu } => {
write!(f, "mu ({mu}) must be finite")
}
Self::KNotFinite { k } => {
write!(f, "k ({k}) must be finite")
}
Self::KNegative { k } => {
write!(f, "k ({k}) must be non-negative")
}
Self::InvalidCategories { modulus } => {
write!(f, "number of categories ({modulus}) must be at least 2")
}
}
}
}
impl HasDensity<usize> for Cdvm {
fn ln_f(&self, x: &usize) -> f64 {
Cdvm::cdvm_kernel(self.twopi_over_m(), self.mu, self.k, *x)
- self.log_norm_const()
}
}
impl Support<usize> for Cdvm {
fn supports(&self, x: &usize) -> bool {
*x < self.modulus
}
}
impl Sampleable<usize> for Cdvm {
fn draw<R: Rng>(&self, rng: &mut R) -> usize {
ln_pflip((0..self.modulus).map(|r| self.ln_f(&r)), true, rng)
}
}
impl HasSuffStat<usize> for Cdvm {
type Stat = CdvmSuffStat;
fn empty_suffstat(&self) -> Self::Stat {
CdvmSuffStat::new(self.modulus)
}
fn ln_f_stat(&self, stat: &Self::Stat) -> f64 {
let twopimu_over_m = self.mu * self.twopi_over_m();
let (sin_twopimu_over_m, cos_twopimu_over_m) = twopimu_over_m.sin_cos();
self.k.mul_add(
stat.sum_cos().mul_add(
cos_twopimu_over_m,
stat.sum_sin() * sin_twopimu_over_m,
),
-(stat.n() as f64 * self.log_norm_const()),
)
}
}
#[cfg(test)]
mod tests {
use crate::misc::x2_test;
use super::*;
use proptest::prelude::*;
use rand::{SeedableRng, rngs::SmallRng};
const TOL: f64 = 1E-12;
#[test]
fn new_should_validate_parameters() {
assert!(Cdvm::new(3, 1.0, 1.5).is_ok());
assert!(matches!(
Cdvm::new(1, 1.0, 1.5),
Err(CdvmError::InvalidCategories { modulus: 1 })
));
assert!(matches!(
Cdvm::new(3, 1.0, -1.5),
Err(CdvmError::KNegative { k: -1.5 })
));
}
#[test]
fn supports_correct_range() {
let cdvm = Cdvm::new(4, 1.0, 1.5).unwrap();
assert!(cdvm.supports(&0));
assert!(cdvm.supports(&1));
assert!(cdvm.supports(&2));
assert!(cdvm.supports(&3));
assert!(!cdvm.supports(&4));
}
proptest! {
#[test]
fn ln_f_symmetry(
m in 3..100_usize,
mu in 0.0..100_f64,
k in 0.1..50.0_f64,
x in 0..100_usize
) {
let mu = mu % (m as f64);
let cdvm1 = Cdvm::new(m, mu, k).unwrap();
let cdvm2 = Cdvm::new(m, (m as f64) - mu, k).unwrap();
let x1 = x % m;
let x2 = m - x1;
let lnf1 = cdvm1.ln_f(&x1);
let lnf2 = cdvm2.ln_f(&x2);
prop_assert!((lnf1 - lnf2).abs() < TOL,
"ln_f not symmetric for m={}, mu={}, k={}, x={}, lnf1={}, lnf2={}", m, mu, k, x, lnf1, lnf2);
}
}
proptest! {
#[test]
fn density_is_normalized(
m in 3..100_usize,
mu in 0.0..100_f64,
k in 0.1..50.0_f64,
) {
let cdvm = Cdvm::new(m, mu, k).unwrap();
let logsum = (0..m).map(|x| cdvm.ln_f(&x)).logsumexp();
prop_assert!((logsum).abs() < TOL,
"density not normalized for m={}, mu={}, k={}, logsum={}", m, mu, k, logsum);
}
}
proptest! {
#[test]
fn wrap_around_invariance(
m in 3..100_usize,
mu in 0.0..100_f64,
k in 0.1..50.0_f64,
x in 0..100_usize,
) {
let mu = mu % (m as f64);
let x = x % m;
let cdvm = Cdvm::new(m, mu, k).unwrap();
prop_assert!((cdvm.ln_f(&x) - cdvm.ln_f(&(x + m))).abs() < TOL,
"ln_f not invariant to wrap-around for m={}, mu={}, k={}, x={}", m, mu, k, x);
}
}
#[test]
fn parameterized_trait() {
let original = Cdvm::new(3, 1.0, 1.5).unwrap();
let params = original.emit_params();
let reconstructed = Cdvm::from_params(params);
assert_eq!(original, reconstructed);
}
proptest! {
#[test]
fn ln_f_matches_ln_f_stat(
m in 3..100_usize,
mu in 0.0..100_f64,
k in 0.1..50.0_f64,
xs in prop::collection::vec(0..100_usize, 1..20),
) {
let mu = mu % (m as f64);
let xs: Vec<usize> = xs.into_iter().map(|x| x % m).collect();
let cdvm = Cdvm::new(m, mu, k).unwrap();
let ln_f_sum: f64 = xs.iter().map(|x| cdvm.ln_f(x)).sum();
let stat = CdvmSuffStat::from_data(m, &xs);
let ln_f_stat = cdvm.ln_f_stat(&stat);
assert!((ln_f_sum - ln_f_stat).abs() < TOL,
"ln_f_sum ({ln_f_sum}) != ln_f_stat ({ln_f_stat}) for m={m}, mu={mu}, k={k}, xs={xs:?}");
}
}
proptest! {
#[test]
fn set_k_maintains_consistency(
m in 3..100_usize,
mu in 0.0..100_f64,
k1 in 0.1..50.0_f64,
k2 in 0.1..50.0_f64,
) {
let mu = mu % (m as f64);
let mut cdvm = Cdvm::new(m, mu, k1).unwrap();
cdvm.set_k(k2).unwrap();
prop_assert!(cdvm.is_consistent(),
"CDVM not consistent after set_k: m={}, mu={}, k1={}, k2={}", m, mu, k1, k2);
}
}
proptest! {
#[test]
fn set_mu_maintains_consistency(
m in 3..100_usize,
mu1 in 0.0..100_f64,
mu2 in 0.0..100_f64,
k in 0.1..50.0_f64,
) {
let mu1 = mu1 % (m as f64);
let mu2 = mu2 % (m as f64);
let mut cdvm = Cdvm::new(m, mu1, k).unwrap();
cdvm.set_mu(mu2).unwrap();
prop_assert!(cdvm.is_consistent(),
"CDVM not consistent after set_mu: m={}, mu1={}, mu2={}, k={}", m, mu1, mu2, k);
}
}
#[test]
fn f_is_probability_measure() {
let dist = Cdvm::new_unchecked(10, 5.0, 0.5);
assert::close((0..10).map(|i| dist.f(&i)).sum::<f64>(), 1.0, 1e-10);
}
#[test]
fn ln_f_agrees_with_draw() {
let mut rng = SmallRng::from_os_rng();
let dist = Cdvm::new_unchecked(10, 5.0, 0.5);
let sample = dist.sample(100_000, &mut rng);
let ps: Vec<f64> = (0..10).map(|i| dist.f(&i)).collect();
let observed_counts =
sample.into_iter().fold(vec![0; 10], |mut acc, x| {
acc[x] += 1;
acc
});
let (_, p) = x2_test(&observed_counts, &ps);
assert!(p > 0.05);
}
#[test]
fn emit_and_from_params_are_identity() {
let dist_a = Cdvm::new(10, 5.0, 6.0).unwrap();
let dist_b = Cdvm::from_params(dist_a.emit_params());
assert_eq!(dist_a, dist_b);
}
}