1#![doc = include_str!("../README.md")]
2use std::{fmt, num::NonZeroUsize};
3
4pub use changepoint::rv::{dist, process::gaussian::kernel};
5use changepoint::{
6 rv::{
7 dist::{Gaussian, NormalGamma, NormalInvGamma},
8 process::gaussian::kernel::{
9 AddKernel, ConstantKernel, Kernel, ProductKernel, RBFKernel, WhiteKernel,
10 },
11 traits::{ConjugatePrior, HasSuffStat, Rv},
12 },
13 utils::map_changepoints,
14 Argpcp, BocpdLike, BocpdTruncated,
15};
16use itertools::Itertools;
17
18pub trait Detector {
20 fn detect_changepoints(&mut self, y: &[f64]) -> Vec<usize>;
22}
23
24#[derive(Debug, Clone)]
32pub struct BocpdDetector<Dist, Prior>
33where
34 Dist: Rv<f64> + HasSuffStat<f64>,
35 Prior: ConjugatePrior<f64, Dist> + Clone,
36 Dist::Stat: Clone + fmt::Debug,
37{
38 detector: BocpdTruncated<f64, Dist, Prior>,
39}
40
41impl<Dist, Prior> Detector for BocpdDetector<Dist, Prior>
42where
43 Dist: Rv<f64> + HasSuffStat<f64>,
44 Prior: ConjugatePrior<f64, Dist, Posterior = Prior> + Clone,
45 Dist::Stat: Clone + fmt::Debug,
46{
47 fn detect_changepoints(&mut self, y: &[f64]) -> Vec<usize> {
48 let run_lengths = y
49 .iter()
50 .map(|d| self.detector.step(d).to_vec())
51 .collect_vec();
52 map_changepoints(&run_lengths)
53 }
54}
55
56pub type NormalGammaDetector = BocpdDetector<Gaussian, NormalGamma>;
58pub type NormalInvGammaDetector = BocpdDetector<Gaussian, NormalInvGamma>;
60
61impl NormalGammaDetector {
62 pub fn normal_gamma(hazard_lambda: f64, prior: NormalGamma) -> Self {
64 Self {
65 detector: BocpdTruncated::new(hazard_lambda, prior),
66 }
67 }
68}
69
70impl NormalInvGammaDetector {
71 pub fn normal_inv_gamma(hazard_lambda: f64, prior: NormalInvGamma) -> Self {
73 Self {
74 detector: BocpdTruncated::new(hazard_lambda, prior),
75 }
76 }
77}
78
79impl Default for NormalGammaDetector {
80 fn default() -> Self {
81 Self::normal_gamma(250.0, NormalGamma::new_unchecked(0.0, 1.0, 1.0, 1.0))
82 }
83}
84
85impl Default for NormalInvGammaDetector {
86 fn default() -> Self {
87 Self::normal_inv_gamma(250.0, NormalInvGamma::new_unchecked(0.0, 1.0, 1.0, 1.0))
88 }
89}
90
91type DefaultKernel = AddKernel<ProductKernel<ConstantKernel, RBFKernel>, WhiteKernel>;
92pub type DefaultArgpcpDetector = ArgpcpDetector<DefaultKernel>;
94
95#[derive(Debug, Clone)]
103pub struct ArgpcpDetector<K>
104where
105 K: Kernel,
106{
107 detector: Argpcp<K>,
108}
109
110impl<K> From<Argpcp<K>> for ArgpcpDetector<K>
111where
112 K: Kernel,
113{
114 fn from(detector: Argpcp<K>) -> Self {
115 Self { detector }
116 }
117}
118
119impl DefaultArgpcpDetector {
120 pub fn builder() -> DefaultArgpcpDetectorBuilder {
122 DefaultArgpcpDetectorBuilder::default()
123 }
124}
125
126impl<K> Detector for ArgpcpDetector<K>
127where
128 K: Kernel,
129{
130 fn detect_changepoints(&mut self, y: &[f64]) -> Vec<usize> {
131 let run_lengths = y
132 .iter()
133 .map(|d| self.detector.step(d).to_vec())
134 .collect_vec();
135 map_changepoints(&run_lengths)
136 }
137}
138
139impl Default for DefaultArgpcpDetector {
140 fn default() -> Self {
141 DefaultArgpcpDetectorBuilder::default().build()
142 }
143}
144
145#[derive(Debug, Clone)]
147pub struct DefaultArgpcpDetectorBuilder {
148 constant_value: f64,
149 length_scale: f64,
150 noise_level: f64,
151 max_lag: NonZeroUsize,
152 alpha0: f64,
153 beta0: f64,
154 logistic_hazard_h: f64,
155 logistic_hazard_a: f64,
156 logistic_hazard_b: f64,
157}
158
159impl DefaultArgpcpDetectorBuilder {
160 pub fn constant_value(mut self, cv: f64) -> Self {
162 self.constant_value = cv;
163 self
164 }
165}
166
167impl Default for DefaultArgpcpDetectorBuilder {
168 fn default() -> Self {
169 Self {
170 constant_value: 0.5,
171 length_scale: 10.0,
172 noise_level: 0.01,
173 max_lag: NonZeroUsize::new(3).unwrap(),
174 alpha0: 2.0,
175 beta0: 1.0,
176 logistic_hazard_h: -5.0,
177 logistic_hazard_a: 1.0,
178 logistic_hazard_b: 1.0,
179 }
180 }
181}
182
183impl DefaultArgpcpDetectorBuilder {
184 pub fn build(self) -> DefaultArgpcpDetector {
186 DefaultArgpcpDetector {
187 detector: Argpcp::new(
188 ConstantKernel::new_unchecked(self.constant_value)
189 * RBFKernel::new_unchecked(self.length_scale)
190 + WhiteKernel::new_unchecked(self.noise_level),
191 self.max_lag.into(),
192 self.alpha0,
193 self.beta0,
194 self.logistic_hazard_h,
195 self.logistic_hazard_a,
196 self.logistic_hazard_b,
197 ),
198 }
199 }
200}