use core::marker::PhantomData;
use nalgebra::RealField;
use num_traits::Float;
use crate::models::{ObservationModel, TransitionModel};
use crate::types::spaces::{
ComputeInnovation, Measurement, MeasurementCovariance, StateCovariance, StateVector,
};
use crate::types::transforms::{
ObservationMatrix, TransitionMatrix, compute_innovation_covariance, compute_kalman_gain,
joseph_update,
};
#[derive(Debug, Clone, PartialEq)]
pub struct KalmanState<T: RealField, const N: usize> {
pub mean: StateVector<T, N>,
pub covariance: StateCovariance<T, N>,
}
impl<T: RealField + Copy, const N: usize> KalmanState<T, N> {
#[inline]
pub fn new(mean: StateVector<T, N>, covariance: StateCovariance<T, N>) -> Self {
Self { mean, covariance }
}
#[inline]
pub fn with_identity_covariance(mean: StateVector<T, N>) -> Self {
Self {
mean,
covariance: StateCovariance::identity(),
}
}
#[inline]
pub fn with_diagonal_covariance(
mean: StateVector<T, N>,
diagonal: &nalgebra::SVector<T, N>,
) -> Self {
Self {
mean,
covariance: StateCovariance::from_diagonal(diagonal),
}
}
#[inline]
pub fn uncertainty(&self) -> T {
self.covariance.trace()
}
#[inline]
pub fn position<const P: usize>(&self) -> [T; P] {
let mut pos = [T::zero(); P];
for (i, p) in pos.iter_mut().enumerate() {
*p = *self.mean.index(i);
}
pos
}
}
#[derive(Debug, Clone)]
pub struct KalmanFilter<T, Trans, Obs, const N: usize, const M: usize>
where
T: RealField,
Trans: TransitionModel<T, N>,
Obs: ObservationModel<T, N, M>,
{
pub transition: Trans,
pub observation: Obs,
_marker: PhantomData<T>,
}
impl<T, Trans, Obs, const N: usize, const M: usize> KalmanFilter<T, Trans, Obs, N, M>
where
T: RealField + Float + Copy,
Trans: TransitionModel<T, N>,
Obs: ObservationModel<T, N, M>,
{
#[inline]
pub fn new(transition: Trans, observation: Obs) -> Self {
Self {
transition,
observation,
_marker: PhantomData,
}
}
pub fn predict(&self, state: &KalmanState<T, N>, dt: T) -> KalmanState<T, N> {
let f = self.transition.transition_matrix(dt);
let q = self.transition.process_noise(dt);
let predicted_mean = f.apply_state(&state.mean);
let predicted_cov = f.propagate_covariance(&state.covariance).add(&q);
KalmanState {
mean: predicted_mean,
covariance: predicted_cov,
}
}
pub fn update(
&self,
state: &KalmanState<T, N>,
measurement: &Measurement<T, M>,
) -> Option<KalmanState<T, N>> {
let h = self.observation.observation_matrix();
let r = self.observation.measurement_noise();
self.update_with_matrices(state, measurement, &h, &r)
}
pub fn update_with_matrices(
&self,
state: &KalmanState<T, N>,
measurement: &Measurement<T, M>,
obs_matrix: &ObservationMatrix<T, M, N>,
meas_noise: &MeasurementCovariance<T, M>,
) -> Option<KalmanState<T, N>> {
let predicted_meas = obs_matrix.observe(&state.mean);
let innovation = measurement.innovation(predicted_meas);
let innovation_cov =
compute_innovation_covariance(&state.covariance, obs_matrix, meas_noise);
let kalman_gain = compute_kalman_gain(&state.covariance, obs_matrix, &innovation_cov)?;
let correction = kalman_gain.correct(&innovation);
let updated_mean =
StateVector::from_svector(state.mean.as_svector() + correction.as_svector());
let updated_cov = joseph_update(&state.covariance, &kalman_gain, obs_matrix, meas_noise);
Some(KalmanState {
mean: updated_mean,
covariance: updated_cov,
})
}
pub fn step(
&self,
state: &KalmanState<T, N>,
dt: T,
measurement: &Measurement<T, M>,
) -> Option<KalmanState<T, N>> {
let predicted = self.predict(state, dt);
self.update(&predicted, measurement)
}
pub fn measurement_likelihood(
&self,
state: &KalmanState<T, N>,
measurement: &Measurement<T, M>,
) -> Option<T> {
let h = self.observation.observation_matrix();
let r = self.observation.measurement_noise();
let predicted_meas = h.observe(&state.mean);
let innovation = measurement.innovation(predicted_meas);
let innovation_cov = compute_innovation_covariance(&state.covariance, &h, &r);
let innovation_cov_typed =
crate::types::spaces::Covariance::from_matrix(*innovation_cov.as_matrix());
crate::types::gaussian::gaussian_likelihood(&innovation, &innovation_cov_typed)
}
pub fn mahalanobis_distance_squared(
&self,
state: &KalmanState<T, N>,
measurement: &Measurement<T, M>,
) -> Option<T> {
let h = self.observation.observation_matrix();
let r = self.observation.measurement_noise();
let predicted_meas = h.observe(&state.mean);
let innovation = measurement.innovation(predicted_meas);
let innovation_cov = compute_innovation_covariance(&state.covariance, &h, &r);
let s_inv = innovation_cov.as_matrix().try_inverse()?;
let y = innovation.as_svector();
let d_sq = (y.transpose() * s_inv * y)[(0, 0)];
Some(d_sq)
}
}
pub fn predict<T: RealField + Copy, const N: usize>(
state: &KalmanState<T, N>,
transition: &TransitionMatrix<T, N>,
process_noise: &StateCovariance<T, N>,
) -> KalmanState<T, N> {
let predicted_mean = transition.apply_state(&state.mean);
let predicted_cov = transition
.propagate_covariance(&state.covariance)
.add(process_noise);
KalmanState {
mean: predicted_mean,
covariance: predicted_cov,
}
}
pub fn update<T: RealField + Copy, const N: usize, const M: usize>(
state: &KalmanState<T, N>,
measurement: &Measurement<T, M>,
obs_matrix: &ObservationMatrix<T, M, N>,
meas_noise: &MeasurementCovariance<T, M>,
) -> Option<KalmanState<T, N>> {
let predicted_meas = obs_matrix.observe(&state.mean);
let innovation = measurement.innovation(predicted_meas);
let innovation_cov = compute_innovation_covariance(&state.covariance, obs_matrix, meas_noise);
let kalman_gain = compute_kalman_gain(&state.covariance, obs_matrix, &innovation_cov)?;
let correction = kalman_gain.correct(&innovation);
let updated_mean = StateVector::from_svector(state.mean.as_svector() + correction.as_svector());
let updated_cov = joseph_update(&state.covariance, &kalman_gain, obs_matrix, meas_noise);
Some(KalmanState {
mean: updated_mean,
covariance: updated_cov,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::models::{ConstantVelocity2D, PositionSensor2D};
#[test]
fn test_kalman_state_creation() {
let mean: StateVector<f64, 4> = StateVector::from_array([0.0, 0.0, 1.0, 0.0]);
let cov: StateCovariance<f64, 4> = StateCovariance::identity();
let state = KalmanState::new(mean, cov);
assert!((state.mean.index(2) - 1.0).abs() < 1e-10);
assert!((state.uncertainty() - 4.0).abs() < 1e-10); }
#[test]
fn test_kalman_predict() {
let transition = ConstantVelocity2D::new(0.1, 0.99);
let sensor = PositionSensor2D::new(1.0, 0.95);
let filter = KalmanFilter::new(transition, sensor);
let state = KalmanState::new(
StateVector::from_array([0.0, 0.0, 10.0, 0.0]),
StateCovariance::identity(),
);
let predicted = filter.predict(&state, 1.0);
assert!((predicted.mean.index(0) - 10.0).abs() < 1e-10);
assert!((predicted.mean.index(1) - 0.0).abs() < 1e-10);
assert!((predicted.mean.index(2) - 10.0).abs() < 1e-10);
assert!((predicted.mean.index(3) - 0.0).abs() < 1e-10);
assert!(predicted.uncertainty() > state.uncertainty());
}
#[test]
fn test_kalman_update() {
let transition = ConstantVelocity2D::new(0.1, 0.99);
let sensor = PositionSensor2D::new(1.0, 0.95);
let filter = KalmanFilter::new(transition, sensor);
let state = KalmanState::new(
StateVector::from_array([0.0, 0.0, 0.0, 0.0]),
StateCovariance::from_matrix(nalgebra::SMatrix::<f64, 4, 4>::identity().scale(100.0)),
);
let measurement = Measurement::from_array([10.0, 5.0]);
let updated = filter.update(&state, &measurement).unwrap();
assert!(updated.mean.index(0) > &5.0);
assert!(updated.mean.index(1) > &2.0);
assert!(updated.uncertainty() < state.uncertainty());
}
#[test]
fn test_kalman_step() {
let transition = ConstantVelocity2D::new(0.1, 0.99);
let sensor = PositionSensor2D::new(1.0, 0.95);
let filter = KalmanFilter::new(transition, sensor);
let state = KalmanState::new(
StateVector::from_array([0.0, 0.0, 10.0, 5.0]),
StateCovariance::identity(),
);
let measurement = Measurement::from_array([10.0, 5.0]);
let updated = filter.step(&state, 1.0, &measurement).unwrap();
assert!((updated.mean.index(0) - 10.0).abs() < 1.0);
assert!((updated.mean.index(1) - 5.0).abs() < 1.0);
}
#[test]
fn test_measurement_likelihood() {
let transition = ConstantVelocity2D::new(0.1, 0.99);
let sensor = PositionSensor2D::new(1.0, 0.95);
let filter = KalmanFilter::new(transition, sensor);
let state = KalmanState::new(
StateVector::from_array([10.0, 5.0, 0.0, 0.0]),
StateCovariance::identity(),
);
let close_meas = Measurement::from_array([10.0, 5.0]);
let close_likelihood = filter.measurement_likelihood(&state, &close_meas).unwrap();
let far_meas = Measurement::from_array([100.0, 100.0]);
let far_likelihood = filter.measurement_likelihood(&state, &far_meas).unwrap();
assert!(close_likelihood > far_likelihood);
}
#[test]
fn test_mahalanobis_distance() {
let transition = ConstantVelocity2D::new(0.1, 0.99);
let sensor = PositionSensor2D::new(1.0, 0.95);
let filter = KalmanFilter::new(transition, sensor);
let state = KalmanState::new(
StateVector::from_array([0.0, 0.0, 0.0, 0.0]),
StateCovariance::identity(),
);
let close_meas = Measurement::from_array([0.0, 0.0]);
let close_dist = filter
.mahalanobis_distance_squared(&state, &close_meas)
.unwrap();
let far_meas = Measurement::from_array([10.0, 10.0]);
let far_dist = filter
.mahalanobis_distance_squared(&state, &far_meas)
.unwrap();
assert!(close_dist < far_dist);
assert!(close_dist < 0.1); }
#[test]
fn test_standalone_functions() {
let f = TransitionMatrix::from_matrix(nalgebra::matrix![
1.0, 0.0, 1.0, 0.0_f64;
0.0, 1.0, 0.0, 1.0;
0.0, 0.0, 1.0, 0.0;
0.0, 0.0, 0.0, 1.0
]);
let q = StateCovariance::from_matrix(nalgebra::SMatrix::<f64, 4, 4>::identity().scale(0.1));
let state = KalmanState::new(
StateVector::from_array([0.0, 0.0, 5.0, 3.0]),
StateCovariance::identity(),
);
let predicted = predict(&state, &f, &q);
assert!((predicted.mean.index(0) - 5.0).abs() < 1e-10);
assert!((predicted.mean.index(1) - 3.0).abs() < 1e-10);
let h = ObservationMatrix::from_matrix(nalgebra::matrix![
1.0, 0.0, 0.0, 0.0_f64;
0.0, 1.0, 0.0, 0.0
]);
let r = MeasurementCovariance::from_matrix(nalgebra::SMatrix::<f64, 2, 2>::identity());
let measurement = Measurement::from_array([5.0, 3.0]);
let updated = update(&predicted, &measurement, &h, &r).unwrap();
assert!((updated.mean.index(0) - 5.0).abs() < 1.0);
assert!((updated.mean.index(1) - 3.0).abs() < 1.0);
}
#[test]
fn test_position_extraction() {
let state = KalmanState::new(
StateVector::from_array([10.0_f64, 20.0, 1.0, 2.0]),
StateCovariance::identity(),
);
let pos: [f64; 2] = state.position();
assert!((pos[0] - 10.0).abs() < 1e-10);
assert!((pos[1] - 20.0).abs() < 1e-10);
}
#[test]
fn test_type_safety_compiles() {
let _filter: KalmanFilter<f64, ConstantVelocity2D<f64>, PositionSensor2D<f64>, 4, 2> =
KalmanFilter::new(
ConstantVelocity2D::new(1.0, 0.99),
PositionSensor2D::new(1.0, 0.95),
);
}
}