use rand::RngExt;
use rand::SeedableRng;
use rand::rngs::StdRng;
use rlevo_core::action::DiscreteAction;
use rlevo_core::base::{Action, Observation, State};
use rlevo_core::environment::{ConstructableEnv, Environment, EnvironmentError, SnapshotBase};
use rlevo_core::reward::ScalarReward;
use rlevo_core::state::StateError;
use serde::{Deserialize, Serialize};
const NROW: u8 = 4;
const NCOL: u8 = 12;
const START: (u8, u8) = (3, 0);
const GOAL: (u8, u8) = (3, 11);
#[derive(Debug, Clone, Default)]
pub struct CliffWalkingConfig {
pub is_slippery: bool,
pub seed: u64,
}
impl CliffWalkingConfig {
pub fn builder() -> CliffWalkingConfigBuilder {
CliffWalkingConfigBuilder::default()
}
}
#[derive(Default)]
pub struct CliffWalkingConfigBuilder {
is_slippery: bool,
seed: u64,
}
impl CliffWalkingConfigBuilder {
pub fn is_slippery(mut self, v: bool) -> Self {
self.is_slippery = v;
self
}
pub fn seed(mut self, s: u64) -> Self {
self.seed = s;
self
}
pub fn build(self) -> CliffWalkingConfig {
CliffWalkingConfig {
is_slippery: self.is_slippery,
seed: self.seed,
}
}
}
#[derive(Debug, Clone)]
pub struct CliffWalkingState {
pub row: u8,
pub col: u8,
}
impl CliffWalkingState {
fn state_id(&self) -> u16 {
self.row as u16 * NCOL as u16 + self.col as u16
}
}
impl TryFrom<u16> for CliffWalkingState {
type Error = StateError;
fn try_from(id: u16) -> Result<Self, Self::Error> {
let n = NROW as u16 * NCOL as u16;
if id >= n {
return Err(StateError::InvalidData(format!(
"CliffWalkingState id {id} out of range [0, {n})"
)));
}
Ok(CliffWalkingState {
row: (id / NCOL as u16) as u8,
col: (id % NCOL as u16) as u8,
})
}
}
impl From<CliffWalkingState> for u16 {
fn from(s: CliffWalkingState) -> u16 {
s.state_id()
}
}
impl State<1> for CliffWalkingState {
type Observation = CliffWalkingObservation;
fn shape() -> [usize; 1] {
[NROW as usize * NCOL as usize]
}
fn observe(&self) -> CliffWalkingObservation {
CliffWalkingObservation {
state_id: self.state_id(),
}
}
fn is_valid(&self) -> bool {
self.row < NROW && self.col < NCOL
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CliffWalkingObservation {
pub state_id: u16,
}
impl Observation<1> for CliffWalkingObservation {
fn shape() -> [usize; 1] {
[NROW as usize * NCOL as usize]
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CliffWalkingAction {
Up = 0,
Right = 1,
Down = 2,
Left = 3,
}
impl Action<1> for CliffWalkingAction {
fn shape() -> [usize; 1] {
[1]
}
fn is_valid(&self) -> bool {
true
}
}
impl DiscreteAction<1> for CliffWalkingAction {
const ACTION_COUNT: usize = 4;
fn from_index(index: usize) -> Self {
match index {
0 => CliffWalkingAction::Up,
1 => CliffWalkingAction::Right,
2 => CliffWalkingAction::Down,
3 => CliffWalkingAction::Left,
_ => panic!("CliffWalkingAction index {index} out of range [0, 4)"),
}
}
fn to_index(&self) -> usize {
*self as usize
}
}
impl CliffWalkingAction {
fn perpendiculars(self) -> [CliffWalkingAction; 2] {
match self {
CliffWalkingAction::Up | CliffWalkingAction::Down => {
[CliffWalkingAction::Left, CliffWalkingAction::Right]
}
CliffWalkingAction::Left | CliffWalkingAction::Right => {
[CliffWalkingAction::Up, CliffWalkingAction::Down]
}
}
}
}
fn apply_action(row: u8, col: u8, action: CliffWalkingAction) -> (u8, u8) {
match action {
CliffWalkingAction::Up => (row.saturating_sub(1), col),
CliffWalkingAction::Down => ((row + 1).min(NROW - 1), col),
CliffWalkingAction::Right => (row, (col + 1).min(NCOL - 1)),
CliffWalkingAction::Left => (row, col.saturating_sub(1)),
}
}
fn is_cliff(row: u8, col: u8) -> bool {
row == 3 && (1..=10).contains(&col)
}
#[derive(Debug)]
pub struct CliffWalking {
state: CliffWalkingState,
config: CliffWalkingConfig,
rng: StdRng,
}
impl CliffWalking {
pub fn with_config(config: CliffWalkingConfig) -> Self {
let rng = StdRng::seed_from_u64(config.seed);
Self {
state: CliffWalkingState {
row: START.0,
col: START.1,
},
config,
rng,
}
}
fn resolve_action(&mut self, action: CliffWalkingAction) -> CliffWalkingAction {
if !self.config.is_slippery {
return action;
}
let r = self.rng.random_range(0u32..3);
if r == 0 {
action
} else {
action.perpendiculars()[(r - 1) as usize]
}
}
}
impl ConstructableEnv for CliffWalking {
fn new(_render: bool) -> Self {
Self::with_config(CliffWalkingConfig::default())
}
}
impl Environment<1, 1, 1> for CliffWalking {
type StateType = CliffWalkingState;
type ObservationType = CliffWalkingObservation;
type ActionType = CliffWalkingAction;
type RewardType = ScalarReward;
type SnapshotType = SnapshotBase<1, CliffWalkingObservation, ScalarReward>;
fn reset(&mut self) -> Result<Self::SnapshotType, EnvironmentError> {
self.state = CliffWalkingState {
row: START.0,
col: START.1,
};
Ok(SnapshotBase::running(
self.state.observe(),
ScalarReward(0.0),
))
}
fn step(&mut self, action: CliffWalkingAction) -> Result<Self::SnapshotType, EnvironmentError> {
let effective = self.resolve_action(action);
let (nr, nc) = apply_action(self.state.row, self.state.col, effective);
if is_cliff(nr, nc) {
self.state.row = START.0;
self.state.col = START.1;
return Ok(SnapshotBase::running(
self.state.observe(),
ScalarReward(-100.0),
));
}
self.state.row = nr;
self.state.col = nc;
if (nr, nc) == GOAL {
Ok(SnapshotBase::terminated(
self.state.observe(),
ScalarReward(-1.0),
))
} else {
Ok(SnapshotBase::running(
self.state.observe(),
ScalarReward(-1.0),
))
}
}
}
impl From<(u8, u8)> for CliffWalkingState {
fn from((row, col): (u8, u8)) -> Self {
CliffWalkingState { row, col }
}
}
impl crate::render::AsciiRenderable for CliffWalking {
fn render_ascii(&self) -> String {
let mut out = String::with_capacity((NCOL as usize) * 2 * NROW as usize);
for row in 0..NROW {
for col in 0..NCOL {
out.push(cell_char(row, col, self.state.row, self.state.col));
out.push(' ');
}
out.push('\n');
}
out
}
fn render_styled(&self) -> crate::render::StyledFrame {
use crate::render::palette::{
AGENT_FG, AGENT_MODIFIER, GOAL_FG, GOAL_MODIFIER, HAZARD_FG, HAZARD_MODIFIER,
};
use crate::render::{Color, Modifier, SpanStyle, StyledFrame, StyledLine, StyledSpan};
let mut lines = Vec::with_capacity(NROW as usize);
for row in 0..NROW {
let mut spans: Vec<StyledSpan> = Vec::new();
let mut current_style = SpanStyle::default();
let mut current_text = String::new();
for col in 0..NCOL {
let ch = cell_char(row, col, self.state.row, self.state.col);
let style = match ch {
'@' => SpanStyle::default()
.fg(AGENT_FG)
.with_modifier(AGENT_MODIFIER),
'C' => SpanStyle::default()
.fg(HAZARD_FG)
.with_modifier(HAZARD_MODIFIER),
'G' => SpanStyle::default()
.fg(GOAL_FG)
.with_modifier(GOAL_MODIFIER),
'S' => SpanStyle::default()
.fg(Color::Yellow)
.with_modifier(Modifier::BOLD),
_ => SpanStyle::default(),
};
if style != current_style && !current_text.is_empty() {
spans.push(StyledSpan::new(
std::mem::take(&mut current_text),
current_style,
));
}
current_style = style;
current_text.push(ch);
current_text.push(' ');
}
if !current_text.is_empty() {
spans.push(StyledSpan::new(current_text, current_style));
}
lines.push(StyledLine::from_spans(spans));
}
StyledFrame { lines }
}
}
fn cell_char(row: u8, col: u8, agent_row: u8, agent_col: u8) -> char {
if row == agent_row && col == agent_col {
'@'
} else if (row, col) == GOAL {
'G'
} else if (row, col) == START {
'S'
} else if is_cliff(row, col) {
'C'
} else {
'.'
}
}
#[cfg(test)]
mod tests {
use super::*;
use rlevo_core::action::DiscreteAction;
use rlevo_core::base::Observation;
use rlevo_core::environment::Snapshot;
fn make_env() -> CliffWalking {
CliffWalking::with_config(CliffWalkingConfig::default())
}
#[test]
fn action_count() {
assert_eq!(CliffWalkingAction::ACTION_COUNT, 4);
}
#[test]
fn state_id_encoding() {
let total = NROW as u16 * NCOL as u16;
for id in 0..total {
let state = CliffWalkingState::try_from(id).unwrap();
assert_eq!(u16::from(state), id, "round-trip failed for id {id}");
}
}
#[test]
fn obs_shape() {
assert_eq!(CliffWalkingObservation::shape(), [48]);
}
#[test]
fn cliff_step_teleports_and_costs_100() {
let mut env = make_env();
env.reset().unwrap();
let snap = env.step(CliffWalkingAction::Right).unwrap();
let r: f32 = (*snap.reward()).into();
assert_eq!(r, -100.0);
assert!(!snap.is_done(), "cliff must not terminate episode");
assert_eq!(
env.state.state_id(),
CliffWalkingState::from((3u8, 0u8)).state_id()
);
}
#[test]
fn goal_step_terminates_with_minus_one() {
let mut env = make_env();
env.reset().unwrap();
env.state = CliffWalkingState { row: 2, col: 11 };
let snap = env.step(CliffWalkingAction::Down).unwrap();
let r: f32 = (*snap.reward()).into();
assert_eq!(r, -1.0);
assert!(snap.is_terminated());
}
#[test]
fn off_grid_stays_in_place() {
let mut env = make_env();
env.reset().unwrap();
env.step(CliffWalkingAction::Up).unwrap();
let snap_before = env.state.state_id();
env.step(CliffWalkingAction::Left).unwrap();
assert_eq!(
env.state.state_id(),
snap_before,
"off-grid move must be no-op"
);
}
#[test]
fn shortest_path_minus_13() {
let mut env = make_env();
env.reset().unwrap();
let mut total = 0.0_f32;
let snap = env.step(CliffWalkingAction::Up).unwrap();
{
let r: f32 = (*snap.reward()).into();
total += r;
}
for _ in 0..11 {
let snap = env.step(CliffWalkingAction::Right).unwrap();
let r: f32 = (*snap.reward()).into();
total += r;
}
let snap = env.step(CliffWalkingAction::Down).unwrap();
{
let r: f32 = (*snap.reward()).into();
total += r;
}
assert!(snap.is_done(), "goal must terminate episode");
assert!(
(total - (-13.0)).abs() < 1e-5,
"optimal path must yield -13, got {total}"
);
}
#[test]
fn slippery_distribution_matches_1_3() {
let cfg = CliffWalkingConfig::builder()
.is_slippery(true)
.seed(7)
.build();
let mut env = CliffWalking::with_config(cfg);
env.reset().unwrap();
let n = 12_000u32;
let (mut right_count, mut up_count, mut down_count) = (0u32, 0u32, 0u32);
for _ in 0..n {
env.state = CliffWalkingState { row: 1, col: 6 };
env.step(CliffWalkingAction::Right).unwrap();
match (env.state.row, env.state.col) {
(1, 7) => right_count += 1,
(0, 6) => up_count += 1,
(2, 6) => down_count += 1,
_ => {}
}
}
let p_right = right_count as f32 / n as f32;
let p_up = up_count as f32 / n as f32;
let p_down = down_count as f32 / n as f32;
let tol = 3.0 * (1.0 / 3.0 * 2.0 / 3.0 / n as f32).sqrt();
assert!(
(p_right - 1.0 / 3.0).abs() < tol,
"intended slip p={p_right}"
);
assert!((p_up - 1.0 / 3.0).abs() < tol, "perp-up p={p_up}");
assert!((p_down - 1.0 / 3.0).abs() < tol, "perp-down p={p_down}");
}
#[test]
fn determinism() {
let cfg = CliffWalkingConfig::builder()
.is_slippery(true)
.seed(3)
.build();
let run = || {
let mut env = CliffWalking::with_config(cfg.clone());
env.reset().unwrap();
let mut total = 0.0_f32;
for _ in 0..20 {
let snap = env.step(CliffWalkingAction::Right).unwrap();
let r: f32 = (*snap.reward()).into();
total += r;
if snap.is_done() {
break;
}
}
total
};
assert!((run() - run()).abs() < 1e-5, "determinism check failed");
}
#[test]
fn render_styled_matches_ascii() {
use crate::render::AsciiRenderable;
let mut env = CliffWalking::with_config(CliffWalkingConfig::default());
env.reset().unwrap();
let plain = env.render_ascii();
let styled = env.render_styled();
let plain_no_trailing: String = plain.lines().collect::<Vec<_>>().join("\n");
assert_eq!(styled.plain_text(), plain_no_trailing);
}
#[test]
fn render_styled_uses_palette_consts() {
use crate::render::AsciiRenderable;
use crate::render::palette::{AGENT_FG, GOAL_FG, HAZARD_FG};
let mut env = CliffWalking::with_config(CliffWalkingConfig::default());
env.reset().unwrap();
let styled = env.render_styled();
let mut found_agent = false;
let mut found_goal = false;
let mut found_cliff = false;
for line in &styled.lines {
for span in &line.spans {
if span.text.starts_with('@') {
assert_eq!(span.style.fg, Some(AGENT_FG));
found_agent = true;
}
if span.text.starts_with('G') {
assert_eq!(span.style.fg, Some(GOAL_FG));
found_goal = true;
}
if span.text.starts_with('C') {
assert_eq!(span.style.fg, Some(HAZARD_FG));
found_cliff = true;
}
}
}
assert!(found_agent, "agent glyph @ not found in styled output");
assert!(found_goal, "goal glyph G not found in styled output");
assert!(found_cliff, "cliff glyph C not found in styled output");
}
#[test]
fn render_ascii_within_width_budget() {
use crate::render::AsciiRenderable;
let mut env = CliffWalking::with_config(CliffWalkingConfig::default());
env.reset().unwrap();
for line in env.render_ascii().lines() {
assert!(
line.chars().count() <= 80,
"line exceeds 80 cols: {line:?} ({} chars)",
line.chars().count()
);
}
}
}
impl rlevo_core::render::payload::TabularPayloadSource for CliffWalking {
fn tabular_snapshot(&self) -> rlevo_core::render::payload::TabularSnapshot {
use rlevo_core::render::payload::{
TabularCell, TabularGrid, TabularLayout, TabularMarker, TabularMarkerKind,
TabularSnapshot,
};
let height: u8 = 4;
let mut cells = Vec::with_capacity(usize::from(height) * usize::from(NCOL));
for row in 0..height {
for col in 0..NCOL {
let cell = if (row, col) == START {
TabularCell::Start
} else if (row, col) == GOAL {
TabularCell::Goal
} else if is_cliff(row, col) {
TabularCell::Hazard
} else {
TabularCell::Empty
};
cells.push(cell);
}
}
TabularSnapshot {
layout: TabularLayout::Grid(TabularGrid {
width: u16::from(NCOL),
height: u16::from(height),
cells,
markers: vec![TabularMarker {
x: u16::from(self.state.col),
y: u16::from(self.state.row),
kind: TabularMarkerKind::Agent,
}],
}),
}
}
}