use std::fmt;
use rand::{SeedableRng, rngs::StdRng};
use rand_distr::{Distribution, Uniform};
use rlevo_core::{
action::DiscreteAction,
base::{Action, Observation, State, TensorConversionError, TensorConvertible},
environment::{ConstructableEnv, Environment, EnvironmentError, SnapshotBase},
reward::ScalarReward,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum Integrator {
#[default]
Euler,
SemiImplicit,
}
#[derive(Debug, Clone)]
pub struct CartPoleConfig {
pub gravity: f32,
pub masscart: f32,
pub masspole: f32,
pub length: f32,
pub force_mag: f32,
pub tau: f32,
pub theta_threshold_radians: f32,
pub x_threshold: f32,
pub integrator: Integrator,
pub sutton_barto_reward: bool,
pub seed: u64,
}
impl Default for CartPoleConfig {
fn default() -> Self {
Self {
gravity: 9.8,
masscart: 1.0,
masspole: 0.1,
length: 0.5,
force_mag: 10.0,
tau: 0.02,
theta_threshold_radians: 12.0_f32.to_radians(),
x_threshold: 2.4,
integrator: Integrator::Euler,
sutton_barto_reward: false,
seed: 0,
}
}
}
#[derive(Debug, Default)]
pub struct CartPoleConfigBuilder {
inner: CartPoleConfig,
}
impl CartPoleConfig {
pub fn builder() -> CartPoleConfigBuilder {
CartPoleConfigBuilder {
inner: CartPoleConfig::default(),
}
}
}
impl CartPoleConfigBuilder {
pub fn gravity(mut self, v: f32) -> Self {
self.inner.gravity = v;
self
}
pub fn masscart(mut self, v: f32) -> Self {
self.inner.masscart = v;
self
}
pub fn masspole(mut self, v: f32) -> Self {
self.inner.masspole = v;
self
}
pub fn length(mut self, v: f32) -> Self {
self.inner.length = v;
self
}
pub fn force_mag(mut self, v: f32) -> Self {
self.inner.force_mag = v;
self
}
pub fn tau(mut self, v: f32) -> Self {
self.inner.tau = v;
self
}
pub fn theta_threshold_radians(mut self, v: f32) -> Self {
self.inner.theta_threshold_radians = v;
self
}
pub fn x_threshold(mut self, v: f32) -> Self {
self.inner.x_threshold = v;
self
}
pub fn integrator(mut self, v: Integrator) -> Self {
self.inner.integrator = v;
self
}
pub fn sutton_barto_reward(mut self, v: bool) -> Self {
self.inner.sutton_barto_reward = v;
self
}
pub fn seed(mut self, v: u64) -> Self {
self.inner.seed = v;
self
}
pub fn build(self) -> CartPoleConfig {
self.inner
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct CartPoleState {
pub x: f32,
pub x_dot: f32,
pub theta: f32,
pub theta_dot: f32,
}
impl CartPoleState {
fn new(x: f32, x_dot: f32, theta: f32, theta_dot: f32) -> Self {
Self {
x,
x_dot,
theta,
theta_dot,
}
}
}
impl State<1> for CartPoleState {
type Observation = CartPoleObservation;
fn shape() -> [usize; 1] {
[4]
}
fn numel(&self) -> usize {
4
}
fn is_valid(&self) -> bool {
self.x.is_finite()
&& self.x_dot.is_finite()
&& self.theta.is_finite()
&& self.theta_dot.is_finite()
}
fn observe(&self) -> CartPoleObservation {
CartPoleObservation {
cart_pos: self.x,
cart_vel: self.x_dot,
pole_angle: self.theta,
pole_ang_vel: self.theta_dot,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct CartPoleObservation {
pub cart_pos: f32,
pub cart_vel: f32,
pub pole_angle: f32,
pub pole_ang_vel: f32,
}
impl CartPoleObservation {
pub fn to_array(&self) -> [f32; 4] {
[
self.cart_pos,
self.cart_vel,
self.pole_angle,
self.pole_ang_vel,
]
}
}
impl Observation<1> for CartPoleObservation {
fn shape() -> [usize; 1] {
[4]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum CartPoleAction {
Left,
Right,
}
impl Action<1> for CartPoleAction {
fn shape() -> [usize; 1] {
[2]
}
fn is_valid(&self) -> bool {
true
}
}
impl DiscreteAction<1> for CartPoleAction {
const ACTION_COUNT: usize = 2;
fn from_index(index: usize) -> Self {
match index {
0 => Self::Left,
1 => Self::Right,
_ => panic!("CartPoleAction index out of range: {index}"),
}
}
fn to_index(&self) -> usize {
match self {
Self::Left => 0,
Self::Right => 1,
}
}
}
#[derive(Debug)]
pub struct CartPole {
state: CartPoleState,
config: CartPoleConfig,
rng: StdRng,
steps: usize,
}
impl CartPole {
pub fn with_config(config: CartPoleConfig) -> Self {
let rng = StdRng::seed_from_u64(config.seed);
Self {
state: CartPoleState::new(0.0, 0.0, 0.0, 0.0),
config,
rng,
steps: 0,
}
}
pub fn steps(&self) -> usize {
self.steps
}
fn sample_init_state(&mut self) -> CartPoleState {
let u = Uniform::new_inclusive(-0.05_f32, 0.05_f32).unwrap();
CartPoleState::new(
u.sample(&mut self.rng),
u.sample(&mut self.rng),
u.sample(&mut self.rng),
u.sample(&mut self.rng),
)
}
fn is_terminal(state: &CartPoleState, cfg: &CartPoleConfig) -> bool {
state.x.abs() > cfg.x_threshold
|| state.theta.abs() > cfg.theta_threshold_radians
|| !state.is_valid()
}
fn step_physics(
state: &CartPoleState,
action: CartPoleAction,
cfg: &CartPoleConfig,
) -> CartPoleState {
let force = if action == CartPoleAction::Right {
cfg.force_mag
} else {
-cfg.force_mag
};
let total_mass = cfg.masscart + cfg.masspole;
let pm_l = cfg.masspole * cfg.length;
let cos_t = state.theta.cos();
let sin_t = state.theta.sin();
let temp = (force + pm_l * state.theta_dot * state.theta_dot * sin_t) / total_mass;
let theta_acc = (cfg.gravity * sin_t - cos_t * temp)
/ (cfg.length * (4.0 / 3.0 - cfg.masspole * cos_t * cos_t / total_mass));
let x_acc = temp - pm_l * theta_acc * cos_t / total_mass;
match cfg.integrator {
Integrator::Euler => CartPoleState {
x: state.x + cfg.tau * state.x_dot,
x_dot: state.x_dot + cfg.tau * x_acc,
theta: state.theta + cfg.tau * state.theta_dot,
theta_dot: state.theta_dot + cfg.tau * theta_acc,
},
Integrator::SemiImplicit => {
let x_dot_new = state.x_dot + cfg.tau * x_acc;
let theta_dot_new = state.theta_dot + cfg.tau * theta_acc;
CartPoleState {
x: state.x + cfg.tau * x_dot_new,
x_dot: x_dot_new,
theta: state.theta + cfg.tau * theta_dot_new,
theta_dot: theta_dot_new,
}
}
}
}
}
impl fmt::Display for CartPole {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"CartPole(step={}, x={:.3}, θ={:.3}°)",
self.steps,
self.state.x,
self.state.theta.to_degrees(),
)
}
}
impl ConstructableEnv for CartPole {
fn new(render: bool) -> Self {
let _ = render;
Self::with_config(CartPoleConfig::default())
}
}
impl Environment<1, 1, 1> for CartPole {
type StateType = CartPoleState;
type ObservationType = CartPoleObservation;
type ActionType = CartPoleAction;
type RewardType = ScalarReward;
type SnapshotType = SnapshotBase<1, CartPoleObservation, ScalarReward>;
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
self.rng = StdRng::seed_from_u64(self.config.seed);
self.state = self.sample_init_state();
self.steps = 0;
Ok(SnapshotBase::running(
self.state.observe(),
ScalarReward(0.0),
))
}
fn step(&mut self, action: CartPoleAction) -> Result<Self::SnapshotType, EnvironmentError> {
let next = Self::step_physics(&self.state, action, &self.config);
self.state = next;
self.steps += 1;
let terminated = Self::is_terminal(&self.state, &self.config);
let reward = if self.config.sutton_barto_reward {
if terminated {
ScalarReward(-1.0)
} else {
ScalarReward(0.0)
}
} else {
ScalarReward(1.0)
};
let snap = if terminated {
SnapshotBase::terminated(self.state.observe(), reward)
} else {
SnapshotBase::running(self.state.observe(), reward)
};
Ok(snap)
}
}
const RENDER_WIDTH: usize = 50;
const RENDER_POLE_H: usize = 6;
const RENDER_DANGER_TIER: f32 = 0.66;
impl crate::render::AsciiRenderable for CartPole {
fn render_ascii(&self) -> String {
self.render_styled().plain_text()
}
fn render_styled(&self) -> crate::render::StyledFrame {
crate::render::StyledFrame {
lines: self.scene_lines(),
}
}
}
impl CartPole {
fn scene_lines(&self) -> Vec<crate::render::StyledLine> {
use crate::render::palette::{
AGENT_FG, AGENT_MODIFIER, GOAL_FG, GOAL_MODIFIER, HAZARD_FG, HAZARD_MODIFIER, WALL_FG,
};
use crate::render::{SpanStyle, StyledLine, StyledSpan};
let width = RENDER_WIDTH;
let cart_frac = (self.state.x / self.config.x_threshold * 0.5 + 0.5).clamp(0.0, 1.0);
let cart_col = (cart_frac * (width as f32 - 1.0)) as usize;
let theta = self.state.theta;
let threshold = self.config.theta_threshold_radians.max(f32::EPSILON);
let danger = (theta.abs() / threshold).clamp(0.0, 1.0);
let tip_frac = (theta / threshold).clamp(-1.0, 1.0);
let max_tip = RENDER_POLE_H as f32;
let pole_style = if danger >= RENDER_DANGER_TIER {
SpanStyle::default()
.fg(HAZARD_FG)
.with_modifier(HAZARD_MODIFIER)
} else {
SpanStyle::default().fg(GOAL_FG).with_modifier(GOAL_MODIFIER)
};
let glyph = if tip_frac > 0.08 {
'/'
} else if tip_frac < -0.08 {
'\\'
} else {
'|'
};
let wall_style = SpanStyle::default().fg(WALL_FG);
let agent_style = SpanStyle::default()
.fg(AGENT_FG)
.with_modifier(AGENT_MODIFIER);
let mut lines = Vec::with_capacity(RENDER_POLE_H + 2);
for r in 0..RENDER_POLE_H {
let height_from_base = (RENDER_POLE_H - r) as f32;
let frac = height_from_base / RENDER_POLE_H as f32;
let offset = (frac * tip_frac * max_tip).round() as i32;
let col = (cart_col as i32 + offset).clamp(0, width as i32 - 1) as usize;
lines.push(StyledLine::from_spans([
StyledSpan::raw(" ".repeat(col)),
StyledSpan::new(glyph.to_string(), pole_style),
StyledSpan::raw(" ".repeat(width - col - 1)),
]));
}
lines.push(StyledLine::from_spans([
StyledSpan::new("-".repeat(cart_col), wall_style),
StyledSpan::new("#", agent_style),
StyledSpan::new("-".repeat(width - cart_col - 1), wall_style),
]));
lines.push(StyledLine::unstyled(format!(
" θ={:+.1}° x={:+.2} step={}",
theta.to_degrees(),
self.state.x,
self.steps
)));
lines
}
}
impl<B: burn::tensor::backend::Backend> TensorConvertible<1, B> for CartPoleObservation {
fn to_tensor(&self, device: &<B as burn::tensor::backend::BackendTypes>::Device) -> burn::tensor::Tensor<B, 1> {
burn::tensor::Tensor::from_floats(self.to_array(), device)
}
fn from_tensor(tensor: burn::tensor::Tensor<B, 1>) -> Result<Self, TensorConversionError> {
let dims = tensor.dims();
if dims.as_slice() != [4] {
return Err(TensorConversionError {
message: format!("expected shape [4], got {dims:?}"),
});
}
let v = tensor
.into_data()
.into_vec::<f32>()
.map_err(|e| TensorConversionError {
message: e.to_string(),
})?;
Ok(Self {
cart_pos: v[0],
cart_vel: v[1],
pole_angle: v[2],
pole_ang_vel: v[3],
})
}
}
impl<B: burn::tensor::backend::Backend> TensorConvertible<1, B> for CartPoleAction {
fn to_tensor(&self, device: &<B as burn::tensor::backend::BackendTypes>::Device) -> burn::tensor::Tensor<B, 1> {
let mut one_hot = [0.0_f32; 2];
one_hot[self.to_index()] = 1.0;
burn::tensor::Tensor::from_floats(one_hot, device)
}
fn from_tensor(tensor: burn::tensor::Tensor<B, 1>) -> Result<Self, TensorConversionError> {
let dims = tensor.dims();
if dims.as_slice() != [2] {
return Err(TensorConversionError {
message: format!("expected shape [2], got {dims:?}"),
});
}
let v = tensor
.into_data()
.into_vec::<f32>()
.map_err(|e| TensorConversionError {
message: e.to_string(),
})?;
let idx = v
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
Ok(Self::from_index(idx))
}
}
#[cfg(test)]
impl CartPoleState {
fn to_array(self) -> [f32; 4] {
[self.x, self.x_dot, self.theta, self.theta_dot]
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlevo_core::environment::Snapshot;
fn default_env() -> CartPole {
CartPole::with_config(CartPoleConfig::default())
}
#[test]
fn classic2d_snapshot_has_track_cart_pole_and_upright_pole_at_theta_zero() {
use rlevo_core::render::payload::{Classic2DPayloadSource, Classic2DRole};
let mut env = default_env();
env.state.x = 0.0;
env.state.theta = 0.0; let snap = env.classic2d_snapshot();
let roles: Vec<_> = snap.bodies.iter().map(|b| b.role).collect();
assert!(roles.contains(&Classic2DRole::Track));
assert!(roles.contains(&Classic2DRole::Cart));
let pole = snap
.bodies
.iter()
.find(|b| b.role == Classic2DRole::Pole)
.expect("pole body present");
assert_eq!(pole.points.len(), 2);
let (base, tip) = (pole.points[0], pole.points[1]);
assert!((tip.x - base.x).abs() < 1e-5, "upright pole must be vertical");
assert!(tip.y > base.y, "pole tip must be above the hinge when upright");
}
#[test]
fn reset_returns_running_obs_in_range() {
use rlevo_core::environment::EpisodeStatus;
let mut env = default_env();
let snap = env.reset().unwrap();
assert_eq!(snap.status(), EpisodeStatus::Running);
assert!(!snap.is_done());
let obs = snap.observation();
for v in obs.to_array() {
assert!(v.abs() <= 0.05 + f32::EPSILON, "init obs {v} out of range");
}
}
#[test]
fn observation_shape() {
assert_eq!(CartPoleObservation::shape(), [4]);
}
#[test]
fn action_count() {
assert_eq!(CartPoleAction::ACTION_COUNT, 2);
assert_eq!(CartPoleAction::from_index(0), CartPoleAction::Left);
assert_eq!(CartPoleAction::from_index(1), CartPoleAction::Right);
assert_eq!(CartPoleAction::Left.to_index(), 0);
assert_eq!(CartPoleAction::Right.to_index(), 1);
}
#[test]
fn terminates_on_large_angle() {
use rlevo_core::environment::EpisodeStatus;
let mut env = default_env();
env.reset().unwrap();
env.state.theta = 0.3;
let snap = env.step(CartPoleAction::Left).unwrap();
assert_eq!(snap.status(), EpisodeStatus::Terminated);
assert!(snap.is_terminated());
assert!(!snap.is_truncated());
}
#[test]
fn terminates_on_large_position() {
let mut env = default_env();
env.reset().unwrap();
env.state.x = 2.5;
let snap = env.step(CartPoleAction::Left).unwrap();
assert!(snap.is_terminated());
}
#[test]
fn default_reward_is_one_per_step() {
let mut env = default_env();
env.reset().unwrap();
let snap = env.step(CartPoleAction::Right).unwrap();
if !snap.is_done() {
assert_eq!(*snap.reward(), ScalarReward(1.0));
}
}
#[test]
fn sutton_barto_reward_switch() {
let config = CartPoleConfig {
sutton_barto_reward: true,
..Default::default()
};
let mut env = CartPole::with_config(config);
env.reset().unwrap();
env.state.theta = 0.3;
let snap = env.step(CartPoleAction::Left).unwrap();
assert!(snap.is_done());
assert_eq!(*snap.reward(), ScalarReward(-1.0));
}
#[test]
fn sutton_barto_zero_for_non_terminal_step() {
let config = CartPoleConfig {
sutton_barto_reward: true,
..Default::default()
};
let mut env = CartPole::with_config(config);
env.reset().unwrap();
let snap = env.step(CartPoleAction::Right).unwrap();
if !snap.is_done() {
assert_eq!(*snap.reward(), ScalarReward(0.0));
}
}
#[test]
fn determinism() {
let mut env_a = CartPole::with_config(CartPoleConfig {
seed: 42,
..Default::default()
});
let mut env_b = CartPole::with_config(CartPoleConfig {
seed: 42,
..Default::default()
});
env_a.reset().unwrap();
env_b.reset().unwrap();
let actions = [
CartPoleAction::Right,
CartPoleAction::Left,
CartPoleAction::Right,
];
for action in actions {
let sa = env_a.step(action).unwrap();
let sb = env_b.step(action).unwrap();
assert_eq!(sa.observation().to_array(), sb.observation().to_array());
}
}
#[test]
fn euler_and_semi_implicit_diverge_after_many_steps() {
let euler_cfg = CartPoleConfig {
integrator: Integrator::Euler,
seed: 1,
..Default::default()
};
let si_cfg = CartPoleConfig {
integrator: Integrator::SemiImplicit,
seed: 1,
..Default::default()
};
let mut euler_env = CartPole::with_config(euler_cfg);
let mut si_env = CartPole::with_config(si_cfg);
euler_env.reset().unwrap();
si_env.reset().unwrap();
for _ in 0..100 {
let _ = euler_env.step(CartPoleAction::Right);
let _ = si_env.step(CartPoleAction::Right);
if euler_env.state != si_env.state {
return; }
}
let diff: f32 = euler_env
.state
.to_array()
.iter()
.zip(si_env.state.to_array().iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
diff > 0.0,
"Euler and SemiImplicit produced identical states"
);
}
#[test]
fn config_builder_roundtrip() {
let cfg = CartPoleConfig::builder()
.gravity(9.81)
.masscart(2.0)
.seed(99)
.build();
assert!((cfg.gravity - 9.81).abs() < 1e-6);
assert!((cfg.masscart - 2.0).abs() < 1e-6);
assert_eq!(cfg.seed, 99);
}
#[test]
fn render_styled_matches_ascii() {
use crate::render::AsciiRenderable;
let mut env = CartPole::new(false);
env.reset().unwrap();
let plain = env.render_ascii();
let styled = env.render_styled();
assert_eq!(styled.lines.len(), RENDER_POLE_H + 2);
assert_eq!(styled.plain_text(), plain);
assert!(
plain.lines().last().unwrap().contains("step="),
"status row missing: {plain:?}"
);
}
#[test]
fn render_styled_uses_palette_consts() {
use crate::render::AsciiRenderable;
use crate::render::palette::{AGENT_FG, AGENT_MODIFIER, GOAL_FG, WALL_FG};
let mut env = CartPole::new(false);
env.reset().unwrap();
let styled = env.render_styled();
let spans = || styled.lines.iter().flat_map(|l| l.spans.iter());
let cart_span = spans()
.find(|s| s.text == "#")
.expect("cart glyph span present");
assert_eq!(cart_span.style.fg, Some(AGENT_FG));
assert!(cart_span.style.modifier.contains(AGENT_MODIFIER));
let track_span = spans()
.find(|s| s.text.starts_with('-'))
.expect("track dash span present");
assert_eq!(track_span.style.fg, Some(WALL_FG));
let pole_span = spans()
.find(|s| matches!(s.text.as_str(), "|" | "/" | "\\"))
.expect("pole glyph span present");
assert_eq!(pole_span.style.fg, Some(GOAL_FG));
}
#[test]
fn render_ascii_within_width_budget() {
use crate::render::AsciiRenderable;
let mut env = CartPole::new(false);
env.reset().unwrap();
for line in env.render_ascii().lines() {
assert!(
line.chars().count() <= 80,
"line exceeds 80 cols: {line:?} ({} chars)",
line.chars().count()
);
}
}
}
impl rlevo_core::render::payload::Classic2DPayloadSource for CartPole {
fn classic2d_snapshot(&self) -> rlevo_core::render::payload::Classic2DSnapshot {
use rlevo_core::render::payload::{Classic2DBody, Classic2DRole, Classic2DSnapshot, Point2};
let x = self.state.x;
let theta = self.state.theta; let xt = self.config.x_threshold;
let pole_len = 2.0 * self.config.length; let (cart_w, cart_h) = (0.4_f32, 0.25_f32);
let hinge_y = cart_h; let cy = cart_h * 0.5;
let cart = vec![
Point2::new(x - cart_w * 0.5, cy - cart_h * 0.5),
Point2::new(x + cart_w * 0.5, cy - cart_h * 0.5),
Point2::new(x + cart_w * 0.5, cy + cart_h * 0.5),
Point2::new(x - cart_w * 0.5, cy + cart_h * 0.5),
];
let tip = Point2::new(x + pole_len * theta.sin(), hinge_y + pole_len * theta.cos());
Classic2DSnapshot {
bodies: vec![
Classic2DBody {
points: vec![Point2::new(-xt, 0.0), Point2::new(xt, 0.0)],
role: Classic2DRole::Track,
closed: false,
},
Classic2DBody { points: cart, role: Classic2DRole::Cart, closed: true },
Classic2DBody {
points: vec![Point2::new(x, hinge_y), tip],
role: Classic2DRole::Pole,
closed: false,
},
],
bounds: (
Point2::new(-xt - 0.2, -0.4),
Point2::new(xt + 0.2, pole_len + 0.4),
),
}
}
}