Skip to main content

nabled_sensor/
kalman.rs

1//! Linear Kalman filter.
2
3use nabled_linalg::lu::LuProviderScalar;
4use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
5
6use crate::SensorError;
7
8/// Kalman filter state (mean and covariance).
9#[derive(Debug, Clone, PartialEq)]
10pub struct KalmanState<T> {
11    pub mean:       Array1<T>,
12    pub covariance: Array2<T>,
13}
14
15/// Predict step: x = F x, P = F P F' + Q.
16pub fn predict<T: LuProviderScalar>(
17    state: &KalmanState<T>,
18    f: &ArrayView2<'_, T>,
19    q: &ArrayView2<'_, T>,
20) -> Result<KalmanState<T>, SensorError> {
21    if state.mean.len() != f.nrows() || f.nrows() != f.ncols() {
22        return Err(SensorError::DimensionMismatch);
23    }
24    let mean = f.dot(&state.mean);
25    let covariance = f.dot(&state.covariance).dot(&f.t()) + q;
26    Ok(KalmanState { mean, covariance })
27}
28
29/// Predict into existing state.
30pub fn predict_into<T: LuProviderScalar>(
31    state: &mut KalmanState<T>,
32    f: &ArrayView2<'_, T>,
33    q: &ArrayView2<'_, T>,
34) -> Result<(), SensorError> {
35    *state = predict(state, f, q)?;
36    Ok(())
37}
38
39/// Update step with measurement z.
40pub fn update<T: LuProviderScalar>(
41    state: &KalmanState<T>,
42    z: &ArrayView1<'_, T>,
43    h: &ArrayView2<'_, T>,
44    r: &ArrayView2<'_, T>,
45) -> Result<KalmanState<T>, SensorError> {
46    let innovation = z - &h.dot(&state.mean);
47    let s = h.dot(&state.covariance).dot(&h.t()) + r;
48    let s_inv = nabled_linalg::lu::inverse(&s).map_err(|_| SensorError::NumericalInstability)?;
49    let k = state.covariance.dot(&h.t()).dot(&s_inv);
50    let mean = &state.mean + k.dot(&innovation);
51    let n = state.mean.len();
52    let identity = Array2::<T>::eye(n);
53    let covariance = (identity - k.dot(h)).dot(&state.covariance);
54    Ok(KalmanState { mean, covariance })
55}
56
57/// Update into existing state.
58pub fn update_into<T: LuProviderScalar>(
59    state: &mut KalmanState<T>,
60    z: &ArrayView1<'_, T>,
61    h: &ArrayView2<'_, T>,
62    r: &ArrayView2<'_, T>,
63) -> Result<(), SensorError> {
64    *state = update(state, z, h, r)?;
65    Ok(())
66}
67
68#[cfg(test)]
69mod tests {
70    use ndarray::arr2;
71
72    use super::*;
73
74    #[test]
75    fn kalman_fuses_measurement() {
76        let state =
77            KalmanState { mean: ndarray::arr1(&[0.0_f64]), covariance: arr2(&[[1.0]]) };
78        let f = arr2(&[[1.0]]);
79        let q = arr2(&[[0.01]]);
80        let predicted = predict(&state, &f.view(), &q.view()).unwrap();
81        let h = arr2(&[[1.0]]);
82        let r = arr2(&[[0.1]]);
83        let z = ndarray::arr1(&[1.0]);
84        let updated = update(&predicted, &z.view(), &h.view(), &r.view()).unwrap();
85        assert!(updated.mean[0] > 0.5);
86        assert!(updated.covariance[[0, 0]] < 1.0);
87    }
88
89    #[test]
90    fn predict_rejects_state_transition_dimension_mismatch() {
91        let state = KalmanState {
92            mean:       ndarray::arr1(&[0.0_f64, 1.0]),
93            covariance: arr2(&[[1.0, 0.0], [0.0, 1.0]]),
94        };
95        let f = arr2(&[[1.0]]);
96        let q = arr2(&[[0.01]]);
97        assert_eq!(predict(&state, &f.view(), &q.view()), Err(SensorError::DimensionMismatch));
98    }
99
100    #[test]
101    fn predict_and_update_into_reuse_buffers() {
102        let mut state =
103            KalmanState { mean: ndarray::arr1(&[0.0_f64]), covariance: arr2(&[[1.0]]) };
104        let f = arr2(&[[1.0]]);
105        let q = arr2(&[[0.01]]);
106        predict_into(&mut state, &f.view(), &q.view()).unwrap();
107        let h = arr2(&[[1.0]]);
108        let r = arr2(&[[0.1]]);
109        let z = ndarray::arr1(&[1.0]);
110        update_into(&mut state, &z.view(), &h.view(), &r.view()).unwrap();
111        assert!(state.mean[0] > 0.5);
112    }
113
114    #[test]
115    fn update_rejects_singular_innovation_covariance() {
116        let state =
117            KalmanState { mean: ndarray::arr1(&[0.0_f64]), covariance: arr2(&[[0.0]]) };
118        let h = arr2(&[[1.0]]);
119        let r = arr2(&[[0.0]]);
120        let z = ndarray::arr1(&[1.0]);
121        assert_eq!(
122            update(&state, &z.view(), &h.view(), &r.view()),
123            Err(SensorError::NumericalInstability)
124        );
125    }
126}