Skip to main content

oxicuda_rl/env/
vectorized.rs

1//! Vectorized environment wrapper.
2//!
3//! [`VecEnv`] runs multiple [`Env`] instances in lock-step, auto-resetting any
4//! environment that reaches a terminal state.  All observations are returned as
5//! a flattened `Vec<f32>` of length `n_envs × obs_dim`.
6
7use crate::env::env::{Env, StepResult};
8use crate::error::{RlError, RlResult};
9
10// ─── VecStepResult ────────────────────────────────────────────────────────────
11
12/// Result returned by [`VecEnv::step`].
13#[derive(Debug, Clone)]
14pub struct VecStepResult {
15    /// Flattened observations for all environments, length = `n_envs × obs_dim`.
16    ///
17    /// For done environments this is the first observation of the **new**
18    /// episode (auto-reset semantics).
19    pub obs: Vec<f32>,
20    /// Per-environment scalar rewards, length = `n_envs`.
21    pub rewards: Vec<f32>,
22    /// Per-environment done flags, length = `n_envs`.
23    pub dones: Vec<bool>,
24}
25
26// ─── VecEnv ───────────────────────────────────────────────────────────────────
27
28/// Synchronous vectorized environment.
29///
30/// Wraps a `Vec<E>` of homogeneous environments.  All environments must share
31/// the same `obs_dim` and `action_dim`; this is validated lazily on the first
32/// `step` call to avoid redundant checks.
33///
34/// # Auto-reset
35///
36/// When `step` detects that environment `i` is done, it immediately calls
37/// `reset()` on it and places the resulting observation in the corresponding
38/// slot of `VecStepResult::obs`.
39///
40/// # Example
41///
42/// ```rust
43/// use oxicuda_rl::env::env::LinearQuadraticEnv;
44/// use oxicuda_rl::env::vectorized::VecEnv;
45///
46/// let envs: Vec<_> = (0..4).map(|_| LinearQuadraticEnv::new(3, 50)).collect();
47/// let mut ve = VecEnv::new(envs);
48/// let obs = ve.reset_all().unwrap();
49/// assert_eq!(obs.len(), 4 * 3);
50/// ```
51#[derive(Debug)]
52pub struct VecEnv<E: Env> {
53    envs: Vec<E>,
54}
55
56impl<E: Env> VecEnv<E> {
57    /// Create a new [`VecEnv`] from a non-empty vector of environments.
58    ///
59    /// # Panics
60    ///
61    /// Does **not** panic; returns an instance even with an empty `envs` slice
62    /// (though subsequent calls will fail with [`RlError::DimensionMismatch`]).
63    #[must_use]
64    pub fn new(envs: Vec<E>) -> Self {
65        Self { envs }
66    }
67
68    /// Number of parallel environments.
69    #[must_use]
70    #[inline]
71    pub fn n_envs(&self) -> usize {
72        self.envs.len()
73    }
74
75    /// Reset **all** environments and return flattened observations.
76    ///
77    /// Returns a `Vec<f32>` of length `n_envs × obs_dim`.
78    ///
79    /// # Errors
80    ///
81    /// Propagates any [`RlError`] from individual `reset()` calls.
82    pub fn reset_all(&mut self) -> RlResult<Vec<f32>> {
83        let mut flat = Vec::new();
84        for env in &mut self.envs {
85            let obs = env.reset()?;
86            flat.extend_from_slice(&obs);
87        }
88        Ok(flat)
89    }
90
91    /// Step all environments simultaneously.
92    ///
93    /// `actions` must have length `n_envs × action_dim`; the slice is split
94    /// into per-environment chunks before dispatch.
95    ///
96    /// # Errors
97    ///
98    /// * [`RlError::DimensionMismatch`] — `actions.len()` is not a multiple of
99    ///   `n_envs`, or a chunk length does not match the environment's
100    ///   `action_dim`.
101    /// * Any error propagated from individual `step()` or `reset()` calls.
102    pub fn step(&mut self, actions: &[f32]) -> RlResult<VecStepResult> {
103        let n = self.envs.len();
104        if n == 0 {
105            return Ok(VecStepResult {
106                obs: Vec::new(),
107                rewards: Vec::new(),
108                dones: Vec::new(),
109            });
110        }
111
112        // Infer per-environment action chunk size.
113        if actions.len() % n != 0 {
114            return Err(RlError::DimensionMismatch {
115                expected: n * self.envs[0].action_dim(),
116                got: actions.len(),
117            });
118        }
119        let action_dim = actions.len() / n;
120
121        let mut flat_obs: Vec<f32> = Vec::with_capacity(actions.len());
122        let mut rewards = Vec::with_capacity(n);
123        let mut dones = Vec::with_capacity(n);
124
125        for (env, chunk) in self.envs.iter_mut().zip(actions.chunks_exact(action_dim)) {
126            let StepResult { obs, reward, done } = env.step(chunk)?;
127
128            rewards.push(reward);
129            dones.push(done);
130
131            if done {
132                // Auto-reset: use the first observation of the new episode.
133                let reset_obs = env.reset()?;
134                flat_obs.extend_from_slice(&reset_obs);
135            } else {
136                flat_obs.extend_from_slice(&obs);
137            }
138        }
139
140        Ok(VecStepResult {
141            obs: flat_obs,
142            rewards,
143            dones,
144        })
145    }
146
147    /// Immutable access to the underlying environment slice.
148    #[must_use]
149    #[inline]
150    pub fn envs(&self) -> &[E] {
151        &self.envs
152    }
153
154    /// Mutable access to the underlying environment slice.
155    #[inline]
156    pub fn envs_mut(&mut self) -> &mut [E] {
157        &mut self.envs
158    }
159}
160
161// ─── Tests ───────────────────────────────────────────────────────────────────
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::env::env::LinearQuadraticEnv;
167
168    fn make_vec_env(n: usize, obs_dim: usize, max_steps: usize) -> VecEnv<LinearQuadraticEnv> {
169        let envs = (0..n)
170            .map(|_| LinearQuadraticEnv::new(obs_dim, max_steps))
171            .collect();
172        VecEnv::new(envs)
173    }
174
175    #[test]
176    fn reset_all_length() {
177        let mut ve = make_vec_env(4, 3, 50);
178        let obs = ve.reset_all().unwrap();
179        assert_eq!(obs.len(), 4 * 3);
180    }
181
182    #[test]
183    fn step_output_lengths() {
184        let mut ve = make_vec_env(4, 3, 50);
185        let _ = ve.reset_all().unwrap();
186        let actions = vec![0.0_f32; 4 * 3];
187        let res = ve.step(&actions).unwrap();
188        assert_eq!(res.obs.len(), 4 * 3);
189        assert_eq!(res.rewards.len(), 4);
190        assert_eq!(res.dones.len(), 4);
191    }
192
193    #[test]
194    fn step_dimension_mismatch() {
195        let mut ve = make_vec_env(4, 3, 50);
196        let _ = ve.reset_all().unwrap();
197        // Wrong total length.
198        assert!(ve.step(&[0.0; 10]).is_err());
199    }
200
201    #[test]
202    fn auto_reset_on_done() {
203        // max_steps=1 so every step triggers a done and auto-reset.
204        let mut ve = make_vec_env(2, 2, 1);
205        let _ = ve.reset_all().unwrap();
206        let res = ve.step(&[0.0_f32; 2 * 2]).unwrap();
207        // All dones should be true.
208        assert!(res.dones.iter().all(|&d| d));
209        // obs should still have the reset observations (length correct).
210        assert_eq!(res.obs.len(), 2 * 2);
211    }
212
213    #[test]
214    fn n_envs_accessor() {
215        let ve = make_vec_env(6, 4, 100);
216        assert_eq!(ve.n_envs(), 6);
217    }
218
219    #[test]
220    fn empty_vec_env_step() {
221        let envs: Vec<LinearQuadraticEnv> = Vec::new();
222        let mut ve = VecEnv::new(envs);
223        let res = ve.step(&[]).unwrap();
224        assert!(res.obs.is_empty());
225        assert!(res.rewards.is_empty());
226        assert!(res.dones.is_empty());
227    }
228}