1use std::collections::HashMap;
2use std::collections::VecDeque;
3use std::path::Path;
4
5use burn::prelude::*;
6use burn::tensor::backend::AutodiffBackend;
7use burn::module::AutodiffModule;
8use burn::optim::{Adam, AdamConfig, GradientsParams, Optimizer};
9use burn::optim::adaptor::OptimizerAdaptor;
10use burn::record::CompactRecorder;
11use rand::{SeedableRng};
12use rand::rngs::SmallRng;
13use rl_traits::{Environment, Experience};
14
15use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
16use crate::stats::{Aggregator, Mean};
17use crate::traits::{ActMode, Checkpointable, LearningAgent};
18use super::config::PpoConfig;
19use super::network::ActorCriticNetwork;
20use super::rollout::{RolloutBuffer, Transition};
21
22pub struct PpoAgent<E, Enc, Act, B>
38where
39 E: Environment,
40 B: AutodiffBackend,
41{
42 model: ActorCriticNetwork<B>,
43 optimiser: OptimizerAdaptor<Adam, ActorCriticNetwork<B>, B>,
44
45 rollout: RolloutBuffer<E::Observation>,
46
47 encoder: Enc,
48 action_mapper: Act,
49
50 config: PpoConfig,
51 device: B::Device,
52 total_steps: usize,
53
54 pending: VecDeque<(f32, f32)>,
59
60 update_rng: SmallRng,
61
62 ep_policy_loss: Mean,
64 ep_value_loss: Mean,
65 ep_entropy: Mean,
66 ep_approx_kl: Mean,
67
68 _env: std::marker::PhantomData<E>,
69}
70
71impl<E, Enc, Act, B> PpoAgent<E, Enc, Act, B>
72where
73 E: Environment,
74 E::Observation: Clone + Send + Sync + 'static,
75 E::Action: Clone + Send + Sync + 'static,
76 Enc: ObservationEncoder<E::Observation, B>
77 + ObservationEncoder<E::Observation, B::InnerBackend>,
78 Act: DiscreteActionMapper<E::Action>,
79 B: AutodiffBackend,
80{
81 pub fn new(encoder: Enc, action_mapper: Act, config: PpoConfig, device: B::Device, seed: u64) -> Self {
82 let obs_size = <Enc as ObservationEncoder<E::Observation, B>>::obs_size(&encoder);
83 let n_actions = action_mapper.num_actions();
84
85 let mut layer_sizes = vec![obs_size];
86 layer_sizes.extend_from_slice(&config.hidden_sizes);
87
88 let model = ActorCriticNetwork::new(&layer_sizes, n_actions, &device);
89 let optimiser = AdamConfig::new().with_epsilon(1e-5).init::<B, ActorCriticNetwork<B>>();
90
91 let rollout_size = config.rollout_size();
92
93 Self {
94 model,
95 optimiser,
96 rollout: RolloutBuffer::new(rollout_size),
97 encoder,
98 action_mapper,
99 config,
100 device,
101 total_steps: 0,
102 pending: VecDeque::new(),
103 update_rng: SmallRng::seed_from_u64(seed),
104 ep_policy_loss: Mean::default(),
105 ep_value_loss: Mean::default(),
106 ep_entropy: Mean::default(),
107 ep_approx_kl: Mean::default(),
108 _env: std::marker::PhantomData,
109 }
110 }
111
112 fn forward_inference(
115 &self,
116 obs: &E::Observation,
117 ) -> (usize, f32, f32) {
118 let obs_t = ObservationEncoder::<E::Observation, B::InnerBackend>::encode(
120 &self.encoder, obs, &self.device,
121 ).unsqueeze_dim(0);
122
123 let model_valid = self.model.valid();
124 let (logits, value) = model_valid.forward(obs_t);
125
126 let probs = burn::tensor::activation::softmax(logits.clone(), 1); let probs_vec: Vec<f32> = probs
129 .squeeze::<1>()
130 .into_data()
131 .to_vec::<f32>()
132 .unwrap();
133
134 let action_idx = sample_categorical(&probs_vec, &mut rand::thread_rng());
136
137 let log_prob = probs_vec[action_idx].ln();
138 let v: f32 = value.into_data().to_vec::<f32>().unwrap()[0];
139
140 (action_idx, log_prob, v)
141 }
142
143 fn forward_inference_greedy(&self, obs: &E::Observation) -> usize {
144 let obs_t = ObservationEncoder::<E::Observation, B::InnerBackend>::encode(
145 &self.encoder, obs, &self.device,
146 ).unsqueeze_dim(0);
147
148 let (logits, _) = self.model.valid().forward(obs_t);
149 logits
150 .argmax(1)
151 .into_data()
152 .to_vec::<i64>()
153 .unwrap()[0] as usize
154 }
155
156 fn compute_bootstrap_value(&self, next_obs: &E::Observation) -> f32 {
159 let obs_t = ObservationEncoder::<E::Observation, B::InnerBackend>::encode(
160 &self.encoder, next_obs, &self.device,
161 ).unsqueeze_dim(0);
162 let (_, value) = self.model.valid().forward(obs_t);
163 value.into_data().to_vec::<f32>().unwrap()[0]
164 }
165
166 fn run_update(&mut self, bootstrap_value: f32) {
169 let gamma = self.config.gamma as f32;
170 let lambda = self.config.gae_lambda as f32;
171 self.rollout.compute_gae(bootstrap_value, gamma, lambda);
172 self.rollout.normalize_advantages();
173
174 for _ in 0..self.config.n_epochs {
175 let batches = self.rollout.minibatches(self.config.batch_size, &mut self.update_rng);
176 for batch in batches {
177 let (pl, vl, ent, kl) = self.update_step(&batch);
178 self.ep_policy_loss.update(pl);
179 self.ep_value_loss.update(vl);
180 self.ep_entropy.update(ent);
181 self.ep_approx_kl.update(kl);
182 }
183 }
184
185 self.rollout.clear();
186 }
187
188 fn update_step(
189 &mut self,
190 batch: &super::rollout::Batch<E::Observation>,
191 ) -> (f64, f64, f64, f64) {
192 let bs = batch.obs.len();
193 let obs_t = self.encoder.encode_batch(&batch.obs, &self.device);
194
195 let (logits, values) = self.model.forward(obs_t);
196
197 let probs = burn::tensor::activation::softmax(logits, 1); let log_probs_all = probs.clone().log(); let action_idx_t = Tensor::<B, 1, Int>::from_ints(
203 batch.actions.iter().map(|&a| a as i32).collect::<Vec<_>>().as_slice(),
204 &self.device,
205 );
206 let new_log_probs = log_probs_all.clone()
207 .gather(1, action_idx_t.reshape([bs, 1]))
208 .squeeze::<1>(); let old_log_probs_t: Tensor<B, 1> = Tensor::from_floats(
212 batch.old_log_probs.as_slice(), &self.device,
213 );
214
215 let advantages_t: Tensor<B, 1> = Tensor::from_floats(
217 batch.advantages.as_slice(), &self.device,
218 );
219 let returns_t: Tensor<B, 1> = Tensor::from_floats(
220 batch.returns.as_slice(), &self.device,
221 );
222
223 let ratio = (new_log_probs.clone() - old_log_probs_t.clone().detach()).exp();
225 let clip_eps = self.config.clip_epsilon as f32;
226 let surr1 = ratio.clone() * advantages_t.clone().detach();
227 let surr2 = ratio.clone().clamp(1.0 - clip_eps, 1.0 + clip_eps)
228 * advantages_t.detach();
229 let policy_loss = -surr1.min_pair(surr2).mean();
230
231 let diff = values - returns_t.detach();
233 let value_loss = (diff.clone() * diff).mean()
234 * self.config.value_loss_coef as f32;
235
236 let entropy = -(probs.clone() * log_probs_all.clone()).sum_dim(1).mean();
238 let entropy_loss = entropy.clone() * (-self.config.entropy_coef as f32);
239
240 let total_loss = policy_loss.clone() + value_loss.clone() + entropy_loss;
241
242 let pl_val = policy_loss.clone().into_scalar().elem::<f64>();
244 let vl_val = value_loss.clone().into_scalar().elem::<f64>();
245 let ent_val = entropy.clone().into_scalar().elem::<f64>();
246
247 let approx_kl = (old_log_probs_t.detach() - new_log_probs.detach())
249 .mean()
250 .into_scalar()
251 .elem::<f64>();
252
253 let grads = total_loss.backward();
254 let grads = GradientsParams::from_grads(grads, &self.model);
255 self.model = self.optimiser.step(self.config.learning_rate, self.model.clone(), grads);
256
257 (pl_val, vl_val, ent_val, approx_kl)
258 }
259}
260
261impl<E, Enc, Act, B> Checkpointable for PpoAgent<E, Enc, Act, B>
264where
265 E: Environment,
266 E::Observation: Clone + Send + Sync + 'static,
267 E::Action: Clone + Send + Sync + 'static,
268 Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
269 Act: DiscreteActionMapper<E::Action>,
270 B: AutodiffBackend,
271{
272 fn save(&self, path: &Path) -> anyhow::Result<()> {
273 self.model
274 .clone()
275 .save_file(path.to_path_buf(), &CompactRecorder::new())
276 .map(|_| ())
277 .map_err(|e| anyhow::anyhow!(e))
278 }
279
280 fn load(mut self, path: &Path) -> anyhow::Result<Self> {
281 self.model = self.model
282 .load_file(path.to_path_buf(), &CompactRecorder::new(), &self.device)
283 .map_err(|e| anyhow::anyhow!(e))?;
284 Ok(self)
285 }
286}
287
288impl<E, Enc, Act, B> LearningAgent<E> for PpoAgent<E, Enc, Act, B>
291where
292 E: Environment,
293 E::Observation: Clone + Send + Sync + 'static,
294 E::Action: Clone + Send + Sync + 'static,
295 Enc: ObservationEncoder<E::Observation, B> + ObservationEncoder<E::Observation, B::InnerBackend>,
296 Act: DiscreteActionMapper<E::Action>,
297 B: AutodiffBackend,
298{
299 fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
300 match mode {
301 ActMode::Exploit => {
302 let idx = self.forward_inference_greedy(obs);
303 self.pending.push_back((0.0, 0.0));
305 self.action_mapper.index_to_action(idx)
306 }
307 ActMode::Explore => {
308 let (idx, log_prob, value) = self.forward_inference(obs);
309 self.pending.push_back((log_prob, value));
310 self.action_mapper.index_to_action(idx)
311 }
312 }
313 }
314
315 fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
316 self.total_steps += 1;
317
318 let (log_prob, value) = self.pending.pop_front().unwrap_or((0.0, 0.0));
319 let action = self.action_mapper.action_to_index(&experience.action);
320 let done = !matches!(experience.status, rl_traits::EpisodeStatus::Continuing);
321
322 self.rollout.push(Transition {
323 obs: experience.observation,
324 action,
325 reward: experience.reward as f32,
326 done,
327 value,
328 log_prob,
329 });
330
331 if self.rollout.is_full() {
332 let bootstrap = if self.rollout.last_done() {
335 0.0
336 } else {
337 self.compute_bootstrap_value(&experience.next_observation)
338 };
339 self.run_update(bootstrap);
340 }
341 }
342
343 fn total_steps(&self) -> usize {
344 self.total_steps
345 }
346
347 fn episode_extras(&self) -> HashMap<String, f64> {
348 [
349 ("policy_loss".to_string(), self.ep_policy_loss.value()),
350 ("value_loss".to_string(), self.ep_value_loss.value()),
351 ("entropy".to_string(), self.ep_entropy.value()),
352 ("approx_kl".to_string(), self.ep_approx_kl.value()),
353 ]
354 .into_iter()
355 .filter(|(_, v)| v.is_finite())
356 .collect()
357 }
358
359 fn on_episode_start(&mut self) {
360 self.ep_policy_loss.reset();
361 self.ep_value_loss.reset();
362 self.ep_entropy.reset();
363 self.ep_approx_kl.reset();
364 }
365}
366
367fn sample_categorical(probs: &[f32], rng: &mut impl rand::Rng) -> usize {
370 let r: f32 = rng.gen();
371 let mut cumsum = 0.0f32;
372 for (i, &p) in probs.iter().enumerate() {
373 cumsum += p;
374 if r < cumsum {
375 return i;
376 }
377 }
378 probs.len() - 1
379}