use nalgebra::DMatrix;
#[derive(Clone, Debug)]
pub struct ProcessNoiseConfig {
pub q_matrix: DMatrix<f64>,
pub scale_with_dt: bool,
}
#[derive(Clone, Debug)]
pub struct EKFConfig {
pub process_noise: Option<ProcessNoiseConfig>,
pub store_records: bool,
}
impl Default for EKFConfig {
fn default() -> Self {
Self {
process_noise: None,
store_records: true,
}
}
}
#[derive(Clone, Debug)]
pub struct UKFConfig {
pub process_noise: Option<ProcessNoiseConfig>,
pub state_dim: usize,
pub alpha: f64,
pub beta: f64,
pub kappa: f64,
pub store_records: bool,
}
impl Default for UKFConfig {
fn default() -> Self {
Self {
process_noise: None,
state_dim: 6,
alpha: 1e-3,
beta: 2.0,
kappa: 0.0,
store_records: true,
}
}
}
#[derive(Clone, Debug)]
pub enum BLSSolverMethod {
StackedObservationMatrix,
NormalEquations,
}
#[derive(Clone, Debug)]
pub struct ConsiderParameterConfig {
pub n_solve: usize,
pub consider_covariance: DMatrix<f64>,
}
#[derive(Clone, Debug)]
pub struct BLSConfig {
pub solver_method: BLSSolverMethod,
pub max_iterations: usize,
pub state_correction_threshold: Option<f64>,
pub cost_convergence_threshold: Option<f64>,
pub consider_params: Option<ConsiderParameterConfig>,
pub store_iteration_records: bool,
pub store_observation_residuals: bool,
}
impl Default for BLSConfig {
fn default() -> Self {
Self {
solver_method: BLSSolverMethod::NormalEquations,
max_iterations: 10,
state_correction_threshold: Some(1e-8),
cost_convergence_threshold: None,
consider_params: None,
store_iteration_records: true,
store_observation_residuals: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::DMatrix;
#[test]
fn test_bls_config_default() {
let config = BLSConfig::default();
assert_eq!(config.max_iterations, 10);
assert_eq!(config.state_correction_threshold, Some(1e-8));
assert_eq!(config.cost_convergence_threshold, None);
assert!(config.consider_params.is_none());
assert!(config.store_iteration_records);
assert!(config.store_observation_residuals);
assert!(matches!(
config.solver_method,
BLSSolverMethod::NormalEquations
));
}
#[test]
fn test_bls_solver_method_clone_debug() {
let method = BLSSolverMethod::StackedObservationMatrix;
let cloned = method.clone();
assert!(matches!(cloned, BLSSolverMethod::StackedObservationMatrix));
let _ = format!("{:?}", method);
}
#[test]
fn test_consider_parameter_config() {
let cov = DMatrix::identity(2, 2) * 100.0;
let config = ConsiderParameterConfig {
n_solve: 6,
consider_covariance: cov.clone(),
};
assert_eq!(config.n_solve, 6);
assert_eq!(config.consider_covariance.nrows(), 2);
}
}