use super::linalg::{add, eye, inverse, matmul, matmul_rect, sub, transpose_rect};
use crate::error::{SeqError, SeqResult};
#[derive(Debug, Clone)]
pub struct ExtendedKalmanResult {
pub means: Vec<Vec<f64>>,
pub covs: Vec<Vec<f64>>,
}
pub struct ExtendedKalmanFilter<'a> {
pub dim_x: usize,
pub dim_z: usize,
pub f: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
pub h: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
pub f_jacobian: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
pub h_jacobian: Box<dyn Fn(&[f64]) -> Vec<f64> + 'a>,
pub q: Vec<f64>,
pub r: Vec<f64>,
pub x0: Vec<f64>,
pub p0: Vec<f64>,
}
impl<'a> ExtendedKalmanFilter<'a> {
pub fn run(&self, z: &[f64]) -> SeqResult<ExtendedKalmanResult> {
if z.is_empty() {
return Err(SeqError::EmptyInput);
}
if z.len() % self.dim_z != 0 {
return Err(SeqError::DimensionMismatch {
a: z.len(),
b: self.dim_z,
});
}
let t_max = z.len() / self.dim_z;
let nx = self.dim_x;
let nz = self.dim_z;
let i_eye = eye(nx);
let mut means = Vec::with_capacity(t_max);
let mut covs = Vec::with_capacity(t_max);
let mut x = self.x0.clone();
let mut p = self.p0.clone();
for t in 0..t_max {
let x_pred = (self.f)(&x);
let f_j = (self.f_jacobian)(&x);
let f_j_t = transpose_rect(&f_j, nx, nx);
let fp = matmul(&f_j, &p, nx);
let fpft = matmul_rect(&fp, &f_j_t, nx, nx, nx);
let p_pred = add(&fpft, &self.q);
let z_t = &z[t * nz..(t + 1) * nz];
let h_x = (self.h)(&x_pred);
let mut y = vec![0.0; nz];
for k in 0..nz {
y[k] = z_t[k] - h_x[k];
}
let h_j = (self.h_jacobian)(&x_pred);
let h_j_t = transpose_rect(&h_j, nz, nx);
let hp = matmul_rect(&h_j, &p_pred, nz, nx, nx);
let hpht = matmul_rect(&hp, &h_j_t, nz, nx, nz);
let s = add(&hpht, &self.r);
let s_inv = inverse(&s, nz)?;
let pht = matmul_rect(&p_pred, &h_j_t, nx, nx, nz);
let k_gain = matmul_rect(&pht, &s_inv, nx, nz, nz);
let ky = matmul_rect(&k_gain, &y, nx, nz, 1);
for d in 0..nx {
x[d] = x_pred[d] + ky[d];
}
let kh = matmul_rect(&k_gain, &h_j, nx, nz, nx);
let i_kh = sub(&i_eye, &kh);
p = matmul_rect(&i_kh, &p_pred, nx, nx, nx);
means.push(x.clone());
covs.push(p.clone());
}
Ok(ExtendedKalmanResult { means, covs })
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ekf_linear_matches_kf() {
let dim_x = 1;
let dim_z = 1;
let ekf = ExtendedKalmanFilter {
dim_x,
dim_z,
f: Box::new(|x: &[f64]| vec![x[0]]),
h: Box::new(|x: &[f64]| vec![x[0]]),
f_jacobian: Box::new(|_x: &[f64]| vec![1.0]),
h_jacobian: Box::new(|_x: &[f64]| vec![1.0]),
q: vec![0.01],
r: vec![0.05],
x0: vec![0.0],
p0: vec![1.0],
};
let z = vec![1.0, 1.05, 0.95];
let res = ekf.run(&z).expect("ok");
assert_eq!(res.means.len(), 3);
assert!((res.means[res.means.len() - 1][0] - 1.0).abs() < 0.2);
}
}