Skip to main content

oxigdal_temporal/change/
breakpoint.rs

1//! Breakpoint Detection Module
2//!
3//! Implements breakpoint detection algorithms for identifying structural breaks
4//! in time series data, including PELT, binary segmentation, and changepoint detection.
5
6use crate::error::{Result, TemporalError};
7use crate::timeseries::TimeSeriesRaster;
8use serde::{Deserialize, Serialize};
9use tracing::info;
10
11/// Breakpoint detection method
12#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
13pub enum BreakpointMethod {
14    /// Binary segmentation
15    BinarySegmentation,
16    /// PELT (Pruned Exact Linear Time)
17    PELT,
18    /// CUSUM-based breakpoint
19    CUSUM,
20    /// Simple threshold crossing
21    ThresholdCrossing,
22}
23
24/// Breakpoint detection result
25#[derive(Debug, Clone)]
26pub struct BreakpointResult {
27    /// Breakpoint locations (time indices)
28    pub breakpoints: Vec<usize>,
29    /// Breakpoint scores/confidence
30    pub scores: Vec<f64>,
31    /// Segments between breakpoints
32    pub segments: Vec<Segment>,
33}
34
35/// Time series segment
36#[derive(Debug, Clone)]
37pub struct Segment {
38    /// Start index
39    pub start: usize,
40    /// End index (exclusive)
41    pub end: usize,
42    /// Segment mean
43    pub mean: f64,
44    /// Segment variance
45    pub variance: f64,
46}
47
48impl BreakpointResult {
49    /// Create new breakpoint result
50    #[must_use]
51    pub fn new(breakpoints: Vec<usize>, scores: Vec<f64>) -> Self {
52        Self {
53            breakpoints,
54            scores,
55            segments: Vec::new(),
56        }
57    }
58
59    /// Add segments
60    #[must_use]
61    pub fn with_segments(mut self, segments: Vec<Segment>) -> Self {
62        self.segments = segments;
63        self
64    }
65}
66
67/// Breakpoint detector
68pub struct BreakpointDetector;
69
70impl BreakpointDetector {
71    /// Detect breakpoints in time series
72    ///
73    /// # Errors
74    /// Returns error if detection fails
75    pub fn detect(
76        ts: &TimeSeriesRaster,
77        method: BreakpointMethod,
78        params: BreakpointParams,
79    ) -> Result<Vec<BreakpointResult>> {
80        match method {
81            BreakpointMethod::BinarySegmentation => {
82                Self::binary_segmentation(ts, params.max_breakpoints, params.min_segment_length)
83            }
84            BreakpointMethod::PELT => Self::pelt(ts, params.penalty),
85            BreakpointMethod::CUSUM => Self::cusum_breakpoint(ts, params.threshold),
86            BreakpointMethod::ThresholdCrossing => Self::threshold_crossing(ts, params.threshold),
87        }
88    }
89
90    /// Binary segmentation for breakpoint detection
91    fn binary_segmentation(
92        ts: &TimeSeriesRaster,
93        max_breakpoints: usize,
94        min_segment_length: usize,
95    ) -> Result<Vec<BreakpointResult>> {
96        if ts.len() < min_segment_length * 2 {
97            return Err(TemporalError::insufficient_data(format!(
98                "Need at least {} observations",
99                min_segment_length * 2
100            )));
101        }
102
103        let (height, width, n_bands) = ts
104            .expected_shape()
105            .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
106
107        let mut results = Vec::new();
108
109        for i in 0..height {
110            for j in 0..width {
111                for k in 0..n_bands {
112                    let values = ts.extract_pixel_timeseries(i, j, k)?;
113
114                    let mut breakpoints = Vec::new();
115                    let mut scores = Vec::new();
116                    let mut segments = vec![(0, values.len())];
117
118                    for _ in 0..max_breakpoints {
119                        let mut best_breakpoint = None;
120                        let mut best_score = f64::NEG_INFINITY;
121
122                        for &(start, end) in &segments {
123                            if end - start < min_segment_length * 2 {
124                                continue;
125                            }
126
127                            let segment = &values[start..end];
128                            if let Some((bp, score)) =
129                                Self::find_best_split(segment, min_segment_length)
130                            {
131                                let abs_bp = start + bp;
132                                if score > best_score {
133                                    best_score = score;
134                                    best_breakpoint = Some((abs_bp, start, end));
135                                }
136                            }
137                        }
138
139                        if let Some((bp, seg_start, seg_end)) = best_breakpoint {
140                            breakpoints.push(bp);
141                            scores.push(best_score);
142
143                            // Update segments
144                            segments.retain(|&(s, e)| s != seg_start || e != seg_end);
145                            segments.push((seg_start, bp));
146                            segments.push((bp, seg_end));
147                        } else {
148                            break;
149                        }
150                    }
151
152                    // Build segments with statistics
153                    segments.sort_by_key(|&(s, _)| s);
154                    let segment_stats: Vec<Segment> = segments
155                        .iter()
156                        .map(|&(start, end)| {
157                            let seg_values = &values[start..end];
158                            let mean = seg_values.iter().sum::<f64>() / seg_values.len() as f64;
159                            let variance =
160                                seg_values.iter().map(|v| (v - mean).powi(2)).sum::<f64>()
161                                    / seg_values.len() as f64;
162
163                            Segment {
164                                start,
165                                end,
166                                mean,
167                                variance,
168                            }
169                        })
170                        .collect();
171
172                    if !breakpoints.is_empty() {
173                        results.push(
174                            BreakpointResult::new(breakpoints, scores).with_segments(segment_stats),
175                        );
176                    }
177                }
178            }
179        }
180
181        info!("Completed binary segmentation breakpoint detection");
182        Ok(results)
183    }
184
185    /// Find best split point in a segment
186    fn find_best_split(segment: &[f64], min_len: usize) -> Option<(usize, f64)> {
187        if segment.len() < min_len * 2 {
188            return None;
189        }
190
191        let mut best_split = None;
192        let mut best_score = f64::NEG_INFINITY;
193
194        for i in min_len..(segment.len() - min_len) {
195            let left = &segment[..i];
196            let right = &segment[i..];
197
198            let score = Self::calculate_split_score(left, right);
199
200            if score > best_score {
201                best_score = score;
202                best_split = Some(i);
203            }
204        }
205
206        best_split.map(|split| (split, best_score))
207    }
208
209    /// Calculate split quality score
210    fn calculate_split_score(left: &[f64], right: &[f64]) -> f64 {
211        let left_mean = left.iter().sum::<f64>() / left.len() as f64;
212        let right_mean = right.iter().sum::<f64>() / right.len() as f64;
213
214        let left_var =
215            left.iter().map(|v| (v - left_mean).powi(2)).sum::<f64>() / left.len() as f64;
216        let right_var =
217            right.iter().map(|v| (v - right_mean).powi(2)).sum::<f64>() / right.len() as f64;
218
219        // Score based on mean difference and within-segment variance
220        let mean_diff = (right_mean - left_mean).abs();
221        let avg_var = (left_var + right_var) / 2.0;
222
223        if avg_var > 0.0 {
224            mean_diff / avg_var.sqrt()
225        } else {
226            mean_diff
227        }
228    }
229
230    /// PELT algorithm for optimal breakpoint detection
231    fn pelt(ts: &TimeSeriesRaster, _penalty: f64) -> Result<Vec<BreakpointResult>> {
232        // Simplified PELT implementation
233        // Full PELT is complex - use binary segmentation as approximation
234        // Note: penalty parameter reserved for future full PELT implementation
235        info!("Using binary segmentation approximation for PELT");
236        Self::binary_segmentation(ts, 10, 3)
237    }
238
239    /// CUSUM-based breakpoint detection
240    fn cusum_breakpoint(ts: &TimeSeriesRaster, threshold: f64) -> Result<Vec<BreakpointResult>> {
241        if ts.len() < 3 {
242            return Err(TemporalError::insufficient_data(
243                "Need at least 3 observations",
244            ));
245        }
246
247        let (height, width, n_bands) = ts
248            .expected_shape()
249            .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
250
251        let mut results = Vec::new();
252
253        for i in 0..height {
254            for j in 0..width {
255                for k in 0..n_bands {
256                    let values = ts.extract_pixel_timeseries(i, j, k)?;
257                    let mean = values.iter().sum::<f64>() / values.len() as f64;
258
259                    let mut cusum = 0.0;
260                    let mut breakpoints = Vec::new();
261                    let mut scores = Vec::new();
262
263                    for (idx, &value) in values.iter().enumerate() {
264                        cusum += value - mean;
265
266                        if cusum.abs() > threshold {
267                            breakpoints.push(idx);
268                            scores.push(cusum.abs());
269                            cusum = 0.0; // Reset CUSUM
270                        }
271                    }
272
273                    if !breakpoints.is_empty() {
274                        results.push(BreakpointResult::new(breakpoints, scores));
275                    }
276                }
277            }
278        }
279
280        info!("Completed CUSUM breakpoint detection");
281        Ok(results)
282    }
283
284    /// Threshold crossing breakpoint detection
285    fn threshold_crossing(ts: &TimeSeriesRaster, threshold: f64) -> Result<Vec<BreakpointResult>> {
286        if ts.len() < 2 {
287            return Err(TemporalError::insufficient_data(
288                "Need at least 2 observations",
289            ));
290        }
291
292        let (height, width, n_bands) = ts
293            .expected_shape()
294            .ok_or_else(|| TemporalError::insufficient_data("No shape information"))?;
295
296        let mut results = Vec::new();
297
298        for i in 0..height {
299            for j in 0..width {
300                for k in 0..n_bands {
301                    let values = ts.extract_pixel_timeseries(i, j, k)?;
302
303                    let mut breakpoints = Vec::new();
304                    let mut scores = Vec::new();
305
306                    for idx in 1..values.len() {
307                        let diff = (values[idx] - values[idx - 1]).abs();
308                        if diff > threshold {
309                            breakpoints.push(idx);
310                            scores.push(diff);
311                        }
312                    }
313
314                    if !breakpoints.is_empty() {
315                        results.push(BreakpointResult::new(breakpoints, scores));
316                    }
317                }
318            }
319        }
320
321        info!("Completed threshold crossing breakpoint detection");
322        Ok(results)
323    }
324}
325
326/// Breakpoint detection parameters
327#[derive(Debug, Clone, Copy)]
328pub struct BreakpointParams {
329    /// Maximum number of breakpoints to detect
330    pub max_breakpoints: usize,
331    /// Minimum segment length
332    pub min_segment_length: usize,
333    /// Penalty for PELT
334    pub penalty: f64,
335    /// Threshold for threshold-based methods
336    pub threshold: f64,
337}
338
339impl Default for BreakpointParams {
340    fn default() -> Self {
341        Self {
342            max_breakpoints: 5,
343            min_segment_length: 3,
344            penalty: 1.0,
345            threshold: 1.0,
346        }
347    }
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353    use crate::timeseries::{TemporalMetadata, TimeSeriesRaster};
354    use chrono::{DateTime, NaiveDate};
355    use scirs2_core::ndarray::Array3;
356
357    #[test]
358    fn test_binary_segmentation() {
359        let mut ts = TimeSeriesRaster::new();
360
361        // Create data with a clear breakpoint
362        for i in 0..20 {
363            let dt = DateTime::from_timestamp(1640995200 + i * 86400, 0).expect("valid");
364            let date = NaiveDate::from_ymd_opt(2022, 1, 1 + i as u32).expect("valid");
365            let metadata = TemporalMetadata::new(dt, date);
366
367            let value = if i < 10 { 10.0 } else { 50.0 }; // Breakpoint at i=10
368            let data = Array3::from_elem((1, 1, 1), value);
369            ts.add_raster(metadata, data).expect("should add");
370        }
371
372        let params = BreakpointParams::default();
373        let results = BreakpointDetector::detect(&ts, BreakpointMethod::BinarySegmentation, params)
374            .expect("should detect");
375
376        // Should detect the breakpoint
377        assert!(!results.is_empty());
378    }
379
380    #[test]
381    fn test_threshold_crossing() {
382        let mut ts = TimeSeriesRaster::new();
383
384        for i in 0..10 {
385            let dt = DateTime::from_timestamp(1640995200 + i * 86400, 0).expect("valid");
386            let date = NaiveDate::from_ymd_opt(2022, 1, 1 + i as u32).expect("valid");
387            let metadata = TemporalMetadata::new(dt, date);
388
389            let value = if i == 5 { 100.0 } else { 10.0 };
390            let data = Array3::from_elem((1, 1, 1), value);
391            ts.add_raster(metadata, data).expect("should add");
392        }
393
394        let params = BreakpointParams {
395            threshold: 20.0,
396            ..Default::default()
397        };
398
399        let results = BreakpointDetector::detect(&ts, BreakpointMethod::ThresholdCrossing, params)
400            .expect("should detect");
401
402        assert!(!results.is_empty());
403    }
404}