use std::collections::HashMap;
use crate::env::{Env, InfoValue, RenderFrame, RenderMode, ResetResult, StepResult};
use crate::error::{Error, Result};
#[cfg(feature = "render")]
use crate::render::{Canvas, RenderWindow, sprites::CliffWalkingSprites};
use crate::rng::{self, Rng};
use crate::space::{Discrete, Space};
#[cfg(feature = "render")]
const CELL_SIZE: u32 = 60;
#[cfg(feature = "render")]
const RENDER_FPS: usize = 4;
const UP: i64 = 0;
const RIGHT: i64 = 1;
const DOWN: i64 = 2;
const LEFT: i64 = 3;
const NUM_ROWS: usize = 4;
const NUM_COLS: usize = 12;
const NUM_STATES: u64 = (NUM_ROWS * NUM_COLS) as u64; const NUM_ACTIONS: u64 = 4;
type Transition = (f64, i64, f64, bool);
#[derive(Debug, Clone, Copy)]
pub struct CliffWalkingConfig {
pub render_mode: RenderMode,
}
impl Default for CliffWalkingConfig {
fn default() -> Self {
Self {
render_mode: RenderMode::None,
}
}
}
pub struct CliffWalkingEnv {
action_space: Discrete,
observation_space: Discrete,
transitions: Vec<Vec<Vec<Transition>>>,
state: Option<i64>,
last_action: Option<i64>,
rng: Rng,
render_mode: RenderMode,
start_state: i64,
#[cfg(feature = "render")]
canvas: Option<Canvas>,
#[cfg(feature = "render")]
window: Option<RenderWindow>,
#[cfg(feature = "render")]
sprites: Option<CliffWalkingSprites>,
}
impl std::fmt::Debug for CliffWalkingEnv {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CliffWalkingEnv")
.field("state", &self.state)
.field("render_mode", &self.render_mode)
.finish_non_exhaustive()
}
}
impl CliffWalkingEnv {
#[allow(clippy::cast_possible_wrap, clippy::needless_pass_by_value)]
#[must_use]
pub fn new(config: CliffWalkingConfig) -> Self {
let start_state = (3 * NUM_COLS) as i64; let goal_state = (3 * NUM_COLS + NUM_COLS - 1) as i64;
let is_cliff = |r: usize, c: usize| -> bool { r == 3 && (1..=10).contains(&c) };
let delta = |action: i64| -> (isize, isize) {
match action {
UP => (-1, 0),
RIGHT => (0, 1),
DOWN => (1, 0),
LEFT => (0, -1),
_ => (0, 0),
}
};
#[allow(clippy::cast_possible_truncation)]
let mut transitions: Vec<Vec<Vec<Transition>>> =
vec![vec![Vec::new(); NUM_ACTIONS as usize]; NUM_STATES as usize];
for r in 0..NUM_ROWS {
for c in 0..NUM_COLS {
let s = (r * NUM_COLS + c) as i64;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
for a in 0..NUM_ACTIONS as i64 {
let li = &mut transitions[s as usize][a as usize];
let (dr, dc) = delta(a);
let nr = (r as isize + dr)
.clamp(0, (NUM_ROWS - 1) as isize)
.cast_unsigned();
let nc = (c as isize + dc)
.clamp(0, (NUM_COLS - 1) as isize)
.cast_unsigned();
let ns = (nr * NUM_COLS + nc) as i64;
if is_cliff(nr, nc) {
li.push((1.0, start_state, -100.0, false));
} else {
let terminated = ns == goal_state;
li.push((1.0, ns, -1.0, terminated));
}
}
}
}
Self {
observation_space: Discrete::new(NUM_STATES),
action_space: Discrete::new(NUM_ACTIONS),
transitions,
state: None,
last_action: None,
rng: rng::create_rng(None),
render_mode: config.render_mode,
start_state,
#[cfg(feature = "render")]
canvas: None,
#[cfg(feature = "render")]
window: None,
#[cfg(feature = "render")]
sprites: None,
}
}
#[cfg(feature = "render")]
#[allow(
clippy::cast_possible_truncation,
clippy::cast_sign_loss,
clippy::cast_possible_wrap,
clippy::many_single_char_names,
clippy::match_same_arms
)]
fn render_pixels(&mut self) -> Result<RenderFrame> {
if self.state.is_none() {
return Err(Error::ResetNeeded { method: "render" });
}
let s = self.state.expect("checked above") as usize;
let w = (NUM_COLS as u32) * CELL_SIZE;
let h = (NUM_ROWS as u32) * CELL_SIZE;
let sprites = self
.sprites
.get_or_insert_with(|| CliffWalkingSprites::new(CELL_SIZE, CELL_SIZE));
let is_cliff = |r: usize, c: usize| -> bool { r == 3 && (1..=10).contains(&c) };
let start_state_idx = 3 * NUM_COLS; let goal_state_idx = NUM_ROWS * NUM_COLS - 1;
let canvas = self.canvas.get_or_insert_with(|| Canvas::new(w, h));
canvas.clear(tiny_skia::Color::WHITE);
for state in 0..(NUM_ROWS * NUM_COLS) {
let r = state / NUM_COLS;
let c = state % NUM_COLS;
let px = (c as u32 * CELL_SIZE) as i32;
let py = (r as u32 * CELL_SIZE) as i32;
let check = (r % 2) ^ (c % 2);
canvas.blit(px, py, &sprites.bg[check]);
if is_cliff(r, c) {
canvas.blit(px, py, &sprites.cliff);
}
if r < NUM_ROWS - 1 && is_cliff(r + 1, c) {
canvas.blit(px, py, &sprites.near_cliff[check]);
}
if state == start_state_idx {
canvas.blit(px, py, &sprites.stool);
}
if state == goal_state_idx {
canvas.blit(px, py, &sprites.cookie);
}
if state == s {
let elf_y = py - (CELL_SIZE as f32 * 0.1) as i32;
let last_action = self.last_action.unwrap_or(DOWN) as usize;
let elf_idx = match last_action as i64 {
UP => 0,
RIGHT => 1,
DOWN => 2,
LEFT => 3,
_ => 2,
};
canvas.blit(px, elf_y, &sprites.elf[elf_idx]);
}
}
match self.render_mode {
RenderMode::Human => {
let window = self.window.get_or_insert_with(|| {
RenderWindow::new(
"CliffWalking \u{2014} gmgn",
w as usize,
h as usize,
RENDER_FPS,
)
.expect("failed to create render window")
});
if !window.is_open() {
return Ok(RenderFrame::None);
}
window.show(canvas)?;
Ok(RenderFrame::None)
}
RenderMode::RgbArray => {
let rgb = canvas.pixels_rgb();
Ok(RenderFrame::RgbArray {
width: w,
height: h,
data: rgb,
})
}
_ => Ok(RenderFrame::None),
}
}
}
impl Env for CliffWalkingEnv {
type Obs = i64;
type Act = i64;
type ObsSpace = Discrete;
type ActSpace = Discrete;
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn step(&mut self, action: &i64) -> Result<StepResult<i64>> {
if self.state.is_none() {
return Err(Error::ResetNeeded { method: "step" });
}
if !self.action_space.contains(action) {
return Err(Error::InvalidAction {
reason: format!("action {action} not in {{0..3}}"),
});
}
let s = self.state.expect("checked above") as usize;
let a = *action as usize;
let (p, ns, reward, terminated) = self.transitions[s][a][0];
self.state = Some(ns);
self.last_action = Some(*action);
let mut info = HashMap::new();
info.insert("prob".to_owned(), InfoValue::Float(p));
Ok(StepResult {
obs: ns,
reward,
terminated,
truncated: false,
info,
})
}
fn reset(&mut self, seed: Option<u64>) -> Result<ResetResult<i64>> {
if let Some(s) = seed {
self.rng = rng::create_rng(Some(s));
}
self.state = Some(self.start_state);
self.last_action = None;
let mut info = HashMap::new();
info.insert("prob".to_owned(), InfoValue::Float(1.0));
Ok(ResetResult {
obs: self.start_state,
info,
})
}
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
fn render(&mut self) -> Result<RenderFrame> {
match self.render_mode {
RenderMode::None => Ok(RenderFrame::None),
RenderMode::Ansi => {
if self.state.is_none() {
return Err(Error::ResetNeeded { method: "render" });
}
let s = self.state.expect("checked above") as usize;
let cur_row = s / NUM_COLS;
let cur_col = s % NUM_COLS;
let mut lines = Vec::new();
if let Some(a) = self.last_action {
let dir = match a {
UP => "Up",
RIGHT => "Right",
DOWN => "Down",
LEFT => "Left",
_ => "?",
};
lines.push(format!(" ({dir})"));
}
for r in 0..NUM_ROWS {
let mut line = String::new();
for c in 0..NUM_COLS {
let ch = if r == cur_row && c == cur_col {
'X'
} else if r == 3 && c == 0 {
'S'
} else if r == 3 && c == 11 {
'G'
} else if r == 3 && (1..=10).contains(&c) {
'C'
} else {
'.'
};
if r == cur_row && c == cur_col {
line.push('[');
line.push(ch);
line.push(']');
} else {
line.push(' ');
line.push(ch);
line.push(' ');
}
}
lines.push(line);
}
Ok(RenderFrame::Ansi(lines.join("\n")))
}
#[cfg(feature = "render")]
RenderMode::Human | RenderMode::RgbArray => self.render_pixels(),
#[cfg(not(feature = "render"))]
_ => Err(Error::UnsupportedRenderMode {
mode: format!("{:?}", self.render_mode),
}),
}
}
fn observation_space(&self) -> &Discrete {
&self.observation_space
}
fn action_space(&self) -> &Discrete {
&self.action_space
}
fn render_mode(&self) -> &RenderMode {
&self.render_mode
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_env() -> CliffWalkingEnv {
CliffWalkingEnv::new(CliffWalkingConfig::default())
}
#[test]
fn reset_starts_at_36() {
let mut env = make_env();
let r = env.reset(Some(0)).unwrap();
assert_eq!(r.obs, 36); }
#[test]
fn step_without_reset_errors() {
let mut env = make_env();
assert!(env.step(&0).is_err());
}
#[test]
fn step_invalid_action_errors() {
let mut env = make_env();
env.reset(Some(0)).unwrap();
assert!(env.step(&99).is_err());
}
#[test]
fn stepping_right_into_cliff_returns_to_start() {
let mut env = make_env();
env.reset(Some(0)).unwrap();
let r = env.step(&RIGHT).unwrap();
assert_eq!(r.obs, 36); assert!((r.reward - (-100.0)).abs() < f64::EPSILON);
assert!(!r.terminated);
}
#[test]
fn optimal_path_reaches_goal() {
let mut env = make_env();
env.reset(Some(0)).unwrap();
let r = env.step(&UP).unwrap(); assert_eq!(r.obs, 24);
for _ in 0..11 {
env.step(&RIGHT).unwrap();
}
let r = env.step(&DOWN).unwrap();
assert_eq!(r.obs, 47);
assert!(r.terminated);
assert!((r.reward - (-1.0)).abs() < f64::EPSILON);
}
#[test]
fn transitions_are_deterministic() {
let env = make_env();
for s in 0..48 {
for a in 0..4 {
assert_eq!(env.transitions[s][a].len(), 1);
assert!((env.transitions[s][a][0].0 - 1.0).abs() < f64::EPSILON);
}
}
}
}