use nalgebra::{RealField, SMatrix, SVector};
use crate::{
measurement::{LinearMeasurement, LinearisableMeasurement, Measurement},
system::{
InputSystem, LinearNoInputSystem, LinearSystem, LinearisableSystem, NoInputSystem,
NonLinearSystem, StepFunction, System,
},
};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub enum KfError {
InnovationCovarianceNotInvertible,
}
pub trait KalmanFilter<T, const N: usize, S> {
fn state(&self) -> &SVector<T, N>;
fn covariance(&self) -> &SMatrix<T, N, N>;
fn covariance_mut(&mut self) -> &mut SMatrix<T, N, N>;
fn system_mut(&mut self) -> &mut S;
}
pub trait KalmanPredict<T, const N: usize> {
type Error;
fn predict(&mut self) -> Result<&SVector<T, N>, Self::Error>;
}
pub trait KalmanPredictInput<T, const N: usize, const U: usize> {
type Error;
fn predict(&mut self, u: SVector<T, U>) -> Result<&SVector<T, N>, Self::Error>;
}
pub trait KalmanUpdate<T, const N: usize, const M: usize, ME: Measurement<T, N, M>> {
type Error;
fn update(&mut self, measurement: &ME) -> Result<&SVector<T, N>, Self::Error>;
}
#[allow(non_snake_case)]
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Kalman<T: RealField, const N: usize, const U: usize, S> {
P: SMatrix<T, N, N>,
pub system: S,
}
pub type KalmanLinear<T, const N: usize, const U: usize> = Kalman<T, N, U, LinearSystem<T, N, U>>;
pub type KalmanLinearNoInput<T, const N: usize> = Kalman<T, N, 0, LinearNoInputSystem<T, N>>;
pub type EKF<T, const N: usize, const U: usize> = Kalman<T, N, U, NonLinearSystem<T, N, U>>;
impl<T: RealField, const N: usize, const U: usize, S> Kalman<T, N, U, S>
where
S: System<T, N, U>,
{
pub fn new_custom(system: S, initial_covariance: SMatrix<T, N, N>) -> Self {
Self {
P: initial_covariance,
system,
}
}
}
impl<T: RealField, const N: usize, const U: usize, S> KalmanFilter<T, N, S> for Kalman<T, N, U, S>
where
S: System<T, N, U>,
{
fn state(&self) -> &SVector<T, N> {
self.system.state()
}
fn covariance(&self) -> &SMatrix<T, N, N> {
&self.P
}
fn covariance_mut(&mut self) -> &mut SMatrix<T, N, N> {
&mut self.P
}
fn system_mut(&mut self) -> &mut S {
&mut self.system
}
}
impl<T: RealField + Copy, const N: usize, S> KalmanPredict<T, N> for Kalman<T, N, 0, S>
where
S: NoInputSystem<T, N> + LinearisableSystem<T, N, 0>,
{
type Error = KfError;
fn predict(&mut self) -> Result<&SVector<T, N>, Self::Error> {
self.system.step();
self.P = self.system.transition() * self.P * self.system.transition_transpose()
+ self.system.covariance();
Ok(self.system.state())
}
}
impl<T: RealField + Copy, const N: usize, const U: usize, S> KalmanPredictInput<T, N, U>
for Kalman<T, N, U, S>
where
S: InputSystem<T, N, U> + LinearisableSystem<T, N, U>,
{
type Error = KfError;
fn predict(&mut self, u: SVector<T, U>) -> Result<&SVector<T, N>, Self::Error> {
self.system.step(u);
self.P = self.system.transition() * self.P * self.system.transition_transpose()
+ self.system.covariance();
Ok(self.system.state())
}
}
impl<
T: RealField + Copy,
const N: usize,
const M: usize,
const U: usize,
S: System<T, N, U>,
ME,
> KalmanUpdate<T, N, M, ME> for Kalman<T, N, U, S>
where
ME: Measurement<T, N, M> + LinearisableMeasurement<T, N, M>,
{
type Error = KfError;
#[allow(non_snake_case)]
fn update(&mut self, measurement: &ME) -> Result<&SVector<T, N>, Self::Error> {
let z_p = measurement.predict(self.system.state());
let y = measurement.measurement() - z_p;
let S = measurement.observation() * self.P * measurement.observation_transpose()
+ measurement.covariance();
let Some(S_Inverse) = S.try_inverse() else {
return Err(KfError::InnovationCovarianceNotInvertible);
};
let K = self.P * measurement.observation_transpose() * S_Inverse;
*self.system.state_mut() += K * y;
self.P = (SMatrix::identity() - K * measurement.observation()) * self.P;
self.P = self.P.symmetric_part();
Ok(self.state())
}
}
impl<T, const N: usize, const U: usize> Kalman<T, N, U, LinearSystem<T, N, U>>
where
T: RealField + Copy,
{
#[allow(non_snake_case)]
pub fn new_with_input(
F: SMatrix<T, N, N>,
Q: SMatrix<T, N, N>,
B: SMatrix<T, N, U>,
x_initial: SVector<T, N>,
P_initial: SMatrix<T, N, N>,
) -> Self {
let s = Self {
P: P_initial,
system: LinearSystem::new(F, Q, B, x_initial),
};
s
}
}
impl<T, const N: usize> Kalman<T, N, 0, LinearNoInputSystem<T, N>>
where
T: RealField + Copy,
{
#[allow(non_snake_case)]
pub fn new(
F: SMatrix<T, N, N>,
Q: SMatrix<T, N, N>,
x_initial: SVector<T, N>,
P_initial: SMatrix<T, N, N>,
) -> Self {
Self {
P: P_initial,
system: LinearNoInputSystem::new(F, Q, x_initial),
}
}
}
impl<T, const N: usize, const U: usize> Kalman<T, N, U, NonLinearSystem<T, N, U>>
where
T: RealField + Copy,
{
#[allow(non_snake_case)]
pub fn new_ekf_with_input(
step_fn: StepFunction<T, N, U>,
x_initial: SVector<T, N>,
P_initial: SMatrix<T, N, N>,
) -> Self {
Self {
P: P_initial,
system: NonLinearSystem::new(step_fn, x_initial),
}
}
}
#[derive(Debug, Clone)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "defmt", derive(defmt::Format))]
pub struct Kalman1M<T: RealField, const N: usize, const U: usize, const M: usize, S, ME> {
kalman: Kalman<T, N, U, S>,
measurement: ME,
}
pub type Kalman1MLinearNoInput<T, const N: usize, const M: usize> =
Kalman1M<T, N, 0, M, LinearNoInputSystem<T, N>, LinearMeasurement<T, N, M>>;
pub type Kalman1MLinear<T, const N: usize, const U: usize, const M: usize> =
Kalman1M<T, N, U, M, LinearSystem<T, N, U>, LinearMeasurement<T, N, M>>;
pub type EKF1M<T, const N: usize, const U: usize, const M: usize> =
Kalman1M<T, N, U, M, NonLinearSystem<T, N, U>, LinearMeasurement<T, N, M>>;
impl<T: RealField, const N: usize, const U: usize, const M: usize, S, ME>
Kalman1M<T, N, U, M, S, ME>
where
S: System<T, N, U>,
ME: Measurement<T, N, M>,
{
pub fn new_custom(system: S, initial_covariance: SMatrix<T, N, N>, measurement: ME) -> Self {
Self {
kalman: Kalman::new_custom(system, initial_covariance),
measurement,
}
}
}
impl<T: RealField, const N: usize, const U: usize, const M: usize, S, ME> KalmanFilter<T, N, S>
for Kalman1M<T, N, U, M, S, ME>
where
S: System<T, N, U>,
{
fn state(&self) -> &SVector<T, N> {
self.kalman.state()
}
fn covariance(&self) -> &SMatrix<T, N, N> {
self.kalman.covariance()
}
fn covariance_mut(&mut self) -> &mut SMatrix<T, N, N> {
self.kalman.covariance_mut()
}
fn system_mut(&mut self) -> &mut S {
self.kalman.system_mut()
}
}
impl<T, const N: usize, const U: usize, const M: usize, S, ME> KalmanPredictInput<T, N, U>
for Kalman1M<T, N, U, M, S, ME>
where
T: RealField + Copy,
S: InputSystem<T, N, U> + LinearisableSystem<T, N, U>,
ME: Measurement<T, N, M>,
{
type Error = KfError;
fn predict(&mut self, u: SVector<T, U>) -> Result<&SVector<T, N>, Self::Error> {
self.kalman.predict(u)
}
}
impl<T, const N: usize, const M: usize, S, ME> KalmanPredict<T, N> for Kalman1M<T, N, 0, M, S, ME>
where
T: RealField + Copy,
S: NoInputSystem<T, N> + LinearisableSystem<T, N, 0>,
ME: Measurement<T, N, M>,
{
type Error = KfError;
fn predict(&mut self) -> Result<&SVector<T, N>, Self::Error> {
self.kalman.predict()
}
}
impl<T, const N: usize, const U: usize, const M: usize, S, ME> Kalman1M<T, N, U, M, S, ME>
where
T: RealField + Copy,
S: System<T, N, U>,
ME: LinearisableMeasurement<T, N, M>,
{
pub fn update(&mut self, z: SVector<T, M>) -> Result<&SVector<T, N>, KfError> {
self.measurement.set_measurement(z);
self.kalman.update(&self.measurement)?;
Ok(self.kalman.state())
}
}
impl<T, const N: usize, const U: usize, const M: usize> Kalman1MLinear<T, N, U, M>
where
T: RealField + Copy,
{
#[allow(non_snake_case)]
pub fn new_with_input(
F: SMatrix<T, N, N>,
Q: SMatrix<T, N, N>,
B: SMatrix<T, N, U>,
H: SMatrix<T, M, N>,
R: SMatrix<T, M, M>,
x_initial: SVector<T, N>,
) -> Self {
Self {
kalman: Kalman::new_with_input(F, Q, B, x_initial, Q),
measurement: LinearMeasurement::new(H, R, SMatrix::zeros()),
}
}
}
impl<T, const N: usize, const M: usize> Kalman1MLinearNoInput<T, N, M>
where
T: RealField + Copy,
{
#[allow(non_snake_case)]
pub fn new(
F: SMatrix<T, N, N>,
Q: SMatrix<T, N, N>,
H: SMatrix<T, M, N>,
R: SMatrix<T, M, M>,
x_initial: SVector<T, N>,
) -> Self {
Self {
kalman: Kalman::new(F, Q, x_initial, Q),
measurement: LinearMeasurement::new(H, R, SMatrix::zeros()),
}
}
}
impl<T, const N: usize, const U: usize, const M: usize>
Kalman1M<T, N, U, M, NonLinearSystem<T, N, U>, LinearMeasurement<T, N, M>>
where
T: RealField + Copy,
{
#[allow(non_snake_case)]
pub fn new_ekf_with_input(
step_fn: StepFunction<T, N, U>,
H: SMatrix<T, M, N>,
R: SMatrix<T, M, M>,
x_initial: SVector<T, N>,
P_initial: SMatrix<T, N, N>,
) -> Self {
Self {
kalman: Kalman::new_ekf_with_input(step_fn, x_initial, P_initial),
measurement: LinearMeasurement::new(H, R, SMatrix::zeros()),
}
}
}
#[cfg(test)]
mod test {
use nalgebra::Matrix1;
use crate::{kalman::KalmanUpdate, measurement::LinearMeasurement};
#[test]
fn does_not_panic() {
let mut k = super::KalmanLinear::new_with_input(
Matrix1::new(1.0),
Matrix1::new(0.0),
Matrix1::new(0.0),
Matrix1::new(0.0),
Matrix1::new(100.0),
);
k.update(&LinearMeasurement::new(
Matrix1::identity(),
Matrix1::identity(),
Matrix1::new(0.0),
))
.unwrap();
}
#[test]
#[should_panic]
fn does_panic() {
let mut k = super::KalmanLinear::new_with_input(
Matrix1::new(1.0),
Matrix1::new(0.0),
Matrix1::new(0.0),
Matrix1::new(0.0),
Matrix1::zeros(),
);
k.update(&LinearMeasurement::new(
Matrix1::identity(),
Matrix1::zeros(),
Matrix1::new(0.0),
))
.unwrap();
}
}