hrv_algos/preprocessing/
outliers.rs

1use anyhow::{anyhow, Result};
2use nalgebra::{DVector, DVectorView};
3use rayon::iter::{IndexedParallelIterator, IntoParallelRefIterator, ParallelIterator};
4#[cfg(feature = "serde")]
5use serde::{Deserialize, Serialize};
6
7/// Trait for outlier detection classifiers.
8///
9/// Outlier classifiers are used to detect, classify and filter outliers in a time series.
10pub trait OutlierClassifier {
11    /// Adds a slice of data to the classifier for outlier detection.
12    ///
13    /// # Arguments
14    ///
15    /// * `data` - A slice of f64 values to add to the classifier.
16    ///
17    /// # Returns
18    ///
19    /// A `Result` indicating success or failure of data (re-) classification.
20    fn add_data(&mut self, data: &[f64]) -> Result<()>;
21
22    /// Access to the classified time-series data.
23    ///
24    /// # Returns
25    ///
26    /// A slice of f64 values representing the data that was classified.
27    fn get_data(&self) -> &[f64];
28
29    /// Access to the classification of the time-series data.
30    ///
31    /// # Returns
32    ///
33    /// A slice of `OutlierType` values representing the classification of the data.
34    /// The slice has the same length as the data slice.
35    fn get_classification(&self) -> &[OutlierType];
36
37    /// Returns the filtered data without any outliers.
38    ///
39    /// # Returns
40    ///
41    /// A vector of f64 values representing the data without any outliers.
42    /// The data is filtered based on the classification of the time-series data.
43    fn get_filtered_data(&self) -> Vec<f64> {
44        self.get_data()
45            .par_iter()
46            .zip(self.get_classification())
47            .filter_map(
48                |(&val, class)| {
49                    if class.is_outlier() {
50                        None
51                    } else {
52                        Some(val)
53                    }
54                },
55            )
56            .collect()
57    }
58}
59
60/// Interface for outlier detection criteria used with moving window filters.
61pub trait OutlierCriterion {
62    /// Determines if a test value is acceptable based on a window of values.
63    ///
64    /// # Arguments
65    ///
66    /// * `testvalue` - The value to test for acceptance.
67    /// * `window` - A slice of values to use for comparison.
68    ///
69    /// # Returns
70    ///
71    /// A boolean indicating if the test value is acceptable.
72    fn is_acceptable(&self, testvalue: f64, window: &[f64]) -> bool;
73}
74
75/// Trait for implementing interpolation strategies
76pub trait Interpolator {
77    /// Interpolates a value based on a window of values.
78    ///
79    /// # Arguments
80    ///
81    /// * `window` - A slice of values to use for interpolation.
82    /// * `idx` - The index with respect to the window of the value to interpolate.
83    fn interpolate(&self, window: &[f64], idx: usize) -> Result<f64>;
84}
85
86/// Enum representing different types of outliers.
87///
88/// Outliers can be classified as ectopic, long, short, missed, or extra beats.
89#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
90#[derive(Debug, PartialEq, Clone, Copy)]
91pub enum OutlierType {
92    /// No outlier detected
93    None,
94    /// Ectopic beat
95    Ectopic,
96    /// Long beat
97    Long,
98    /// Short beat
99    Short,
100    /// Missed beat
101    Missed,
102    /// Extra beat
103    Extra,
104    /// Unspecified outlier type
105    Other,
106}
107
108impl OutlierType {
109    pub fn is_outlier(&self) -> bool {
110        !matches!(self, OutlierType::None)
111    }
112}
113
114/// Enum representing different interpolation methods.
115pub enum InterpolationMethod {
116    /// No interpolation
117    None,
118    /// Linear interpolation
119    Linear,
120    /// Custom interpolation strategy
121    Custom(Box<dyn Interpolator>),
122}
123
124impl Interpolator for InterpolationMethod {
125    fn interpolate(&self, window: &[f64], idx: usize) -> Result<f64> {
126        match self {
127            InterpolationMethod::None => Err(anyhow!("No interpolation method specified")),
128            InterpolationMethod::Linear => LinearInterpolation.interpolate(window, idx),
129            InterpolationMethod::Custom(interpolator) => interpolator.interpolate(window, idx),
130        }
131    }
132}
133
134/// Criterion for outlier detection based on a ratio of the value to the mean of a window.
135///
136/// The criterion is acceptable if the value is within the mean +/- ratio * mean.
137///
138/// # Arguments
139///
140/// * `ratio` - The ratio to use for comparison.
141pub struct ValueRatioCriterion {
142    pub ratio: f64,
143}
144
145impl OutlierCriterion for ValueRatioCriterion {
146    fn is_acceptable(&self, testvalue: f64, window: &[f64]) -> bool {
147        let data = DVectorView::from(window);
148        let mean = data.mean();
149        let delta = (testvalue - mean).abs();
150        let limit = mean * self.ratio;
151        delta <= limit
152    }
153}
154
155/// Criterion for outlier detection based on a ratio of the value to the standard deviation of a window.
156///
157/// The criterion is acceptable if the value is within the mean +/- ratio * standard deviation.
158///
159/// # Arguments
160///
161/// * `ratio` - The ratio to use for comparison.
162pub struct StdDevCriterion {
163    pub ratio: f64,
164}
165
166impl OutlierCriterion for StdDevCriterion {
167    fn is_acceptable(&self, testvalue: f64, window: &[f64]) -> bool {
168        let data = DVectorView::from(window);
169        let mean = data.mean();
170        let std_dev = data.variance().sqrt();
171        testvalue >= mean - std_dev * self.ratio && testvalue <= mean + std_dev * self.ratio
172    }
173}
174
175/// Criterion for outlier detection based on symmetric limits around the mean of a window.
176///
177/// The criterion is acceptable if the value is within the mean +/- limit.
178///
179/// # Arguments
180///
181/// * `limit` - The limit to use for comparison.
182pub struct SymmetricLimitsCriterion {
183    pub limit: f64,
184}
185
186impl OutlierCriterion for SymmetricLimitsCriterion {
187    fn is_acceptable(&self, testvalue: f64, window: &[f64]) -> bool {
188        let data = DVectorView::from(window);
189        let mean = data.mean();
190        testvalue >= mean - self.limit && testvalue <= mean + self.limit
191    }
192}
193
194/// Criterion for outlier detection based on upper and lower limits.
195///
196/// The criterion is acceptable if the value is within the lower and upper limits.
197///
198/// # Arguments
199///
200/// * `lower` - The lower limit.
201/// * `upper` - The upper limit.
202pub struct LimitsCriterion {
203    pub lower: f64,
204    pub upper: f64,
205}
206
207impl OutlierCriterion for LimitsCriterion {
208    fn is_acceptable(&self, testvalue: f64, _window: &[f64]) -> bool {
209        testvalue >= self.lower && testvalue <= self.upper
210    }
211}
212
213/// Linear interpolation between neighbors
214///
215/// The interpolated value is based on the linear interpolation of the two neighbors.
216/// If the index is at the beginning or end of the window, the value is extrapolated.
217///
218/// # Examples
219///
220/// ```
221/// use hrv_algos::preprocessing::outliers::LinearInterpolation;
222/// use hrv_algos::preprocessing::outliers::Interpolator;
223/// let window = vec![1.0, 2.0, 5.0];
224/// let interpolator = LinearInterpolation;
225/// // the interpolated value is the average of the neighbors
226/// assert_eq!(interpolator.interpolate(&window, 1).unwrap(), 3.0);
227/// //rhe extrapolated value is 2.0 - (5.0 - 2.0) = -1.0
228/// assert_eq!(interpolator.interpolate(&window, 0).unwrap(), -1.0);
229/// // the extrapolated value is 2.0 + (2.0 - 1.0) = 3.0
230/// assert_eq!(interpolator.interpolate(&window, 2).unwrap(), 3.0);
231/// ```
232pub struct LinearInterpolation;
233
234impl Interpolator for LinearInterpolation {
235    fn interpolate(&self, window: &[f64], idx: usize) -> Result<f64> {
236        if window.len() < 3 {
237            return Err(anyhow!("Window size must be at least 3"));
238        }
239        if idx >= window.len() {
240            return Err(anyhow!("Index out of bounds"));
241        }
242
243        match idx {
244            0 => Ok(2.0 * window[1] - window[2]),
245            i if i == window.len() - 1 => {
246                Ok(2.0 * window[window.len() - 2] - window[window.len() - 3])
247            }
248            i => Ok((window[i - 1] + window[i + 1]) / 2.0),
249        }
250    }
251}
252
253pub struct MovingWindowFilter {
254    rr_intervals: Vec<f64>,
255    rr_classification: Vec<OutlierType>,
256    criterion: Box<dyn OutlierCriterion>,
257    window_size: usize,
258}
259
260impl MovingWindowFilter {
261    pub fn new(criterion: Box<dyn OutlierCriterion>, window_size: usize) -> Self {
262        Self {
263            rr_intervals: Vec::new(),
264            rr_classification: Vec::new(),
265            criterion,
266            window_size: window_size.max(1),
267        }
268    }
269
270    pub fn update_classification(&mut self) -> Result<()> {
271        if self.rr_intervals.len() < self.window_size {
272            return Err(anyhow::anyhow!(
273                "Window size must be less than the signal length."
274            ));
275        }
276        let half_window = self.window_size / 2;
277        let siglen = self.rr_intervals.len();
278        self.rr_classification = (0..self.rr_intervals.len())
279            .map(|idx| {
280                let (window, _window_idx) = if idx < half_window {
281                    (&self.rr_intervals[0..self.window_size], idx)
282                } else if idx >= siglen - half_window {
283                    (
284                        &self.rr_intervals[siglen - self.window_size..siglen],
285                        self.window_size - (siglen - idx),
286                    )
287                } else {
288                    (
289                        &self.rr_intervals[idx - half_window..idx + half_window],
290                        half_window,
291                    )
292                };
293                if self.criterion.is_acceptable(self.rr_intervals[idx], window) {
294                    OutlierType::None
295                } else {
296                    OutlierType::Other
297                }
298            })
299            .collect();
300        Ok(())
301    }
302}
303
304impl OutlierClassifier for MovingWindowFilter {
305    fn add_data(&mut self, data: &[f64]) -> Result<()> {
306        self.rr_intervals.extend_from_slice(data);
307        self.update_classification()
308    }
309    fn get_data(&self) -> &[f64] {
310        &self.rr_intervals
311    }
312    fn get_classification(&self) -> &[OutlierType] {
313        &self.rr_classification
314    }
315}
316
317/// Computes a rolling quantile over a 1D time series.
318///
319/// # Arguments
320///
321/// * `signal` - Slice of input data to process.
322/// * `window_size` - Size of the rolling window. Considers both sides of current index.
323/// * `quantile` - Desired quantile in [0.0, 1.0].
324///
325/// # Returns
326///
327/// A vector where each element is the quantile value in the local window around that index.
328/// Returns an error if `quantile` is not in the range 0.0..=1.0.
329///
330/// # Errors
331///
332/// If `quantile` is not between 0.0 and 1.0, returns an error.
333fn rolling_quantile(signal: &[f64], window_size: usize, quantile: f64) -> Result<Vec<f64>> {
334    if !(0.0..=1.0).contains(&quantile) {
335        return Err(anyhow!("Quantile must be between 0 and 1"));
336    }
337    // ensure window size handles even and odd window sizes
338    let back_window = window_size / 2;
339    let fwd_window = window_size - back_window;
340
341    signal
342        .par_iter()
343        .enumerate()
344        .map(|(idx, _)| {
345            let start = idx.saturating_sub(back_window);
346            let end = signal.len().min(idx + fwd_window);
347            let mut window = signal[start..end].to_vec();
348            window.sort_by(|a, b| a.partial_cmp(b).unwrap());
349            let quantile_idx = ((window.len() - 1) as f64 * quantile).round() as usize;
350            Ok(window[quantile_idx])
351        })
352        .collect()
353}
354
355/// Detects artefacts (ectopic, long, short, missed, extra beats) in an RR interval series
356/// using rolling quantiles and threshold-based classification.
357///
358/// The algorithm is based on the Systole Python package by Legrand and Allen (2022).
359/// Link: https://github.com/embodied-computation-group/systole
360///
361/// # References
362///
363///  - Legrand, N. & Allen, M., (2022). Systole: A python package for cardiac signal synchrony and analysis. Journal of Open Source Software, 7(69), 3832, https://doi.org/10.21105/joss.03832
364#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
365#[derive(Debug, Clone)]
366pub struct MovingQuantileFilter {
367    rr_intervals: Vec<f64>,
368    rr_classification: Vec<OutlierType>,
369    slope: f64,
370    intercept: f64,
371    quantile_scale: f64,
372    median_window: usize,
373    threshold_window: usize,
374}
375
376impl MovingQuantileFilter {
377    /// Creates a new `MovingQuantileFilter` instance.
378    /// The default values for slope, intercept, and quantile scale are 0.13, 0.17, and 5.2, respectively.
379    /// The default window sizes for median and threshold calculations are 11 and 91, respectively.
380    ///
381    /// # Arguments
382    ///
383    /// * `slope` - The slope value for the threshold calculation. Default is 0.13.
384    /// * `intercept` - The intercept value for the threshold calculation. Default is 0.17.
385    /// * `quantile_scale` - The scaling factor for the threshold calculation. Default is 5.2.
386    ///
387    /// # Returns
388    ///
389    /// A new `MovingQuantileFilter` instance.
390    ///
391    pub fn new(slope: Option<f64>, intercept: Option<f64>, quantile_scale: Option<f64>) -> Self {
392        Self {
393            rr_intervals: Vec::new(),
394            rr_classification: Vec::new(),
395            slope: slope.unwrap_or(0.13),
396            intercept: intercept.unwrap_or(0.17),
397            quantile_scale: quantile_scale.unwrap_or(5.2),
398            median_window: 11,
399            threshold_window: 91,
400        }
401    }
402
403    /// Sets the slope value for the threshold comparison.
404    ///
405    /// # Arguments
406    ///
407    /// * `slope` - The slope value for the threshold calculation.
408    ///
409    /// # Returns
410    ///
411    /// A `Result` indicating success or failure of reclassification.
412    pub fn set_slope(&mut self, slope: f64) -> Result<()> {
413        self.slope = slope;
414        self.rr_classification.clear();
415        self.update_classification()
416    }
417
418    /// Sets the intercept value for the threshold comparison.
419    ///
420    /// # Arguments
421    ///
422    /// * `intercept` - The intercept value for the threshold calculation.
423    ///
424    /// # Returns
425    ///
426    /// A `Result` indicating success or failure of reclassification.
427    pub fn set_intercept(&mut self, intercept: f64) -> Result<()> {
428        self.intercept = intercept;
429        self.rr_classification.clear();
430        self.update_classification()
431    }
432
433    /// Sets the quantile scale value for the threshold comparison.
434    ///
435    /// # Arguments
436    ///
437    /// * `quantile_scale` - The scaling factor for the threshold calculation.
438    ///
439    /// # Returns
440    ///
441    /// A `Result` indicating success or failure of reclassification.
442    pub fn set_quantile_scale(&mut self, quantile_scale: f64) -> Result<()> {
443        self.quantile_scale = quantile_scale;
444        self.rr_classification.clear();
445        self.update_classification()
446    }
447
448    pub fn get_slope(&self) -> f64 {
449        self.slope
450    }
451
452    pub fn get_intercept(&self) -> f64 {
453        self.intercept
454    }
455    pub fn get_quantile_scale(&self) -> f64 {
456        self.quantile_scale
457    }
458
459    fn update_classification(&mut self) -> Result<()> {
460        // classification uses a 91 item rolling quantile
461        // take 91 last rr, update the last 46 elements classification
462        let win_start = self
463            .rr_classification
464            .len()
465            .saturating_sub(self.threshold_window);
466        let cutoff = self
467            .rr_classification
468            .len()
469            .saturating_sub(self.threshold_window / 2);
470        let added_rr = self
471            .rr_intervals
472            .len()
473            .saturating_sub(self.rr_classification.len());
474        let data = &self.rr_intervals[win_start..];
475        if data.is_empty() {
476            return Ok(());
477        }
478        let new_class = if data.len() == 1 {
479            vec![OutlierType::None]
480        } else {
481            self.classify_rr_values(data)?
482        };
483
484        let mut added_classes = new_class[new_class.len().saturating_sub(added_rr)..].to_vec();
485        self.rr_classification.append(&mut added_classes);
486        let to_update = self.rr_classification.len().saturating_sub(cutoff);
487        let new_class_skip = new_class.len().saturating_sub(to_update);
488        //  update the last 46 elements classification
489        for (a, b) in self
490            .rr_classification
491            .iter_mut()
492            .skip(cutoff)
493            .zip(new_class.iter().skip(new_class_skip))
494        {
495            *a = *b;
496        }
497        Ok(())
498    }
499
500    /// Calculates the threshold for RR interval artefact detection.
501    ///
502    /// The threshold is calculated as the difference between the 75th and 25th percentiles
503    /// of the RR interval series.
504    ///
505    /// # Arguments
506    ///
507    /// * `signal` - Slice of RR intervals in milliseconds.
508    /// * `quantile_scale` - Scaling factor for threshold calculations.
509    fn calc_rr_threshold(&self, signal: &[f64]) -> Result<DVector<f64>> {
510        let first_quantile: DVector<f64> =
511            rolling_quantile(signal, self.threshold_window, 0.25)?.into();
512        let third_quantile: DVector<f64> =
513            rolling_quantile(signal, self.threshold_window, 0.75)?.into();
514        let threshold = (third_quantile - first_quantile) * (self.quantile_scale / 2.0);
515        Ok(threshold)
516    }
517
518    /// Classifies RR intervals as ectopic, long, short, missed, or extra beats.
519    ///
520    /// The classification is based on the RR interval series and the calculated thresholds.
521    ///
522    /// # Arguments
523    ///
524    /// * `rr` - Slice of RR intervals in milliseconds.
525    ///
526    /// # Returns
527    ///
528    /// A vector of `OutlierType` values indicating the presence of outliers.
529    fn classify_rr_values(&self, rr: &[f64]) -> Result<Vec<OutlierType>> {
530        if rr.len() < 2 {
531            return Err(anyhow!("RR intervals must have at least 2 elements"));
532        }
533
534        let drr = {
535            let drr: DVector<f64> =
536                DVector::from_iterator(rr.len() - 1, rr.windows(2).map(|w| w[1] - w[0]));
537            let mean_drr = drr.mean();
538            drr.insert_row(0, mean_drr)
539        };
540
541        // Rolling quantile calculations (q1 and q3)
542        let first_threshold = self.calc_rr_threshold(drr.as_slice())?;
543        // Calculate median RR (mRR)
544        let med_rr = DVector::from(rolling_quantile(rr, self.median_window, 0.5)?);
545        let med_rr_deviation: DVector<f64> = DVector::from_iterator(
546            rr.len(),
547            rr.iter().zip(med_rr.iter()).map(|(&rr, &med)| {
548                let val = rr - med;
549                if val < 0.0 {
550                    val * 2.0
551                } else {
552                    val
553                }
554            }),
555        );
556
557        // Calculate second threshold (th2)
558        let second_threshold = self.calc_rr_threshold(med_rr_deviation.as_slice())?;
559
560        let normalized_mrr: DVector<f64> = med_rr_deviation.component_div(&second_threshold);
561
562        // Decision
563        let mean_rr = rr.iter().copied().sum::<f64>() / rr.len() as f64;
564        let result = rr
565            .par_iter()
566            .enumerate()
567            .map(|(idx, &rr_val)| {
568                let drr_val = drr[idx];
569                let nmrr = normalized_mrr[idx];
570                let th1_val = first_threshold[idx];
571                let th2_val = second_threshold[idx];
572
573                let s11 = drr_val / th1_val;
574
575                let s12 = if idx == 0 || idx == rr.len() - 1 {
576                    0.0
577                } else {
578                    let ma = drr[idx - 1].max(drr[idx + 1]);
579                    let mi = drr[idx - 1].min(drr[idx + 1]);
580                    if drr_val < 0.0 {
581                        mi
582                    } else {
583                        ma
584                    }
585                };
586
587                let s22 = if idx >= rr.len() - 2 {
588                    0.0
589                } else {
590                    let ma = drr[idx + 1].max(drr[idx + 2]);
591                    let mi = drr[idx + 1].min(drr[idx + 2]);
592                    if drr_val >= 0.0 {
593                        mi
594                    } else {
595                        ma
596                    }
597                };
598
599                let ectopic = (s11 > 1.0 && s12 < (-self.slope * s11 - self.intercept))
600                    || (s11 < -1.0 && s12 > (-self.slope * s11 + self.intercept));
601
602                if ectopic {
603                    return OutlierType::Ectopic;
604                }
605                let long = ((s11 > 1.0 && s22 < -1.0) || (nmrr.abs() > 3.0 && rr_val > mean_rr))
606                    && !ectopic;
607
608                let short = ((s11 < -1.0 && s22 > 1.0) || (nmrr.abs() > 3.0 && rr_val <= mean_rr))
609                    && !ectopic;
610
611                let missed = long && ((rr_val / 2.0 - med_rr[idx]).abs() < th2_val);
612                if missed {
613                    return OutlierType::Missed;
614                }
615                if long {
616                    return OutlierType::Long;
617                }
618                let extra = short
619                    && ((rr_val + rr.get(idx + 1).unwrap_or(&0.0) - med_rr[idx]).abs() < th2_val);
620                if extra {
621                    return OutlierType::Extra;
622                }
623                if short {
624                    return OutlierType::Short;
625                }
626                OutlierType::None
627            })
628            .collect();
629
630        Ok(result)
631    }
632}
633
634impl OutlierClassifier for MovingQuantileFilter {
635    fn add_data(&mut self, data: &[f64]) -> Result<()> {
636        self.rr_intervals.extend_from_slice(data);
637        self.update_classification()
638    }
639
640    fn get_data(&self) -> &[f64] {
641        &self.rr_intervals
642    }
643
644    fn get_classification(&self) -> &[OutlierType] {
645        &self.rr_classification
646    }
647}
648
649#[cfg(test)]
650mod tests {
651    use rand::{Rng, SeedableRng};
652
653    use super::*;
654
655    /// Generates a random signal with a given length.
656    /// the signal is generated using a fixed seed for reproducibility.
657    /// The signal is pseudorandom within 990 and 1010.
658    ///
659    /// # Arguments
660    ///
661    /// * `len` - The length of the signal to generate.
662    ///
663    /// # Returns
664    ///
665    /// A vector of f64 values representing the generated signal.
666    ///
667    /// # Examples
668    ///
669    /// ```rust
670    /// use hrv_algos::preprocessing::outliers::get_signal;
671    /// let signal = get_signal(100);
672    /// assert_eq!(signal.len(), 100);
673    /// assert!(signal.iter().all(|&val| val > 990.0 && val < 1010.0));
674    /// ```
675    fn get_signal(len: usize) -> Vec<f64> {
676        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
677        (0..len)
678            .map(|_| 1000.0 + rng.gen_range(-10.0..10.0))
679            .collect()
680    }
681
682    #[test]
683    fn test_interpolation_method_enum() {
684        let window = vec![1.0, 2.0, 3.0];
685        let interpolator = InterpolationMethod::None;
686        assert!(interpolator.interpolate(&window, 1).is_err());
687        let interpolator = InterpolationMethod::Linear;
688        assert_eq!(interpolator.interpolate(&window, 1).unwrap(), 2.0);
689        let interpolator = InterpolationMethod::Custom(Box::new(LinearInterpolation));
690        assert_eq!(interpolator.interpolate(&window, 1).unwrap(), 2.0);
691    }
692
693    #[test]
694    fn test_symmetric_limits_acceptable() {
695        let criterion = SymmetricLimitsCriterion { limit: 0.2 };
696        let window = vec![1.0, 1.1, 0.9, 1.1, 0.9];
697        assert!(criterion.is_acceptable(1.15, &window));
698        assert!(!criterion.is_acceptable(1.3, &window));
699    }
700
701    #[test]
702    fn test_limits_acceptable() {
703        let criterion = LimitsCriterion {
704            lower: 0.8,
705            upper: 1.2,
706        };
707        let window = vec![1.0, 1.1, 0.9, 1.1, 0.9];
708        assert!(criterion.is_acceptable(1.1, &window));
709        assert!(!criterion.is_acceptable(1.3, &window));
710    }
711    #[test]
712    fn test_value_ratio_acceptable() {
713        let criterion = ValueRatioCriterion { ratio: 0.01 };
714        let window = get_signal(5);
715        let mean = window.iter().sum::<f64>() / window.len() as f64;
716        let thr = mean * 0.01;
717        let eps = f64::EPSILON * 1e3;
718        assert!(criterion.is_acceptable(mean + thr - eps, &window));
719        assert!(criterion.is_acceptable(mean - thr + eps, &window));
720        assert!(!criterion.is_acceptable(mean + thr + eps, &window));
721        assert!(!criterion.is_acceptable(mean - thr - eps, &window));
722    }
723
724    #[test]
725    fn test_std_dev_acceptable() {
726        let criterion = StdDevCriterion { ratio: 2.0 };
727        let window = get_signal(5);
728        let data = DVectorView::from(window.as_slice());
729        let mean = data.mean();
730        let std_dev = data.variance().sqrt();
731        let eps = f64::EPSILON * 1e3;
732        assert!(criterion.is_acceptable(mean + 2.0 * std_dev - eps, &window));
733        assert!(criterion.is_acceptable(mean - 2.0 * std_dev + eps, &window));
734        assert!(!criterion.is_acceptable(mean - 2.0 * std_dev - eps, &window));
735        assert!(!criterion.is_acceptable(mean + 2.0 * std_dev + eps, &window));
736    }
737    #[test]
738    fn test_moving_window_filter_with_symmetric_limits() {
739        let signal = get_signal(16);
740        let criterion = SymmetricLimitsCriterion { limit: 10. };
741        let mut filter = MovingWindowFilter::new(Box::new(criterion), 3);
742        assert!(filter.add_data(&signal).is_ok());
743        let classes = filter.get_classification();
744        assert!(classes.iter().all(|&outlier| !outlier.is_outlier()));
745    }
746
747    #[test]
748    fn test_moving_window_filter_with_limits() {
749        let signal = get_signal(16);
750        let criterion = LimitsCriterion {
751            lower: 990.0,
752            upper: 1010.0,
753        };
754        let mut filter = MovingWindowFilter::new(Box::new(criterion), 3);
755        assert!(filter.add_data(&signal).is_ok());
756        let classes = filter.get_classification();
757        assert!(classes.iter().all(|&outlier| !outlier.is_outlier()));
758    }
759
760    #[test]
761    fn test_moving_window_filter_with_outliers() {
762        let mut signal = get_signal(16);
763        let criterion = SymmetricLimitsCriterion { limit: 15. };
764        signal[3] = 980.0;
765        signal[4] = 1020.0;
766        let mut filter = MovingWindowFilter::new(Box::new(criterion), 5);
767        assert!(filter.add_data(&signal).is_ok());
768        let classes = filter.get_classification();
769        assert!(classes.iter().enumerate().all(|(idx, &outlier)| {
770            if idx == 3 || idx == 4 {
771                outlier.is_outlier()
772            } else {
773                !outlier.is_outlier()
774            }
775        }));
776    }
777
778    #[test]
779    fn test_rr_outliers() {
780        let mut signal = get_signal(128);
781        signal[50] = 1100.0;
782        let mut filter = MovingQuantileFilter::new(None, None, None);
783        assert!(filter.add_data(&signal).is_ok());
784        let classes = filter.get_classification();
785        assert!(classes.iter().enumerate().all(|(idx, &outlier)| {
786            if idx == 50 {
787                outlier.is_outlier()
788            } else {
789                !outlier.is_outlier()
790            }
791        }));
792    }
793
794    #[test]
795    fn test_moving_window_filter_even() {
796        let signal = get_signal(16);
797        let criterion = ValueRatioCriterion { ratio: 0.2 };
798        let mut filter = MovingWindowFilter::new(Box::new(criterion), 4);
799        assert!(filter.add_data(&signal).is_ok());
800        let classes = filter.get_classification();
801        assert_eq!(filter.get_data().len(), classes.len());
802        assert_eq!(signal.len(), classes.len());
803        assert!(classes.iter().all(|&outlier| !outlier.is_outlier()));
804    }
805
806    #[test]
807    fn test_moving_window_filter_with_std_dev() {
808        let signal = get_signal(16);
809        let criterion = StdDevCriterion { ratio: 2.0 };
810        let mut filter = MovingWindowFilter::new(Box::new(criterion), 3);
811        assert!(filter.add_data(&signal).is_ok());
812        let classes = filter.get_classification();
813        assert_eq!(filter.get_data().len(), classes.len());
814        assert_eq!(signal.len(), classes.len());
815        assert!(classes.iter().all(|&outlier| !outlier.is_outlier()));
816    }
817
818    #[test]
819    fn test_moving_window_filter_with_small_window() {
820        let signal = get_signal(3);
821        let criterion = ValueRatioCriterion { ratio: 0.2 };
822        let mut filter = MovingWindowFilter::new(Box::new(criterion), 5);
823        assert!(filter.add_data(&signal).is_err());
824    }
825    #[test]
826    fn test_linear_interpolation() {
827        let window = vec![1.0, 2.0, 3.0];
828        let interpolator = LinearInterpolation;
829        assert_eq!(interpolator.interpolate(&window, 1).unwrap(), 2.0);
830        assert_eq!(interpolator.interpolate(&window, 0).unwrap(), 1.0);
831        assert_eq!(interpolator.interpolate(&window, 2).unwrap(), 3.0);
832    }
833
834    #[test]
835    fn test_linear_interpolation_out_of_bounds() {
836        let window = get_signal(4);
837        let interpolation = LinearInterpolation;
838        assert!(interpolation.interpolate(&window, 5).is_err());
839    }
840
841    #[test]
842    fn test_linear_interpolation_window_too_small() {
843        let window = get_signal(2);
844        let interpolation = LinearInterpolation;
845        assert!(interpolation.interpolate(&window, 0).is_err());
846    }
847    #[test]
848    fn test_moving_quantile_filter() {
849        let signal = get_signal(128);
850        let mut filter = MovingQuantileFilter::new(None, None, None);
851        assert!(filter.add_data(&signal).is_ok());
852        let classification = filter.get_classification();
853        assert_eq!(classification.len(), signal.len());
854        assert_eq!(filter.get_data().len(), signal.len());
855        assert!(classification
856            .iter()
857            .all(|&outlier| { !outlier.is_outlier() }));
858    }
859
860    #[test]
861    fn test_moving_quantile_filter_set_intercept_inf() {
862        let signal = get_signal(128);
863        let mut filter = MovingQuantileFilter::new(None, None, None);
864        assert!(filter.add_data(&signal).is_ok());
865        assert_eq!(filter.get_classification().len(), filter.get_data().len());
866        assert!(filter
867            .get_classification()
868            .iter()
869            .all(|&outlier| { !outlier.is_outlier() }));
870        assert_eq!(filter.get_intercept(), 0.17);
871        assert!(filter.set_intercept(0.0).is_ok());
872        assert_eq!(filter.get_intercept(), 0.0);
873        assert_eq!(filter.get_classification().len(), filter.get_data().len());
874        assert_eq!(filter.get_slope(), 0.13);
875        assert!(filter.set_slope(0.0).is_ok());
876        assert_eq!(filter.get_slope(), 0.0);
877        assert_eq!(filter.get_classification().len(), filter.get_data().len());
878
879        assert_eq!(filter.get_quantile_scale(), 5.2);
880        assert!(filter.set_quantile_scale(0.0).is_ok());
881        assert_eq!(filter.get_quantile_scale(), 0.0);
882        assert_eq!(filter.get_classification().len(), filter.get_data().len());
883
884        assert!(filter
885            .get_classification()
886            .iter()
887            .take(filter.get_classification().len() - 1)
888            .any(|&outlier| { outlier.is_outlier() }));
889    }
890
891    #[test]
892    fn test_moving_quantile_filter_add_data() {
893        let signal: Vec<f64> = get_signal(1010);
894        let first_signal = &signal[0..1000];
895        let mut second_signal = signal[1000..].to_vec();
896        second_signal[5] = 1500.0;
897        let mut filter = MovingQuantileFilter::new(None, None, None);
898        assert!(filter.add_data(first_signal).is_ok());
899        assert_eq!(filter.get_classification().len(), filter.get_data().len());
900        assert!(filter
901            .get_classification()
902            .iter()
903            .all(|&outlier| { !outlier.is_outlier() }));
904        assert!(filter.add_data(&second_signal).is_ok());
905        assert_eq!(filter.get_classification().len(), filter.get_data().len());
906        assert!(filter
907            .get_classification()
908            .iter()
909            .take(filter.get_classification().len() - 56)
910            .all(|&outlier| { !outlier.is_outlier() }));
911        assert!(filter.get_classification()[1005].is_outlier());
912    }
913    #[test]
914    fn test_moving_quantile_filter_add_empty_data() {
915        let mut filter = MovingQuantileFilter::new(None, None, None);
916        assert!(filter.add_data(&[]).is_ok());
917        assert_eq!(filter.get_classification().len(), filter.get_data().len());
918        assert_eq!(filter.get_classification().len(), 0);
919    }
920    #[test]
921    fn test_moving_quantile_filter_add_single_data() {
922        let signal = get_signal(1);
923        let mut filter = MovingQuantileFilter::new(None, None, None);
924        assert!(filter.add_data(&signal).is_ok());
925        assert_eq!(filter.get_classification().len(), filter.get_data().len());
926        filter.get_classification().iter().for_each(|&outlier| {
927            assert!(!outlier.is_outlier());
928        });
929        assert!(filter.add_data(&[1000.0]).is_ok());
930        assert_eq!(filter.get_classification().len(), filter.get_data().len());
931        assert_eq!(filter.get_classification().len(), 2);
932        filter.get_classification().iter().for_each(|&outlier| {
933            assert!(!outlier.is_outlier());
934        });
935    }
936    #[test]
937    fn test_moving_quantile_filter_add_data_missed() {
938        let signal = get_signal(1010);
939        let mut new_data: Vec<f64> = signal[1000..].to_vec();
940        new_data[5] = 2000.0;
941        let signal = &signal[0..1000];
942        let mut filter = MovingQuantileFilter::new(None, None, None);
943        assert!(filter.add_data(signal).is_ok());
944        assert_eq!(filter.get_classification().len(), filter.get_data().len());
945        assert!(filter
946            .get_classification()
947            .iter()
948            .all(|&outlier| { !outlier.is_outlier() }));
949        assert!(filter.add_data(&new_data).is_ok());
950        assert_eq!(filter.get_classification().len(), filter.get_data().len());
951        assert!(filter
952            .get_classification()
953            .iter()
954            .take(signal.len())
955            .all(|&outlier| { !outlier.is_outlier() }));
956        assert!(filter.get_classification()[1005].is_outlier());
957        assert!(matches!(
958            filter.get_classification()[1005],
959            OutlierType::Missed
960        ));
961        let filtered = filter.get_filtered_data();
962        assert!(filtered.iter().all(|&val| (val - 1000.0).abs() <= 20.0));
963    }
964
965    #[test]
966    fn test_moving_quantile_filter_add_data_extra() {
967        let signal = get_signal(1010);
968        let mut new_data: Vec<f64> = signal[1000..].to_vec();
969        new_data[5] = 20.0;
970        let signal = &signal[0..1000];
971        let mut filter = MovingQuantileFilter::new(None, None, None);
972        assert!(filter.add_data(signal).is_ok());
973        assert_eq!(filter.get_classification().len(), filter.get_data().len());
974        assert!(filter
975            .get_classification()
976            .iter()
977            .all(|&outlier| { !outlier.is_outlier() }));
978        assert!(filter.add_data(&new_data).is_ok());
979        assert_eq!(filter.get_classification().len(), filter.get_data().len());
980        assert!(filter
981            .get_classification()
982            .iter()
983            .take(filter.get_classification().len() - 56)
984            .all(|&outlier| { !outlier.is_outlier() }));
985        assert!(filter.get_classification()[1005].is_outlier());
986        assert!(matches!(
987            filter.get_classification()[1005],
988            OutlierType::Extra
989        ));
990        let filtered = filter.get_filtered_data();
991        assert!(filtered.iter().all(|&val| (val - 1000.0).abs() <= 20.0));
992    }
993}