border_tch_agent/dqn/
explorer.rs1use std::convert::TryInto;
3
4use serde::{Deserialize, Serialize};
5use tch::Tensor;
6
7#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
9pub enum DqnExplorer {
10 Softmax(Softmax),
12
13 EpsilonGreedy(EpsilonGreedy),
15}
16
17#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
19pub struct Softmax {}
20
21#[allow(clippy::new_without_default)]
22impl Softmax {
23 pub fn new() -> Self {
25 Self {}
26 }
27
28 pub fn action(&mut self, a: &Tensor) -> Tensor {
30 a.softmax(-1, tch::Kind::Float).multinomial(1, true)
31 }
32}
33
34#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
36pub struct EpsilonGreedy {
37 pub n_opts: usize,
38 pub eps_start: f64,
39 pub eps_final: f64,
40 pub final_step: usize,
41}
42
43#[allow(clippy::new_without_default)]
44impl EpsilonGreedy {
45 pub fn new() -> Self {
47 Self {
48 n_opts: 0,
49 eps_start: 1.0,
50 eps_final: 0.02,
51 final_step: 100_000,
52 }
53 }
54
55 pub fn with_final_step(final_step: usize) -> DqnExplorer {
59 DqnExplorer::EpsilonGreedy(Self {
60 n_opts: 0,
61 eps_start: 1.0,
62 eps_final: 0.02,
63 final_step,
64 })
65 }
66
67 pub fn action(&mut self, a: &Tensor) -> Tensor {
69 let d = (self.eps_start - self.eps_final) / (self.final_step as f64);
70 let eps = (self.eps_start - d * self.n_opts as f64).max(self.eps_final);
71 let r = fastrand::f64();
72 let is_random = r < eps;
73 self.n_opts += 1;
74
75 let best = a.argmax(-1, true);
76
77 if is_random {
78 let n_procs = a.size()[0] as u32;
79 let n_actions = a.size()[1] as u32;
80 let act = Tensor::from_slice(
81 (0..n_procs)
82 .map(|_| fastrand::u32(..n_actions) as i32)
83 .collect::<Vec<_>>()
84 .as_slice(),
85 );
86 act
87 } else {
88 best
89 }
90 }
91
92 pub fn action_with_best(&mut self, a: &Tensor) -> (Tensor, bool) {
94 let d = (self.eps_start - self.eps_final) / (self.final_step as f64);
95 let eps = (self.eps_start - d * self.n_opts as f64).max(self.eps_final);
96 let r = fastrand::f64();
97 let is_random = r < eps;
98 self.n_opts += 1;
99
100 let best = a.argmax(-1, true);
101
102 if is_random {
103 let n_procs = a.size()[0] as u32;
104 let n_actions = a.size()[1] as u32;
105 let act = Tensor::from_slice(
106 (0..n_procs)
107 .map(|_| fastrand::u32(..n_actions) as i32)
108 .collect::<Vec<_>>()
109 .as_slice(),
110 );
111 let diff: i64 = (&act - &best.to(tch::Device::Cpu))
112 .abs()
113 .sum(tch::Kind::Int64)
114 .try_into()
115 .unwrap();
116 (act, diff == 0)
117 } else {
118 (best, true)
119 }
120 }
121
122 pub fn eps_final(self, v: f64) -> Self {
124 let mut s = self;
125 s.eps_final = v;
126 s
127 }
128
129 pub fn eps_start(self, v: f64) -> Self {
131 let mut s = self;
132 s.eps_start = v;
133 s
134 }
135}