border_tch_agent/dqn/
explorer.rs

1//! Exploration strategies of DQN.
2use std::convert::TryInto;
3
4use serde::{Deserialize, Serialize};
5use tch::Tensor;
6
7/// Explorers for DQN.
8#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
9pub enum DqnExplorer {
10    /// Softmax action selection.
11    Softmax(Softmax),
12
13    /// Epsilon-greedy action selection.
14    EpsilonGreedy(EpsilonGreedy),
15}
16
17/// Softmax explorer for DQN.
18#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
19pub struct Softmax {}
20
21#[allow(clippy::new_without_default)]
22impl Softmax {
23    /// Constructs softmax explorer.
24    pub fn new() -> Self {
25        Self {}
26    }
27
28    /// Takes an action based on the observation and the critic.
29    pub fn action(&mut self, a: &Tensor) -> Tensor {
30        a.softmax(-1, tch::Kind::Float).multinomial(1, true)
31    }
32}
33
34/// Epsilon-greedy explorer for DQN.
35#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
36pub struct EpsilonGreedy {
37    pub n_opts: usize,
38    pub eps_start: f64,
39    pub eps_final: f64,
40    pub final_step: usize,
41}
42
43#[allow(clippy::new_without_default)]
44impl EpsilonGreedy {
45    /// Constructs epsilon-greedy explorer.
46    pub fn new() -> Self {
47        Self {
48            n_opts: 0,
49            eps_start: 1.0,
50            eps_final: 0.02,
51            final_step: 100_000,
52        }
53    }
54
55    /// Constructs epsilon-greedy explorer.
56    ///
57    /// TODO: improve interface.
58    pub fn with_final_step(final_step: usize) -> DqnExplorer {
59        DqnExplorer::EpsilonGreedy(Self {
60            n_opts: 0,
61            eps_start: 1.0,
62            eps_final: 0.02,
63            final_step,
64        })
65    }
66
67    /// Takes an action based on the observation and the critic.
68    pub fn action(&mut self, a: &Tensor) -> Tensor {
69        let d = (self.eps_start - self.eps_final) / (self.final_step as f64);
70        let eps = (self.eps_start - d * self.n_opts as f64).max(self.eps_final);
71        let r = fastrand::f64();
72        let is_random = r < eps;
73        self.n_opts += 1;
74
75        let best = a.argmax(-1, true);
76
77        if is_random {
78            let n_procs = a.size()[0] as u32;
79            let n_actions = a.size()[1] as u32;
80            let act = Tensor::from_slice(
81                (0..n_procs)
82                    .map(|_| fastrand::u32(..n_actions) as i32)
83                    .collect::<Vec<_>>()
84                    .as_slice(),
85            );
86            act
87        } else {
88            best
89        }
90    }
91
92    /// Takes an action based on the observation and the critic.
93    pub fn action_with_best(&mut self, a: &Tensor) -> (Tensor, bool) {
94        let d = (self.eps_start - self.eps_final) / (self.final_step as f64);
95        let eps = (self.eps_start - d * self.n_opts as f64).max(self.eps_final);
96        let r = fastrand::f64();
97        let is_random = r < eps;
98        self.n_opts += 1;
99
100        let best = a.argmax(-1, true);
101
102        if is_random {
103            let n_procs = a.size()[0] as u32;
104            let n_actions = a.size()[1] as u32;
105            let act = Tensor::from_slice(
106                (0..n_procs)
107                    .map(|_| fastrand::u32(..n_actions) as i32)
108                    .collect::<Vec<_>>()
109                    .as_slice(),
110            );
111            let diff: i64 = (&act - &best.to(tch::Device::Cpu))
112                .abs()
113                .sum(tch::Kind::Int64)
114                .try_into()
115                .unwrap();
116            (act, diff == 0)
117        } else {
118            (best, true)
119        }
120    }
121
122    /// Set the epsilon value at the final step.
123    pub fn eps_final(self, v: f64) -> Self {
124        let mut s = self;
125        s.eps_final = v;
126        s
127    }
128
129    /// Set the epsilon value at the start.
130    pub fn eps_start(self, v: f64) -> Self {
131        let mut s = self;
132        s.eps_start = v;
133        s
134    }
135}