1pub mod univariate {
18 use num_traits::Float;
19 use rand_distr::{Distribution, StandardNormal};
20
21 pub struct Autoregressive<F, const N: usize>
22 where
23 F: Float,
24 StandardNormal: Distribution<F>,
25 {
26 c: F,
27 x: [F; N],
28 phi: [F; N],
29 noise: rand_distr::Normal<F>,
30 }
31
32 impl<F, const N: usize> Autoregressive<F, N>
33 where
34 F: Float + std::iter::Sum,
35 StandardNormal: Distribution<F>,
36 {
37 pub fn new(c: F, noise_variance: F, phi: &[F; N]) -> Self {
42 let x = [num_traits::identities::zero(); N];
43 let noise =
44 rand_distr::Normal::new(num_traits::identities::zero(), noise_variance).unwrap();
45 Self {
46 c,
47 phi: *phi,
48 x,
49 noise,
50 }
51 }
52
53 pub fn step(&mut self) -> F {
55 let mut rng = rand::rng();
56 let epsilon: F = self.noise.sample(&mut rng);
57 let new_x = self.c
58 + self
59 .x
60 .iter()
61 .zip(self.phi.iter())
62 .map(|(x, p)| *x * *p)
63 .sum::<F>()
64 + epsilon;
65 if !self.x.is_empty() {
66 self.x.rotate_right(1);
67 self.x[0] = new_x;
68 }
69 new_x
70 }
71 }
72
73 impl<F, const N: usize> Iterator for Autoregressive<F, N>
74 where
75 F: Float + std::iter::Sum,
76 StandardNormal: Distribution<F>,
77 {
78 type Item = F;
79
80 fn next(&mut self) -> Option<Self::Item> {
81 Some(self.step())
82 }
83 }
84}
85
86#[cfg(test)]
87mod test {
88 #[test]
89 fn bounded() {
90 const NUM: usize = 1_000_000;
91
92 let ar = super::univariate::Autoregressive::new(0.0, 1.0, &[]);
93 let avg = ar.take(NUM).sum::<f32>() / (NUM as f32);
94 assert!(avg.abs() < 1.0);
95
96 let ar = super::univariate::Autoregressive::new(0.0, 1.0, &[0.3]);
97 let avg = ar.take(NUM).sum::<f32>() / (NUM as f32);
98 assert!(avg.abs() < 1.0);
99
100 let ar = super::univariate::Autoregressive::new(0.0, 1.0, &[0.9]);
101 let avg = ar.take(NUM).sum::<f32>() / (NUM as f32);
102 assert!(avg.abs() < 1.0);
103
104 let ar = super::univariate::Autoregressive::new(0.0, 1.0, &[0.3, 0.3]);
105 let avg = ar.take(NUM).sum::<f32>() / (NUM as f32);
106 assert!(avg.abs() < 1.0);
107
108 let ar = super::univariate::Autoregressive::new(0.0, 1.0, &[0.9, -0.8]);
109 let avg = ar.take(NUM).sum::<f32>() / (NUM as f32);
110 assert!(avg.abs() < 1.0);
111 }
112}