border_tch_agent/iqn/
explorer.rs

1//! Exploration strategies of IQN.
2use serde::{Deserialize, Serialize};
3use std::default::Default;
4use tch::Tensor;
5
6#[allow(clippy::upper_case_acronyms)]
7#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
8/// Explorers for IQN.
9pub enum IqnExplorer {
10    /// Softmax action selection.
11    Softmax(Softmax),
12    /// Epsilon-greedy action selection.
13    EpsilonGreedy(EpsilonGreedy),
14}
15
16/// Softmax explorer for IQN.
17#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
18pub struct Softmax {}
19
20#[allow(clippy::new_without_default)]
21impl Softmax {
22    /// Constructs softmax explorer.
23    pub fn new() -> Self {
24        Self {}
25    }
26
27    /// Takes an action based on the observation and the critic.
28    pub fn action(&mut self, a: &Tensor) -> Tensor {
29        a.softmax(-1, tch::Kind::Float).multinomial(1, true)
30    }
31}
32
33#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
34/// Epsilon-greedy explorer for IQN.
35pub struct EpsilonGreedy {
36    n_opts: usize,
37    eps_start: f64,
38    eps_final: f64,
39    final_step: usize,
40}
41
42impl Default for EpsilonGreedy {
43    fn default() -> Self {
44        Self {
45            n_opts: 0,
46            eps_start: 1.0,
47            eps_final: 0.02,
48            final_step: 100_000,
49        }
50    }
51}
52
53#[allow(clippy::new_without_default)]
54impl EpsilonGreedy {
55    /// Constructs epsilon-greedy explorer.
56    pub fn with_params(eps_start: f64, eps_final: f64, final_step: usize) -> IqnExplorer {
57        IqnExplorer::EpsilonGreedy(Self {
58            n_opts: 0,
59            eps_start,
60            eps_final,
61            final_step,
62        })
63    }
64
65    /// Constructs epsilon-greedy explorer.
66    ///
67    /// TODO: improve interface.
68    pub fn with_final_step(final_step: usize) -> IqnExplorer {
69        IqnExplorer::EpsilonGreedy(Self {
70            n_opts: 0,
71            eps_start: 1.0,
72            eps_final: 0.02,
73            final_step,
74        })
75    }
76
77    /// Takes an action based on the observation and the critic.
78    pub fn action(&mut self, action_value: Tensor) -> Tensor {
79        let d = (self.eps_start - self.eps_final) / (self.final_step as f64);
80        let eps = (self.eps_start - d * self.n_opts as f64).max(self.eps_final);
81        let r = fastrand::f64();
82        let is_random = r < eps;
83        self.n_opts += 1;
84
85        if is_random {
86            let batch_size = action_value.size()[0];
87            let n_actions = action_value.size()[1] as u32;
88            Tensor::from_slice(
89                (0..batch_size)
90                    .map(|_| fastrand::u32(..n_actions) as i32)
91                    .collect::<Vec<_>>()
92                    .as_slice(),
93            )
94        } else {
95            action_value.argmax(-1, true)
96        }
97    }
98}