#![forbid(unsafe_code)]
#[cfg(feature = "std")]
#[allow(unused)]
use colored::Colorize;
use minikalman::regular::builder::KalmanFilterBuilder;
use rand_distr::{Distribution, Normal};
use minikalman::prelude::*;
const NUM_STATES: usize = 3; const NUM_CONTROLS: usize = 1; const NUM_OBSERVATIONS: usize = 1;
#[allow(non_snake_case)]
fn main() {
let builder = KalmanFilterBuilder::<NUM_STATES, f32>::default();
let mut filter = builder.build();
let mut control = builder.controls().build::<NUM_CONTROLS>();
let mut measurement = builder.observations().build::<NUM_OBSERVATIONS>();
initialize_state_vector(filter.state_vector_mut());
initialize_state_transition_matrix(filter.state_transition_mut());
initialize_state_covariance_matrix(filter.estimate_covariance_mut());
initialize_control_vector(control.control_vector_mut());
initialize_control_matrix(control.control_matrix_mut());
initialize_control_covariance_matrix(control.process_noise_covariance_mut());
initialize_position_measurement_transformation_matrix(measurement.observation_matrix_mut());
initialize_position_measurement_process_noise_matrix(
measurement.measurement_noise_covariance_mut(),
);
let measurements = generate_values(100);
let measurement_noise = generate_error(100);
for (t, (m, err)) in measurements
.iter()
.copied()
.zip(measurement_noise)
.enumerate()
{
filter.predict();
filter.control(&mut control);
print_state_prediction(t, filter.state_vector());
measurement
.measurement_vector_mut()
.apply(|z| z[0] = m + err);
print_measurement(t, m, err);
filter.correct(&mut measurement);
print_state_correction(t, filter.state_vector());
}
let gravity_x = filter.state_vector();
let g_estimated = gravity_x.as_matrix().get_at(0, 2);
assert!(g_estimated > 9.0 && g_estimated < 10.0);
}
fn generate_values(n: usize) -> Vec<f32> {
let g = 9.81;
let delta_t = 1.0;
let mut s = 0.0;
let mut v = 0.0;
(0..n)
.map(|_| {
s = s + v * delta_t + g * 0.5 * delta_t * delta_t;
v += g * delta_t;
s
})
.collect()
}
fn generate_error(n: usize) -> Vec<f32> {
let normal = Normal::new(0.0, 0.1).unwrap();
let mut rng = rand::thread_rng();
(0..n)
.map(|_| 0.5 * 0.5 * normal.sample(&mut rng))
.collect()
}
fn initialize_state_vector(filter: &mut impl StateVectorMut<NUM_STATES, f32>) {
filter.as_matrix_mut().apply(|state| {
state[0] = 0 as _; state[1] = 0 as _; state[2] = 6 as _; });
}
fn initialize_state_transition_matrix(filter: &mut impl StateTransitionMatrixMut<NUM_STATES, f32>) {
filter.as_matrix_mut().apply(|a| {
const T: f32 = 1 as _;
a.set_at(0, 0, 1 as _); a.set_at(0, 1, T as _); a.set_at(0, 2, 0.5 * T * T);
a.set_at(1, 0, 0 as _); a.set_at(1, 1, 1 as _); a.set_at(1, 2, T as _);
a.set_at(2, 0, 0 as _); a.set_at(2, 1, 0 as _); a.set_at(2, 2, 1 as _); });
}
fn initialize_state_covariance_matrix(filter: &mut impl EstimateCovarianceMatrix<NUM_STATES, f32>) {
filter.as_matrix_mut().apply(|p| {
p.set_at(0, 0, 0.1 as _); p.set_at(0, 1, 0 as _); p.set_at(0, 2, 0 as _);
p.set_at(1, 1, 1 as _); p.set_at(1, 2, 0 as _);
p.set_at(2, 2, 1 as _); });
}
fn initialize_control_vector(filter: &mut impl ControlVectorMut<NUM_CONTROLS, f32>) {
filter.as_matrix_mut().apply(|state| {
state[0] = 0.0 as _; });
}
fn initialize_control_matrix(filter: &mut impl ControlMatrixMut<NUM_STATES, NUM_CONTROLS, f32>) {
filter.as_matrix_mut().apply(|mat| {
mat[0] = 0.0;
mat[1] = 0.0;
mat[2] = 1.0;
});
}
fn initialize_control_covariance_matrix(
filter: &mut impl ControlProcessNoiseCovarianceMatrixMut<NUM_CONTROLS, f32>,
) {
filter.as_matrix_mut().apply(|mat| {
mat[0] = 1.0; });
}
fn initialize_position_measurement_transformation_matrix(
measurement: &mut impl ObservationMatrixMut<NUM_OBSERVATIONS, NUM_STATES, f32>,
) {
measurement.as_matrix_mut().apply(|h| {
h.set_at(0, 0, 1 as _); h.set_at(0, 1, 0 as _); h.set_at(0, 2, 0 as _); });
}
fn initialize_position_measurement_process_noise_matrix(
measurement: &mut impl MeasurementNoiseCovarianceMatrix<NUM_OBSERVATIONS, f32>,
) {
measurement.as_matrix_mut().apply(|r| {
r.set_at(0, 0, 0.5 as _); });
}
#[allow(unused)]
fn print_state_prediction<T>(t: usize, x: T)
where
T: AsRef<[f32]>,
{
let x = x.as_ref();
#[cfg(feature = "std")]
println!(
"At t = {}, predicted state: s = {}, v = {}, a = {}",
format!("{}", t).bright_white(),
format!("{} m", x[0]).magenta(),
format!("{} m/s", x[1]).magenta(),
format!("{} m/s²", x[2]).magenta(),
);
}
#[allow(unused)]
fn print_state_correction<T>(t: usize, x: T)
where
T: AsRef<[f32]>,
{
let x = x.as_ref();
#[cfg(feature = "std")]
println!(
"At t = {}, corrected state: s = {}, v = {}, a = {}",
format!("{}", t).bright_white(),
format!("{} m", x[0]).yellow(),
format!("{} m/s", x[1]).yellow(),
format!("{} m/s²", x[2]).yellow(),
);
}
#[allow(unused)]
fn print_measurement(t: usize, real: f32, error: f32) {
#[cfg(feature = "std")]
println!(
"At t = {}, measurement: s = {}, noise ε = {}",
format!("{}", t).bright_white(),
format!("{} m", real).green(),
format!("{} m", error).blue()
);
}