use std::fmt;
use rand::{SeedableRng, rngs::StdRng};
use rand_distr::{Distribution, Uniform};
use rlevo_core::{
action::DiscreteAction,
base::{Action, Observation, State, TensorConversionError, TensorConvertible},
environment::{ConstructableEnv, Environment, EnvironmentError, SnapshotBase},
reward::ScalarReward,
};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
pub struct MountainCarConfig {
pub force: f32,
pub gravity: f32,
pub min_pos: f32,
pub max_pos: f32,
pub max_speed: f32,
pub goal_position: f32,
pub goal_velocity: f32,
pub seed: u64,
}
impl Default for MountainCarConfig {
fn default() -> Self {
Self {
force: 0.001,
gravity: 0.0025,
min_pos: -1.2,
max_pos: 0.6,
max_speed: 0.07,
goal_position: 0.5,
goal_velocity: 0.0,
seed: 0,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct MountainCarState {
pub position: f32,
pub velocity: f32,
}
impl State<1> for MountainCarState {
type Observation = MountainCarObservation;
fn shape() -> [usize; 1] {
[2]
}
fn numel(&self) -> usize {
2
}
fn is_valid(&self) -> bool {
self.position.is_finite() && self.velocity.is_finite()
}
fn observe(&self) -> MountainCarObservation {
MountainCarObservation {
position: self.position,
velocity: self.velocity,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct MountainCarObservation {
pub position: f32,
pub velocity: f32,
}
impl MountainCarObservation {
pub fn to_array(&self) -> [f32; 2] {
[self.position, self.velocity]
}
}
impl Observation<1> for MountainCarObservation {
fn shape() -> [usize; 1] {
[2]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MountainCarAction {
Left,
NoAccel,
Right,
}
impl Action<1> for MountainCarAction {
fn shape() -> [usize; 1] {
[3]
}
fn is_valid(&self) -> bool {
true
}
}
impl DiscreteAction<1> for MountainCarAction {
const ACTION_COUNT: usize = 3;
fn from_index(index: usize) -> Self {
match index {
0 => Self::Left,
1 => Self::NoAccel,
2 => Self::Right,
_ => panic!("MountainCarAction index out of range: {index}"),
}
}
fn to_index(&self) -> usize {
match self {
Self::Left => 0,
Self::NoAccel => 1,
Self::Right => 2,
}
}
}
#[derive(Debug)]
pub struct MountainCar {
state: MountainCarState,
config: MountainCarConfig,
rng: StdRng,
steps: usize,
}
impl MountainCar {
pub fn with_config(config: MountainCarConfig) -> Self {
let rng = StdRng::seed_from_u64(config.seed);
Self {
state: MountainCarState {
position: -0.5,
velocity: 0.0,
},
config,
rng,
steps: 0,
}
}
pub fn steps(&self) -> usize {
self.steps
}
fn sample_init_state(&mut self) -> MountainCarState {
let pos = Uniform::new_inclusive(-0.6_f32, -0.4_f32)
.unwrap()
.sample(&mut self.rng);
MountainCarState {
position: pos,
velocity: 0.0,
}
}
fn apply_physics(
state: MountainCarState,
action: MountainCarAction,
cfg: &MountainCarConfig,
) -> MountainCarState {
let action_val = action.to_index() as f32 - 1.0; let mut vel =
state.velocity + action_val * cfg.force - (3.0 * state.position).cos() * cfg.gravity;
vel = vel.clamp(-cfg.max_speed, cfg.max_speed);
let mut pos = state.position + vel;
pos = pos.clamp(cfg.min_pos, cfg.max_pos);
if pos <= cfg.min_pos {
vel = 0.0;
}
MountainCarState {
position: pos,
velocity: vel,
}
}
fn is_terminal(state: &MountainCarState, cfg: &MountainCarConfig) -> bool {
state.position >= cfg.goal_position && state.velocity >= cfg.goal_velocity
}
}
impl fmt::Display for MountainCar {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"MountainCar(step={}, pos={:.3}, vel={:.4})",
self.steps, self.state.position, self.state.velocity
)
}
}
impl ConstructableEnv for MountainCar {
fn new(render: bool) -> Self {
let _ = render;
Self::with_config(MountainCarConfig::default())
}
}
impl Environment<1, 1, 1> for MountainCar {
type StateType = MountainCarState;
type ObservationType = MountainCarObservation;
type ActionType = MountainCarAction;
type RewardType = ScalarReward;
type SnapshotType = SnapshotBase<1, MountainCarObservation, ScalarReward>;
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
self.rng = StdRng::seed_from_u64(self.config.seed);
self.state = self.sample_init_state();
self.steps = 0;
Ok(SnapshotBase::running(
self.state.observe(),
ScalarReward(0.0),
))
}
fn step(&mut self, action: MountainCarAction) -> Result<Self::SnapshotType, EnvironmentError> {
self.state = Self::apply_physics(self.state, action, &self.config);
self.steps += 1;
let terminated = Self::is_terminal(&self.state, &self.config);
let snap = if terminated {
SnapshotBase::terminated(self.state.observe(), ScalarReward(-1.0))
} else {
SnapshotBase::running(self.state.observe(), ScalarReward(-1.0))
};
Ok(snap)
}
}
impl crate::render::AsciiRenderable for MountainCar {
fn render_ascii(&self) -> String {
let width = 40_usize;
let span = self.config.max_pos - self.config.min_pos;
let frac = ((self.state.position - self.config.min_pos) / span).clamp(0.0, 1.0);
let col = (frac * (width as f32 - 1.0)) as usize;
let mut track = vec!['.'; width];
track[col] = 'A';
let track_str: String = track.iter().collect();
format!(
"[{track_str}] pos={:.3} vel={:.4} step={}",
self.state.position, self.state.velocity, self.steps
)
}
fn render_styled(&self) -> crate::render::StyledFrame {
let line = self.render_ascii();
crate::render::StyledFrame {
lines: vec![style_mountain_car_line(&line)],
}
}
}
fn style_mountain_car_line(line: &str) -> crate::render::StyledLine {
use crate::render::palette::{AGENT_FG, AGENT_MODIFIER, WALL_FG};
use crate::render::{SpanStyle, StyledLine, StyledSpan};
let wall_style = SpanStyle::default().fg(WALL_FG);
let agent_style = SpanStyle::default()
.fg(AGENT_FG)
.with_modifier(AGENT_MODIFIER);
let Some(close_idx) = line.find(']') else {
return StyledLine::unstyled(line);
};
let track_segment = &line[..=close_idx];
let suffix = &line[close_idx + 1..];
let Some(agent_col) = track_segment.find('A') else {
return StyledLine::unstyled(line);
};
let mut spans = Vec::with_capacity(4);
spans.push(StyledSpan::new(
track_segment[..agent_col].to_string(),
wall_style,
));
spans.push(StyledSpan::new("A", agent_style));
spans.push(StyledSpan::new(
track_segment[agent_col + 1..].to_string(),
wall_style,
));
if !suffix.is_empty() {
spans.push(StyledSpan::raw(suffix.to_string()));
}
StyledLine::from_spans(spans)
}
impl<B: burn::tensor::backend::Backend> TensorConvertible<1, B> for MountainCarObservation {
fn to_tensor(&self, device: &<B as burn::tensor::backend::BackendTypes>::Device) -> burn::tensor::Tensor<B, 1> {
burn::tensor::Tensor::from_floats(self.to_array(), device)
}
fn from_tensor(tensor: burn::tensor::Tensor<B, 1>) -> Result<Self, TensorConversionError> {
let dims = tensor.dims();
if dims.as_slice() != [2] {
return Err(TensorConversionError {
message: format!("expected shape [2], got {dims:?}"),
});
}
let v = tensor
.into_data()
.into_vec::<f32>()
.map_err(|e| TensorConversionError {
message: e.to_string(),
})?;
Ok(Self {
position: v[0],
velocity: v[1],
})
}
}
impl<B: burn::tensor::backend::Backend> TensorConvertible<1, B> for MountainCarAction {
fn to_tensor(&self, device: &<B as burn::tensor::backend::BackendTypes>::Device) -> burn::tensor::Tensor<B, 1> {
let mut one_hot = [0.0_f32; 3];
one_hot[self.to_index()] = 1.0;
burn::tensor::Tensor::from_floats(one_hot, device)
}
fn from_tensor(tensor: burn::tensor::Tensor<B, 1>) -> Result<Self, TensorConversionError> {
let dims = tensor.dims();
if dims.as_slice() != [3] {
return Err(TensorConversionError {
message: format!("expected shape [3], got {dims:?}"),
});
}
let v = tensor
.into_data()
.into_vec::<f32>()
.map_err(|e| TensorConversionError {
message: e.to_string(),
})?;
let idx = v
.iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap())
.map(|(i, _)| i)
.unwrap_or(0);
Ok(Self::from_index(idx))
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlevo_core::environment::Snapshot;
fn default_env() -> MountainCar {
MountainCar::with_config(MountainCarConfig::default())
}
#[test]
fn reset_initialises_correctly() {
use rlevo_core::environment::EpisodeStatus;
let mut env = default_env();
let snap = env.reset().unwrap();
assert_eq!(snap.status(), EpisodeStatus::Running);
let obs = snap.observation();
assert!(
obs.position >= -0.6 && obs.position <= -0.4,
"position {}",
obs.position
);
assert_eq!(obs.velocity, 0.0);
}
#[test]
fn observation_shape() {
assert_eq!(MountainCarObservation::shape(), [2]);
}
#[test]
fn action_count() {
assert_eq!(MountainCarAction::ACTION_COUNT, 3);
assert_eq!(MountainCarAction::from_index(0), MountainCarAction::Left);
assert_eq!(MountainCarAction::from_index(2), MountainCarAction::Right);
}
#[test]
fn left_wall_kills_velocity() {
let cfg = MountainCarConfig::default();
let state = MountainCarState {
position: -1.19,
velocity: -0.05,
};
let next = MountainCar::apply_physics(state, MountainCarAction::Left, &cfg);
assert_eq!(next.position, cfg.min_pos);
assert_eq!(next.velocity, 0.0);
}
#[test]
fn goal_terminates() {
let mut env = default_env();
env.reset().unwrap();
env.state = MountainCarState {
position: 0.55,
velocity: 0.01,
};
let snap = env.step(MountainCarAction::Right).unwrap();
assert!(snap.is_terminated());
}
#[test]
fn reward_is_minus_one_per_step() {
let mut env = default_env();
env.reset().unwrap();
let snap = env.step(MountainCarAction::NoAccel).unwrap();
assert_eq!(*snap.reward(), ScalarReward(-1.0));
}
#[test]
fn determinism() {
let mut a = MountainCar::with_config(MountainCarConfig {
seed: 7,
..Default::default()
});
let mut b = MountainCar::with_config(MountainCarConfig {
seed: 7,
..Default::default()
});
a.reset().unwrap();
b.reset().unwrap();
for action in [
MountainCarAction::Right,
MountainCarAction::Left,
MountainCarAction::NoAccel,
] {
let sa = a.step(action).unwrap();
let sb = b.step(action).unwrap();
assert_eq!(sa.observation().to_array(), sb.observation().to_array());
}
}
#[test]
fn render_styled_matches_ascii() {
use crate::render::AsciiRenderable;
let mut env = default_env();
env.reset().unwrap();
let plain = env.render_ascii();
let styled = env.render_styled();
assert_eq!(styled.lines.len(), 1);
assert_eq!(styled.plain_text(), plain);
}
#[test]
fn render_styled_uses_palette_consts() {
use crate::render::AsciiRenderable;
use crate::render::palette::{AGENT_FG, AGENT_MODIFIER, WALL_FG};
let mut env = default_env();
env.reset().unwrap();
let styled = env.render_styled();
let line = &styled.lines[0];
let agent_span = line
.spans
.iter()
.find(|s| s.text == "A")
.expect("agent glyph span present");
assert_eq!(agent_span.style.fg, Some(AGENT_FG));
assert!(agent_span.style.modifier.contains(AGENT_MODIFIER));
let bracket_span = line
.spans
.iter()
.find(|s| s.text.starts_with('['))
.expect("track-opening span present");
assert_eq!(bracket_span.style.fg, Some(WALL_FG));
}
#[test]
fn render_ascii_within_width_budget() {
use crate::render::AsciiRenderable;
let mut env = default_env();
env.reset().unwrap();
for line in env.render_ascii().lines() {
assert!(
line.chars().count() <= 80,
"line exceeds 80 cols: {line:?} ({} chars)",
line.chars().count()
);
}
}
}
impl rlevo_core::render::payload::Classic2DPayloadSource for MountainCar {
fn classic2d_snapshot(&self) -> rlevo_core::render::payload::Classic2DSnapshot {
use rlevo_core::render::payload::{Classic2DBody, Classic2DRole, Classic2DSnapshot, Point2};
let (lo, hi) = (self.config.min_pos, self.config.max_pos);
const SAMPLES: usize = 48;
let terrain: Vec<Point2> = (0..=SAMPLES)
.map(|i| {
let x = lo + (hi - lo) * (i as f32 / SAMPLES as f32);
Point2::new(x, (3.0 * x).sin())
})
.collect();
let px = self.state.position;
let py = (3.0 * px).sin();
let r = 0.04;
let car = vec![
Point2::new(px - r, py - r),
Point2::new(px + r, py - r),
Point2::new(px + r, py + r),
Point2::new(px - r, py + r),
];
Classic2DSnapshot {
bodies: vec![
Classic2DBody { points: terrain, role: Classic2DRole::Track, closed: false },
Classic2DBody { points: car, role: Classic2DRole::Car, closed: true },
],
bounds: (Point2::new(lo - 0.1, -1.1), Point2::new(hi + 0.1, 1.1)),
}
}
}