1use std::collections::HashMap;
6
7use rurel::mdp::{Agent, State};
8use rurel::strategy::explore::RandomExploration;
9use rurel::strategy::learn::QLearning;
10use rurel::strategy::terminate::FixedIterations;
11use rurel::AgentTrainer;
12
13#[derive(PartialEq, Eq, Hash, Clone)]
14struct MyState {
15 x: i32,
16 y: i32,
17 maxx: i32,
18 maxy: i32,
19}
20
21#[derive(PartialEq, Eq, Hash, Clone)]
22enum MyAction {
23 Move { dx: i32, dy: i32 },
24}
25
26impl State for MyState {
27 type A = MyAction;
28
29 fn reward(&self) -> f64 {
30 let (tx, ty) = (10, 10);
31 let d = (((tx - self.x).pow(2) + (ty - self.y).pow(2)) as f64).sqrt();
32 -d
33 }
34
35 fn actions(&self) -> Vec<MyAction> {
36 vec![
37 MyAction::Move { dx: -1, dy: 0 },
38 MyAction::Move { dx: 1, dy: 0 },
39 MyAction::Move { dx: 0, dy: -1 },
40 MyAction::Move { dx: 0, dy: 1 },
41 ]
42 }
43}
44
45struct MyAgent {
46 state: MyState,
47}
48
49impl Agent<MyState> for MyAgent {
50 fn current_state(&self) -> &MyState {
51 &self.state
52 }
53
54 fn take_action(&mut self, action: &MyAction) {
55 match action {
56 &MyAction::Move { dx, dy } => {
57 self.state = MyState {
58 x: (((self.state.x + dx) % self.state.maxx) + self.state.maxx)
59 % self.state.maxx,
60 y: (((self.state.y + dy) % self.state.maxy) + self.state.maxy)
61 % self.state.maxy,
62 ..self.state.clone()
63 };
64 }
65 }
66 }
67}
68
69fn main() {
70 let initial_state = MyState {
71 x: 0,
72 y: 0,
73 maxx: 21,
74 maxy: 21,
75 };
76 let mut trainer = AgentTrainer::new();
77 let mut agent = MyAgent {
78 state: initial_state.clone(),
79 };
80 trainer.train(
81 &mut agent,
82 &QLearning::new(0.2, 0.01, 2.),
83 &mut FixedIterations::new(100000),
84 &RandomExploration::new(),
85 );
86 for j in 0..21 {
87 for i in 0..21 {
88 let entry: &HashMap<MyAction, f64> = trainer
89 .expected_values(&MyState {
90 x: i,
91 y: j,
92 ..initial_state
93 })
94 .unwrap();
95 let best_action = entry
96 .iter()
97 .max_by(|(_, v1), (_, v2)| v1.partial_cmp(v2).unwrap())
98 .map(|(v, _)| v)
99 .unwrap();
100 match best_action {
101 MyAction::Move { dx: -1, dy: 0 } => print!("<"),
102 MyAction::Move { dx: 1, dy: 0 } => print!(">"),
103 MyAction::Move { dx: 0, dy: -1 } => print!("^"),
104 MyAction::Move { dx: 0, dy: 1 } => print!("v"),
105 _ => unreachable!(),
106 };
107 }
108 println!();
109 }
110}