#![allow(mixed_script_confusables)]
#![allow(confusable_idents)]
use easy_ml::matrices::Matrix;
use rand::{Rng, SeedableRng};
use std::fmt;
#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord)]
enum Cell {
Path,
Goal,
Cliff,
}
#[derive(Clone, Copy, Debug, Eq, PartialEq, PartialOrd, Ord)]
enum Direction {
North,
East,
South,
West,
}
const DIRECTIONS: usize = 4;
impl Direction {
fn order(&self) -> usize {
match self {
Direction::North => 0,
Direction::East => 1,
Direction::South => 2,
Direction::West => 3,
}
}
fn actions() -> [Direction; DIRECTIONS] {
[
Direction::North,
Direction::East,
Direction::South,
Direction::West,
]
}
}
impl Cell {
fn to_str(&self) -> &'static str {
match self {
Cell::Path => "_",
Cell::Goal => "G",
Cell::Cliff => "^",
}
}
}
impl fmt::Display for Cell {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_str())
}
}
type Position = (usize, usize);
struct GridWorld {
start: Position,
tiles: Matrix<Cell>,
agent: Position,
expected_rewards: Vec<f64>,
steps: u64,
reward: f64,
}
trait Policy {
fn choose(&mut self, choices: &[(Direction, f64); DIRECTIONS]) -> Direction;
}
struct Greedy;
impl Policy for Greedy {
fn choose(&mut self, choices: &[(Direction, f64); DIRECTIONS]) -> Direction {
let mut best_q = -f64::INFINITY;
let mut best_direction = Direction::North;
for &(d, q) in choices {
if q > best_q {
best_direction = d;
best_q = q;
}
}
best_direction
}
}
struct EpsilonGreedy {
rng: rand_chacha::ChaCha8Rng,
exploration_rate: f64,
}
impl Policy for EpsilonGreedy {
fn choose(&mut self, choices: &[(Direction, f64); DIRECTIONS]) -> Direction {
let random: f64 = self.rng.random();
if random < self.exploration_rate {
choices[self.rng.random_range(0..choices.len())].0
} else {
Greedy.choose(choices)
}
}
}
impl<P: Policy> Policy for &mut P {
fn choose(&mut self, choices: &[(Direction, f64); DIRECTIONS]) -> Direction {
P::choose(self, choices)
}
}
impl fmt::Display for GridWorld {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
#[rustfmt::skip]
self.tiles.row_major_iter().with_index().try_for_each(|((r, c), cell)| {
write!(
f,
"{}{}",
if (c, r) == self.agent { "A" } else { cell.to_str() },
if c == self.tiles.columns() - 1 { "\n" } else { ", " }
)
})?;
Ok(())
}
}
impl GridWorld {
fn step(&self, current: Position, direction: Direction) -> Option<Position> {
let (x1, y1) = current;
#[rustfmt::skip]
let (x2, y2) = match direction {
Direction::North => (
x1,
y1.saturating_sub(1)
),
Direction::East => (
std::cmp::min(x1.saturating_add(1), self.tiles.columns() - 1),
y1,
),
Direction::South => (
x1,
std::cmp::min(y1.saturating_add(1), self.tiles.rows() - 1),
),
Direction::West => (
x1.saturating_sub(1),
y1
),
};
if x1 == x2 && y1 == y2 {
None
} else {
Some((x2, y2))
}
}
fn take_action(&mut self, direction: Direction) -> f64 {
if let Some((x, y)) = self.step(self.agent, direction) {
self.agent = (x, y);
match self.tiles.get(y, x) {
Cell::Cliff => {
self.agent = self.start;
-100.0
}
Cell::Path => -1.0,
Cell::Goal => 0.0,
}
} else {
-1.0
}
}
fn q_sarsa(
&mut self,
step_size: f64,
mut policy: impl Policy,
discount_factor: f64,
q_sarsa: f64,
) {
let (α, γ) = (step_size, discount_factor);
let actions = Direction::actions();
let mut state = self.agent;
let mut action = policy.choose(&actions.map(|d| (d, self.q(state, d))));
while self.tiles.get(self.agent.1, self.agent.0) != Cell::Goal {
let reward = self.take_action(action);
self.reward += reward;
let new_state = self.agent;
let new_action = policy.choose(&actions.map(|d| (d, self.q(new_state, d))));
let greedy_action = Greedy.choose(&actions.map(|d| (d, self.q(new_state, d))));
let expected_q_value = q_sarsa * self.q(new_state, new_action)
+ ((1.0 - q_sarsa) * self.q(new_state, greedy_action));
*self.q_mut(state, action) = self.q(state, action)
+ α * (reward + (γ * expected_q_value) - self.q(state, action));
state = new_state;
action = new_action;
self.steps += 1;
}
}
#[allow(dead_code)]
fn q_learning(&mut self, step_size: f64, policy: impl Policy, discount_factor: f64) {
self.q_sarsa(step_size, policy, discount_factor, 0.0);
}
fn sarsa(&mut self, step_size: f64, policy: impl Policy, discount_factor: f64) {
self.q_sarsa(step_size, policy, discount_factor, 1.0);
}
fn q(&self, position: Position, direction: Direction) -> f64 {
*&self.expected_rewards[index(position, direction, self.tiles.columns(), DIRECTIONS)]
}
fn q_mut(&mut self, position: Position, direction: Direction) -> &mut f64 {
&mut self.expected_rewards[index(position, direction, self.tiles.columns(), DIRECTIONS)]
}
fn reset(&mut self) {
self.steps = 0;
self.reward = 0.0;
self.agent = self.start;
}
}
fn index(position: Position, direction: Direction, width: usize, directions: usize) -> usize {
direction.order() + (directions * (position.0 + (position.1 * width)))
}
fn main() {
let mut grid_world = GridWorld {
tiles: {
use Cell::Cliff as C;
use Cell::Goal as G;
use Cell::Path as P;
Matrix::from(vec![
vec![P, P, P, P, P, P, P, P, P, P, P, P],
vec![P, P, P, P, P, P, P, P, P, P, P, P],
vec![P, P, P, P, P, P, P, P, P, P, P, P],
vec![P, C, C, C, C, C, C, C, C, C, C, G],
])
},
start: (0, 3),
agent: (0, 3),
expected_rewards: vec![0.0; DIRECTIONS * 4 * 12],
steps: 0,
reward: 0.0,
};
let episodes = 100;
let mut policy = EpsilonGreedy {
rng: rand_chacha::ChaCha8Rng::seed_from_u64(16),
exploration_rate: 0.1,
};
let mut total_steps = 0;
for n in 0..episodes {
grid_world.reset();
grid_world.sarsa(0.5, &mut policy, 0.9);
total_steps += grid_world.steps;
println!(
"Steps to complete episode {:?}:\t{:?}/{:?}\t\tSum of rewards during episode: {:?}",
n, grid_world.steps, total_steps, grid_world.reward
);
}
}