mod dyn_env;
mod dyn_value;
mod dyn_wrappers;
mod env_id;
use std::collections::HashMap;
use std::sync::{LazyLock, RwLock};
pub use dyn_env::DynEnv;
pub use dyn_value::DynValue;
use dyn_wrappers::{DynOrderEnforcing, DynTimeLimit};
pub use env_id::EnvId;
use crate::env::RenderMode;
use crate::error::{Error, Result};
pub struct EnvSpec {
pub id: String,
pub max_episode_steps: Option<u64>,
pub reward_threshold: Option<f64>,
pub order_enforce: bool,
factory: Box<dyn Fn(RenderMode) -> Result<Box<dyn DynEnv>> + Send + Sync>,
}
#[derive(Debug, Clone, PartialEq)]
pub struct EnvSpecView {
pub id: String,
pub max_episode_steps: Option<u64>,
pub reward_threshold: Option<f64>,
pub order_enforce: bool,
}
impl From<&EnvSpec> for EnvSpecView {
fn from(spec: &EnvSpec) -> Self {
Self {
id: spec.id.clone(),
max_episode_steps: spec.max_episode_steps,
reward_threshold: spec.reward_threshold,
order_enforce: spec.order_enforce,
}
}
}
impl std::fmt::Debug for EnvSpec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EnvSpec")
.field("id", &self.id)
.field("max_episode_steps", &self.max_episode_steps)
.field("reward_threshold", &self.reward_threshold)
.field("order_enforce", &self.order_enforce)
.finish_non_exhaustive()
}
}
static REGISTRY: LazyLock<RwLock<HashMap<String, EnvSpec>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
#[allow(clippy::unwrap_in_result)] pub fn register(spec: EnvSpec) -> Result<()> {
let mut reg = REGISTRY.write().expect("registry lock poisoned");
if reg.contains_key(&spec.id) {
return Err(Error::AlreadyRegistered { id: spec.id });
}
reg.insert(spec.id.clone(), spec);
Ok(())
}
pub fn make(id: &str) -> Result<Box<dyn DynEnv>> {
make_with(id, RenderMode::None)
}
#[allow(clippy::unwrap_in_result)] pub fn make_with(id: &str, render_mode: RenderMode) -> Result<Box<dyn DynEnv>> {
let reg = REGISTRY.read().expect("registry lock poisoned");
let spec = reg
.get(id)
.ok_or_else(|| Error::NotRegistered { id: id.to_owned() })?;
let mut env = (spec.factory)(render_mode)?;
if let Some(max_steps) = spec.max_episode_steps {
env = Box::new(DynTimeLimit::new(env, max_steps));
}
if spec.order_enforce {
env = Box::new(DynOrderEnforcing::new(env));
}
Ok(env)
}
pub fn make_vec(id: &str, num_envs: usize) -> Result<DynVectorEnv> {
make_vec_with(id, num_envs, RenderMode::None)
}
pub fn make_vec_with(id: &str, num_envs: usize, render_mode: RenderMode) -> Result<DynVectorEnv> {
if num_envs == 0 {
return Err(Error::InvalidSpace {
reason: "make_vec requires at least 1 environment".to_owned(),
});
}
let envs: Result<Vec<_>> = (0..num_envs).map(|_| make_with(id, render_mode)).collect();
Ok(DynVectorEnv::new(envs?))
}
#[derive(Debug)]
pub struct DynVectorEnv {
envs: Vec<Box<dyn DynEnv>>,
needs_reset: Vec<bool>,
}
impl DynVectorEnv {
fn new(envs: Vec<Box<dyn DynEnv>>) -> Self {
let n = envs.len();
Self {
envs,
needs_reset: vec![false; n],
}
}
#[must_use]
pub fn num_envs(&self) -> usize {
self.envs.len()
}
pub fn reset(&mut self, seed: Option<u64>) -> Result<crate::vector::VecResetResult<DynValue>> {
let mut obs = Vec::with_capacity(self.envs.len());
let mut infos = Vec::with_capacity(self.envs.len());
for (i, env) in self.envs.iter_mut().enumerate() {
let s = seed.map(|s| s + i as u64);
let r = env.reset_dyn(s)?;
obs.push(r.obs);
infos.push(r.info);
}
self.needs_reset.fill(false);
Ok(crate::vector::VecResetResult { obs, infos })
}
pub fn step(&mut self, actions: &[DynValue]) -> Result<crate::vector::VecStepResult<DynValue>> {
if actions.len() != self.envs.len() {
return Err(Error::InvalidAction {
reason: format!(
"expected {} actions, got {}",
self.envs.len(),
actions.len()
),
});
}
let n = self.envs.len();
let mut obs = Vec::with_capacity(n);
let mut rewards = Vec::with_capacity(n);
let mut terminated = Vec::with_capacity(n);
let mut truncated = Vec::with_capacity(n);
let mut infos = Vec::with_capacity(n);
for (i, (env, action)) in self.envs.iter_mut().zip(actions.iter()).enumerate() {
if self.needs_reset[i] {
env.reset_dyn(None)?;
}
let r = env.step_dyn(action)?;
let done = r.terminated || r.truncated;
self.needs_reset[i] = done;
obs.push(r.obs);
rewards.push(r.reward);
terminated.push(r.terminated);
truncated.push(r.truncated);
infos.push(r.info);
}
Ok(crate::vector::VecStepResult {
obs,
rewards,
terminated,
truncated,
infos,
})
}
pub fn render(&mut self) -> Result<Vec<crate::env::RenderFrame>> {
self.envs.iter_mut().map(|e| e.render_dyn()).collect()
}
pub fn close(&mut self) {
for env in &mut self.envs {
env.close_dyn();
}
}
}
#[must_use]
pub fn spec(id: &str) -> Option<EnvSpecView> {
let reg = REGISTRY.read().expect("registry lock poisoned");
reg.get(id).map(EnvSpecView::from)
}
#[must_use]
pub fn list_registered() -> Vec<String> {
let reg = REGISTRY.read().expect("registry lock poisoned");
let mut ids: Vec<String> = reg.keys().cloned().collect();
ids.sort();
ids
}
pub fn register_builtins() {
let _ = register(EnvSpec {
id: "CartPole-v1".to_owned(),
max_episode_steps: Some(500),
reward_threshold: Some(475.0),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
let env = CartPoleEnv::new(CartPoleConfig {
render_mode: rm,
..CartPoleConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "MountainCar-v0".to_owned(),
max_episode_steps: Some(200),
reward_threshold: Some(-110.0),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::classic_control::{MountainCarConfig, MountainCarEnv};
let env = MountainCarEnv::new(MountainCarConfig {
render_mode: rm,
..MountainCarConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "Pendulum-v1".to_owned(),
max_episode_steps: Some(200),
reward_threshold: None,
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::classic_control::{PendulumConfig, PendulumEnv};
let env = PendulumEnv::new(PendulumConfig {
render_mode: rm,
..PendulumConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "MountainCarContinuous-v0".to_owned(),
max_episode_steps: Some(999),
reward_threshold: Some(90.0),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::classic_control::{
ContinuousMountainCarConfig, ContinuousMountainCarEnv,
};
let env = ContinuousMountainCarEnv::new(ContinuousMountainCarConfig {
render_mode: rm,
..ContinuousMountainCarConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "Acrobot-v1".to_owned(),
max_episode_steps: Some(500),
reward_threshold: Some(-100.0),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::classic_control::{AcrobotConfig, AcrobotEnv};
let env = AcrobotEnv::new(AcrobotConfig { render_mode: rm })?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "FrozenLake-v1".to_owned(),
max_episode_steps: Some(100),
reward_threshold: Some(0.70),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::toy_text::{FrozenLakeConfig, FrozenLakeEnv};
let env = FrozenLakeEnv::new(FrozenLakeConfig {
render_mode: rm,
..FrozenLakeConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "FrozenLake8x8-v1".to_owned(),
max_episode_steps: Some(200),
reward_threshold: Some(0.85),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::toy_text::{FrozenLakeConfig, FrozenLakeEnv, MAP_8X8};
let env = FrozenLakeEnv::new(FrozenLakeConfig {
desc: MAP_8X8.iter().map(|s| (*s).to_owned()).collect(),
render_mode: rm,
..FrozenLakeConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "Taxi-v3".to_owned(),
max_episode_steps: Some(200),
reward_threshold: Some(8.0),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::toy_text::{TaxiConfig, TaxiEnv};
let env = TaxiEnv::new(TaxiConfig { render_mode: rm });
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "CliffWalking-v1".to_owned(),
max_episode_steps: None,
reward_threshold: None,
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::toy_text::{CliffWalkingConfig, CliffWalkingEnv};
let env = CliffWalkingEnv::new(CliffWalkingConfig {
render_mode: rm,
..CliffWalkingConfig::default()
});
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "CliffWalkingSlippery-v1".to_owned(),
max_episode_steps: None,
reward_threshold: None,
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::toy_text::{CliffWalkingConfig, CliffWalkingEnv};
let env = CliffWalkingEnv::new(CliffWalkingConfig {
render_mode: rm,
is_slippery: true,
});
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "Blackjack-v1".to_owned(),
max_episode_steps: None,
reward_threshold: None,
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::toy_text::{BlackjackConfig, BlackjackEnv};
let env = BlackjackEnv::new(BlackjackConfig {
sab: true,
natural: false,
render_mode: rm,
});
Ok(Box::new(env))
}),
});
#[cfg(feature = "box2d")]
{
let _ = register(EnvSpec {
id: "LunarLander-v3".to_owned(),
max_episode_steps: Some(1000),
reward_threshold: Some(200.0),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::box2d::{LunarLanderConfig, LunarLanderEnv};
let env = LunarLanderEnv::new(LunarLanderConfig {
render_mode: rm,
..LunarLanderConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "BipedalWalker-v3".to_owned(),
max_episode_steps: Some(1600),
reward_threshold: Some(300.0),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::box2d::{BipedalWalkerConfig, BipedalWalkerEnv};
let env = BipedalWalkerEnv::new(BipedalWalkerConfig {
render_mode: rm,
..BipedalWalkerConfig::default()
})?;
Ok(Box::new(env))
}),
});
let _ = register(EnvSpec {
id: "BipedalWalkerHardcore-v3".to_owned(),
max_episode_steps: Some(2000),
reward_threshold: Some(300.0),
order_enforce: true,
factory: Box::new(|rm| {
use crate::envs::box2d::{BipedalWalkerConfig, BipedalWalkerEnv};
let env = BipedalWalkerEnv::new(BipedalWalkerConfig {
hardcore: true,
render_mode: rm,
})?;
Ok(Box::new(env))
}),
});
}
}
#[cfg(test)]
#[allow(clippy::panic)]
mod tests {
use super::*;
fn setup() {
register_builtins();
}
#[test]
fn make_cartpole() {
setup();
let mut env = make("CartPole-v1").unwrap();
let r = env.reset_dyn(Some(42)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 4);
} else {
panic!("expected Continuous observation");
}
}
#[test]
fn make_mountain_car() {
setup();
let mut env = make("MountainCar-v0").unwrap();
let r = env.reset_dyn(Some(0)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 2);
} else {
panic!("expected Continuous observation");
}
}
#[test]
fn make_pendulum() {
setup();
let mut env = make("Pendulum-v1").unwrap();
let r = env.reset_dyn(Some(0)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 3);
} else {
panic!("expected Continuous observation");
}
}
#[test]
fn make_unknown_errors() {
setup();
assert!(make("NonExistent-v99").is_err());
}
#[test]
fn list_includes_builtins() {
setup();
let ids = list_registered();
assert!(ids.contains(&"CartPole-v1".to_owned()));
assert!(ids.contains(&"MountainCar-v0".to_owned()));
assert!(ids.contains(&"Pendulum-v1".to_owned()));
assert!(ids.contains(&"FrozenLake-v1".to_owned()));
assert!(ids.contains(&"FrozenLake8x8-v1".to_owned()));
assert!(ids.contains(&"Taxi-v3".to_owned()));
assert!(ids.contains(&"CliffWalking-v1".to_owned()));
assert!(ids.contains(&"Blackjack-v1".to_owned()));
#[cfg(feature = "box2d")]
{
assert!(ids.contains(&"LunarLander-v3".to_owned()));
assert!(ids.contains(&"BipedalWalker-v3".to_owned()));
assert!(ids.contains(&"BipedalWalkerHardcore-v3".to_owned()));
}
}
#[cfg(feature = "box2d")]
#[test]
fn make_lunar_lander() {
setup();
let mut env = make("LunarLander-v3").unwrap();
let r = env.reset_dyn(Some(42)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 8);
} else {
panic!("expected Continuous observation, got {:?}", r.obs);
}
}
#[cfg(feature = "box2d")]
#[test]
fn make_bipedal_walker() {
setup();
let mut env = make("BipedalWalker-v3").unwrap();
let r = env.reset_dyn(Some(42)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 24);
} else {
panic!("expected Continuous observation, got {:?}", r.obs);
}
}
#[cfg(feature = "box2d")]
#[test]
fn make_bipedal_walker_hardcore() {
setup();
let mut env = make("BipedalWalkerHardcore-v3").unwrap();
let r = env.reset_dyn(Some(42)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 24);
} else {
panic!("expected Continuous observation, got {:?}", r.obs);
}
}
#[test]
fn step_through_dyn_env() {
setup();
let mut env = make("CartPole-v1").unwrap();
env.reset_dyn(Some(42)).unwrap();
let r = env.step_dyn(&DynValue::Discrete(1)).unwrap();
if let DynValue::Continuous(obs) = &r.obs {
assert_eq!(obs.len(), 4);
} else {
panic!("expected Continuous observation");
}
}
#[test]
fn make_frozen_lake() {
setup();
let mut env = make("FrozenLake-v1").unwrap();
let r = env.reset_dyn(Some(0)).unwrap();
assert!(matches!(r.obs, DynValue::Discrete(_)));
}
#[test]
fn make_taxi() {
setup();
let mut env = make("Taxi-v3").unwrap();
let r = env.reset_dyn(Some(0)).unwrap();
assert!(matches!(r.obs, DynValue::Discrete(_)));
}
#[test]
fn make_blackjack() {
setup();
let mut env = make("Blackjack-v1").unwrap();
let r = env.reset_dyn(Some(42)).unwrap();
if let DynValue::Tuple(elems) = &r.obs {
assert_eq!(elems.len(), 3);
} else {
panic!("expected Tuple observation, got {:?}", r.obs);
}
}
}