Skip to main content

ember_rl/algorithms/ppo/
agent.rs

1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::path::Path;
4
5use burn::prelude::*;
6use burn::tensor::backend::AutodiffBackend;
7use burn::module::AutodiffModule;
8use burn::optim::{Adam, AdamConfig, GradientsParams, Optimizer};
9use burn::optim::adaptor::OptimizerAdaptor;
10use burn::record::CompactRecorder;
11use rand::{SeedableRng};
12use rand::rngs::SmallRng;
13use rl_traits::{Environment, Experience};
14
15use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
16use crate::stats::{Aggregator, Mean};
17use crate::traits::{ActMode, Checkpointable, LearningAgent};
18use super::config::PpoConfig;
19use super::network::ActorCriticNetwork;
20use super::rollout::{RolloutBuffer, Transition};
21
22/// A PPO agent with discrete actions.
23///
24/// Implements clipped-surrogate PPO with GAE advantage estimation and an
25/// actor-critic network. Generic over environment, encoder, action mapper,
26/// and Burn backend -- the same pattern as `DqnAgent`.
27///
28/// **On-policy**: experience is collected into a rollout buffer, used for
29/// `n_epochs` gradient passes, then discarded. The buffer size is
30/// `n_steps * n_envs`; an update fires automatically when it fills.
31///
32/// # Parallel environments
33///
34/// Set `PpoConfig::n_envs` to match the number of environments feeding
35/// this agent (e.g. `bevy-gym`'s `NUM_ENVS`). All envs contribute to the
36/// same rollout buffer; the update fires after `n_steps` ticks.
37pub struct PpoAgent<E, Enc, Act, B>
38where
39    E: Environment,
40    B: AutodiffBackend,
41{
42    model: ActorCriticNetwork<B>,
43    optimiser: OptimizerAdaptor<Adam, ActorCriticNetwork<B>, B>,
44
45    rollout: RolloutBuffer<E::Observation>,
46
47    encoder: Enc,
48    action_mapper: Act,
49
50    config: PpoConfig,
51    device: B::Device,
52    total_steps: usize,
53
54    // FIFO cache populated by act() and consumed by observe().
55    // Each act() call pushes (log_prob, value); observe() pops one entry.
56    // This works because bevy-gym (and DqnTrainer) call act then observe
57    // in matched pairs within the same tick / loop iteration.
58    pending: VecDeque<(f32, f32)>,
59
60    update_rng: SmallRng,
61
62    // Per-update stats (reset after each PPO update)
63    ep_policy_loss: Mean,
64    ep_value_loss: Mean,
65    ep_entropy: Mean,
66    ep_approx_kl: Mean,
67
68    _env: std::marker::PhantomData<E>,
69}
70
71impl<E, Enc, Act, B> PpoAgent<E, Enc, Act, B>
72where
73    E: Environment,
74    E::Observation: Clone + Send + Sync + 'static,
75    E::Action: Clone + Send + Sync + 'static,
76    Enc: ObservationEncoder<E::Observation, B>
77        + ObservationEncoder<E::Observation, B::InnerBackend>,
78    Act: DiscreteActionMapper<E::Action>,
79    B: AutodiffBackend,
80{
81    pub fn new(encoder: Enc, action_mapper: Act, config: PpoConfig, device: B::Device, seed: u64) -> Self {
82        let obs_size = <Enc as ObservationEncoder<E::Observation, B>>::obs_size(&encoder);
83        let n_actions = action_mapper.num_actions();
84
85        let mut layer_sizes = vec![obs_size];
86        layer_sizes.extend_from_slice(&config.hidden_sizes);
87
88        let model = ActorCriticNetwork::new(&layer_sizes, n_actions, &device);
89        let optimiser = AdamConfig::new().with_epsilon(1e-5).init::<B, ActorCriticNetwork<B>>();
90
91        let rollout_size = config.rollout_size();
92
93        Self {
94            model,
95            optimiser,
96            rollout: RolloutBuffer::new(rollout_size),
97            encoder,
98            action_mapper,
99            config,
100            device,
101            total_steps: 0,
102            pending: VecDeque::new(),
103            update_rng: SmallRng::seed_from_u64(seed),
104            ep_policy_loss: Mean::default(),
105            ep_value_loss: Mean::default(),
106            ep_entropy: Mean::default(),
107            ep_approx_kl: Mean::default(),
108            _env: std::marker::PhantomData,
109        }
110    }
111
112    // ---- Forward pass (no grad) ----
113
114    fn forward_inference(
115        &self,
116        obs: &E::Observation,
117    ) -> (usize, f32, f32) {
118        // Use valid() to avoid tracking gradients during collection.
119        let obs_t = ObservationEncoder::<E::Observation, B::InnerBackend>::encode(
120            &self.encoder, obs, &self.device,
121        ).unsqueeze_dim(0);
122
123        let model_valid = self.model.valid();
124        let (logits, value) = model_valid.forward(obs_t);
125
126        // Softmax -> categorical sample
127        let probs = burn::tensor::activation::softmax(logits.clone(), 1); // [1, n_actions]
128        let probs_vec: Vec<f32> = probs
129            .squeeze::<1>()
130            .into_data()
131            .to_vec::<f32>()
132            .unwrap();
133
134        // Sample action proportional to probabilities
135        let action_idx = sample_categorical(&probs_vec, &mut rand::thread_rng());
136
137        let log_prob = probs_vec[action_idx].ln();
138        let v: f32 = value.into_data().to_vec::<f32>().unwrap()[0];
139
140        (action_idx, log_prob, v)
141    }
142
143    fn forward_inference_greedy(&self, obs: &E::Observation) -> usize {
144        let obs_t = ObservationEncoder::<E::Observation, B::InnerBackend>::encode(
145            &self.encoder, obs, &self.device,
146        ).unsqueeze_dim(0);
147
148        let (logits, _) = self.model.valid().forward(obs_t);
149        logits
150            .argmax(1)
151            .into_data()
152            .to_vec::<i64>()
153            .unwrap()[0] as usize
154    }
155
156    // ---- Bootstrap value for GAE ----
157
158    fn compute_bootstrap_value(&self, next_obs: &E::Observation) -> f32 {
159        let obs_t = ObservationEncoder::<E::Observation, B::InnerBackend>::encode(
160            &self.encoder, next_obs, &self.device,
161        ).unsqueeze_dim(0);
162        let (_, value) = self.model.valid().forward(obs_t);
163        value.into_data().to_vec::<f32>().unwrap()[0]
164    }
165
166    // ---- PPO update ----
167
168    fn run_update(&mut self, bootstrap_value: f32) {
169        let gamma = self.config.gamma as f32;
170        let lambda = self.config.gae_lambda as f32;
171        self.rollout.compute_gae(bootstrap_value, gamma, lambda);
172        self.rollout.normalize_advantages();
173
174        for _ in 0..self.config.n_epochs {
175            let batches = self.rollout.minibatches(self.config.batch_size, &mut self.update_rng);
176            for batch in batches {
177                let (pl, vl, ent, kl) = self.update_step(&batch);
178                self.ep_policy_loss.update(pl);
179                self.ep_value_loss.update(vl);
180                self.ep_entropy.update(ent);
181                self.ep_approx_kl.update(kl);
182            }
183        }
184
185        self.rollout.clear();
186    }
187
188    fn update_step(
189        &mut self,
190        batch: &super::rollout::Batch<E::Observation>,
191    ) -> (f64, f64, f64, f64) {
192        let bs = batch.obs.len();
193        let obs_t = self.encoder.encode_batch(&batch.obs, &self.device);
194
195        let (logits, values) = self.model.forward(obs_t);
196
197        // Probabilities and log-probabilities
198        let probs = burn::tensor::activation::softmax(logits, 1);  // [bs, n_actions]
199        let log_probs_all = probs.clone().log();                    // [bs, n_actions]
200
201        // Gather log_prob for the taken actions
202        let action_idx_t = Tensor::<B, 1, Int>::from_ints(
203            batch.actions.iter().map(|&a| a as i32).collect::<Vec<_>>().as_slice(),
204            &self.device,
205        );
206        let new_log_probs = log_probs_all.clone()
207            .gather(1, action_idx_t.reshape([bs, 1]))
208            .squeeze::<1>(); // [bs]
209
210        // Old log probs (from collection time)
211        let old_log_probs_t: Tensor<B, 1> = Tensor::from_floats(
212            batch.old_log_probs.as_slice(), &self.device,
213        );
214
215        // Advantages and returns
216        let advantages_t: Tensor<B, 1> = Tensor::from_floats(
217            batch.advantages.as_slice(), &self.device,
218        );
219        let returns_t: Tensor<B, 1> = Tensor::from_floats(
220            batch.returns.as_slice(), &self.device,
221        );
222
223        // PPO clipped surrogate loss
224        let ratio = (new_log_probs.clone() - old_log_probs_t.clone().detach()).exp();
225        let clip_eps = self.config.clip_epsilon as f32;
226        let surr1 = ratio.clone() * advantages_t.clone().detach();
227        let surr2 = ratio.clone().clamp(1.0 - clip_eps, 1.0 + clip_eps)
228            * advantages_t.detach();
229        let policy_loss = -surr1.min_pair(surr2).mean();
230
231        // Value loss (MSE)
232        let diff = values - returns_t.detach();
233        let value_loss = (diff.clone() * diff).mean()
234            * self.config.value_loss_coef as f32;
235
236        // Entropy bonus
237        let entropy = -(probs.clone() * log_probs_all.clone()).sum_dim(1).mean();
238        let entropy_loss = entropy.clone() * (-self.config.entropy_coef as f32);
239
240        let total_loss = policy_loss.clone() + value_loss.clone() + entropy_loss;
241
242        // Scalar stats (detach before backward)
243        let pl_val = policy_loss.clone().into_scalar().elem::<f64>();
244        let vl_val = value_loss.clone().into_scalar().elem::<f64>();
245        let ent_val = entropy.clone().into_scalar().elem::<f64>();
246
247        // Approximate KL (for monitoring)
248        let approx_kl = (old_log_probs_t.detach() - new_log_probs.detach())
249            .mean()
250            .into_scalar()
251            .elem::<f64>();
252
253        let grads = total_loss.backward();
254        let grads = GradientsParams::from_grads(grads, &self.model);
255        self.model = self.optimiser.step(self.config.learning_rate, self.model.clone(), grads);
256
257        (pl_val, vl_val, ent_val, approx_kl)
258    }
259}
260
261// ---- Checkpointable ----
262
263impl<E, Enc, Act, B> Checkpointable for PpoAgent<E, Enc, Act, B>
264where
265    E: Environment,
266    E::Observation: Clone + Send + Sync + 'static,
267    E::Action: Clone + Send + Sync + 'static,
268    Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
269    Act: DiscreteActionMapper<E::Action>,
270    B: AutodiffBackend,
271{
272    fn save(&self, path: &Path) -> anyhow::Result<()> {
273        self.model
274            .clone()
275            .save_file(path.to_path_buf(), &CompactRecorder::new())
276            .map(|_| ())
277            .map_err(|e| anyhow::anyhow!(e))
278    }
279
280    fn load(mut self, path: &Path) -> anyhow::Result<Self> {
281        self.model = self.model
282            .load_file(path.to_path_buf(), &CompactRecorder::new(), &self.device)
283            .map_err(|e| anyhow::anyhow!(e))?;
284        Ok(self)
285    }
286}
287
288// ---- LearningAgent ----
289
290impl<E, Enc, Act, B> LearningAgent<E> for PpoAgent<E, Enc, Act, B>
291where
292    E: Environment,
293    E::Observation: Clone + Send + Sync + 'static,
294    E::Action: Clone + Send + Sync + 'static,
295    Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
296    Act: DiscreteActionMapper<E::Action>,
297    B: AutodiffBackend,
298{
299    fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
300        match mode {
301            ActMode::Exploit => {
302                let idx = self.forward_inference_greedy(obs);
303                // Push dummy pending entry so observe() can safely pop
304                self.pending.push_back((0.0, 0.0));
305                self.action_mapper.index_to_action(idx)
306            }
307            ActMode::Explore => {
308                let (idx, log_prob, value) = self.forward_inference(obs);
309                self.pending.push_back((log_prob, value));
310                self.action_mapper.index_to_action(idx)
311            }
312        }
313    }
314
315    fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
316        self.total_steps += 1;
317
318        let (log_prob, value) = self.pending.pop_front().unwrap_or((0.0, 0.0));
319        let action = self.action_mapper.action_to_index(&experience.action);
320        let done = !matches!(experience.status, rl_traits::EpisodeStatus::Continuing);
321
322        self.rollout.push(Transition {
323            obs: experience.observation,
324            action,
325            reward: experience.reward as f32,
326            done,
327            value,
328            log_prob,
329        });
330
331        if self.rollout.is_full() {
332            // Bootstrap: if the last step ended an episode, future value is 0;
333            // otherwise estimate it from the next observation.
334            let bootstrap = if self.rollout.last_done() {
335                0.0
336            } else {
337                self.compute_bootstrap_value(&experience.next_observation)
338            };
339            self.run_update(bootstrap);
340        }
341    }
342
343    fn total_steps(&self) -> usize {
344        self.total_steps
345    }
346
347    fn episode_extras(&self) -> HashMap<String, f64> {
348        [
349            ("policy_loss".to_string(), self.ep_policy_loss.value()),
350            ("value_loss".to_string(),  self.ep_value_loss.value()),
351            ("entropy".to_string(),     self.ep_entropy.value()),
352            ("approx_kl".to_string(),   self.ep_approx_kl.value()),
353        ]
354        .into_iter()
355        .filter(|(_, v)| v.is_finite())
356        .collect()
357    }
358
359    fn on_episode_start(&mut self) {
360        self.ep_policy_loss.reset();
361        self.ep_value_loss.reset();
362        self.ep_entropy.reset();
363        self.ep_approx_kl.reset();
364    }
365}
366
367// ---- Categorical sampling ----
368
369fn sample_categorical(probs: &[f32], rng: &mut impl rand::Rng) -> usize {
370    let r: f32 = rng.gen();
371    let mut cumsum = 0.0f32;
372    for (i, &p) in probs.iter().enumerate() {
373        cumsum += p;
374        if r < cumsum {
375            return i;
376        }
377    }
378    probs.len() - 1
379}