use crate::env::{Env, RenderFrame, RenderMode, ResetResult, StepResult};
use crate::error::{Error, Result};
use crate::space::{BoundedSpace, Discrete, Space};
#[derive(Debug)]
pub struct DiscretizeAction<E>
where
E: Env<Act = Vec<f32>, ActSpace = BoundedSpace>,
{
env: E,
bins: usize,
n_dims: usize,
bin_centers: Vec<Vec<f32>>,
action_space: Discrete,
}
impl<E> DiscretizeAction<E>
where
E: Env<Act = Vec<f32>, ActSpace = BoundedSpace>,
{
pub fn new(env: E, bins: usize) -> Result<Self> {
if bins == 0 {
return Err(Error::InvalidSpace {
reason: "DiscretizeAction requires bins >= 1".to_owned(),
});
}
let space = env.action_space();
let n_dims = space.low.len();
for (i, (&lo, &hi)) in space.low.iter().zip(space.high.iter()).enumerate() {
if !lo.is_finite() || !hi.is_finite() {
return Err(Error::InvalidSpace {
reason: format!(
"DiscretizeAction requires finite bounds, dim {i}: [{lo}, {hi}]"
),
});
}
}
let bin_centers: Vec<Vec<f32>> = space
.low
.iter()
.zip(space.high.iter())
.map(|(&lo, &hi)| {
let step = (hi - lo) / bins as f32;
(0..bins)
.map(|b| step.mul_add(b as f32 + 0.5, lo))
.collect()
})
.collect();
#[allow(clippy::cast_possible_truncation)]
let total = bins
.checked_pow(n_dims as u32)
.ok_or_else(|| Error::InvalidSpace {
reason: format!("bins^n_dims overflow: {bins}^{n_dims}"),
})?;
let action_space = Discrete::new(total as u64);
Ok(Self {
env,
bins,
n_dims,
bin_centers,
action_space,
})
}
#[must_use]
pub const fn inner(&self) -> &E {
&self.env
}
#[must_use]
pub const fn inner_mut(&mut self) -> &mut E {
&mut self.env
}
#[must_use]
pub fn into_inner(self) -> E {
self.env
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn decode_action(&self, flat_index: i64) -> Vec<f32> {
let mut idx = flat_index as usize;
let mut indices = vec![0_usize; self.n_dims];
for d in (0..self.n_dims).rev() {
indices[d] = idx % self.bins;
idx /= self.bins;
}
indices
.iter()
.enumerate()
.map(|(d, &b)| self.bin_centers[d][b.min(self.bins - 1)])
.collect()
}
}
impl<E> Env for DiscretizeAction<E>
where
E: Env<Act = Vec<f32>, ActSpace = BoundedSpace>,
{
type Obs = E::Obs;
type Act = i64;
type ObsSpace = E::ObsSpace;
type ActSpace = Discrete;
fn step(&mut self, action: &i64) -> Result<StepResult<Self::Obs>> {
if !self.action_space.contains(action) {
return Err(Error::InvalidAction {
reason: format!(
"discrete action {action} not in {{0..{}}}",
self.action_space.n
),
});
}
let continuous = self.decode_action(*action);
self.env.step(&continuous)
}
fn reset(&mut self, seed: Option<u64>) -> Result<ResetResult<Self::Obs>> {
self.env.reset(seed)
}
fn render(&mut self) -> Result<RenderFrame> {
self.env.render()
}
fn close(&mut self) {
self.env.close();
}
fn observation_space(&self) -> &Self::ObsSpace {
self.env.observation_space()
}
fn action_space(&self) -> &Discrete {
&self.action_space
}
fn render_mode(&self) -> &RenderMode {
self.env.render_mode()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{PendulumConfig, PendulumEnv};
use crate::rng::create_rng;
#[test]
fn discretize_pendulum() {
let env = PendulumEnv::new(PendulumConfig::default()).unwrap();
let mut env = DiscretizeAction::new(env, 10).unwrap();
assert_eq!(env.action_space().n, 10);
let r = env.reset(Some(42)).unwrap();
assert!(env.observation_space().contains(&r.obs));
let step = env.step(&3).unwrap();
assert!(env.observation_space().contains(&step.obs));
}
#[test]
fn invalid_action_errors() {
let env = PendulumEnv::new(PendulumConfig::default()).unwrap();
let mut env = DiscretizeAction::new(env, 5).unwrap();
env.reset(Some(0)).unwrap();
assert!(env.step(&5).is_err()); assert!(env.step(&-1).is_err());
}
#[test]
fn zero_bins_errors() {
let env = PendulumEnv::new(PendulumConfig::default()).unwrap();
assert!(DiscretizeAction::new(env, 0).is_err());
}
#[test]
fn bin_centers_correct() {
let env = PendulumEnv::new(PendulumConfig::default()).unwrap();
let env = DiscretizeAction::new(env, 4).unwrap();
let decoded = env.decode_action(0);
assert!((decoded[0] - (-1.5)).abs() < 1e-5);
let decoded = env.decode_action(3);
assert!((decoded[0] - 1.5).abs() < 1e-5);
}
#[test]
fn multi_dim_discretization() {
use crate::envs::classic_control::{ContinuousMountainCarConfig, ContinuousMountainCarEnv};
let env = ContinuousMountainCarEnv::new(ContinuousMountainCarConfig::default()).unwrap();
let mut env = DiscretizeAction::new(env, 5).unwrap();
assert_eq!(env.action_space().n, 5);
env.reset(Some(0)).unwrap();
let step = env.step(&2).unwrap();
assert!(env.observation_space().contains(&step.obs));
}
#[test]
fn sample_and_step() {
let env = PendulumEnv::new(PendulumConfig::default()).unwrap();
let mut env = DiscretizeAction::new(env, 7).unwrap();
env.reset(Some(99)).unwrap();
let mut rng = create_rng(Some(42));
for _ in 0..20 {
let action = env.action_space().sample(&mut rng);
let r = env.step(&action).unwrap();
assert!(env.observation_space().contains(&r.obs));
}
}
}