use rand::Rng;
use crate::base::{
error::StateSamplingError,
space::{AnyStateSpace, StateSpace},
state::CompoundState,
};
#[derive(Clone)]
pub struct CompoundStateSpace {
pub subspaces: Vec<Box<dyn AnyStateSpace>>,
pub weights: Vec<f64>,
}
impl CompoundStateSpace {
pub fn new(subspaces: Vec<Box<dyn AnyStateSpace>>, weights: Vec<f64>) -> Self {
assert_eq!(
subspaces.len(),
weights.len(),
"Number of subspaces must match number of weights."
);
Self { subspaces, weights }
}
}
impl StateSpace for CompoundStateSpace {
type StateType = CompoundState;
fn distance(&self, state1: &Self::StateType, state2: &Self::StateType) -> f64 {
let mut total_dist_sq = 0.0;
for i in 0..self.subspaces.len() {
let component_dist =
self.subspaces[i].distance_dyn(&*state1.components[i], &*state2.components[i]);
total_dist_sq += (component_dist * self.weights[i]).powi(2);
}
total_dist_sq.sqrt()
}
fn interpolate(
&self,
from: &Self::StateType,
to: &Self::StateType,
t: f64,
out_state: &mut Self::StateType,
) {
for i in 0..self.subspaces.len() {
self.subspaces[i].interpolate_dyn(
&*from.components[i],
&*to.components[i],
t,
&mut *out_state.components[i],
);
}
}
fn sample_uniform(&self, rng: &mut impl Rng) -> Result<Self::StateType, StateSamplingError> {
let mut components = Vec::with_capacity(self.subspaces.len());
for subspace in &self.subspaces {
let component_state = subspace.sample_uniform_dyn(rng)?;
components.push(component_state);
}
Ok(CompoundState { components })
}
fn enforce_bounds(&self, state: &mut Self::StateType) {
for i in 0..self.subspaces.len() {
self.subspaces[i].enforce_bounds_dyn(&mut *state.components[i]);
}
}
fn satisfies_bounds(&self, state: &Self::StateType) -> bool {
for i in 0..self.subspaces.len() {
if !self.subspaces[i].satisfies_bounds_dyn(&*state.components[i]) {
return false;
}
}
true
}
fn get_longest_valid_segment_length(&self) -> f64 {
let mut total_longest_valid_segment_length_sq = 0.0;
for i in 0..self.subspaces.len() {
let component_longest_valid_segment_length =
self.subspaces[i].get_longest_valid_segment_length_dyn();
total_longest_valid_segment_length_sq +=
(component_longest_valid_segment_length * self.weights[i]).powi(2);
}
total_longest_valid_segment_length_sq.sqrt()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::base::{
space::{RealVectorStateSpace, SO2StateSpace},
state::{RealVectorState, SO2State},
};
use std::f64::consts::PI;
#[test]
fn test_single_subspace() {
let rvs = RealVectorStateSpace::new(2, Some(vec![(-5.0, 5.0), (-5.0, 5.0)])).unwrap();
let space = CompoundStateSpace::new(vec![Box::new(rvs)], vec![2.0]);
let state1 = CompoundState {
components: vec![Box::new(RealVectorState::new(vec![0.0, 0.0]))],
};
let state2 = CompoundState {
components: vec![Box::new(RealVectorState::new(vec![3.0, 4.0]))],
};
assert!((space.distance(&state1, &state2) - 10.0).abs() < 1e-9);
}
#[test]
fn test_multiple_subspaces() {
let rvs1 = RealVectorStateSpace::new(1, Some(vec![(-10.0, 10.0)])).unwrap();
let so2 = SO2StateSpace::new(None).unwrap();
let rvs2 = RealVectorStateSpace::new(1, Some(vec![(-10.0, 10.0)])).unwrap();
let space = CompoundStateSpace::new(
vec![Box::new(rvs1), Box::new(so2), Box::new(rvs2)],
vec![1.0, 0.5, 1.0],
);
let state1 = CompoundState {
components: vec![
Box::new(RealVectorState::new(vec![1.0])),
Box::new(SO2State::new(0.0)),
Box::new(RealVectorState::new(vec![5.0])),
],
};
let state2 = CompoundState {
components: vec![
Box::new(RealVectorState::new(vec![2.0])),
Box::new(SO2State::new(PI)),
Box::new(RealVectorState::new(vec![1.0])),
],
};
let dist1_sq = (1.0f64 * 1.0).powi(2);
let dist2_sq = (PI * 0.5).powi(2);
let dist3_sq = (4.0f64 * 1.0).powi(2);
let expected_dist = (dist1_sq + dist2_sq + dist3_sq).sqrt();
assert!((space.distance(&state1, &state2) - expected_dist).abs() < 1e-9);
let mut rng = rand::rng();
let sample = space.sample_uniform(&mut rng);
assert!(sample.is_ok());
assert_eq!(sample.unwrap().components.len(), 3);
}
#[test]
#[should_panic(expected = "Number of subspaces must match number of weights.")]
fn test_mismatched_subspaces_and_weights() {
let rvs = RealVectorStateSpace::new(2, Some(vec![(-1.0, 1.0), (-1.0, 1.0)])).unwrap();
CompoundStateSpace::new(vec![Box::new(rvs)], vec![1.0, 1.0]);
}
}