Skip to main content

ember_rl/algorithms/dqn/
config.rs

1/// Configuration for a DQN agent.
2///
3/// All hyperparameters live here. Pass this to `DqnAgent::new()`.
4/// The defaults reflect standard DQN practice suitable for moderately
5/// complex environments. Simple environments like CartPole will want
6/// smaller buffer/warmup values and faster epsilon decay.
7#[derive(Debug, Clone)]
8pub struct DqnConfig {
9    /// Discount factor γ. Controls how much future rewards are valued.
10    /// Typical values: 0.95–0.999. Default: 0.99
11    pub gamma: f64,
12
13    /// Learning rate for the Adam optimiser.
14    /// Default: 1e-4
15    pub learning_rate: f64,
16
17    /// Number of experiences sampled per gradient update.
18    /// Default: 32
19    pub batch_size: usize,
20
21    /// Maximum number of experiences in the replay buffer.
22    /// Oldest are overwritten when full.
23    /// Default: 100_000
24    pub buffer_capacity: usize,
25
26    /// Minimum number of experiences collected before training begins.
27    /// During warm-up, actions are sampled randomly.
28    /// Must be >= batch_size. Default: 10_000
29    pub min_replay_size: usize,
30
31    /// Number of steps between hard target network updates.
32    ///
33    /// The target network is a frozen copy of the online network used to
34    /// compute stable TD targets. Updating it too frequently causes
35    /// instability; too rarely slows learning. Default: 1_000
36    pub target_update_freq: usize,
37
38    /// Hidden layer sizes for the Q-network.
39    ///
40    /// The network architecture is:
41    /// `obs_size -> hidden[0] -> hidden[1] -> ... -> num_actions`
42    /// All hidden layers use ReLU activations.
43    /// Default: [128, 128]
44    pub hidden_sizes: Vec<usize>,
45
46    /// Starting epsilon for ε-greedy exploration.
47    /// At step 0, actions are random with this probability.
48    /// Default: 1.0
49    pub epsilon_start: f64,
50
51    /// Final epsilon after decay is complete.
52    /// Default: 0.05
53    pub epsilon_end: f64,
54
55    /// Number of steps over which epsilon decays linearly from
56    /// `epsilon_start` to `epsilon_end`. Default: 50_000
57    pub epsilon_decay_steps: usize,
58}
59
60impl Default for DqnConfig {
61    fn default() -> Self {
62        Self {
63            gamma: 0.99,
64            learning_rate: 1e-4,
65            batch_size: 32,
66            buffer_capacity: 100_000,
67            min_replay_size: 10_000,
68            target_update_freq: 1_000,
69            hidden_sizes: vec![128, 128],
70            epsilon_start: 1.0,
71            epsilon_end: 0.05,
72            epsilon_decay_steps: 50_000,
73        }
74    }
75}
76
77impl DqnConfig {
78    /// Compute the current epsilon given the number of elapsed steps.
79    ///
80    /// Decays linearly from `epsilon_start` to `epsilon_end` over
81    /// `epsilon_decay_steps`, then stays flat.
82    pub fn epsilon_at(&self, step: usize) -> f64 {
83        if step >= self.epsilon_decay_steps {
84            return self.epsilon_end;
85        }
86        let progress = step as f64 / self.epsilon_decay_steps as f64;
87        self.epsilon_start + progress * (self.epsilon_end - self.epsilon_start)
88    }
89}
90
91#[cfg(test)]
92mod tests {
93    use super::*;
94
95    #[test]
96    fn epsilon_at_zero() {
97        let config = DqnConfig::default();
98        assert_eq!(config.epsilon_at(0), config.epsilon_start);
99    }
100
101    #[test]
102    fn epsilon_at_end() {
103        let config = DqnConfig::default();
104        assert_eq!(config.epsilon_at(config.epsilon_decay_steps), config.epsilon_end);
105    }
106
107    #[test]
108    fn epsilon_past_end_is_clamped() {
109        let config = DqnConfig::default();
110        assert_eq!(config.epsilon_at(config.epsilon_decay_steps * 10), config.epsilon_end);
111    }
112
113    #[test]
114    fn epsilon_midpoint() {
115        let config = DqnConfig::default();
116        let mid = config.epsilon_decay_steps / 2;
117        let expected = (config.epsilon_start + config.epsilon_end) / 2.0;
118        let actual = config.epsilon_at(mid);
119        assert!((actual - expected).abs() < 1e-6);
120    }
121}