use rand::RngCore;
use std::{any::Any, clone::Clone};
use crate::base::{error::StateSamplingError, space::StateSpace, state::State};
pub trait DynCloneAnyStateSpace {
fn clone_box(&self) -> Box<dyn AnyStateSpace>;
}
impl<T> DynCloneAnyStateSpace for T
where
T: AnyStateSpace + Clone + 'static,
{
fn clone_box(&self) -> Box<dyn AnyStateSpace> {
Box::new(self.clone())
}
}
impl Clone for Box<dyn AnyStateSpace> {
fn clone(&self) -> Self {
self.clone_box()
}
}
pub trait AnyStateSpace: DynCloneAnyStateSpace {
fn distance_dyn(&self, state1: &dyn State, state2: &dyn State) -> f64;
fn interpolate_dyn(&self, from: &dyn State, to: &dyn State, t: f64, state: &mut dyn State);
fn enforce_bounds_dyn(&self, state: &mut dyn State);
fn satisfies_bounds_dyn(&self, state: &dyn State) -> bool;
fn sample_uniform_dyn(
&self,
rng: &mut dyn RngCore,
) -> Result<Box<dyn State>, StateSamplingError>;
fn get_longest_valid_segment_length_dyn(&self) -> f64;
}
impl<T: StateSpace + Clone + 'static> AnyStateSpace for T
where
T::StateType: 'static,
{
fn distance_dyn(&self, state1: &dyn State, state2: &dyn State) -> f64 {
let s1 = (state1 as &dyn Any).downcast_ref::<T::StateType>().unwrap();
let s2 = (state2 as &dyn Any).downcast_ref::<T::StateType>().unwrap();
self.distance(s1, s2)
}
fn interpolate_dyn(&self, from: &dyn State, to: &dyn State, t: f64, state: &mut dyn State) {
let from_s = (from as &dyn Any).downcast_ref::<T::StateType>().unwrap();
let to_s = (to as &dyn Any).downcast_ref::<T::StateType>().unwrap();
let state_s = (state as &mut dyn Any)
.downcast_mut::<T::StateType>()
.unwrap();
self.interpolate(from_s, to_s, t, state_s);
}
fn sample_uniform_dyn(
&self,
rng: &mut dyn RngCore,
) -> Result<Box<dyn State>, StateSamplingError> {
struct RngWrapper<'a>(&'a mut dyn RngCore);
impl<'a> RngCore for RngWrapper<'a> {
fn next_u32(&mut self) -> u32 {
self.0.next_u32()
}
fn next_u64(&mut self) -> u64 {
self.0.next_u64()
}
fn fill_bytes(&mut self, dest: &mut [u8]) {
self.0.fill_bytes(dest)
}
}
let mut wrapper = RngWrapper(rng);
let concrete_state = self.sample_uniform(&mut wrapper)?;
Ok(Box::new(concrete_state))
}
fn enforce_bounds_dyn(&self, state: &mut dyn State) {
let state_s = (state as &mut dyn Any)
.downcast_mut::<T::StateType>()
.unwrap();
self.enforce_bounds(state_s);
}
fn satisfies_bounds_dyn(&self, state: &dyn State) -> bool {
let state_s = (state as &dyn Any).downcast_ref::<T::StateType>().unwrap();
self.satisfies_bounds(state_s)
}
fn get_longest_valid_segment_length_dyn(&self) -> f64 {
self.get_longest_valid_segment_length()
}
}