Skip to main content

auto_regressive/
lib.rs

1mod auto_correlation;
2mod yule_walker;
3
4#[derive(Debug)]
5pub struct AutoRegressiveModel {
6    coefficients: Vec<f64>,
7    noise_variance: f64,
8    noise_ratio: f64,
9}
10
11
12impl AutoRegressiveModel {
13    pub fn new_with_order(signal: &[f64], order: usize) -> Self {
14        assert!(signal.len() > 0);
15        assert!(order < signal.len());
16        assert!(order > 0);
17
18        let auto_correlation = auto_correlation::auto_correlation_fft(signal);
19        let result = yule_walker::yule_walker_from_auto_correlation(&auto_correlation[0..order + 1]);
20        Self {
21            coefficients: result.coefficients,
22            noise_variance: result.noise_variance / signal.len() as f64,
23            noise_ratio: result.noise_variance / auto_correlation[0],
24        }
25    }
26
27    pub fn coefficients(&self) -> &Vec<f64> {
28        &self.coefficients
29    }
30
31    pub fn noise_variance(&self) -> f64 {
32        self.noise_variance
33    }
34
35    pub fn noise_ratio(&self) -> f64 {
36        self.noise_ratio
37    }
38
39    pub fn predict(&self, length: usize, init: &[f64]) -> Vec<f64> {
40        let mut init_generated = Vec::with_capacity(self.coefficients.len());
41        if init.len() < self.coefficients.len() {
42            for _ in 0..(self.coefficients.len() - init.len()) {
43                init_generated.push(0.0);
44            }
45            for i in 0..init.len() {
46                init_generated.push(init[i]);
47            }
48        } else {
49            for i in 0..self.coefficients.len() {
50                init_generated.push(init[i]);
51            }
52        }
53        
54        for i in 0..length {
55            let mut sum = 0.0;
56            for j in 0..self.coefficients.len() {
57                sum += self.coefficients[j] * init_generated[i + self.coefficients.len() - j - 1];
58            }
59            init_generated.push(sum);
60        }
61        // TODO: inefficient to copy everything    
62        init_generated[self.coefficients.len()..].to_vec()
63    }
64}
65
66#[allow(unused)]
67#[cfg(test)]
68mod tests {
69    use std::f64::consts::PI;
70
71    use super::*;
72    use float_eq::assert_float_eq;
73    use rand::Rng;
74    use rand_distr::{Normal, Distribution};
75    use rustfft::{FftPlanner, num_complex::Complex};
76    use plotters::prelude::*;
77
78    #[test]
79    fn test_basic_model() {
80        let mut signal = vec![2.0, -1.0];
81        let coefficients = vec![-0.5, -1.0];
82        let len = signal.len();
83        for i in len..10000 {
84            signal.push(coefficients[0] * signal[i - 1] + coefficients[1] * signal[i - 2]);
85        }
86
87        let mut model = AutoRegressiveModel::new_with_order(&signal, 2);
88        assert_float_eq!(model.coefficients()[0], coefficients[0], abs <= 1e-3);
89        assert_float_eq!(model.coefficients()[1], coefficients[1], abs <= 1e-3);
90
91        model.coefficients = coefficients.clone();
92        let predicted = model.predict(1000 - len, &signal[0..len]);
93        assert_eq!(predicted.len(), 1000 - len);
94        for i in 0..predicted.len() {
95            assert_float_eq!(predicted[i], signal[i + len], abs <= 1e-3);
96        }
97    }
98
99    fn plot_series(series: &[f64]) {
100        let max = series.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
101        let min = series.iter().cloned().fold(f64::INFINITY, f64::min).min(0.0);
102
103        let mut plot = BitMapBackend::new("test_auto_regressive_model_known_input.png", (800, 600)).into_drawing_area();
104        plot.fill(&WHITE).unwrap();
105        let mut chart = ChartBuilder::on(&plot)
106            .caption("Power Spectral Density", ("sans-serif", 20).into_font())
107            .margin(5)
108            .x_label_area_size(40)
109            .y_label_area_size(40)
110            .build_cartesian_2d(0..series.len(), min..max).unwrap();
111        chart.configure_mesh().draw().unwrap();
112        chart.draw_series(LineSeries::new(series.iter().enumerate().map(|v| (v.0, *v.1)), BLACK)).unwrap();
113    }
114
115    #[test]
116    fn test_auto_regressive_model() {
117        let signal = vec![1.0, 2.0, 3.0, 4.0, 5.0];
118        let model = AutoRegressiveModel::new_with_order(&signal, 2);
119        let expected_coefficients = vec![232.0 / 285.0, -34.0 / 285.0];
120        let expected_noise_variance = 55.0 - 232.0 / 285.0 * 40.0 + 34.0 / 285.0 * 26.0;
121
122        assert_eq!(model.coefficients().len(), expected_coefficients.len());
123        for (actual, expected) in model.coefficients().iter().zip(expected_coefficients.iter()) {
124            assert_float_eq!(*actual, *expected, abs <= 1e-9);
125        }
126
127        assert_float_eq!(model.noise_variance, expected_noise_variance, abs <= 1e-9);
128    }
129    
130    #[test]
131    fn test_auto_regressive_model_known_input() {
132        let size = 100000;
133        let p = 2;
134        let mut rng = rand::thread_rng();
135        let normal = Normal::new(0.0, 0.5).unwrap();
136        let signal: Vec<f64> = (0..size).map(|_| normal.sample(&mut rng)).collect();
137        let a1 = 0.5;
138        let a2 = -0.3;
139        let mut ar_signal = vec![0.0; signal.len()];
140        ar_signal[0] = signal[0];
141        ar_signal[1] = signal[1];
142        for i in p..signal.len() {
143            ar_signal[i] = a1 * ar_signal[i - 1] + a2 * ar_signal[i - 2] + signal[i];
144        }
145
146        let model = AutoRegressiveModel::new_with_order(&ar_signal, p);
147        println!("{:?}", model);
148        assert_float_eq!(model.coefficients()[0], a1, abs <= 1e-2);
149        assert_float_eq!(model.coefficients()[1], a2, abs <= 1e-2);
150
151        // let mut fftplanner = FftPlanner::new();
152        // let fft = fftplanner.plan_fft_forward(size);
153        // let mut buffer = ar_signal.iter().map(|&value| Complex { re: value, im: 0.0 }).collect::<Vec<Complex<f64>>>();
154        // fft.process(&mut buffer);
155        // let power_spectral_density = buffer.iter().map(|value| value.norm_sqr()).collect::<Vec<f64>>();
156
157        //plot_series(&power_spectral_density);
158    }
159}