1use rand::rngs::SmallRng;
18use rand::{Rng, SeedableRng as _};
19use rl_traits::{EpisodeStatus, Environment, StepResult, TimeLimit};
20
21const GRAVITY: f32 = 9.8;
24const MASS_CART: f32 = 1.0;
25const MASS_POLE: f32 = 0.1;
26const TOTAL_MASS: f32 = MASS_CART + MASS_POLE;
27const HALF_POLE_LEN: f32 = 0.5; const POLE_MASS_LEN: f32 = MASS_POLE * HALF_POLE_LEN;
29const FORCE_MAG: f32 = 10.0;
30const TAU: f32 = 0.02; const X_THRESHOLD: f32 = 2.4;
33const THETA_THRESHOLD_RAD: f32 = 12.0 * std::f32::consts::PI / 180.0;
34
35const MAX_STEPS: usize = 500;
36
37pub type CartPoleObs = [f32; 4];
41
42pub struct CartPole {
45 state: CartPoleObs,
46 rng: SmallRng,
47}
48
49impl CartPole {
50 pub fn new(seed: u64) -> Self {
51 Self {
52 state: [0.0; 4],
53 rng: SmallRng::seed_from_u64(seed),
54 }
55 }
56
57 fn is_terminal(state: &CartPoleObs) -> bool {
58 state[0].abs() > X_THRESHOLD || state[2].abs() > THETA_THRESHOLD_RAD
59 }
60}
61
62impl Environment for CartPole {
63 type Observation = CartPoleObs;
64 type Action = usize; type Info = ();
66
67 fn step(&mut self, action: usize) -> StepResult<CartPoleObs, ()> {
68 let [x, x_dot, theta, theta_dot] = self.state;
69
70 let force = if action == 1 { FORCE_MAG } else { -FORCE_MAG };
71 let cos_theta = theta.cos();
72 let sin_theta = theta.sin();
73
74 let temp = (force + POLE_MASS_LEN * theta_dot * theta_dot * sin_theta) / TOTAL_MASS;
76 let theta_acc = (GRAVITY * sin_theta - cos_theta * temp)
77 / (HALF_POLE_LEN * (4.0 / 3.0 - MASS_POLE * cos_theta * cos_theta / TOTAL_MASS));
78 let x_acc = temp - POLE_MASS_LEN * theta_acc * cos_theta / TOTAL_MASS;
79
80 let new_x = x + TAU * x_dot;
81 let new_x_dot = x_dot + TAU * x_acc;
82 let new_theta = theta + TAU * theta_dot;
83 let new_theta_dot = theta_dot + TAU * theta_acc;
84
85 self.state = [new_x, new_x_dot, new_theta, new_theta_dot];
86
87 let status = if Self::is_terminal(&self.state) {
88 EpisodeStatus::Terminated
89 } else {
90 EpisodeStatus::Continuing
91 };
92
93 let reward = if status == EpisodeStatus::Continuing {
96 1.0
97 } else {
98 0.0
99 };
100
101 StepResult::new(self.state, reward, status, ())
102 }
103
104 fn reset(&mut self, seed: Option<u64>) -> (CartPoleObs, ()) {
105 if let Some(s) = seed {
106 self.rng = SmallRng::seed_from_u64(s);
107 }
108 self.state = self.rng.gen::<[f32; 4]>().map(|v| v * 0.1 - 0.05);
110 (self.state, ())
111 }
112
113 fn sample_action(&self, rng: &mut impl Rng) -> usize {
114 rng.gen_range(0..2)
115 }
116}
117
118fn run_episode(env: &mut TimeLimit<CartPole>, rng: &mut SmallRng) -> (f64, EpisodeStatus, usize) {
121 env.reset(None);
122 let mut total_reward = 0.0;
123 let mut steps = 0;
124
125 loop {
126 let action = env.sample_action(rng);
127 let result = env.step(action);
128 total_reward += result.reward;
129 steps += 1;
130 if result.is_done() {
131 return (total_reward, result.status, steps);
132 }
133 }
134}
135
136fn main() {
137 const NUM_EPISODES: usize = 10;
138 const ENV_SEED: u64 = 42;
139
140 let mut env = TimeLimit::new(CartPole::new(ENV_SEED), MAX_STEPS);
141 let mut rng = SmallRng::seed_from_u64(0);
142
143 println!("CartPole-v1 — random agent, {NUM_EPISODES} episodes\n");
144 println!("{:<8} {:>8} {:>7} {:>12}", "Episode", "Return", "Steps", "Outcome");
145 println!("{}", "-".repeat(40));
146
147 let mut total_return = 0.0;
148
149 for ep in 1..=NUM_EPISODES {
150 let (ret, status, steps) = run_episode(&mut env, &mut rng);
151 total_return += ret;
152
153 let outcome = match status {
154 EpisodeStatus::Terminated => "Terminated",
155 EpisodeStatus::Truncated => "Truncated ",
156 EpisodeStatus::Continuing => unreachable!(),
157 };
158
159 println!("{ep:<8} {ret:>8.1} {steps:>7} {outcome:>12}");
160 }
161
162 println!("{}", "-".repeat(40));
163 println!(
164 "Mean return over {NUM_EPISODES} episodes: {:.1}",
165 total_return / NUM_EPISODES as f64
166 );
167}