use rand::Rng;
use std::f64::consts::PI;
use crate::base::{
error::{StateSamplingError, StateSpaceError},
space::StateSpace,
state::SO2State,
};
#[derive(Clone)]
pub struct SO2StateSpace {
pub bounds: (f64, f64),
longest_valid_segment_fraction: f64,
}
impl SO2StateSpace {
pub fn new(bounds_option: Option<(f64, f64)>) -> Result<Self, StateSpaceError> {
let bounds = bounds_option.unwrap_or((-PI, PI));
if bounds.0 >= bounds.1 {
return Err(StateSpaceError::InvalidBound {
lower: bounds.0,
upper: bounds.1,
});
}
let clamped_bounds = (bounds.0.max(-PI), bounds.1.min(PI));
Ok(Self {
bounds: clamped_bounds,
longest_valid_segment_fraction: 0.05,
})
}
pub fn get_maximum_extent(&self) -> f64 {
PI
}
pub fn set_longest_valid_segment_fraction(&mut self, fraction: f64) {
if fraction > 0.0 && fraction <= 1.0 {
self.longest_valid_segment_fraction = fraction;
} else if fraction <= 0.0 {
self.longest_valid_segment_fraction = 0.;
} else {
self.longest_valid_segment_fraction = 1.;
}
}
}
impl StateSpace for SO2StateSpace {
type StateType = SO2State;
fn distance(&self, state1: &Self::StateType, state2: &Self::StateType) -> f64 {
let mut diff = state1.value - state2.value;
diff = (diff + PI).rem_euclid(2.0 * PI) - PI;
diff.abs()
}
fn interpolate(
&self,
from: &Self::StateType,
to: &Self::StateType,
t: f64,
out_state: &mut Self::StateType,
) {
let mut diff_to_from = to.clone().normalise().value - from.clone().normalise().value;
if diff_to_from > PI {
diff_to_from -= 2.0 * PI;
} else if diff_to_from < -PI {
diff_to_from += 2.0 * PI;
}
out_state.value = from.value + diff_to_from * t;
out_state.value = out_state.normalise().value;
}
fn enforce_bounds(&self, state: &mut Self::StateType) {
state.normalise();
if self.satisfies_bounds(state) {
return;
};
let (min_b, max_b) = self.bounds;
let dist_to_min = self.distance(&SO2State { value: min_b }, state);
let dist_to_max = self.distance(&SO2State { value: max_b }, state);
if dist_to_min < dist_to_max {
state.value = min_b;
} else {
state.value = max_b;
}
}
fn satisfies_bounds(&self, state: &Self::StateType) -> bool {
let val = state.clone().normalise().value;
let (lower, upper) = self.bounds;
val >= lower && val <= upper
}
fn sample_uniform(&self, rng: &mut impl Rng) -> Result<SO2State, StateSamplingError> {
let (lower, upper) = self.bounds;
Ok(SO2State {
value: rng.random_range(lower..upper),
})
}
fn get_longest_valid_segment_length(&self) -> f64 {
self.get_maximum_extent() * self.longest_valid_segment_fraction
}
}