use std::collections::{HashMap, HashSet};
use apex_manifolds::{LieGroup, Tangent};
use faer::Mat;
use nalgebra::DVector;
#[derive(Clone, Debug)]
pub struct Variable<M: LieGroup> {
pub value: M,
pub fixed_indices: HashSet<usize>,
pub bounds: HashMap<usize, (f64, f64)>,
pub covariance: Option<faer::Mat<f64>>,
}
impl<M> Variable<M>
where
M: LieGroup + Clone + 'static,
M::TangentVector: Tangent<M>,
{
pub fn new(value: M) -> Self {
Variable {
value,
fixed_indices: HashSet::new(),
bounds: HashMap::new(),
covariance: None,
}
}
pub fn set_value(&mut self, value: M) {
self.value = value;
}
pub fn get_size(&self) -> usize {
self.value.tangent_dim()
}
pub fn plus(&self, tangent: &M::TangentVector) -> M {
self.value.plus(tangent, None, None)
}
pub fn minus(&self, other: &Self) -> M::TangentVector {
self.value.minus(&other.value, None, None)
}
pub fn get_covariance(&self) -> Option<&Mat<f64>> {
self.covariance.as_ref()
}
pub fn set_covariance(&mut self, cov: Mat<f64>) {
self.covariance = Some(cov);
}
pub fn clear_covariance(&mut self) {
self.covariance = None;
}
}
use apex_manifolds::rn::Rn;
impl Variable<Rn> {
pub fn to_vector(&self) -> DVector<f64> {
self.value.data().clone()
}
pub fn from_vector(values: DVector<f64>) -> Self {
Self::new(Rn::new(values))
}
pub fn update_variable(&mut self, mut tangent_delta: DVector<f64>) {
for (&idx, &(lower, upper)) in &self.bounds {
tangent_delta[idx] = tangent_delta[idx].max(lower).min(upper);
}
for &index_to_fix in &self.fixed_indices {
tangent_delta[index_to_fix] = self.value.data()[index_to_fix];
}
self.value = Rn::new(tangent_delta);
}
}
#[cfg(test)]
mod tests {
use super::*;
use apex_manifolds::{rn::Rn, se2::SE2, se3::SE3, so2::SO2, so3::SO3};
use nalgebra::{DVector, Quaternion, Vector3};
use std;
#[test]
fn test_variable_creation_rn() {
let vec_data = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let rn_value = Rn::new(vec_data);
let variable = Variable::new(rn_value);
assert_eq!(variable.get_size(), 5);
assert!(variable.fixed_indices.is_empty());
assert!(variable.bounds.is_empty());
}
#[test]
fn test_variable_creation_se2() {
let se2 = SE2::from_xy_angle(1.0, 2.0, 0.5);
let variable = Variable::new(se2);
assert_eq!(variable.get_size(), SE2::DOF);
assert!(variable.fixed_indices.is_empty());
assert!(variable.bounds.is_empty());
}
#[test]
fn test_variable_creation_se3() {
let se3 = SE3::from_translation_quaternion(
Vector3::new(1.0, 2.0, 3.0),
Quaternion::new(1.0, 0.0, 0.0, 0.0),
);
let variable = Variable::new(se3);
assert_eq!(variable.get_size(), SE3::DOF);
assert!(variable.fixed_indices.is_empty());
assert!(variable.bounds.is_empty());
}
#[test]
fn test_variable_creation_so2() {
let so2 = SO2::from_angle(0.5);
let variable = Variable::new(so2);
assert_eq!(variable.get_size(), SO2::DOF);
assert!(variable.fixed_indices.is_empty());
assert!(variable.bounds.is_empty());
}
#[test]
fn test_variable_creation_so3() {
let so3 = SO3::from_euler_angles(0.1, 0.2, 0.3);
let variable = Variable::new(so3);
assert_eq!(variable.get_size(), SO3::DOF);
assert!(variable.fixed_indices.is_empty());
assert!(variable.bounds.is_empty());
}
#[test]
fn test_variable_set_value() {
let initial_vec = DVector::from_vec(vec![1.0, 2.0, 3.0]);
let mut variable = Variable::new(Rn::new(initial_vec));
let new_vec = DVector::from_vec(vec![4.0, 5.0, 6.0, 7.0]);
variable.set_value(Rn::new(new_vec));
assert_eq!(variable.get_size(), 4);
let se2_initial = SE2::from_xy_angle(0.0, 0.0, 0.0);
let mut se2_variable = Variable::new(se2_initial);
let se2_new = SE2::from_xy_angle(1.0, 2.0, std::f64::consts::PI / 4.0);
se2_variable.set_value(se2_new);
assert_eq!(se2_variable.get_size(), SE2::DOF);
}
#[test]
fn test_variable_plus_minus_operations() {
let se2_1 = SE2::from_xy_angle(2.0, 3.0, std::f64::consts::PI / 2.0);
let se2_2 = SE2::from_xy_angle(1.0, 1.0, std::f64::consts::PI / 4.0);
let var1 = Variable::new(se2_1);
let var2 = Variable::new(se2_2);
let diff_tangent = var1.minus(&var2);
let var2_updated = var2.plus(&diff_tangent);
let final_diff = var1.minus(&Variable::new(var2_updated));
assert!(DVector::from(final_diff).norm() < 1e-10);
}
#[test]
fn test_variable_rn_plus_minus_operations() {
let rn_1 = Rn::new(DVector::from_vec(vec![1.0, 2.0, 3.0]));
let rn_2 = Rn::new(DVector::from_vec(vec![4.0, 5.0, 6.0]));
let var1 = Variable::new(rn_1);
let var2 = Variable::new(rn_2);
let diff_tangent = var1.minus(&var2);
assert_eq!(
diff_tangent.to_vector(),
DVector::from_vec(vec![-3.0, -3.0, -3.0])
);
let var2_updated = var2.plus(&diff_tangent);
assert_eq!(var2_updated.data(), &DVector::from_vec(vec![1.0, 2.0, 3.0]));
let final_diff = var1.minus(&Variable::new(var2_updated));
assert!(final_diff.to_vector().norm() < 1e-10);
}
#[test]
fn test_variable_update_with_bounds() {
let vec_data = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let mut variable = Variable::new(Rn::new(vec_data));
variable.bounds.insert(0, (-1.0, 1.0));
variable.bounds.insert(2, (0.0, 5.0));
let new_values = DVector::from_vec(vec![-5.0, 10.0, -3.0, 20.0, 30.0, 40.0]);
variable.update_variable(new_values);
let result_vec = variable.to_vector();
assert!(result_vec.len() == 6);
}
#[test]
fn test_variable_update_with_fixed_indices() {
let vec_data = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]);
let mut variable = Variable::new(Rn::new(vec_data.clone()));
variable.fixed_indices.insert(1);
variable.fixed_indices.insert(4);
let delta_values = DVector::from_vec(vec![9.0, 18.0, 27.0, 36.0, 45.0, 54.0, 63.0, 72.0]);
variable.update_variable(delta_values);
let result_vec = variable.to_vector();
assert_eq!(result_vec[1], 2.0);
assert_eq!(result_vec[4], 5.0);
assert!(result_vec.len() == 8);
}
#[test]
fn test_variable_combined_bounds_and_fixed() {
let vec_data = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0]);
let mut variable = Variable::new(Rn::new(vec_data.clone()));
variable.bounds.insert(0, (-2.0, 2.0));
variable.bounds.insert(3, (-1.0, 1.0));
variable.fixed_indices.insert(1);
variable.fixed_indices.insert(5);
let delta_values = DVector::from_vec(vec![-5.0, 100.0, 30.0, 10.0, 50.0, 600.0, 70.0]);
variable.update_variable(delta_values);
let result = variable.to_vector();
assert_eq!(result[1], 2.0);
assert_eq!(result[5], 6.0);
assert!(result.len() == 7);
}
#[test]
fn test_variable_type_safety() {
let se2_var = Variable::new(SE2::from_xy_angle(1.0, 2.0, 0.5));
let se3_var = Variable::new(SE3::from_translation_quaternion(
Vector3::new(1.0, 2.0, 3.0),
Quaternion::new(1.0, 0.0, 0.0, 0.0),
));
let so2_var = Variable::new(SO2::from_angle(0.5));
let so3_var = Variable::new(SO3::from_euler_angles(0.1, 0.2, 0.3));
let rn_var = Variable::new(Rn::new(DVector::from_vec(vec![1.0, 2.0, 3.0])));
assert_eq!(se2_var.get_size(), SE2::DOF);
assert_eq!(se3_var.get_size(), SE3::DOF);
assert_eq!(so2_var.get_size(), SO2::DOF);
assert_eq!(so3_var.get_size(), SO3::DOF);
assert_eq!(rn_var.get_size(), 3);
}
#[test]
fn test_variable_vector_conversion_roundtrip() {
let original_data = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0]);
let rn_var = Variable::new(Rn::new(original_data.clone()));
let vec_repr = rn_var.to_vector();
assert_eq!(vec_repr, original_data);
let reconstructed_var = Variable::<Rn>::from_vector(vec_repr);
assert_eq!(reconstructed_var.to_vector(), original_data);
}
#[test]
fn test_variable_manifold_operations_consistency() {
let rn_initial = Rn::new(DVector::from_vec(vec![1.0, 2.0, 3.0]));
let mut rn_var = Variable::new(rn_initial);
let rn_new_values = DVector::from_vec(vec![2.0, 3.0, 4.0]);
rn_var.update_variable(rn_new_values);
let rn_result = rn_var.to_vector();
assert_eq!(rn_result, DVector::from_vec(vec![2.0, 3.0, 4.0]));
let se2_1 = SE2::from_xy_angle(2.0, 3.0, std::f64::consts::PI / 2.0);
let se2_2 = SE2::from_xy_angle(1.0, 1.0, std::f64::consts::PI / 4.0);
let var1 = Variable::new(se2_1);
let var2 = Variable::new(se2_2);
let diff_tangent = var1.minus(&var2);
let var2_updated = var2.plus(&diff_tangent);
let final_diff = var1.minus(&Variable::new(var2_updated));
assert!(DVector::from(final_diff).norm() < 1e-10);
}
#[test]
fn test_variable_constraints_interaction() {
let rn_data = DVector::from_vec(vec![0.0, 0.0, 0.0, 0.0, 0.0]);
let mut rn_var = Variable::new(Rn::new(rn_data));
rn_var.bounds.insert(0, (-1.0, 1.0));
rn_var.bounds.insert(2, (-10.0, 10.0));
rn_var.fixed_indices.insert(1);
rn_var.fixed_indices.insert(4);
let large_delta = DVector::from_vec(vec![5.0, 100.0, 15.0, 20.0, 200.0]);
rn_var.update_variable(large_delta);
let result = rn_var.to_vector();
assert_eq!(result[0], 1.0);
assert_eq!(result[1], 0.0);
assert_eq!(result[2], 10.0);
assert_eq!(result[3], 20.0);
assert_eq!(result[4], 0.0);
}
}