1use nabled_linalg::lu::LuProviderScalar;
4use ndarray::{Array1, Array2, ArrayView1, ArrayView2};
5
6use crate::SensorError;
7
8#[derive(Debug, Clone, PartialEq)]
10pub struct KalmanState<T> {
11 pub mean: Array1<T>,
12 pub covariance: Array2<T>,
13}
14
15pub fn predict<T: LuProviderScalar>(
17 state: &KalmanState<T>,
18 f: &ArrayView2<'_, T>,
19 q: &ArrayView2<'_, T>,
20) -> Result<KalmanState<T>, SensorError> {
21 if state.mean.len() != f.nrows() || f.nrows() != f.ncols() {
22 return Err(SensorError::DimensionMismatch);
23 }
24 let mean = f.dot(&state.mean);
25 let covariance = f.dot(&state.covariance).dot(&f.t()) + q;
26 Ok(KalmanState { mean, covariance })
27}
28
29pub fn predict_into<T: LuProviderScalar>(
31 state: &mut KalmanState<T>,
32 f: &ArrayView2<'_, T>,
33 q: &ArrayView2<'_, T>,
34) -> Result<(), SensorError> {
35 *state = predict(state, f, q)?;
36 Ok(())
37}
38
39pub fn update<T: LuProviderScalar>(
41 state: &KalmanState<T>,
42 z: &ArrayView1<'_, T>,
43 h: &ArrayView2<'_, T>,
44 r: &ArrayView2<'_, T>,
45) -> Result<KalmanState<T>, SensorError> {
46 let innovation = z - &h.dot(&state.mean);
47 let s = h.dot(&state.covariance).dot(&h.t()) + r;
48 let s_inv = nabled_linalg::lu::inverse(&s).map_err(|_| SensorError::NumericalInstability)?;
49 let k = state.covariance.dot(&h.t()).dot(&s_inv);
50 let mean = &state.mean + k.dot(&innovation);
51 let n = state.mean.len();
52 let identity = Array2::<T>::eye(n);
53 let covariance = (identity - k.dot(h)).dot(&state.covariance);
54 Ok(KalmanState { mean, covariance })
55}
56
57pub fn update_into<T: LuProviderScalar>(
59 state: &mut KalmanState<T>,
60 z: &ArrayView1<'_, T>,
61 h: &ArrayView2<'_, T>,
62 r: &ArrayView2<'_, T>,
63) -> Result<(), SensorError> {
64 *state = update(state, z, h, r)?;
65 Ok(())
66}
67
68#[cfg(test)]
69mod tests {
70 use ndarray::arr2;
71
72 use super::*;
73
74 #[test]
75 fn kalman_fuses_measurement() {
76 let state =
77 KalmanState { mean: ndarray::arr1(&[0.0_f64]), covariance: arr2(&[[1.0]]) };
78 let f = arr2(&[[1.0]]);
79 let q = arr2(&[[0.01]]);
80 let predicted = predict(&state, &f.view(), &q.view()).unwrap();
81 let h = arr2(&[[1.0]]);
82 let r = arr2(&[[0.1]]);
83 let z = ndarray::arr1(&[1.0]);
84 let updated = update(&predicted, &z.view(), &h.view(), &r.view()).unwrap();
85 assert!(updated.mean[0] > 0.5);
86 assert!(updated.covariance[[0, 0]] < 1.0);
87 }
88
89 #[test]
90 fn predict_rejects_state_transition_dimension_mismatch() {
91 let state = KalmanState {
92 mean: ndarray::arr1(&[0.0_f64, 1.0]),
93 covariance: arr2(&[[1.0, 0.0], [0.0, 1.0]]),
94 };
95 let f = arr2(&[[1.0]]);
96 let q = arr2(&[[0.01]]);
97 assert_eq!(predict(&state, &f.view(), &q.view()), Err(SensorError::DimensionMismatch));
98 }
99
100 #[test]
101 fn predict_and_update_into_reuse_buffers() {
102 let mut state =
103 KalmanState { mean: ndarray::arr1(&[0.0_f64]), covariance: arr2(&[[1.0]]) };
104 let f = arr2(&[[1.0]]);
105 let q = arr2(&[[0.01]]);
106 predict_into(&mut state, &f.view(), &q.view()).unwrap();
107 let h = arr2(&[[1.0]]);
108 let r = arr2(&[[0.1]]);
109 let z = ndarray::arr1(&[1.0]);
110 update_into(&mut state, &z.view(), &h.view(), &r.view()).unwrap();
111 assert!(state.mean[0] > 0.5);
112 }
113
114 #[test]
115 fn update_rejects_singular_innovation_covariance() {
116 let state =
117 KalmanState { mean: ndarray::arr1(&[0.0_f64]), covariance: arr2(&[[0.0]]) };
118 let h = arr2(&[[1.0]]);
119 let r = arr2(&[[0.0]]);
120 let z = ndarray::arr1(&[1.0]);
121 assert_eq!(
122 update(&state, &z.view(), &h.view(), &r.view()),
123 Err(SensorError::NumericalInstability)
124 );
125 }
126}