Skip to main content

oxicuda_rl/env/
env.rs

1//! Core environment trait and a reference Linear-Quadratic environment.
2//!
3//! # Overview
4//!
5//! [`Env`] is the standard single-environment interface. [`LinearQuadraticEnv`]
6//! is a fully deterministic LQR environment used in unit and integration tests.
7
8use crate::error::{RlError, RlResult};
9
10// ─── StepResult ───────────────────────────────────────────────────────────────
11
12/// Result returned by [`Env::step`].
13#[derive(Debug, Clone)]
14pub struct StepResult {
15    /// Next observation vector (length = `obs_dim`).
16    pub obs: Vec<f32>,
17    /// Scalar reward received.
18    pub reward: f32,
19    /// Whether the episode has ended.
20    pub done: bool,
21}
22
23// ─── EnvInfo ─────────────────────────────────────────────────────────────────
24
25/// Static metadata about an environment.
26#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct EnvInfo {
28    /// Length of the observation vector.
29    pub obs_dim: usize,
30    /// Length of the action vector.
31    pub action_dim: usize,
32    /// Maximum number of steps per episode (0 = unlimited).
33    pub max_steps: usize,
34}
35
36// ─── Env trait ────────────────────────────────────────────────────────────────
37
38/// Standard RL environment interface.
39///
40/// Implementors provide `reset`, `step`, and metadata queries.  All methods
41/// return [`RlResult`] so errors propagate cleanly without panicking.
42pub trait Env {
43    /// Reset the environment to its initial state and return the first
44    /// observation.
45    fn reset(&mut self) -> RlResult<Vec<f32>>;
46
47    /// Advance the environment by one step given `action`.
48    ///
49    /// `action.len()` must equal [`Self::action_dim`]; otherwise
50    /// [`RlError::DimensionMismatch`] is returned.
51    fn step(&mut self, action: &[f32]) -> RlResult<StepResult>;
52
53    /// Return static metadata for this environment.
54    fn info(&self) -> EnvInfo;
55
56    /// Observation vector length.
57    fn obs_dim(&self) -> usize;
58
59    /// Action vector length.
60    fn action_dim(&self) -> usize;
61}
62
63// ─── LinearQuadraticEnv ───────────────────────────────────────────────────────
64
65/// A deterministic Linear-Quadratic Regulator (LQR) environment.
66///
67/// ## Dynamics
68///
69/// State `x ∈ ℝ^d`, action `u ∈ ℝ^d` (same dimension):
70///
71/// ```text
72/// x_{t+1}[i] = 0.9 · x_t[i] + 0.1 · u[i]
73/// ```
74///
75/// ## Reward
76///
77/// ```text
78/// r = −‖x‖² − 0.1·‖u‖²
79/// ```
80///
81/// ## Reset
82///
83/// State is initialised to alternating `[0.5, -0.5, 0.5, -0.5, ...]`.
84///
85/// ## Episode termination
86///
87/// The episode ends when `step_count >= max_steps` or `‖x‖ > 10.0`.
88#[derive(Debug, Clone)]
89pub struct LinearQuadraticEnv {
90    obs_dim: usize,
91    max_steps: usize,
92    state: Vec<f32>,
93    step_count: usize,
94}
95
96impl LinearQuadraticEnv {
97    /// Create a new LQR environment.
98    ///
99    /// # Arguments
100    ///
101    /// * `obs_dim`  — dimension of state and action vectors (must be `>= 1`).
102    /// * `max_steps` — maximum number of steps per episode (must be `>= 1`).
103    ///
104    /// # Errors
105    ///
106    /// Returns [`RlError::InvalidHyperparameter`] if either argument is zero.
107    pub fn new(obs_dim: usize, max_steps: usize) -> Self {
108        // Deterministic initial state: alternating 0.5 / -0.5.
109        let state = (0..obs_dim)
110            .map(|i| if i % 2 == 0 { 0.5_f32 } else { -0.5_f32 })
111            .collect();
112        Self {
113            obs_dim,
114            max_steps,
115            state,
116            step_count: 0,
117        }
118    }
119
120    /// Compute the squared L2 norm of `v`.
121    fn sq_norm(v: &[f32]) -> f32 {
122        v.iter().map(|x| x * x).sum()
123    }
124}
125
126impl Env for LinearQuadraticEnv {
127    fn reset(&mut self) -> RlResult<Vec<f32>> {
128        self.step_count = 0;
129        // Deterministic reset: alternating 0.5 / -0.5.
130        for (i, x) in self.state.iter_mut().enumerate() {
131            *x = if i % 2 == 0 { 0.5_f32 } else { -0.5_f32 };
132        }
133        Ok(self.state.clone())
134    }
135
136    fn step(&mut self, action: &[f32]) -> RlResult<StepResult> {
137        if action.len() != self.obs_dim {
138            return Err(RlError::DimensionMismatch {
139                expected: self.obs_dim,
140                got: action.len(),
141            });
142        }
143
144        // Compute reward before state transition (uses current state).
145        let x_sq = Self::sq_norm(&self.state);
146        let u_sq = Self::sq_norm(action);
147        let reward = -x_sq - 0.1 * u_sq;
148
149        // Dynamics: x_{t+1}[i] = 0.9 * x[i] + 0.1 * u[i]
150        for (x, u) in self.state.iter_mut().zip(action.iter()) {
151            *x = 0.9 * (*x) + 0.1 * u;
152        }
153
154        self.step_count += 1;
155
156        // Done conditions.
157        let x_norm = Self::sq_norm(&self.state).sqrt();
158        let done = self.step_count >= self.max_steps || x_norm > 10.0;
159
160        Ok(StepResult {
161            obs: self.state.clone(),
162            reward,
163            done,
164        })
165    }
166
167    fn info(&self) -> EnvInfo {
168        EnvInfo {
169            obs_dim: self.obs_dim,
170            action_dim: self.obs_dim,
171            max_steps: self.max_steps,
172        }
173    }
174
175    #[inline]
176    fn obs_dim(&self) -> usize {
177        self.obs_dim
178    }
179
180    #[inline]
181    fn action_dim(&self) -> usize {
182        self.obs_dim
183    }
184}
185
186// ─── Tests ───────────────────────────────────────────────────────────────────
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    #[test]
193    fn lqr_reset_alternating() {
194        let mut env = LinearQuadraticEnv::new(4, 10);
195        let obs = env.reset().unwrap();
196        assert_eq!(obs.len(), 4);
197        assert!((obs[0] - 0.5).abs() < 1e-6);
198        assert!((obs[1] + 0.5).abs() < 1e-6);
199        assert!((obs[2] - 0.5).abs() < 1e-6);
200        assert!((obs[3] + 0.5).abs() < 1e-6);
201    }
202
203    #[test]
204    fn lqr_step_dimension_mismatch() {
205        let mut env = LinearQuadraticEnv::new(4, 10);
206        let _ = env.reset().unwrap();
207        assert!(env.step(&[0.0; 3]).is_err());
208    }
209
210    #[test]
211    fn lqr_step_reward_is_negative() {
212        let mut env = LinearQuadraticEnv::new(4, 10);
213        let _ = env.reset().unwrap();
214        // State is non-zero after reset, so reward should be negative.
215        let res = env.step(&[0.0; 4]).unwrap();
216        assert!(res.reward <= 0.0, "reward={}", res.reward);
217    }
218
219    #[test]
220    fn lqr_episode_ends_at_max_steps() {
221        let max = 5;
222        let mut env = LinearQuadraticEnv::new(2, max);
223        let _ = env.reset().unwrap();
224        let mut done = false;
225        for i in 0..max {
226            let res = env.step(&[0.0; 2]).unwrap();
227            done = res.done;
228            if i < max - 1 {
229                assert!(!done, "should not be done before max_steps");
230            }
231        }
232        assert!(done, "should be done at max_steps");
233    }
234
235    #[test]
236    fn lqr_info() {
237        let env = LinearQuadraticEnv::new(3, 100);
238        let info = env.info();
239        assert_eq!(info.obs_dim, 3);
240        assert_eq!(info.action_dim, 3);
241        assert_eq!(info.max_steps, 100);
242    }
243
244    #[test]
245    fn lqr_obs_action_dim() {
246        let env = LinearQuadraticEnv::new(5, 10);
247        assert_eq!(env.obs_dim(), 5);
248        assert_eq!(env.action_dim(), 5);
249    }
250
251    #[test]
252    fn lqr_large_action_terminates_early() {
253        let mut env = LinearQuadraticEnv::new(2, 1000);
254        let _ = env.reset().unwrap();
255        // Massive action drives state norm above 10.
256        let done_at_some_point =
257            (0..1000).any(|_| env.step(&[100.0, 100.0]).map(|r| r.done).unwrap_or(true));
258        assert!(done_at_some_point);
259    }
260}