use crate::env::{Env, RenderFrame};
use crate::error::{Error, Result};
use crate::vector::{AutoresetMode, VecResetResult, VecStepResult, VectorEnv};
#[derive(Debug)]
pub struct SyncVectorEnv<E: Env> {
envs: Vec<E>,
autoreset_mode: AutoresetMode,
needs_reset: Vec<bool>,
}
impl<E: Env> SyncVectorEnv<E> {
pub fn new(envs: Vec<E>) -> Result<Self> {
if envs.is_empty() {
return Err(Error::InvalidSpace {
reason: "SyncVectorEnv requires at least one sub-environment".to_owned(),
});
}
let n = envs.len();
Ok(Self {
envs,
autoreset_mode: AutoresetMode::NextStep,
needs_reset: vec![false; n],
})
}
pub fn with_autoreset(envs: Vec<E>, mode: AutoresetMode) -> Result<Self> {
let mut v = Self::new(envs)?;
v.autoreset_mode = mode;
Ok(v)
}
#[must_use]
pub const fn num_envs(&self) -> usize {
self.envs.len()
}
#[must_use]
pub const fn autoreset_mode(&self) -> AutoresetMode {
self.autoreset_mode
}
#[must_use]
pub fn get_env(&self, index: usize) -> Option<&E> {
self.envs.get(index)
}
#[must_use]
pub fn get_env_mut(&mut self, index: usize) -> Option<&mut E> {
self.envs.get_mut(index)
}
}
impl<E: Env> VectorEnv for SyncVectorEnv<E> {
type Obs = E::Obs;
type Act = E::Act;
fn num_envs(&self) -> usize {
self.envs.len()
}
fn reset(&mut self, seed: Option<u64>) -> Result<VecResetResult<E::Obs>> {
let mut obs = Vec::with_capacity(self.envs.len());
let mut infos = Vec::with_capacity(self.envs.len());
for (i, env) in self.envs.iter_mut().enumerate() {
let s = seed.map(|s| s + i as u64);
let r = env.reset(s)?;
obs.push(r.obs);
infos.push(r.info);
}
self.needs_reset.fill(false);
Ok(VecResetResult { obs, infos })
}
fn step(&mut self, actions: &[E::Act]) -> Result<VecStepResult<E::Obs>> {
if actions.len() != self.envs.len() {
return Err(Error::InvalidAction {
reason: format!(
"expected {} actions, got {}",
self.envs.len(),
actions.len()
),
});
}
let n = self.envs.len();
let mut obs = Vec::with_capacity(n);
let mut rewards = Vec::with_capacity(n);
let mut terminated = Vec::with_capacity(n);
let mut truncated = Vec::with_capacity(n);
let mut infos = Vec::with_capacity(n);
for (i, (env, action)) in self.envs.iter_mut().zip(actions.iter()).enumerate() {
if self.needs_reset[i] && self.autoreset_mode == AutoresetMode::NextStep {
env.reset(None)?;
}
let r = env.step(action)?;
let done = r.terminated || r.truncated;
self.needs_reset[i] = done;
obs.push(r.obs);
rewards.push(r.reward);
terminated.push(r.terminated);
truncated.push(r.truncated);
infos.push(r.info);
}
Ok(VecStepResult {
obs,
rewards,
terminated,
truncated,
infos,
})
}
fn render(&mut self) -> Result<Vec<RenderFrame>> {
self.envs.iter_mut().map(Env::render).collect()
}
fn close(&mut self) {
for env in &mut self.envs {
env.close();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::envs::classic_control::{CartPoleConfig, CartPoleEnv};
fn make_vec(n: usize) -> SyncVectorEnv<CartPoleEnv> {
let envs: Vec<_> = (0..n)
.map(|_| CartPoleEnv::new(CartPoleConfig::default()).unwrap())
.collect();
SyncVectorEnv::new(envs).unwrap()
}
#[test]
fn reset_returns_n_observations() {
let mut v = make_vec(4);
let r = v.reset(Some(42)).unwrap();
assert_eq!(r.obs.len(), 4);
assert_eq!(r.infos.len(), 4);
}
#[test]
fn step_returns_n_results() {
let mut v = make_vec(3);
v.reset(Some(0)).unwrap();
let actions = vec![0_i64, 1, 0];
let r = v.step(&actions).unwrap();
assert_eq!(r.obs.len(), 3);
assert_eq!(r.rewards.len(), 3);
assert_eq!(r.terminated.len(), 3);
assert_eq!(r.truncated.len(), 3);
}
#[test]
fn step_wrong_action_count_errors() {
let mut v = make_vec(3);
v.reset(Some(0)).unwrap();
let actions = vec![0_i64, 1]; assert!(v.step(&actions).is_err());
}
#[test]
fn empty_envs_errors() {
let result = SyncVectorEnv::<CartPoleEnv>::new(vec![]);
assert!(result.is_err());
}
#[test]
fn deterministic_with_seed() {
let mut v1 = make_vec(2);
let mut v2 = make_vec(2);
let r1 = v1.reset(Some(99)).unwrap();
let r2 = v2.reset(Some(99)).unwrap();
assert_eq!(r1.obs, r2.obs);
}
#[test]
fn autoreset_on_next_step() {
let mut v = make_vec(1);
v.reset(Some(0)).unwrap();
let mut done = false;
for _ in 0..500 {
let r = v.step(&[1]).unwrap();
if r.terminated[0] {
done = true;
break;
}
}
assert!(done, "should terminate within 500 steps");
let r = v.step(&[0]).unwrap();
assert_eq!(r.obs.len(), 1);
}
#[test]
fn num_envs_correct() {
let v = make_vec(5);
assert_eq!(v.num_envs(), 5);
}
#[test]
fn render_returns_frames() {
let mut v = make_vec(2);
v.reset(Some(0)).unwrap();
let frames = v.render().unwrap();
assert_eq!(frames.len(), 2);
}
}