use crate::base::{
error::StateSpaceError,
space::{AnyStateSpace, CompoundStateSpace, RealVectorStateSpace, SO2StateSpace, StateSpace},
state::SE2State,
};
#[derive(Clone)]
pub struct SE2StateSpace(pub CompoundStateSpace);
impl SE2StateSpace {
pub fn new(
weight: f64,
bounds_option: Option<Vec<(f64, f64)>>,
) -> Result<Self, StateSpaceError> {
let (r2, so2) = match bounds_option {
Some(bounds) => {
if bounds.len() != 3 {
return Err(StateSpaceError::DimensionMismatch {
expected: 3,
found: bounds.len(),
});
} else {
(
RealVectorStateSpace::new(2, Some(vec![bounds[0], bounds[1]]))?,
SO2StateSpace::new(Some(bounds[2]))?,
)
}
}
None => (
RealVectorStateSpace::new(2, None)?,
SO2StateSpace::new(None)?,
),
};
let compound_space =
CompoundStateSpace::new(vec![Box::new(r2), Box::new(so2)], vec![1.0, weight]);
Ok(SE2StateSpace(compound_space))
}
}
impl StateSpace for SE2StateSpace {
type StateType = SE2State;
fn distance(&self, state1: &Self::StateType, state2: &Self::StateType) -> f64 {
self.0.distance_dyn(&state1.0, &state2.0)
}
fn interpolate(
&self,
from: &Self::StateType,
to: &Self::StateType,
t: f64,
state: &mut Self::StateType,
) {
self.0.interpolate_dyn(&from.0, &to.0, t, &mut state.0);
}
fn enforce_bounds(&self, state: &mut Self::StateType) {
self.0.enforce_bounds_dyn(&mut state.0);
}
fn satisfies_bounds(&self, state: &Self::StateType) -> bool {
self.0.satisfies_bounds_dyn(&state.0)
}
fn sample_uniform(
&self,
rng: &mut impl rand::Rng,
) -> Result<Self::StateType, crate::base::error::StateSamplingError> {
let compound_state = self.0.sample_uniform(rng)?;
Ok(SE2State(compound_state))
}
fn get_longest_valid_segment_length(&self) -> f64 {
self.0.get_longest_valid_segment_length_dyn()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::base::state::SE2State;
use rand::rng;
use std::f64::consts::PI;
#[test]
fn test_se2_space_creation() {
assert!(SE2StateSpace::new(0.5, None).is_ok());
let bounds = vec![(-1.0, 1.0), (-1.0, 1.0), (-PI, PI)];
assert!(SE2StateSpace::new(1.0, Some(bounds)).is_ok());
}
#[test]
fn test_se2_space_creation_invalid_bounds() {
let bounds = vec![(-1.0, 1.0)];
let result = SE2StateSpace::new(1.0, Some(bounds));
assert!(result.is_err());
match result {
Err(StateSpaceError::DimensionMismatch { expected, found }) => {
assert_eq!(expected, 3);
assert_eq!(found, 1);
}
_ => panic!("Expected DimensionMismatch error"),
}
}
#[test]
fn test_distance() {
let space = SE2StateSpace::new(0.5, None).unwrap();
let state1 = SE2State::new(0.0, 0.0, 0.0);
let state2 = SE2State::new(3.0, 4.0, PI);
let expected_dist_r2: f64 = 5.0;
let expected_dist_so2 = PI;
let expected_total_dist =
(expected_dist_r2.powi(2) + (0.5 * expected_dist_so2).powi(2)).sqrt();
assert!((space.distance(&state1, &state2) - expected_total_dist).abs() < 1e-9);
}
#[test]
fn test_interpolate() {
let space = SE2StateSpace::new(1.0, None).unwrap();
let state1 = SE2State::new(0.0, 0.0, 0.0);
let state2 = SE2State::new(10.0, -10.0, PI / 2.0);
let mut interpolated_state = SE2State::new(0.0, 0.0, 0.0);
space.interpolate(&state1, &state2, 0.5, &mut interpolated_state);
assert_eq!(interpolated_state.get_x(), 5.0);
assert_eq!(interpolated_state.get_y(), -5.0);
assert!((interpolated_state.get_yaw() - PI / 4.0).abs() < 1e-9);
}
#[test]
fn test_bounds() {
let bounds = vec![(-1.0, 1.0), (-2.0, 2.0), (-PI / 2.0, PI / 2.0)];
let space = SE2StateSpace::new(1.0, Some(bounds)).unwrap();
let out_of_bounds_x = SE2State::new(2.0, 0.0, 0.0);
assert!(!space.satisfies_bounds(&out_of_bounds_x));
let out_of_bounds_y = SE2State::new(0.0, -3.0, 0.0);
assert!(!space.satisfies_bounds(&out_of_bounds_y));
let out_of_bounds_yaw = SE2State::new(0.0, 0.0, 0.8 * PI);
assert!(!space.satisfies_bounds(&out_of_bounds_yaw));
let mut out_of_bounds_state = SE2State::new(2.0, -3.0, 0.8 * PI);
let in_bounds_state = SE2State::new(0.5, 1.5, 0.0);
assert!(!space.satisfies_bounds(&out_of_bounds_state));
assert!(space.satisfies_bounds(&in_bounds_state));
space.enforce_bounds(&mut out_of_bounds_state);
assert_eq!(out_of_bounds_state.get_x(), 1.0);
assert_eq!(out_of_bounds_state.get_y(), -2.0);
assert!((out_of_bounds_state.get_yaw() - (PI / 2.0)).abs() < 1e-9);
}
#[test]
fn test_sample_uniform() {
let bounds = vec![(-1.0, 1.0), (5.0, 10.0), (0.0, PI)];
let space = SE2StateSpace::new(1.0, Some(bounds)).unwrap();
let mut rng = rng();
for _ in 0..100 {
let sample = space.sample_uniform(&mut rng).unwrap();
assert!(space.satisfies_bounds(&sample));
}
}
}