use std::marker::PhantomData;
use rlevo_core::environment::{Environment, Snapshot};
use rlevo_core::evaluation::{BenchEnv, BenchError, BenchStep};
use rlevo_core::reward::ScalarReward;
#[derive(Debug)]
pub struct BenchAdapter<E, const D: usize, const SD: usize, const AD: usize> {
env: E,
_phantom: PhantomData<()>,
}
impl<E, const D: usize, const SD: usize, const AD: usize> BenchAdapter<E, D, SD, AD> {
pub const fn new(env: E) -> Self {
Self {
env,
_phantom: PhantomData,
}
}
pub const fn inner(&self) -> &E {
&self.env
}
pub fn into_inner(self) -> E {
self.env
}
}
impl<E, const D: usize, const SD: usize, const AD: usize> BenchEnv for BenchAdapter<E, D, SD, AD>
where
E: Environment<D, SD, AD, RewardType = ScalarReward>,
E::ObservationType: Clone,
{
type Observation = E::ObservationType;
type Action = E::ActionType;
fn reset(&mut self) -> Result<Self::Observation, BenchError> {
let snap = self.env.reset().map_err(BenchError::Reset)?;
Ok(snap.observation().clone())
}
fn step(&mut self, action: Self::Action) -> Result<BenchStep<Self::Observation>, BenchError> {
let snap = self.env.step(action).map_err(BenchError::Step)?;
Ok(BenchStep {
observation: snap.observation().clone(),
reward: f64::from(snap.reward().value()),
done: snap.is_done(),
})
}
}