oxicuda_rl/env/vectorized.rs
1//! Vectorized environment wrapper.
2//!
3//! [`VecEnv`] runs multiple [`Env`] instances in lock-step, auto-resetting any
4//! environment that reaches a terminal state. All observations are returned as
5//! a flattened `Vec<f32>` of length `n_envs × obs_dim`.
6
7use crate::env::env::{Env, StepResult};
8use crate::error::{RlError, RlResult};
9
10// ─── VecStepResult ────────────────────────────────────────────────────────────
11
12/// Result returned by [`VecEnv::step`].
13#[derive(Debug, Clone)]
14pub struct VecStepResult {
15 /// Flattened observations for all environments, length = `n_envs × obs_dim`.
16 ///
17 /// For done environments this is the first observation of the **new**
18 /// episode (auto-reset semantics).
19 pub obs: Vec<f32>,
20 /// Per-environment scalar rewards, length = `n_envs`.
21 pub rewards: Vec<f32>,
22 /// Per-environment done flags, length = `n_envs`.
23 pub dones: Vec<bool>,
24}
25
26// ─── VecEnv ───────────────────────────────────────────────────────────────────
27
28/// Synchronous vectorized environment.
29///
30/// Wraps a `Vec<E>` of homogeneous environments. All environments must share
31/// the same `obs_dim` and `action_dim`; this is validated lazily on the first
32/// `step` call to avoid redundant checks.
33///
34/// # Auto-reset
35///
36/// When `step` detects that environment `i` is done, it immediately calls
37/// `reset()` on it and places the resulting observation in the corresponding
38/// slot of `VecStepResult::obs`.
39///
40/// # Example
41///
42/// ```rust
43/// use oxicuda_rl::env::env::LinearQuadraticEnv;
44/// use oxicuda_rl::env::vectorized::VecEnv;
45///
46/// let envs: Vec<_> = (0..4).map(|_| LinearQuadraticEnv::new(3, 50)).collect();
47/// let mut ve = VecEnv::new(envs);
48/// let obs = ve.reset_all().unwrap();
49/// assert_eq!(obs.len(), 4 * 3);
50/// ```
51#[derive(Debug)]
52pub struct VecEnv<E: Env> {
53 envs: Vec<E>,
54}
55
56impl<E: Env> VecEnv<E> {
57 /// Create a new [`VecEnv`] from a non-empty vector of environments.
58 ///
59 /// # Panics
60 ///
61 /// Does **not** panic; returns an instance even with an empty `envs` slice
62 /// (though subsequent calls will fail with [`RlError::DimensionMismatch`]).
63 #[must_use]
64 pub fn new(envs: Vec<E>) -> Self {
65 Self { envs }
66 }
67
68 /// Number of parallel environments.
69 #[must_use]
70 #[inline]
71 pub fn n_envs(&self) -> usize {
72 self.envs.len()
73 }
74
75 /// Reset **all** environments and return flattened observations.
76 ///
77 /// Returns a `Vec<f32>` of length `n_envs × obs_dim`.
78 ///
79 /// # Errors
80 ///
81 /// Propagates any [`RlError`] from individual `reset()` calls.
82 pub fn reset_all(&mut self) -> RlResult<Vec<f32>> {
83 let mut flat = Vec::new();
84 for env in &mut self.envs {
85 let obs = env.reset()?;
86 flat.extend_from_slice(&obs);
87 }
88 Ok(flat)
89 }
90
91 /// Step all environments simultaneously.
92 ///
93 /// `actions` must have length `n_envs × action_dim`; the slice is split
94 /// into per-environment chunks before dispatch.
95 ///
96 /// # Errors
97 ///
98 /// * [`RlError::DimensionMismatch`] — `actions.len()` is not a multiple of
99 /// `n_envs`, or a chunk length does not match the environment's
100 /// `action_dim`.
101 /// * Any error propagated from individual `step()` or `reset()` calls.
102 pub fn step(&mut self, actions: &[f32]) -> RlResult<VecStepResult> {
103 let n = self.envs.len();
104 if n == 0 {
105 return Ok(VecStepResult {
106 obs: Vec::new(),
107 rewards: Vec::new(),
108 dones: Vec::new(),
109 });
110 }
111
112 // Infer per-environment action chunk size.
113 if actions.len() % n != 0 {
114 return Err(RlError::DimensionMismatch {
115 expected: n * self.envs[0].action_dim(),
116 got: actions.len(),
117 });
118 }
119 let action_dim = actions.len() / n;
120
121 let mut flat_obs: Vec<f32> = Vec::with_capacity(actions.len());
122 let mut rewards = Vec::with_capacity(n);
123 let mut dones = Vec::with_capacity(n);
124
125 for (env, chunk) in self.envs.iter_mut().zip(actions.chunks_exact(action_dim)) {
126 let StepResult { obs, reward, done } = env.step(chunk)?;
127
128 rewards.push(reward);
129 dones.push(done);
130
131 if done {
132 // Auto-reset: use the first observation of the new episode.
133 let reset_obs = env.reset()?;
134 flat_obs.extend_from_slice(&reset_obs);
135 } else {
136 flat_obs.extend_from_slice(&obs);
137 }
138 }
139
140 Ok(VecStepResult {
141 obs: flat_obs,
142 rewards,
143 dones,
144 })
145 }
146
147 /// Immutable access to the underlying environment slice.
148 #[must_use]
149 #[inline]
150 pub fn envs(&self) -> &[E] {
151 &self.envs
152 }
153
154 /// Mutable access to the underlying environment slice.
155 #[inline]
156 pub fn envs_mut(&mut self) -> &mut [E] {
157 &mut self.envs
158 }
159}
160
161// ─── Tests ───────────────────────────────────────────────────────────────────
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use crate::env::env::LinearQuadraticEnv;
167
168 fn make_vec_env(n: usize, obs_dim: usize, max_steps: usize) -> VecEnv<LinearQuadraticEnv> {
169 let envs = (0..n)
170 .map(|_| LinearQuadraticEnv::new(obs_dim, max_steps))
171 .collect();
172 VecEnv::new(envs)
173 }
174
175 #[test]
176 fn reset_all_length() {
177 let mut ve = make_vec_env(4, 3, 50);
178 let obs = ve.reset_all().unwrap();
179 assert_eq!(obs.len(), 4 * 3);
180 }
181
182 #[test]
183 fn step_output_lengths() {
184 let mut ve = make_vec_env(4, 3, 50);
185 let _ = ve.reset_all().unwrap();
186 let actions = vec![0.0_f32; 4 * 3];
187 let res = ve.step(&actions).unwrap();
188 assert_eq!(res.obs.len(), 4 * 3);
189 assert_eq!(res.rewards.len(), 4);
190 assert_eq!(res.dones.len(), 4);
191 }
192
193 #[test]
194 fn step_dimension_mismatch() {
195 let mut ve = make_vec_env(4, 3, 50);
196 let _ = ve.reset_all().unwrap();
197 // Wrong total length.
198 assert!(ve.step(&[0.0; 10]).is_err());
199 }
200
201 #[test]
202 fn auto_reset_on_done() {
203 // max_steps=1 so every step triggers a done and auto-reset.
204 let mut ve = make_vec_env(2, 2, 1);
205 let _ = ve.reset_all().unwrap();
206 let res = ve.step(&[0.0_f32; 2 * 2]).unwrap();
207 // All dones should be true.
208 assert!(res.dones.iter().all(|&d| d));
209 // obs should still have the reset observations (length correct).
210 assert_eq!(res.obs.len(), 2 * 2);
211 }
212
213 #[test]
214 fn n_envs_accessor() {
215 let ve = make_vec_env(6, 4, 100);
216 assert_eq!(ve.n_envs(), 6);
217 }
218
219 #[test]
220 fn empty_vec_env_step() {
221 let envs: Vec<LinearQuadraticEnv> = Vec::new();
222 let mut ve = VecEnv::new(envs);
223 let res = ve.step(&[]).unwrap();
224 assert!(res.obs.is_empty());
225 assert!(res.rewards.is_empty());
226 assert!(res.dones.is_empty());
227 }
228}