Skip to main content

nabled_sensor/
ekf.rs

1//! Extended Kalman filter.
2
3use nabled_core::scalar::NabledReal;
4use nabled_linalg::lu::{self, LuProviderScalar};
5use ndarray::{Array1, Array2, ArrayView1};
6
7use crate::SensorError;
8use crate::kalman::KalmanState;
9
10/// Nonlinear EKF model callbacks.
11#[derive(Clone)]
12pub struct EkModel<T> {
13    pub predict_state:    fn(&ArrayView1<'_, T>) -> Array1<T>,
14    pub predict_jacobian: fn(&ArrayView1<'_, T>) -> Array2<T>,
15    pub measure:          fn(&ArrayView1<'_, T>) -> Array1<T>,
16    pub measure_jacobian: fn(&ArrayView1<'_, T>) -> Array2<T>,
17}
18
19#[derive(Debug, Clone)]
20pub struct EkConfig<T> {
21    pub process_noise:     Array2<T>,
22    pub measurement_noise: Array2<T>,
23}
24
25pub fn ekf_predict<T: NabledReal>(
26    state: &KalmanState<T>,
27    model: &EkModel<T>,
28    config: &EkConfig<T>,
29) -> Result<KalmanState<T>, SensorError> {
30    let mut mean = Array1::zeros(state.mean.len());
31    let mut covariance = Array2::zeros(state.covariance.dim());
32    ekf_predict_into(state, model, config, &mut mean, &mut covariance)?;
33    Ok(KalmanState { mean, covariance })
34}
35
36pub fn ekf_predict_into<T: NabledReal>(
37    state: &KalmanState<T>,
38    model: &EkModel<T>,
39    config: &EkConfig<T>,
40    mean_out: &mut Array1<T>,
41    covariance_out: &mut Array2<T>,
42) -> Result<(), SensorError> {
43    let f = (model.predict_jacobian)(&state.mean.view());
44    let predicted_mean = (model.predict_state)(&state.mean.view());
45    if mean_out.len() != predicted_mean.len() || covariance_out.dim() != state.covariance.dim() {
46        return Err(SensorError::DimensionMismatch);
47    }
48    mean_out.assign(&predicted_mean);
49    *covariance_out = f.dot(&state.covariance).dot(&f.t()) + &config.process_noise;
50    Ok(())
51}
52
53pub fn ekf_update<T: NabledReal + LuProviderScalar>(
54    state: &KalmanState<T>,
55    measurement: &ArrayView1<'_, T>,
56    model: &EkModel<T>,
57    config: &EkConfig<T>,
58) -> Result<KalmanState<T>, SensorError> {
59    let mut mean = state.mean.clone();
60    let mut covariance = state.covariance.clone();
61    ekf_update_into(state, measurement, model, config, &mut mean, &mut covariance)?;
62    Ok(KalmanState { mean, covariance })
63}
64
65pub fn ekf_update_into<T: NabledReal + LuProviderScalar>(
66    state: &KalmanState<T>,
67    measurement: &ArrayView1<'_, T>,
68    model: &EkModel<T>,
69    config: &EkConfig<T>,
70    mean_out: &mut Array1<T>,
71    covariance_out: &mut Array2<T>,
72) -> Result<(), SensorError> {
73    let h = (model.measure_jacobian)(&state.mean.view());
74    let predicted = (model.measure)(&state.mean.view());
75    let innovation = measurement - &predicted;
76    let s = h.dot(&state.covariance).dot(&h.t()) + &config.measurement_noise;
77    let s_inv = lu::inverse(&s).map_err(|_| SensorError::NumericalInstability)?;
78    let k = state.covariance.dot(&h.t()).dot(&s_inv);
79    mean_out.assign(&(state.mean.clone() + k.dot(&innovation)));
80    let n = state.mean.len();
81    let identity = Array2::<T>::eye(n);
82    *covariance_out = (identity - k.dot(&h)).dot(&state.covariance);
83    Ok(())
84}
85
86#[cfg(test)]
87mod tests {
88    use approx::assert_relative_eq;
89    use ndarray::arr1;
90
91    use super::*;
92
93    fn scalar_model() -> EkModel<f64> {
94        EkModel {
95            predict_state:    |x| arr1(&[x[0].sin()]),
96            predict_jacobian: |x| ndarray::arr2(&[[x[0].cos()]]),
97            measure:          |x| arr1(&[x[0]]),
98            measure_jacobian: |_| ndarray::arr2(&[[1.0]]),
99        }
100    }
101
102    #[test]
103    fn ekf_update_moves_toward_measurement() {
104        let state =
105            KalmanState { mean: arr1(&[0.2_f64]), covariance: ndarray::arr2(&[[1.0]]) };
106        let model = scalar_model();
107        let config = EkConfig {
108            process_noise:     ndarray::arr2(&[[0.01]]),
109            measurement_noise: ndarray::arr2(&[[0.05]]),
110        };
111        let updated = ekf_update(&state, &arr1(&[1.0]).view(), &model, &config).unwrap();
112        assert!(updated.mean[0] > state.mean[0]);
113    }
114
115    #[test]
116    fn ekf_predict_advances_nonlinear_mean() {
117        let state =
118            KalmanState { mean: arr1(&[0.2_f64]), covariance: ndarray::arr2(&[[1.0]]) };
119        let model = scalar_model();
120        let config = EkConfig {
121            process_noise:     ndarray::arr2(&[[0.01]]),
122            measurement_noise: ndarray::arr2(&[[0.05]]),
123        };
124        let predicted = ekf_predict(&state, &model, &config).unwrap();
125        assert_relative_eq!(predicted.mean[0], state.mean[0].sin(), epsilon = 1e-12);
126        assert!(predicted.covariance[[0, 0]].is_finite());
127        assert!(predicted.covariance[[0, 0]] >= 0.0);
128    }
129
130    #[test]
131    fn ekf_predict_into_reuses_output_buffers() {
132        let state =
133            KalmanState { mean: arr1(&[0.1_f64]), covariance: ndarray::arr2(&[[0.5]]) };
134        let model = scalar_model();
135        let config = EkConfig {
136            process_noise:     ndarray::arr2(&[[0.01]]),
137            measurement_noise: ndarray::arr2(&[[0.05]]),
138        };
139        let mut mean = arr1(&[0.0]);
140        let mut covariance = ndarray::arr2(&[[0.0]]);
141        ekf_predict_into(&state, &model, &config, &mut mean, &mut covariance).unwrap();
142        assert_relative_eq!(mean[0], state.mean[0].sin(), epsilon = 1e-12);
143    }
144
145    #[test]
146    fn ekf_predict_into_rejects_output_dimension_mismatch() {
147        let state =
148            KalmanState { mean: arr1(&[0.1_f64]), covariance: ndarray::arr2(&[[0.5]]) };
149        let model = scalar_model();
150        let config = EkConfig {
151            process_noise:     ndarray::arr2(&[[0.01]]),
152            measurement_noise: ndarray::arr2(&[[0.05]]),
153        };
154        let mut mean = arr1(&[0.0, 0.0]);
155        let mut covariance = ndarray::arr2(&[[0.0]]);
156        assert_eq!(
157            ekf_predict_into(&state, &model, &config, &mut mean, &mut covariance),
158            Err(SensorError::DimensionMismatch)
159        );
160    }
161
162    #[test]
163    fn ekf_update_into_reuses_output_buffers() {
164        let state =
165            KalmanState { mean: arr1(&[0.2_f64]), covariance: ndarray::arr2(&[[1.0]]) };
166        let model = scalar_model();
167        let config = EkConfig {
168            process_noise:     ndarray::arr2(&[[0.01]]),
169            measurement_noise: ndarray::arr2(&[[0.05]]),
170        };
171        let mut mean = state.mean.clone();
172        let mut covariance = state.covariance.clone();
173        ekf_update_into(&state, &arr1(&[1.0]).view(), &model, &config, &mut mean, &mut covariance)
174            .unwrap();
175        assert!(mean[0] > state.mean[0]);
176    }
177
178    #[test]
179    fn ekf_update_rejects_singular_innovation_covariance() {
180        let state =
181            KalmanState { mean: arr1(&[0.0_f64]), covariance: ndarray::arr2(&[[0.0]]) };
182        let model = scalar_model();
183        let config = EkConfig {
184            process_noise:     ndarray::arr2(&[[0.0]]),
185            measurement_noise: ndarray::arr2(&[[0.0]]),
186        };
187        assert_eq!(
188            ekf_update(&state, &arr1(&[1.0]).view(), &model, &config),
189            Err(SensorError::NumericalInstability)
190        );
191    }
192}