1use crate::error::{RlError, RlResult};
9
10#[derive(Debug, Clone)]
14pub struct StepResult {
15 pub obs: Vec<f32>,
17 pub reward: f32,
19 pub done: bool,
21}
22
23#[derive(Debug, Clone, PartialEq, Eq)]
27pub struct EnvInfo {
28 pub obs_dim: usize,
30 pub action_dim: usize,
32 pub max_steps: usize,
34}
35
36pub trait Env {
43 fn reset(&mut self) -> RlResult<Vec<f32>>;
46
47 fn step(&mut self, action: &[f32]) -> RlResult<StepResult>;
52
53 fn info(&self) -> EnvInfo;
55
56 fn obs_dim(&self) -> usize;
58
59 fn action_dim(&self) -> usize;
61}
62
63#[derive(Debug, Clone)]
89pub struct LinearQuadraticEnv {
90 obs_dim: usize,
91 max_steps: usize,
92 state: Vec<f32>,
93 step_count: usize,
94}
95
96impl LinearQuadraticEnv {
97 pub fn new(obs_dim: usize, max_steps: usize) -> Self {
108 let state = (0..obs_dim)
110 .map(|i| if i % 2 == 0 { 0.5_f32 } else { -0.5_f32 })
111 .collect();
112 Self {
113 obs_dim,
114 max_steps,
115 state,
116 step_count: 0,
117 }
118 }
119
120 fn sq_norm(v: &[f32]) -> f32 {
122 v.iter().map(|x| x * x).sum()
123 }
124}
125
126impl Env for LinearQuadraticEnv {
127 fn reset(&mut self) -> RlResult<Vec<f32>> {
128 self.step_count = 0;
129 for (i, x) in self.state.iter_mut().enumerate() {
131 *x = if i % 2 == 0 { 0.5_f32 } else { -0.5_f32 };
132 }
133 Ok(self.state.clone())
134 }
135
136 fn step(&mut self, action: &[f32]) -> RlResult<StepResult> {
137 if action.len() != self.obs_dim {
138 return Err(RlError::DimensionMismatch {
139 expected: self.obs_dim,
140 got: action.len(),
141 });
142 }
143
144 let x_sq = Self::sq_norm(&self.state);
146 let u_sq = Self::sq_norm(action);
147 let reward = -x_sq - 0.1 * u_sq;
148
149 for (x, u) in self.state.iter_mut().zip(action.iter()) {
151 *x = 0.9 * (*x) + 0.1 * u;
152 }
153
154 self.step_count += 1;
155
156 let x_norm = Self::sq_norm(&self.state).sqrt();
158 let done = self.step_count >= self.max_steps || x_norm > 10.0;
159
160 Ok(StepResult {
161 obs: self.state.clone(),
162 reward,
163 done,
164 })
165 }
166
167 fn info(&self) -> EnvInfo {
168 EnvInfo {
169 obs_dim: self.obs_dim,
170 action_dim: self.obs_dim,
171 max_steps: self.max_steps,
172 }
173 }
174
175 #[inline]
176 fn obs_dim(&self) -> usize {
177 self.obs_dim
178 }
179
180 #[inline]
181 fn action_dim(&self) -> usize {
182 self.obs_dim
183 }
184}
185
186#[cfg(test)]
189mod tests {
190 use super::*;
191
192 #[test]
193 fn lqr_reset_alternating() {
194 let mut env = LinearQuadraticEnv::new(4, 10);
195 let obs = env.reset().unwrap();
196 assert_eq!(obs.len(), 4);
197 assert!((obs[0] - 0.5).abs() < 1e-6);
198 assert!((obs[1] + 0.5).abs() < 1e-6);
199 assert!((obs[2] - 0.5).abs() < 1e-6);
200 assert!((obs[3] + 0.5).abs() < 1e-6);
201 }
202
203 #[test]
204 fn lqr_step_dimension_mismatch() {
205 let mut env = LinearQuadraticEnv::new(4, 10);
206 let _ = env.reset().unwrap();
207 assert!(env.step(&[0.0; 3]).is_err());
208 }
209
210 #[test]
211 fn lqr_step_reward_is_negative() {
212 let mut env = LinearQuadraticEnv::new(4, 10);
213 let _ = env.reset().unwrap();
214 let res = env.step(&[0.0; 4]).unwrap();
216 assert!(res.reward <= 0.0, "reward={}", res.reward);
217 }
218
219 #[test]
220 fn lqr_episode_ends_at_max_steps() {
221 let max = 5;
222 let mut env = LinearQuadraticEnv::new(2, max);
223 let _ = env.reset().unwrap();
224 let mut done = false;
225 for i in 0..max {
226 let res = env.step(&[0.0; 2]).unwrap();
227 done = res.done;
228 if i < max - 1 {
229 assert!(!done, "should not be done before max_steps");
230 }
231 }
232 assert!(done, "should be done at max_steps");
233 }
234
235 #[test]
236 fn lqr_info() {
237 let env = LinearQuadraticEnv::new(3, 100);
238 let info = env.info();
239 assert_eq!(info.obs_dim, 3);
240 assert_eq!(info.action_dim, 3);
241 assert_eq!(info.max_steps, 100);
242 }
243
244 #[test]
245 fn lqr_obs_action_dim() {
246 let env = LinearQuadraticEnv::new(5, 10);
247 assert_eq!(env.obs_dim(), 5);
248 assert_eq!(env.action_dim(), 5);
249 }
250
251 #[test]
252 fn lqr_large_action_terminates_early() {
253 let mut env = LinearQuadraticEnv::new(2, 1000);
254 let _ = env.reset().unwrap();
255 let done_at_some_point =
257 (0..1000).any(|_| env.step(&[100.0, 100.0]).map(|r| r.done).unwrap_or(true));
258 assert!(done_at_some_point);
259 }
260}