use crate::error::{RlError, RlResult};
#[derive(Debug, Clone)]
pub struct StepResult {
pub obs: Vec<f32>,
pub reward: f32,
pub done: bool,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EnvInfo {
pub obs_dim: usize,
pub action_dim: usize,
pub max_steps: usize,
}
pub trait Env {
fn reset(&mut self) -> RlResult<Vec<f32>>;
fn step(&mut self, action: &[f32]) -> RlResult<StepResult>;
fn info(&self) -> EnvInfo;
fn obs_dim(&self) -> usize;
fn action_dim(&self) -> usize;
}
#[derive(Debug, Clone)]
pub struct LinearQuadraticEnv {
obs_dim: usize,
max_steps: usize,
state: Vec<f32>,
step_count: usize,
}
impl LinearQuadraticEnv {
pub fn new(obs_dim: usize, max_steps: usize) -> Self {
let state = (0..obs_dim)
.map(|i| if i % 2 == 0 { 0.5_f32 } else { -0.5_f32 })
.collect();
Self {
obs_dim,
max_steps,
state,
step_count: 0,
}
}
fn sq_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum()
}
}
impl Env for LinearQuadraticEnv {
fn reset(&mut self) -> RlResult<Vec<f32>> {
self.step_count = 0;
for (i, x) in self.state.iter_mut().enumerate() {
*x = if i % 2 == 0 { 0.5_f32 } else { -0.5_f32 };
}
Ok(self.state.clone())
}
fn step(&mut self, action: &[f32]) -> RlResult<StepResult> {
if action.len() != self.obs_dim {
return Err(RlError::DimensionMismatch {
expected: self.obs_dim,
got: action.len(),
});
}
let x_sq = Self::sq_norm(&self.state);
let u_sq = Self::sq_norm(action);
let reward = -x_sq - 0.1 * u_sq;
for (x, u) in self.state.iter_mut().zip(action.iter()) {
*x = 0.9 * (*x) + 0.1 * u;
}
self.step_count += 1;
let x_norm = Self::sq_norm(&self.state).sqrt();
let done = self.step_count >= self.max_steps || x_norm > 10.0;
Ok(StepResult {
obs: self.state.clone(),
reward,
done,
})
}
fn info(&self) -> EnvInfo {
EnvInfo {
obs_dim: self.obs_dim,
action_dim: self.obs_dim,
max_steps: self.max_steps,
}
}
#[inline]
fn obs_dim(&self) -> usize {
self.obs_dim
}
#[inline]
fn action_dim(&self) -> usize {
self.obs_dim
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lqr_reset_alternating() {
let mut env = LinearQuadraticEnv::new(4, 10);
let obs = env.reset().unwrap();
assert_eq!(obs.len(), 4);
assert!((obs[0] - 0.5).abs() < 1e-6);
assert!((obs[1] + 0.5).abs() < 1e-6);
assert!((obs[2] - 0.5).abs() < 1e-6);
assert!((obs[3] + 0.5).abs() < 1e-6);
}
#[test]
fn lqr_step_dimension_mismatch() {
let mut env = LinearQuadraticEnv::new(4, 10);
let _ = env.reset().unwrap();
assert!(env.step(&[0.0; 3]).is_err());
}
#[test]
fn lqr_step_reward_is_negative() {
let mut env = LinearQuadraticEnv::new(4, 10);
let _ = env.reset().unwrap();
let res = env.step(&[0.0; 4]).unwrap();
assert!(res.reward <= 0.0, "reward={}", res.reward);
}
#[test]
fn lqr_episode_ends_at_max_steps() {
let max = 5;
let mut env = LinearQuadraticEnv::new(2, max);
let _ = env.reset().unwrap();
let mut done = false;
for i in 0..max {
let res = env.step(&[0.0; 2]).unwrap();
done = res.done;
if i < max - 1 {
assert!(!done, "should not be done before max_steps");
}
}
assert!(done, "should be done at max_steps");
}
#[test]
fn lqr_info() {
let env = LinearQuadraticEnv::new(3, 100);
let info = env.info();
assert_eq!(info.obs_dim, 3);
assert_eq!(info.action_dim, 3);
assert_eq!(info.max_steps, 100);
}
#[test]
fn lqr_obs_action_dim() {
let env = LinearQuadraticEnv::new(5, 10);
assert_eq!(env.obs_dim(), 5);
assert_eq!(env.action_dim(), 5);
}
#[test]
fn lqr_large_action_terminates_early() {
let mut env = LinearQuadraticEnv::new(2, 1000);
let _ = env.reset().unwrap();
let done_at_some_point =
(0..1000).any(|_| env.step(&[100.0, 100.0]).map(|r| r.done).unwrap_or(true));
assert!(done_at_some_point);
}
}