use std::collections::HashMap;
use std::sync::{LazyLock, RwLock};
use crate::env::RenderMode;
use crate::env::{RenderFrame, ResetResult, StepResult};
use crate::error::{Error, Result};
pub struct EnvSpec {
pub id: String,
pub max_episode_steps: Option<u64>,
pub reward_threshold: Option<f64>,
factory: Box<dyn Fn(RenderMode) -> Result<Box<dyn DynEnv>> + Send + Sync>,
}
impl std::fmt::Debug for EnvSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnvSpec")
.field("id", &self.id)
.field("max_episode_steps", &self.max_episode_steps)
.field("reward_threshold", &self.reward_threshold)
.finish_non_exhaustive()
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum DynValue {
Continuous(Vec<f32>),
Discrete(i64),
}
pub trait DynEnv: std::fmt::Debug {
fn step_dyn(&mut self, action: &DynValue) -> Result<StepResult<DynValue>>;
fn reset_dyn(&mut self, seed: Option<u64>) -> Result<ResetResult<DynValue>>;
fn render_dyn(&mut self) -> Result<RenderFrame>;
fn close_dyn(&mut self);
}
use crate::env::Env;
impl<E> DynEnv for E
where
E: Env + std::fmt::Debug,
E::Obs: Into<DynValue>,
E::Act: TryFrom<DynValue, Error = Error>,
{
fn step_dyn(&mut self, action: &DynValue) -> Result<StepResult<DynValue>> {
let act = E::Act::try_from(action.clone())?;
let r = self.step(&act)?;
Ok(StepResult {
obs: r.obs.into(),
reward: r.reward,
terminated: r.terminated,
truncated: r.truncated,
info: r.info,
})
}
fn reset_dyn(&mut self, seed: Option<u64>) -> Result<ResetResult<DynValue>> {
let r = self.reset(seed)?;
Ok(ResetResult {
obs: r.obs.into(),
info: r.info,
})
}
fn render_dyn(&mut self) -> Result<RenderFrame> {
self.render()
}
fn close_dyn(&mut self) {
self.close();
}
}
impl From<Vec<f32>> for DynValue {
fn from(v: Vec<f32>) -> Self {
Self::Continuous(v)
}
}
impl From<i64> for DynValue {
fn from(v: i64) -> Self {
Self::Discrete(v)
}
}
impl TryFrom<DynValue> for Vec<f32> {
type Error = Error;
fn try_from(v: DynValue) -> Result<Self> {
match v {
DynValue::Continuous(c) => Ok(c),
other @ DynValue::Discrete(_) => Err(Error::TypeMismatch {
reason: format!("expected Continuous, got {other:?}"),
}),
}
}
}
impl TryFrom<DynValue> for i64 {
type Error = Error;
fn try_from(v: DynValue) -> Result<Self> {
match v {
DynValue::Discrete(d) => Ok(d),
other @ DynValue::Continuous(_) => Err(Error::TypeMismatch {
reason: format!("expected Discrete, got {other:?}"),
}),
}
}
}
static REGISTRY: LazyLock<RwLock<HashMap<String, EnvSpec>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
#[allow(clippy::unwrap_in_result)] pub fn register(spec: EnvSpec) -> Result<()> {
let mut reg = REGISTRY.write().expect("registry lock poisoned");
if reg.contains_key(&spec.id) {
return Err(Error::AlreadyRegistered { id: spec.id });
}
reg.insert(spec.id.clone(), spec);
Ok(())
}
pub fn make(id: &str) -> Result<Box<dyn DynEnv>> {
make_with(id, RenderMode::None)
}
#[allow(clippy::unwrap_in_result)] pub fn make_with(id: &str, render_mode: RenderMode) -> Result<Box<dyn DynEnv>> {
let reg = REGISTRY.read().expect("registry lock poisoned");
let spec = reg
.get(id)
.ok_or_else(|| Error::NotRegistered { id: id.to_owned() })?;
(spec.factory)(render_mode)
}
#[must_use]
pub fn spec(id: &str) -> Option<String> {
let reg = REGISTRY.read().expect("registry lock poisoned");
reg.get(id).map(|s| format!("{s:?}"))
}
#[must_use]
pub fn list_registered() -> Vec<String> {
let reg = REGISTRY.read().expect("registry lock poisoned");
let mut ids: Vec<String> = reg.keys().cloned().collect();
ids.sort();
ids
}
pub fn register_builtins() {
let _ = register(EnvSpec {
id: "CartPole-v1".to_owned(),
max_episode_steps: Some(500),
reward_threshold: Some(475.0),
factory: Box::new(|rm| {
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
let env = CartPoleEnv::new(CartPoleConfig {
render_mode: rm,
..CartPoleConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "MountainCar-v0".to_owned(),
max_episode_steps: Some(200),
reward_threshold: Some(-110.0),
factory: Box::new(|rm| {
use crate::envs::classic_control::{MountainCarConfig, MountainCarEnv};
let env = MountainCarEnv::new(MountainCarConfig {
render_mode: rm,
..MountainCarConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "Pendulum-v1".to_owned(),
max_episode_steps: Some(200),
reward_threshold: None,
factory: Box::new(|rm| {
use crate::envs::classic_control::{PendulumConfig, PendulumEnv};
let env = PendulumEnv::new(PendulumConfig {
render_mode: rm,
..PendulumConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "MountainCarContinuous-v0".to_owned(),
max_episode_steps: Some(999),
reward_threshold: Some(90.0),
factory: Box::new(|rm| {
use crate::envs::classic_control::{
ContinuousMountainCarConfig, ContinuousMountainCarEnv,
};
let env = ContinuousMountainCarEnv::new(ContinuousMountainCarConfig {
render_mode: rm,
..ContinuousMountainCarConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "Acrobot-v1".to_owned(),
max_episode_steps: Some(500),
reward_threshold: Some(-100.0),
factory: Box::new(|rm| {
use crate::envs::classic_control::{AcrobotConfig, AcrobotEnv};
let env = AcrobotEnv::new(AcrobotConfig { render_mode: rm })?;
Ok(Box::new(env))
}),
});
}
#[cfg(test)]
#[allow(clippy::panic)] mod tests {
use super::*;
fn setup() {
register_builtins();
}
#[test]
fn make_cartpole() {
setup();
let mut env = make("CartPole-v1").unwrap();
let r = env.reset_dyn(Some(42)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 4);
} else {
panic!("expected Continuous observation");
}
}
#[test]
fn make_mountain_car() {
setup();
let mut env = make("MountainCar-v0").unwrap();
let r = env.reset_dyn(Some(0)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 2);
} else {
panic!("expected Continuous observation");
}
}
#[test]
fn make_pendulum() {
setup();
let mut env = make("Pendulum-v1").unwrap();
let r = env.reset_dyn(Some(0)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 3);
} else {
panic!("expected Continuous observation");
}
}
#[test]
fn make_unknown_errors() {
setup();
assert!(make("NonExistent-v99").is_err());
}
#[test]
fn list_includes_builtins() {
setup();
let ids = list_registered();
assert!(ids.contains(&"CartPole-v1".to_owned()));
assert!(ids.contains(&"MountainCar-v0".to_owned()));
assert!(ids.contains(&"Pendulum-v1".to_owned()));
}
#[test]
fn step_through_dyn_env() {
setup();
let mut env = make("CartPole-v1").unwrap();
env.reset_dyn(Some(42)).unwrap();
let r = env.step_dyn(&DynValue::Discrete(1)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 4);
} else {
panic!("expected Continuous observation");
}
}
}