1use nabled_core::scalar::NabledReal;
4use nabled_linalg::lu::{self, LuProviderScalar};
5use ndarray::{Array1, Array2, ArrayView1};
6
7use crate::SensorError;
8use crate::kalman::KalmanState;
9
10#[derive(Clone)]
12pub struct EkModel<T> {
13 pub predict_state: fn(&ArrayView1<'_, T>) -> Array1<T>,
14 pub predict_jacobian: fn(&ArrayView1<'_, T>) -> Array2<T>,
15 pub measure: fn(&ArrayView1<'_, T>) -> Array1<T>,
16 pub measure_jacobian: fn(&ArrayView1<'_, T>) -> Array2<T>,
17}
18
19#[derive(Debug, Clone)]
20pub struct EkConfig<T> {
21 pub process_noise: Array2<T>,
22 pub measurement_noise: Array2<T>,
23}
24
25pub fn ekf_predict<T: NabledReal>(
26 state: &KalmanState<T>,
27 model: &EkModel<T>,
28 config: &EkConfig<T>,
29) -> Result<KalmanState<T>, SensorError> {
30 let mut mean = Array1::zeros(state.mean.len());
31 let mut covariance = Array2::zeros(state.covariance.dim());
32 ekf_predict_into(state, model, config, &mut mean, &mut covariance)?;
33 Ok(KalmanState { mean, covariance })
34}
35
36pub fn ekf_predict_into<T: NabledReal>(
37 state: &KalmanState<T>,
38 model: &EkModel<T>,
39 config: &EkConfig<T>,
40 mean_out: &mut Array1<T>,
41 covariance_out: &mut Array2<T>,
42) -> Result<(), SensorError> {
43 let f = (model.predict_jacobian)(&state.mean.view());
44 let predicted_mean = (model.predict_state)(&state.mean.view());
45 if mean_out.len() != predicted_mean.len() || covariance_out.dim() != state.covariance.dim() {
46 return Err(SensorError::DimensionMismatch);
47 }
48 mean_out.assign(&predicted_mean);
49 *covariance_out = f.dot(&state.covariance).dot(&f.t()) + &config.process_noise;
50 Ok(())
51}
52
53pub fn ekf_update<T: NabledReal + LuProviderScalar>(
54 state: &KalmanState<T>,
55 measurement: &ArrayView1<'_, T>,
56 model: &EkModel<T>,
57 config: &EkConfig<T>,
58) -> Result<KalmanState<T>, SensorError> {
59 let mut mean = state.mean.clone();
60 let mut covariance = state.covariance.clone();
61 ekf_update_into(state, measurement, model, config, &mut mean, &mut covariance)?;
62 Ok(KalmanState { mean, covariance })
63}
64
65pub fn ekf_update_into<T: NabledReal + LuProviderScalar>(
66 state: &KalmanState<T>,
67 measurement: &ArrayView1<'_, T>,
68 model: &EkModel<T>,
69 config: &EkConfig<T>,
70 mean_out: &mut Array1<T>,
71 covariance_out: &mut Array2<T>,
72) -> Result<(), SensorError> {
73 let h = (model.measure_jacobian)(&state.mean.view());
74 let predicted = (model.measure)(&state.mean.view());
75 let innovation = measurement - &predicted;
76 let s = h.dot(&state.covariance).dot(&h.t()) + &config.measurement_noise;
77 let s_inv = lu::inverse(&s).map_err(|_| SensorError::NumericalInstability)?;
78 let k = state.covariance.dot(&h.t()).dot(&s_inv);
79 mean_out.assign(&(state.mean.clone() + k.dot(&innovation)));
80 let n = state.mean.len();
81 let identity = Array2::<T>::eye(n);
82 *covariance_out = (identity - k.dot(&h)).dot(&state.covariance);
83 Ok(())
84}
85
86#[cfg(test)]
87mod tests {
88 use approx::assert_relative_eq;
89 use ndarray::arr1;
90
91 use super::*;
92
93 fn scalar_model() -> EkModel<f64> {
94 EkModel {
95 predict_state: |x| arr1(&[x[0].sin()]),
96 predict_jacobian: |x| ndarray::arr2(&[[x[0].cos()]]),
97 measure: |x| arr1(&[x[0]]),
98 measure_jacobian: |_| ndarray::arr2(&[[1.0]]),
99 }
100 }
101
102 #[test]
103 fn ekf_update_moves_toward_measurement() {
104 let state =
105 KalmanState { mean: arr1(&[0.2_f64]), covariance: ndarray::arr2(&[[1.0]]) };
106 let model = scalar_model();
107 let config = EkConfig {
108 process_noise: ndarray::arr2(&[[0.01]]),
109 measurement_noise: ndarray::arr2(&[[0.05]]),
110 };
111 let updated = ekf_update(&state, &arr1(&[1.0]).view(), &model, &config).unwrap();
112 assert!(updated.mean[0] > state.mean[0]);
113 }
114
115 #[test]
116 fn ekf_predict_advances_nonlinear_mean() {
117 let state =
118 KalmanState { mean: arr1(&[0.2_f64]), covariance: ndarray::arr2(&[[1.0]]) };
119 let model = scalar_model();
120 let config = EkConfig {
121 process_noise: ndarray::arr2(&[[0.01]]),
122 measurement_noise: ndarray::arr2(&[[0.05]]),
123 };
124 let predicted = ekf_predict(&state, &model, &config).unwrap();
125 assert_relative_eq!(predicted.mean[0], state.mean[0].sin(), epsilon = 1e-12);
126 assert!(predicted.covariance[[0, 0]].is_finite());
127 assert!(predicted.covariance[[0, 0]] >= 0.0);
128 }
129
130 #[test]
131 fn ekf_predict_into_reuses_output_buffers() {
132 let state =
133 KalmanState { mean: arr1(&[0.1_f64]), covariance: ndarray::arr2(&[[0.5]]) };
134 let model = scalar_model();
135 let config = EkConfig {
136 process_noise: ndarray::arr2(&[[0.01]]),
137 measurement_noise: ndarray::arr2(&[[0.05]]),
138 };
139 let mut mean = arr1(&[0.0]);
140 let mut covariance = ndarray::arr2(&[[0.0]]);
141 ekf_predict_into(&state, &model, &config, &mut mean, &mut covariance).unwrap();
142 assert_relative_eq!(mean[0], state.mean[0].sin(), epsilon = 1e-12);
143 }
144
145 #[test]
146 fn ekf_predict_into_rejects_output_dimension_mismatch() {
147 let state =
148 KalmanState { mean: arr1(&[0.1_f64]), covariance: ndarray::arr2(&[[0.5]]) };
149 let model = scalar_model();
150 let config = EkConfig {
151 process_noise: ndarray::arr2(&[[0.01]]),
152 measurement_noise: ndarray::arr2(&[[0.05]]),
153 };
154 let mut mean = arr1(&[0.0, 0.0]);
155 let mut covariance = ndarray::arr2(&[[0.0]]);
156 assert_eq!(
157 ekf_predict_into(&state, &model, &config, &mut mean, &mut covariance),
158 Err(SensorError::DimensionMismatch)
159 );
160 }
161
162 #[test]
163 fn ekf_update_into_reuses_output_buffers() {
164 let state =
165 KalmanState { mean: arr1(&[0.2_f64]), covariance: ndarray::arr2(&[[1.0]]) };
166 let model = scalar_model();
167 let config = EkConfig {
168 process_noise: ndarray::arr2(&[[0.01]]),
169 measurement_noise: ndarray::arr2(&[[0.05]]),
170 };
171 let mut mean = state.mean.clone();
172 let mut covariance = state.covariance.clone();
173 ekf_update_into(&state, &arr1(&[1.0]).view(), &model, &config, &mut mean, &mut covariance)
174 .unwrap();
175 assert!(mean[0] > state.mean[0]);
176 }
177
178 #[test]
179 fn ekf_update_rejects_singular_innovation_covariance() {
180 let state =
181 KalmanState { mean: arr1(&[0.0_f64]), covariance: ndarray::arr2(&[[0.0]]) };
182 let model = scalar_model();
183 let config = EkConfig {
184 process_noise: ndarray::arr2(&[[0.0]]),
185 measurement_noise: ndarray::arr2(&[[0.0]]),
186 };
187 assert_eq!(
188 ekf_update(&state, &arr1(&[1.0]).view(), &model, &config),
189 Err(SensorError::NumericalInstability)
190 );
191 }
192}