ember_rl/algorithms/dqn/config.rs
1/// Configuration for a DQN agent.
2///
3/// All hyperparameters live here. Pass this to `DqnAgent::new()`.
4/// The defaults reflect standard DQN practice suitable for moderately
5/// complex environments. Simple environments like CartPole will want
6/// smaller buffer/warmup values and faster epsilon decay.
7#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
8pub struct DqnConfig {
9 /// Discount factor γ. Controls how much future rewards are valued.
10 /// Typical values: 0.95–0.999. Default: 0.99
11 pub gamma: f64,
12
13 /// Learning rate for the Adam optimiser.
14 /// Default: 1e-4
15 pub learning_rate: f64,
16
17 /// Number of experiences sampled per gradient update.
18 /// Default: 32
19 pub batch_size: usize,
20
21 /// Maximum number of experiences in the replay buffer.
22 /// Oldest are overwritten when full.
23 /// Default: 100_000
24 pub buffer_capacity: usize,
25
26 /// Minimum number of experiences collected before training begins.
27 /// During warm-up, actions are sampled randomly.
28 /// Must be >= batch_size. Default: 10_000
29 pub min_replay_size: usize,
30
31 /// Number of steps between hard target network updates.
32 ///
33 /// The target network is a frozen copy of the online network used to
34 /// compute stable TD targets. Updating it too frequently causes
35 /// instability; too rarely slows learning. Default: 1_000
36 pub target_update_freq: usize,
37
38 /// Hidden layer sizes for the Q-network.
39 ///
40 /// The network architecture is:
41 /// `obs_size -> hidden[0] -> hidden[1] -> ... -> num_actions`
42 /// All hidden layers use ReLU activations.
43 /// Default: [128, 128]
44 pub hidden_sizes: Vec<usize>,
45
46 /// Starting epsilon for ε-greedy exploration.
47 /// At step 0, actions are random with this probability.
48 /// Default: 1.0
49 pub epsilon_start: f64,
50
51 /// Final epsilon after decay is complete.
52 /// Default: 0.05
53 pub epsilon_end: f64,
54
55 /// Number of steps over which epsilon decays linearly from
56 /// `epsilon_start` to `epsilon_end`. Default: 50_000
57 pub epsilon_decay_steps: usize,
58}
59
60impl Default for DqnConfig {
61 fn default() -> Self {
62 Self {
63 gamma: 0.99,
64 learning_rate: 1e-4,
65 batch_size: 32,
66 buffer_capacity: 100_000,
67 min_replay_size: 10_000,
68 target_update_freq: 1_000,
69 hidden_sizes: vec![128, 128],
70 epsilon_start: 1.0,
71 epsilon_end: 0.05,
72 epsilon_decay_steps: 50_000,
73 }
74 }
75}
76
77impl DqnConfig {
78 /// Compute the current epsilon given the number of elapsed steps.
79 ///
80 /// Decays linearly from `epsilon_start` to `epsilon_end` over
81 /// `epsilon_decay_steps`, then stays flat.
82 pub fn epsilon_at(&self, step: usize) -> f64 {
83 if step >= self.epsilon_decay_steps {
84 return self.epsilon_end;
85 }
86 let progress = step as f64 / self.epsilon_decay_steps as f64;
87 self.epsilon_start + progress * (self.epsilon_end - self.epsilon_start)
88 }
89}
90
91#[cfg(test)]
92mod tests {
93 use super::*;
94
95 #[test]
96 fn epsilon_at_zero() {
97 let config = DqnConfig::default();
98 assert_eq!(config.epsilon_at(0), config.epsilon_start);
99 }
100
101 #[test]
102 fn epsilon_at_end() {
103 let config = DqnConfig::default();
104 assert_eq!(config.epsilon_at(config.epsilon_decay_steps), config.epsilon_end);
105 }
106
107 #[test]
108 fn epsilon_past_end_is_clamped() {
109 let config = DqnConfig::default();
110 assert_eq!(config.epsilon_at(config.epsilon_decay_steps * 10), config.epsilon_end);
111 }
112
113 #[test]
114 fn epsilon_midpoint() {
115 let config = DqnConfig::default();
116 let mid = config.epsilon_decay_steps / 2;
117 let expected = (config.epsilon_start + config.epsilon_end) / 2.0;
118 let actual = config.epsilon_at(mid);
119 assert!((actual - expected).abs() < 1e-6);
120 }
121}