border_tch_agent/iqn/
explorer.rs1use serde::{Deserialize, Serialize};
3use std::default::Default;
4use tch::Tensor;
5
6#[allow(clippy::upper_case_acronyms)]
7#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
8pub enum IqnExplorer {
10 Softmax(Softmax),
12 EpsilonGreedy(EpsilonGreedy),
14}
15
16#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
18pub struct Softmax {}
19
20#[allow(clippy::new_without_default)]
21impl Softmax {
22 pub fn new() -> Self {
24 Self {}
25 }
26
27 pub fn action(&mut self, a: &Tensor) -> Tensor {
29 a.softmax(-1, tch::Kind::Float).multinomial(1, true)
30 }
31}
32
33#[derive(Debug, Deserialize, Serialize, PartialEq, Clone)]
34pub struct EpsilonGreedy {
36 n_opts: usize,
37 eps_start: f64,
38 eps_final: f64,
39 final_step: usize,
40}
41
42impl Default for EpsilonGreedy {
43 fn default() -> Self {
44 Self {
45 n_opts: 0,
46 eps_start: 1.0,
47 eps_final: 0.02,
48 final_step: 100_000,
49 }
50 }
51}
52
53#[allow(clippy::new_without_default)]
54impl EpsilonGreedy {
55 pub fn with_params(eps_start: f64, eps_final: f64, final_step: usize) -> IqnExplorer {
57 IqnExplorer::EpsilonGreedy(Self {
58 n_opts: 0,
59 eps_start,
60 eps_final,
61 final_step,
62 })
63 }
64
65 pub fn with_final_step(final_step: usize) -> IqnExplorer {
69 IqnExplorer::EpsilonGreedy(Self {
70 n_opts: 0,
71 eps_start: 1.0,
72 eps_final: 0.02,
73 final_step,
74 })
75 }
76
77 pub fn action(&mut self, action_value: Tensor) -> Tensor {
79 let d = (self.eps_start - self.eps_final) / (self.final_step as f64);
80 let eps = (self.eps_start - d * self.n_opts as f64).max(self.eps_final);
81 let r = fastrand::f64();
82 let is_random = r < eps;
83 self.n_opts += 1;
84
85 if is_random {
86 let batch_size = action_value.size()[0];
87 let n_actions = action_value.size()[1] as u32;
88 Tensor::from_slice(
89 (0..batch_size)
90 .map(|_| fastrand::u32(..n_actions) as i32)
91 .collect::<Vec<_>>()
92 .as_slice(),
93 )
94 } else {
95 action_value.argmax(-1, true)
96 }
97 }
98}