kalman_rust/kalman/kalman_1d.rs
1use std::fmt;
2use std::error::Error;
3use nalgebra;
4
5// Error struct for failed `nalgebra` operations
6#[derive(Debug)]
7pub struct Kalman1DError{typ: u16}
8impl fmt::Display for Kalman1DError {
9 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
10 match self.typ {
11 1 => write!(f, "Can inverse matrix"),
12 _ => write!(f, "Undefined error")
13 }
14 }
15}
16impl Error for Kalman1DError {}
17
18
19// Identity matrix. See the ref. https://en.wikipedia.org/wiki/Identity_matrix
20const I: nalgebra::SMatrix::<f32, 2, 2> = nalgebra::SMatrix::<f32, 2, 2>::new(
21 1.0, 0.0,
22 0.0, 1.0,
23);
24
25/// Implementation of Discrete Kalman filter for case when there is only on variable X.
26#[derive(Debug, Clone)]
27pub struct Kalman1D {
28 // Single cycle time
29 dt: f32,
30 // Control input
31 u: f32,
32 // Standart deviation of acceleration
33 std_dev_a: f32,
34 // Standart deviation of measurement
35 std_dev_m: f32,
36 // Transition matrix
37 A: nalgebra::SMatrix<f32, 2, 2>,
38 // Control matrix
39 B: nalgebra::SMatrix<f32, 2, 1>,
40 // Transformation (observation) matrix
41 H: nalgebra::SMatrix<f32, 1, 2>,
42 // Process noise covariance matrix
43 Q: nalgebra::SMatrix<f32, 2, 2>,
44 // Measurement noise covariance matrix
45 R: nalgebra::SMatrix<f32, 1, 1>,
46 // Error covariance matrix
47 P: nalgebra::SMatrix<f32, 2, 2>,
48 // State vector: x, vx
49 x: nalgebra::SVector<f32, 2>,
50}
51
52impl Kalman1D {
53 /// Creates new `Kalman1D`
54 ///
55 /// Basic usage:
56 ///
57 /// ```
58 /// use kalman_rust::kalman::Kalman1D;
59 /// let dt = 0.1; // Single cycle time
60 /// let u = 2.0; // Control input
61 /// let std_dev_a = 0.25; // Standart deviation of acceleration
62 /// let std_dev_m = 1.2; // Standart deviation of measurement
63 /// let mut kalman = Kalman1D::new(dt, u, std_dev_a, std_dev_m);
64 /// ```
65 pub fn new(dt: f32, u: f32, std_dev_a: f32, std_dev_m: f32) -> Self {
66 Kalman1D {
67 dt,
68 u,
69 std_dev_a,
70 std_dev_m,
71 // Ref.: Eq.(17)
72 A: nalgebra::SMatrix::<f32, 2, 2>::new(
73 1.0, dt,
74 0.0, 1.0,
75 ),
76 // Ref.: Eq.(18)
77 B: nalgebra::SMatrix::<f32, 2, 1>::new(
78 0.5 * dt.powi(2),
79 dt,
80 ),
81 // Ref.: Eq.(20)
82 H: nalgebra::SMatrix::<f32, 1, 2>::new(
83 1.0, 0.0,
84 ),
85 // Ref.: Eq.(25)
86 Q: nalgebra::SMatrix::<f32, 2, 2>::new(
87 0.25 * dt.powi(4), 0.5 * dt.powi(3),
88 0.5 * dt.powi(3), dt.powi(2),
89 )*std_dev_a.powi(2),
90 // Ref.: Eq.(26)
91 R: nalgebra::SMatrix::<f32, 1, 1>::new(
92 std_dev_m.powi(2),
93 ),
94 P: nalgebra::SMatrix::<f32, 2, 2>::new(
95 1.0, 0.0,
96 0.0, 1.0,
97 ),
98 x: nalgebra::SVector::<f32, 2>::new(
99 0.0,
100 0.0,
101 ),
102 }
103 }
104 /// Projects the state and the error covariance ahead
105 /// Mutates the state vector and the error covariance matrix
106 ///
107 /// Basic usage:
108 ///
109 /// ```
110 /// use kalman_rust::kalman::Kalman1D;
111 /// let dt = 0.1; // Single cycle time
112 /// let u = 2.0; // Control input
113 /// let std_dev_a = 0.25; // Standart deviation of acceleration
114 /// let std_dev_m = 1.2; // Standart deviation of measurement
115 /// let mut kalman = Kalman1D::new(dt, u, std_dev_a, std_dev_m);
116 /// let measurements = vec![1.0, 2.0, 3.0, 4.0, 5.0];
117 /// for x in measurements.iter() {
118 /// // get measurement
119 /// kalman.predict();
120 /// // then do update
121 /// }
122 /// ```
123 pub fn predict(&mut self) {
124 // Ref.: Eq.(5)
125 self.x = (self.A*self.x) + (self.B*self.u);
126 // Ref.: Eq.(6)
127 self.P = self.A*self.P*self.A.transpose() + self.Q;
128 }
129 /// Computes the Kalman gain and then updates the state vector and the error covariance matrix
130 /// Mutates the state vector and the error covariance matrix.
131 ///
132 /// Basic usage:
133 ///
134 /// ```
135 /// use kalman_rust::kalman::Kalman1D;
136 /// let dt = 0.1; // Single cycle time
137 /// let u = 2.0; // Control input
138 /// let std_dev_a = 0.25; // Standart deviation of acceleration
139 /// let std_dev_m = 1.2; // Standart deviation of measurement
140 /// let mut kalman = Kalman1D::new(dt, u, std_dev_a, std_dev_m);
141 /// let measurements = vec![1.0, 2.0, 3.0, 4.0, 5.0];
142 /// for x in measurements {
143 /// kalman.predict();
144 /// kalman.update(x).unwrap(); // assuming that there is noise in measurement
145 /// }
146 /// ```
147 pub fn update(&mut self, _z: f32) -> Result<(), Kalman1DError> {
148 // Ref.: Eq.(7)
149 let gain = match (self.H*self.P*self.H.transpose() + self.R).try_inverse() {
150 Some(inv) => self.P*self.H.transpose()*inv,
151 None => return Err(Kalman1DError{typ: 1}),
152 };
153 // Ref.: Eq.(8)
154 let z = nalgebra::SMatrix::<f32, 1, 1>::new(_z);
155 let r = z - self.H*self.x;
156 // Ref.: Eq.(9)
157 self.x = self.x + gain*r;
158 // Ref.: Eq.(10)
159 self.P = (I - gain*self.H)*self.P;
160 Ok(())
161 }
162 /// Returns the current state (only X, not Vx)
163 pub fn get_state(&self) -> f32 {
164 self.x[0]
165 }
166 /// Returns the current state (both X and Vx)
167 pub fn get_vector_state(&self) -> nalgebra::SVector::<f32, 2> {
168 self.x
169 }
170}
171
172fn float_loop(start: f32, threshold: f32, step_size: f32) -> impl Iterator<Item = f32> {
173 std::iter::successors(Some(start), move |&prev| {
174 let next = prev + step_size;
175 (next < threshold).then_some(next)
176 })
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182 use rand::prelude::*;
183 use rand_distr::StandardNormal;
184 #[test]
185 fn test_1d_kalman() {
186 // Just and adoptation of https://machinelearningspace.com/object-tracking-python/
187 let dt = 0.1;
188 let u = 2.0;
189 let std_dev_a = 0.25;
190 let std_dev_m = 1.2;
191
192 let t: nalgebra::SVector::<f32, 1000> = nalgebra::SVector::<f32, 1000>::from_iterator(float_loop(0.0, 100.0, dt));
193 // let t: = (0..100).map(|t| t as f32).collect();
194 let track = t.map(|t| dt*(t*t - t));
195
196 let mut kalman = Kalman1D::new(dt, u, std_dev_a, std_dev_m);
197 let mut measurement: Vec<f32> = vec![];
198 let mut predictions: Vec<f32>= vec![];
199 for (t, x) in t.iter().zip(track.iter()) {
200 // Add some noise to perfect track
201 let v: f32 = StdRng::from_os_rng().sample::<f32, StandardNormal>(StandardNormal) * (50.0+50.0) - 50.0; // Generate noise in [-50, 50)
202 let z = kalman.H.x * x + v;
203 measurement.push(z);
204
205 // Predict stage
206 kalman.predict();
207 let state = kalman.get_vector_state();
208 predictions.push(state.x);
209
210 // Update stage
211 kalman.update(z).unwrap();
212 }
213 // println!("time;perfect;measurement;prediction");
214 // for i in 0..track.len() {
215 // println!("{};{};{};{}", t[i], track[i], measurement[i], predictions[i]);
216 // }
217 }
218}