use crate::run_length_detector::RunLengthDetector;
use rv::prelude::*;
use std::collections::VecDeque;
#[cfg(feature = "serde1")]
use serde::{Deserialize, Serialize};
#[cfg_attr(feature = "serde1", derive(Serialize, Deserialize))]
pub struct Bocpd<X, H, Fx, Pr>
where
H: Fn(usize) -> f64,
Fx: Rv<X> + HasSuffStat<X>,
Pr: ConjugatePrior<X, Fx>,
Fx::Stat: Clone,
{
hazard: H,
predictive_prior: Pr,
suff_stats: VecDeque<Fx::Stat>,
t: usize,
r: Vec<f64>,
empty_suffstat: Fx::Stat,
cdf_threshold: f64,
}
impl<X, H, Fx, Pr> Bocpd<X, H, Fx, Pr>
where
H: Fn(usize) -> f64,
Fx: Rv<X> + HasSuffStat<X>,
Pr: ConjugatePrior<X, Fx>,
Fx::Stat: Clone,
{
pub fn new(hazard: H, fx: Fx, predictive_prior: Pr) -> Self {
Self {
hazard,
predictive_prior,
suff_stats: VecDeque::new(),
t: 0,
r: Vec::new(),
empty_suffstat: fx.empty_suffstat(),
cdf_threshold: 1E-3,
}
}
}
impl<X, H, Fx, Pr> RunLengthDetector<X> for Bocpd<X, H, Fx, Pr>
where
H: Fn(usize) -> f64,
Fx: Rv<X> + HasSuffStat<X>,
Pr: ConjugatePrior<X, Fx>,
Fx::Stat: Clone,
{
fn step(&mut self, data: &X) -> &[f64] {
self.suff_stats.push_front(self.empty_suffstat.clone());
if self.t == 0 {
self.r.push(1.0);
} else {
self.r.push(0.0);
let mut r0 = 0.0;
let mut r_sum = 0.0;
let mut r_seen = 0.0;
for i in (0..self.t).rev() {
if self.r[i] == 0.0 {
self.r[i + 1] = 0.0;
} else {
let pp = self
.predictive_prior
.ln_pp(
data,
&DataOrSuffStat::SuffStat(&self.suff_stats[i]),
)
.exp();
r_seen += self.r[i];
let h = (self.hazard)(i);
self.r[i + 1] = self.r[i] * pp * (1.0 - h);
r0 += self.r[i] * pp * h;
r_sum += self.r[i + 1];
if 1.0 - r_seen < self.cdf_threshold {
break;
}
}
}
r_sum += r0;
self.r[0] = r0;
for i in 0..=self.t {
self.r[i] /= r_sum;
}
}
self.suff_stats
.iter_mut()
.for_each(|stat| stat.observe(data));
self.t = self.t + 1;
&self.r
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::{ChangePointDetectionMethod, MostLikelyPathWrapper};
use crate::{
constant_hazard, generators, MapPathDetector, RunLengthDetector,
};
use rand::rngs::StdRng;
use rand::SeedableRng;
#[test]
fn each_vec_is_a_probability_dist() {
let mut rng: StdRng = StdRng::seed_from_u64(0xABCD);
let data = generators::discontinuous_jump(
&mut rng, 0.0, 1.0, 10.0, 5.0, 500, 1000,
);
let mut cpd = Bocpd::new(
constant_hazard(250.0),
Gaussian::standard(),
NormalGamma::new(0.0, 1.0, 1.0, 1.0).unwrap(),
);
let res: Vec<Vec<f64>> =
data.iter().map(|d| cpd.step(d).to_vec()).collect();
for row in res.iter() {
let sum: f64 = row.iter().sum();
assert::close(sum, 1.0, 1E-8);
}
}
#[test]
fn detect_obvious_switch() {
let mut rng: StdRng = StdRng::seed_from_u64(0xABCD);
let data = generators::discontinuous_jump(
&mut rng, 0.0, 1.0, 10.0, 5.0, 500, 1000,
);
let mut cpd = MostLikelyPathWrapper::new(Bocpd::new(
constant_hazard(250.0),
Gaussian::standard(),
NormalGamma::new_unchecked(0.0, 1.0, 1.0, 1.0),
));
let res: Vec<Vec<f64>> = data
.iter()
.map(|d| cpd.step(d).map_path_probs.clone().into())
.collect();
let change_points =
ChangePointDetectionMethod::NonIncremental.detect(&res);
assert_eq!(change_points, vec![500, 501]);
}
#[test]
fn coal_mining_data() {
let data = generators::coal_mining_incidents();
let mut cpd = MostLikelyPathWrapper::new(Bocpd::new(
constant_hazard(100.0),
Poisson::new_unchecked(123.0),
Gamma::new_unchecked(1.0, 1.0),
));
let res: Vec<Vec<f64>> = data
.iter()
.map(|d| cpd.step(d).map_path_probs.clone().into())
.collect();
let change_points =
ChangePointDetectionMethod::DropThreshold(0.5).detect(&res);
assert_eq!(change_points, vec![50, 107]);
}
#[test]
fn treasury_changes() {
let raw_data: &str = include_str!("../resources/TB3MS.csv");
let data: Vec<f64> = raw_data
.lines()
.skip(1)
.map(|line| {
let (_, line) = line.split_at(11);
line.parse().unwrap()
})
.collect();
let mut cpd = MostLikelyPathWrapper::new(Bocpd::new(
constant_hazard(250.0),
Gaussian::standard(),
NormalGamma::new_unchecked(0.0, 1.0, 1.0, 1E-5),
));
let res: Vec<Vec<f64>> = data
.iter()
.zip(data.iter().skip(1))
.map(|(a, b)| (b - a) / a)
.map(|d| cpd.step(&d).map_path_probs.clone().into())
.collect();
let change_points =
ChangePointDetectionMethod::DropThreshold(0.1).detect(&res);
assert_eq!(change_points, vec![135, 295, 897, 981, 1010]);
}
}