Skip to main content

ember_rl/algorithms/ppo/
config.rs

1/// Hyperparameters for the PPO algorithm.
2#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
3pub struct PpoConfig {
4    // -- Rollout collection --
5    /// Steps collected per environment before each update.
6    /// Total rollout size = n_steps * n_envs.
7    pub n_steps: usize,
8
9    /// Number of parallel environments feeding this agent.
10    /// Must match the number of envs in your BevyGymPlugin / training loop.
11    pub n_envs: usize,
12
13    // -- Update --
14    /// Number of gradient epochs over each rollout. Typical: 4-10.
15    pub n_epochs: usize,
16
17    /// Minibatch size for each gradient step. Must divide n_steps * n_envs evenly.
18    pub batch_size: usize,
19
20    pub learning_rate: f64,
21
22    // -- PPO objective --
23    /// Clipping range for the probability ratio. Typical: 0.1-0.3.
24    pub clip_epsilon: f64,
25
26    /// Weight on the value function loss. Typical: 0.5.
27    pub value_loss_coef: f64,
28
29    /// Weight on the entropy bonus (encourages exploration). Typical: 0.01.
30    pub entropy_coef: f64,
31
32    // -- Returns / advantages --
33    /// Discount factor.
34    pub gamma: f64,
35
36    /// GAE smoothing parameter. 1.0 = full Monte Carlo, 0.0 = TD(0).
37    pub gae_lambda: f64,
38
39    // -- Network --
40    /// Hidden layer sizes for the shared trunk.
41    pub hidden_sizes: Vec<usize>,
42
43    /// Clip gradient norm to this value. Set to 0.0 to disable.
44    pub max_grad_norm: f64,
45}
46
47impl Default for PpoConfig {
48    fn default() -> Self {
49        Self {
50            n_steps: 128,
51            n_envs: 1,
52            n_epochs: 4,
53            batch_size: 64,
54            learning_rate: 2.5e-4,
55            clip_epsilon: 0.2,
56            value_loss_coef: 0.5,
57            entropy_coef: 0.01,
58            gamma: 0.99,
59            gae_lambda: 0.95,
60            hidden_sizes: vec![64, 64],
61            max_grad_norm: 0.5,
62        }
63    }
64}
65
66impl PpoConfig {
67    /// Total number of transitions collected before each update.
68    pub fn rollout_size(&self) -> usize {
69        self.n_steps * self.n_envs
70    }
71}