Skip to main content

cartpole/
cartpole.rs

1//! CartPole-v1 — a reference implementation of the classic control task.
2//!
3//! This example exists to validate the [`rl_traits`] API under a realistic
4//! environment. It mirrors the semantics of Gymnasium's `CartPole-v1`:
5//!
6//! - 4-dimensional continuous observation space: `[x, ẋ, θ, θ̇]`
7//! - Discrete action space: `0` = push left, `1` = push right
8//! - +1 reward every step the pole stays up
9//! - **Terminated** when the pole tips past ±12° or the cart leaves ±2.4 m
10//! - **Truncated** at 500 steps via [`rl_traits::TimeLimit`]
11//!
12//! Run with:
13//! ```text
14//! cargo run --example cartpole
15//! ```
16
17use rand::rngs::SmallRng;
18use rand::{Rng, SeedableRng as _};
19use rl_traits::{EpisodeStatus, Environment, StepResult, TimeLimit};
20
21// ── Physical constants (identical to Gymnasium's CartPole-v1) ────────────────
22
23const GRAVITY: f32 = 9.8;
24const MASS_CART: f32 = 1.0;
25const MASS_POLE: f32 = 0.1;
26const TOTAL_MASS: f32 = MASS_CART + MASS_POLE;
27const HALF_POLE_LEN: f32 = 0.5; // centre of mass is at the midpoint
28const POLE_MASS_LEN: f32 = MASS_POLE * HALF_POLE_LEN;
29const FORCE_MAG: f32 = 10.0;
30const TAU: f32 = 0.02; // seconds per step
31
32const X_THRESHOLD: f32 = 2.4;
33const THETA_THRESHOLD_RAD: f32 = 12.0 * std::f32::consts::PI / 180.0;
34
35const MAX_STEPS: usize = 500;
36
37// ── State ────────────────────────────────────────────────────────────────────
38
39/// `[cart_position, cart_velocity, pole_angle, pole_angular_velocity]`
40pub type CartPoleObs = [f32; 4];
41
42// ── Environment ──────────────────────────────────────────────────────────────
43
44pub struct CartPole {
45    state: CartPoleObs,
46    rng: SmallRng,
47}
48
49impl CartPole {
50    pub fn new(seed: u64) -> Self {
51        Self {
52            state: [0.0; 4],
53            rng: SmallRng::seed_from_u64(seed),
54        }
55    }
56
57    fn is_terminal(state: &CartPoleObs) -> bool {
58        state[0].abs() > X_THRESHOLD || state[2].abs() > THETA_THRESHOLD_RAD
59    }
60}
61
62impl Environment for CartPole {
63    type Observation = CartPoleObs;
64    type Action = usize; // 0 = left, 1 = right
65    type Info = ();
66
67    fn step(&mut self, action: usize) -> StepResult<CartPoleObs, ()> {
68        let [x, x_dot, theta, theta_dot] = self.state;
69
70        let force = if action == 1 { FORCE_MAG } else { -FORCE_MAG };
71        let cos_theta = theta.cos();
72        let sin_theta = theta.sin();
73
74        // Equations of motion (Euler integration, same as Gymnasium)
75        let temp = (force + POLE_MASS_LEN * theta_dot * theta_dot * sin_theta) / TOTAL_MASS;
76        let theta_acc = (GRAVITY * sin_theta - cos_theta * temp)
77            / (HALF_POLE_LEN * (4.0 / 3.0 - MASS_POLE * cos_theta * cos_theta / TOTAL_MASS));
78        let x_acc = temp - POLE_MASS_LEN * theta_acc * cos_theta / TOTAL_MASS;
79
80        let new_x = x + TAU * x_dot;
81        let new_x_dot = x_dot + TAU * x_acc;
82        let new_theta = theta + TAU * theta_dot;
83        let new_theta_dot = theta_dot + TAU * theta_acc;
84
85        self.state = [new_x, new_x_dot, new_theta, new_theta_dot];
86
87        let status = if Self::is_terminal(&self.state) {
88            EpisodeStatus::Terminated
89        } else {
90            EpisodeStatus::Continuing
91        };
92
93        // +1 every step the pole stays up (reward is 0 on the terminal step
94        // in Gymnasium; we match that here)
95        let reward = if status == EpisodeStatus::Continuing {
96            1.0
97        } else {
98            0.0
99        };
100
101        StepResult::new(self.state, reward, status, ())
102    }
103
104    fn reset(&mut self, seed: Option<u64>) -> (CartPoleObs, ()) {
105        if let Some(s) = seed {
106            self.rng = SmallRng::seed_from_u64(s);
107        }
108        // Uniform [-0.05, 0.05] for all four state variables
109        self.state = self.rng.gen::<[f32; 4]>().map(|v| v * 0.1 - 0.05);
110        (self.state, ())
111    }
112
113    fn sample_action(&self, rng: &mut impl Rng) -> usize {
114        rng.gen_range(0..2)
115    }
116}
117
118// ── Demo loop ────────────────────────────────────────────────────────────────
119
120fn run_episode(env: &mut TimeLimit<CartPole>, rng: &mut SmallRng) -> (f64, EpisodeStatus, usize) {
121    env.reset(None);
122    let mut total_reward = 0.0;
123    let mut steps = 0;
124
125    loop {
126        let action = env.sample_action(rng);
127        let result = env.step(action);
128        total_reward += result.reward;
129        steps += 1;
130        if result.is_done() {
131            return (total_reward, result.status, steps);
132        }
133    }
134}
135
136fn main() {
137    const NUM_EPISODES: usize = 10;
138    const ENV_SEED: u64 = 42;
139
140    let mut env = TimeLimit::new(CartPole::new(ENV_SEED), MAX_STEPS);
141    let mut rng = SmallRng::seed_from_u64(0);
142
143    println!("CartPole-v1 — random agent, {NUM_EPISODES} episodes\n");
144    println!("{:<8} {:>8} {:>7} {:>12}", "Episode", "Return", "Steps", "Outcome");
145    println!("{}", "-".repeat(40));
146
147    let mut total_return = 0.0;
148
149    for ep in 1..=NUM_EPISODES {
150        let (ret, status, steps) = run_episode(&mut env, &mut rng);
151        total_return += ret;
152
153        let outcome = match status {
154            EpisodeStatus::Terminated => "Terminated",
155            EpisodeStatus::Truncated => "Truncated ",
156            EpisodeStatus::Continuing => unreachable!(),
157        };
158
159        println!("{ep:<8} {ret:>8.1} {steps:>7} {outcome:>12}");
160    }
161
162    println!("{}", "-".repeat(40));
163    println!(
164        "Mean return over {NUM_EPISODES} episodes: {:.1}",
165        total_return / NUM_EPISODES as f64
166    );
167}