use crate::env::env::{Env, StepResult};
use crate::error::{RlError, RlResult};
#[derive(Debug, Clone)]
pub struct VecStepResult {
pub obs: Vec<f32>,
pub rewards: Vec<f32>,
pub dones: Vec<bool>,
}
#[derive(Debug)]
pub struct VecEnv<E: Env> {
envs: Vec<E>,
}
impl<E: Env> VecEnv<E> {
#[must_use]
pub fn new(envs: Vec<E>) -> Self {
Self { envs }
}
#[must_use]
#[inline]
pub fn n_envs(&self) -> usize {
self.envs.len()
}
pub fn reset_all(&mut self) -> RlResult<Vec<f32>> {
let mut flat = Vec::new();
for env in &mut self.envs {
let obs = env.reset()?;
flat.extend_from_slice(&obs);
}
Ok(flat)
}
pub fn step(&mut self, actions: &[f32]) -> RlResult<VecStepResult> {
let n = self.envs.len();
if n == 0 {
return Ok(VecStepResult {
obs: Vec::new(),
rewards: Vec::new(),
dones: Vec::new(),
});
}
if actions.len() % n != 0 {
return Err(RlError::DimensionMismatch {
expected: n * self.envs[0].action_dim(),
got: actions.len(),
});
}
let action_dim = actions.len() / n;
let mut flat_obs: Vec<f32> = Vec::with_capacity(actions.len());
let mut rewards = Vec::with_capacity(n);
let mut dones = Vec::with_capacity(n);
for (env, chunk) in self.envs.iter_mut().zip(actions.chunks_exact(action_dim)) {
let StepResult { obs, reward, done } = env.step(chunk)?;
rewards.push(reward);
dones.push(done);
if done {
let reset_obs = env.reset()?;
flat_obs.extend_from_slice(&reset_obs);
} else {
flat_obs.extend_from_slice(&obs);
}
}
Ok(VecStepResult {
obs: flat_obs,
rewards,
dones,
})
}
#[must_use]
#[inline]
pub fn envs(&self) -> &[E] {
&self.envs
}
#[inline]
pub fn envs_mut(&mut self) -> &mut [E] {
&mut self.envs
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::env::env::LinearQuadraticEnv;
fn make_vec_env(n: usize, obs_dim: usize, max_steps: usize) -> VecEnv<LinearQuadraticEnv> {
let envs = (0..n)
.map(|_| LinearQuadraticEnv::new(obs_dim, max_steps))
.collect();
VecEnv::new(envs)
}
#[test]
fn reset_all_length() {
let mut ve = make_vec_env(4, 3, 50);
let obs = ve.reset_all().unwrap();
assert_eq!(obs.len(), 4 * 3);
}
#[test]
fn step_output_lengths() {
let mut ve = make_vec_env(4, 3, 50);
let _ = ve.reset_all().unwrap();
let actions = vec![0.0_f32; 4 * 3];
let res = ve.step(&actions).unwrap();
assert_eq!(res.obs.len(), 4 * 3);
assert_eq!(res.rewards.len(), 4);
assert_eq!(res.dones.len(), 4);
}
#[test]
fn step_dimension_mismatch() {
let mut ve = make_vec_env(4, 3, 50);
let _ = ve.reset_all().unwrap();
assert!(ve.step(&[0.0; 10]).is_err());
}
#[test]
fn auto_reset_on_done() {
let mut ve = make_vec_env(2, 2, 1);
let _ = ve.reset_all().unwrap();
let res = ve.step(&[0.0_f32; 2 * 2]).unwrap();
assert!(res.dones.iter().all(|&d| d));
assert_eq!(res.obs.len(), 2 * 2);
}
#[test]
fn n_envs_accessor() {
let ve = make_vec_env(6, 4, 100);
assert_eq!(ve.n_envs(), 6);
}
#[test]
fn empty_vec_env_step() {
let envs: Vec<LinearQuadraticEnv> = Vec::new();
let mut ve = VecEnv::new(envs);
let res = ve.step(&[]).unwrap();
assert!(res.obs.is_empty());
assert!(res.rewards.is_empty());
assert!(res.dones.is_empty());
}
}