use crate::time::{Duration, Epoch};
use crate::alloc::{boxed::Box, vec::Vec};
use crate::errors::YakfError;
use crate::lie::base::{LieAlgebraSE3, LieGroupSE3, LieVectorSE3};
use crate::linalg::allocator::Allocator;
use crate::linalg::{Const, DefaultAllocator, DimName, OMatrix, OVector, U3, U4, U6};
pub trait ESStates {
fn state(&self) -> &LieGroupSE3 {
unimplemented!()
}
fn set_state(&mut self, state: LieGroupSE3) {
unimplemented!()
}
fn epoch(&self) -> Epoch {
unimplemented!()
}
fn set_epoch(&mut self, _epoch: Epoch) {
unimplemented!()
}
}
pub struct ESEKF<S>
where
S: ESStates,
{
pub stamp_state: S,
pmatrix: OMatrix<f64, U6, U6>,
qmatrix: OMatrix<f64, U6, U6>,
nmatrix: OMatrix<f64, Const<12>, Const<12>>,
f: Box<dyn Fn(&LieGroupSE3, &LieVectorSE3, Duration) -> OMatrix<f64, U6, U6>>,
g: Box<dyn Fn(&LieGroupSE3, &LieVectorSE3, Duration) -> OMatrix<f64, U6, U6>>,
h: Box<dyn Fn(&LieGroupSE3) -> OMatrix<f64, Const<12>, U6>>,
ob: Box<dyn Fn(&LieGroupSE3) -> OVector<f64, Const<12>>>,
}
impl<S> ESEKF<S>
where
S: ESStates,
{
#[allow(dead_code)]
pub fn build(
f: Box<dyn Fn(&LieGroupSE3, &LieVectorSE3, Duration) -> OMatrix<f64, U6, U6>>,
g: Box<dyn Fn(&LieGroupSE3, &LieVectorSE3, Duration) -> OMatrix<f64, U6, U6>>,
h: Box<dyn Fn(&LieGroupSE3) -> OMatrix<f64, Const<12>, U6>>,
ob: Box<dyn Fn(&LieGroupSE3) -> OVector<f64, Const<12>>>,
stamp_state: S,
pmatrix: OMatrix<f64, U6, U6>,
qmatrix: OMatrix<f64, U6, U6>,
nmatrix: OMatrix<f64, Const<12>, Const<12>>,
) -> Self {
Self {
stamp_state,
pmatrix,
qmatrix,
nmatrix,
f,
g,
h,
ob,
}
}
pub fn transition_f(
&self,
x_estimate: &LieGroupSE3,
u: &LieVectorSE3,
dt: Duration,
) -> OMatrix<f64, U6, U6> {
(self.f)(x_estimate, u, dt)
}
pub fn transition_g(
&self,
x_estimate: &LieGroupSE3,
u: &LieVectorSE3,
dt: Duration,
) -> OMatrix<f64, U6, U6> {
(self.g)(x_estimate, u, dt)
}
pub fn transition_h(&self, x_predict: &LieGroupSE3) -> OMatrix<f64, Const<12>, U6> {
(self.h)(x_predict)
}
pub fn propagate(&self, u: &LieVectorSE3, dt: Duration) -> LieGroupSE3 {
let u_col = u.to_vec6();
let f = self.transition_f(self.stamp_state.state(), u, dt);
let inc_col = f * u_col;
let delta_alg = LieVectorSE3::from_vec6(&inc_col).to_algebra() * dt.in_seconds();
let delta_group = LieGroupSE3::from_algebra(&delta_alg);
let mut m = self.stamp_state.state().clone();
m.increment_by_left_delta(delta_group);
m
}
pub fn measure(&self, x: &LieGroupSE3) -> OVector<f64, Const<12>> {
(self.ob)(x)
}
#[allow(dead_code)]
pub fn feed_and_update(
&mut self,
measure: OVector<f64, Const<12>>,
m_epoch: Epoch,
u: LieVectorSE3,
) -> Result<(), YakfError> {
let dt = m_epoch - self.stamp_state.epoch();
let mut x_predict = self.propagate(&u, dt);
let f = self.transition_f(self.stamp_state.state(), &u, dt);
let g = self.transition_g(self.stamp_state.state(), &u, dt);
let p_predict = f * &self.pmatrix * &f.transpose() + g * &self.qmatrix * &g.transpose();
let ob_predict = self.measure(&x_predict);
let z = measure - ob_predict;
let h = self.transition_h(&x_predict);
let zmatrix = h * p_predict * h.transpose() + self.nmatrix;
match zmatrix.try_inverse() {
Some(zm_inv) => {
let kmatrix = p_predict * h.transpose() * zm_inv;
let dx = kmatrix * z;
let dx_group = LieVectorSE3::from_vec6(&dx).to_group();
x_predict.increment_by_left_delta(dx_group);
self.stamp_state.set_state(x_predict);
self.stamp_state.set_epoch(m_epoch);
self.pmatrix = &self.pmatrix - &kmatrix * &zmatrix * &kmatrix.transpose();
Ok(())
}
None => Err(YakfError::InverseErr),
}
}
}