augurs_changepoint/
lib.rs

1#![doc = include_str!("../README.md")]
2use std::{fmt, num::NonZeroUsize};
3
4pub use changepoint::rv::{dist, process::gaussian::kernel};
5use changepoint::{
6    rv::{
7        dist::{Gaussian, NormalGamma, NormalInvGamma},
8        process::gaussian::kernel::{
9            AddKernel, ConstantKernel, Kernel, ProductKernel, RBFKernel, WhiteKernel,
10        },
11        traits::{ConjugatePrior, HasSuffStat, Rv},
12    },
13    utils::map_changepoints,
14    Argpcp, BocpdLike, BocpdTruncated,
15};
16use itertools::Itertools;
17
18/// Trait implemented by changepoint detectors.
19pub trait Detector {
20    /// Detect changepoints in the provided vector, returning their indices.
21    fn detect_changepoints(&mut self, y: &[f64]) -> Vec<usize>;
22}
23
24/// A changepoint detector using Bayesian Online Changepoint Detection.
25///
26/// Based on [this paper][paper], using the implementation from the
27/// [changepoint] crate.
28///
29/// [changepoint]: https://crates.io/crates/changepoint
30/// [paper]: https://arxiv.org/abs/0710.3742
31#[derive(Debug, Clone)]
32pub struct BocpdDetector<Dist, Prior>
33where
34    Dist: Rv<f64> + HasSuffStat<f64>,
35    Prior: ConjugatePrior<f64, Dist> + Clone,
36    Dist::Stat: Clone + fmt::Debug,
37{
38    detector: BocpdTruncated<f64, Dist, Prior>,
39}
40
41impl<Dist, Prior> Detector for BocpdDetector<Dist, Prior>
42where
43    Dist: Rv<f64> + HasSuffStat<f64>,
44    Prior: ConjugatePrior<f64, Dist, Posterior = Prior> + Clone,
45    Dist::Stat: Clone + fmt::Debug,
46{
47    fn detect_changepoints(&mut self, y: &[f64]) -> Vec<usize> {
48        let run_lengths = y
49            .iter()
50            .map(|d| self.detector.step(d).to_vec())
51            .collect_vec();
52        map_changepoints(&run_lengths)
53    }
54}
55
56/// A [`BocpdDetector`] for Normal data with a Normal Gamma prior.
57pub type NormalGammaDetector = BocpdDetector<Gaussian, NormalGamma>;
58/// A [`BocpdDetector`] for Normal data with a Normal inverse-Gamma prior.
59pub type NormalInvGammaDetector = BocpdDetector<Gaussian, NormalInvGamma>;
60
61impl NormalGammaDetector {
62    /// Create a detector for Normal data using the given hazard_lambda and prior.
63    pub fn normal_gamma(hazard_lambda: f64, prior: NormalGamma) -> Self {
64        Self {
65            detector: BocpdTruncated::new(hazard_lambda, prior),
66        }
67    }
68}
69
70impl NormalInvGammaDetector {
71    /// Create a detector for Normal data using the given hazard_lambda and prior.
72    pub fn normal_inv_gamma(hazard_lambda: f64, prior: NormalInvGamma) -> Self {
73        Self {
74            detector: BocpdTruncated::new(hazard_lambda, prior),
75        }
76    }
77}
78
79impl Default for NormalGammaDetector {
80    fn default() -> Self {
81        Self::normal_gamma(250.0, NormalGamma::new_unchecked(0.0, 1.0, 1.0, 1.0))
82    }
83}
84
85impl Default for NormalInvGammaDetector {
86    fn default() -> Self {
87        Self::normal_inv_gamma(250.0, NormalInvGamma::new_unchecked(0.0, 1.0, 1.0, 1.0))
88    }
89}
90
91type DefaultKernel = AddKernel<ProductKernel<ConstantKernel, RBFKernel>, WhiteKernel>;
92/// An [`ArgpcpDetector`] with a sensible default choice of kernel.
93pub type DefaultArgpcpDetector = ArgpcpDetector<DefaultKernel>;
94
95/// A changepoint detector using autoregressive Gaussian Processes.
96///
97/// Based on [Ryan Turner's thesis][thesis], using the implementation from
98/// the [changepoint] crate.
99///
100/// [thesis]: https://www.repository.cam.ac.uk/bitstream/handle/1810/242181/thesis.pdf?sequence=1&isAllowed=y
101/// [changepoint]: https://crates.io/crates/changepoint
102#[derive(Debug, Clone)]
103pub struct ArgpcpDetector<K>
104where
105    K: Kernel,
106{
107    detector: Argpcp<K>,
108}
109
110impl<K> From<Argpcp<K>> for ArgpcpDetector<K>
111where
112    K: Kernel,
113{
114    fn from(detector: Argpcp<K>) -> Self {
115        Self { detector }
116    }
117}
118
119impl DefaultArgpcpDetector {
120    /// Get a builder to configure the parameters of the detector.
121    pub fn builder() -> DefaultArgpcpDetectorBuilder {
122        DefaultArgpcpDetectorBuilder::default()
123    }
124}
125
126impl<K> Detector for ArgpcpDetector<K>
127where
128    K: Kernel,
129{
130    fn detect_changepoints(&mut self, y: &[f64]) -> Vec<usize> {
131        let run_lengths = y
132            .iter()
133            .map(|d| self.detector.step(d).to_vec())
134            .collect_vec();
135        map_changepoints(&run_lengths)
136    }
137}
138
139impl Default for DefaultArgpcpDetector {
140    fn default() -> Self {
141        DefaultArgpcpDetectorBuilder::default().build()
142    }
143}
144
145/// Builder for a [`DefaultArgpcpDetector`].
146#[derive(Debug, Clone)]
147pub struct DefaultArgpcpDetectorBuilder {
148    constant_value: f64,
149    length_scale: f64,
150    noise_level: f64,
151    max_lag: NonZeroUsize,
152    alpha0: f64,
153    beta0: f64,
154    logistic_hazard_h: f64,
155    logistic_hazard_a: f64,
156    logistic_hazard_b: f64,
157}
158
159impl DefaultArgpcpDetectorBuilder {
160    /// Set the value for the constant kernel.
161    pub fn constant_value(mut self, cv: f64) -> Self {
162        self.constant_value = cv;
163        self
164    }
165}
166
167impl Default for DefaultArgpcpDetectorBuilder {
168    fn default() -> Self {
169        Self {
170            constant_value: 0.5,
171            length_scale: 10.0,
172            noise_level: 0.01,
173            max_lag: NonZeroUsize::new(3).unwrap(),
174            alpha0: 2.0,
175            beta0: 1.0,
176            logistic_hazard_h: -5.0,
177            logistic_hazard_a: 1.0,
178            logistic_hazard_b: 1.0,
179        }
180    }
181}
182
183impl DefaultArgpcpDetectorBuilder {
184    /// Build this [`DefaultArgpcpDetector`].
185    pub fn build(self) -> DefaultArgpcpDetector {
186        DefaultArgpcpDetector {
187            detector: Argpcp::new(
188                ConstantKernel::new_unchecked(self.constant_value)
189                    * RBFKernel::new_unchecked(self.length_scale)
190                    + WhiteKernel::new_unchecked(self.noise_level),
191                self.max_lag.into(),
192                self.alpha0,
193                self.beta0,
194                self.logistic_hazard_h,
195                self.logistic_hazard_a,
196                self.logistic_hazard_b,
197            ),
198        }
199    }
200}