use learnwell::environment::Environment;
use show_image::{ImageInfo, ImageView};
#[derive(Hash, Eq, PartialEq, Clone)]
pub struct TaxiState {
taxi: Point,
dropoff: Point,
passenger: Point,
in_taxi: bool,
}
impl Default for TaxiState {
fn default() -> Self {
let points = [
Point { y: 0, x: 0 },
Point { y: 0, x: 4 },
Point { y: 4, x: 0 },
Point { y: 4, x: 3 },
];
let passenger = points[fastrand::usize(0..points.len())].clone();
let dropoff;
loop {
let temp = points[fastrand::usize(0..points.len())].clone();
if temp.x != passenger.x || temp.y != passenger.y {
dropoff = temp;
break;
}
}
let taxi = Point {
x: fastrand::usize(0..5),
y: fastrand::usize(0..5),
};
TaxiState {
taxi,
dropoff,
passenger,
in_taxi: false,
}
}
}
#[derive(Hash, Eq, PartialEq, Clone)]
pub struct Point {
x: usize,
y: usize,
}
#[derive(Clone, Hash, PartialEq, Eq)]
pub enum TaxiAction {
Up,
Down,
Left,
Right,
Dropoff,
Pickup,
}
#[derive(Default)]
pub struct TaxiEnvironment {
found: usize,
steps: usize,
state: TaxiState,
pixels: Vec<u8>,
}
impl Environment<TaxiState, TaxiAction> for TaxiEnvironment {
fn state(&self) -> TaxiState {
self.state.clone()
}
fn reset(&mut self, epoch: usize) {
self.state = TaxiState::default();
let update = 20;
if epoch % update == 0 {
println!(
"{epoch}: found {:.2}% avg steps:{}",
100. * self.found as f32 / update as f32,
self.steps / update
);
self.found = 0;
self.steps = 0;
}
}
fn all_actions(&self) -> Vec<TaxiAction> {
vec![
TaxiAction::Up,
TaxiAction::Down,
TaxiAction::Left,
TaxiAction::Right,
TaxiAction::Dropoff,
TaxiAction::Pickup,
]
}
fn take_action_get_reward(&mut self, action: &TaxiAction) -> f64 {
self.steps += 1;
let mut reward = -1.;
let state = &mut self.state;
match (state.taxi.y, state.taxi.x, action) {
(0, _, TaxiAction::Up) => reward = -10.,
(4, _, TaxiAction::Down) => reward = -10.,
(_, 0, TaxiAction::Left) => reward = -10.,
(_, 4, TaxiAction::Right) => reward = -10.,
(3, 0, TaxiAction::Right) => reward = -10.,
(4, 0, TaxiAction::Right) => reward = -10.,
(0, 1, TaxiAction::Right) => reward = -10.,
(3, 2, TaxiAction::Right) => reward = -10.,
(4, 2, TaxiAction::Right) => reward = -10.,
(3, 1, TaxiAction::Left) => reward = -10.,
(4, 1, TaxiAction::Left) => reward = -10.,
(0, 2, TaxiAction::Left) => reward = -10.,
(3, 3, TaxiAction::Left) => reward = -10.,
(4, 3, TaxiAction::Left) => reward = -10.,
(r, c, TaxiAction::Dropoff)
if r == state.dropoff.y && c == state.dropoff.x && state.in_taxi =>
{
state.in_taxi = false;
reward = 20.
}
(r, c, TaxiAction::Pickup)
if r == state.passenger.y && c == state.passenger.x && !state.in_taxi =>
{
state.in_taxi = true;
reward = 10.
}
(_, _, TaxiAction::Up) => state.taxi.y -= 1,
(_, _, TaxiAction::Down) => state.taxi.y += 1,
(_, _, TaxiAction::Left) => state.taxi.x -= 1,
(_, _, TaxiAction::Right) => state.taxi.x += 1,
_ => reward = -10., }
if state.in_taxi {
state.passenger.x = state.taxi.x;
state.passenger.y = state.taxi.y;
}
reward
}
fn should_stop(&mut self, step: usize) -> bool {
if step > 100 {
true
} else if self.is_finished() {
self.found += 1;
true
} else {
false
}
}
fn get_image(&mut self) -> show_image::ImageView {
self.save_image();
ImageView::new(ImageInfo::rgb8(10, 5), &self.pixels)
}
}
impl TaxiEnvironment {
fn is_finished(&self) -> bool {
let state = &self.state;
state.passenger.x == state.dropoff.x
&& state.passenger.y == state.dropoff.y
&& !state.in_taxi
}
fn save_image(&mut self) {
let mut pixels = [0u8; 150];
let pass = &self.state.passenger;
let drop = &self.state.dropoff;
let taxi = &self.state.taxi;
pixels[0 * 3 * 2 + 3 * 3 + 2] = 50;
pixels[3 * 5 * 3 * 2 + 1 * 3 + 2] = 50;
pixels[3 * 5 * 3 * 2 + 5 * 3 + 2] = 50;
pixels[4 * 5 * 3 * 2 + 1 * 3 + 2] = 50;
pixels[4 * 5 * 3 * 2 + 5 * 3 + 2] = 50;
pixels[pass.y * 5 * 3 * 2 + pass.x * 3 * 2 + 1] = 255; pixels[drop.y * 5 * 3 * 2 + drop.x * 3 * 2 + 0] = 255;
if self.state.in_taxi {
pixels[taxi.y * 5 * 3 * 2 + taxi.x * 3 * 2 + 0] = 255;
pixels[taxi.y * 5 * 3 * 2 + taxi.x * 3 * 2 + 1] = 255; } else {
pixels[taxi.y * 5 * 3 * 2 + taxi.x * 3 * 2 + 0] = 255;
pixels[taxi.y * 5 * 3 * 2 + taxi.x * 3 * 2 + 2] = 255; }
self.pixels = pixels.to_vec();
}
}