1use burn::tensor::backend::AutodiffBackend;
2use rand::rngs::SmallRng;
3use rand::SeedableRng;
4use rl_traits::{Environment, Experience, Policy, ReplayBuffer};
5
6use crate::algorithms::dqn::{CircularBuffer, DqnAgent};
7use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
8use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
9use crate::training::run::TrainingRun;
10
11#[derive(Debug, Clone)]
16pub struct StepMetrics {
17 pub total_steps: usize,
19
20 pub episode: usize,
22
23 pub episode_step: usize,
25
26 pub reward: f64,
28
29 pub episode_reward: f64,
31
32 pub epsilon: f64,
34
35 pub did_update: bool,
37
38 pub episode_done: bool,
40
41 pub episode_status: rl_traits::EpisodeStatus,
43}
44
45pub struct DqnTrainer<E, Enc, Act, B, Buf = CircularBuffer<
73 <E as Environment>::Observation,
74 <E as Environment>::Action,
75>>
76where
77 E: Environment,
78 B: AutodiffBackend,
79 Buf: ReplayBuffer<E::Observation, E::Action>,
80{
81 env: E,
82 agent: DqnAgent<E, Enc, Act, B, Buf>,
83 rng: SmallRng,
84
85 current_obs: Option<E::Observation>,
87 episode: usize,
88 episode_step: usize,
89 episode_reward: f64,
90
91 run: Option<TrainingRun>,
93 stats: StatsTracker,
94
95 checkpoint_freq: usize,
97 keep_checkpoints: usize,
98
99 best_eval_reward: f64,
101}
102
103impl<E, Enc, Act, B, Buf> DqnTrainer<E, Enc, Act, B, Buf>
104where
105 E: Environment,
106 E::Observation: Clone + Send + Sync + 'static,
107 E::Action: Clone + Send + Sync + 'static,
108 Enc: ObservationEncoder<E::Observation, B>
109 + ObservationEncoder<E::Observation, B::InnerBackend>,
110 Act: DiscreteActionMapper<E::Action>,
111 B: AutodiffBackend,
112 Buf: ReplayBuffer<E::Observation, E::Action>,
113{
114 pub fn new(env: E, agent: DqnAgent<E, Enc, Act, B, Buf>, seed: u64) -> Self {
115 Self {
116 env,
117 agent,
118 rng: SmallRng::seed_from_u64(seed),
119 current_obs: None,
120 episode: 0,
121 episode_step: 0,
122 episode_reward: 0.0,
123 run: None,
124 stats: StatsTracker::new(),
125 checkpoint_freq: 10_000,
126 keep_checkpoints: 5,
127 best_eval_reward: f64::NEG_INFINITY,
128 }
129 }
130
131 pub fn with_run(mut self, run: TrainingRun) -> Self {
133 self.run = Some(run);
134 self
135 }
136
137 pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
139 self.checkpoint_freq = freq;
140 self
141 }
142
143 pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
145 self.keep_checkpoints = keep;
146 self
147 }
148
149 pub fn with_stats(mut self, stats: StatsTracker) -> Self {
151 self.stats = stats;
152 self
153 }
154
155 pub fn steps(&mut self) -> TrainIter<'_, E, Enc, Act, B, Buf> {
159 TrainIter { trainer: self }
160 }
161
162 pub fn agent(&self) -> &DqnAgent<E, Enc, Act, B, Buf> {
164 &self.agent
165 }
166
167 pub fn into_agent(self) -> DqnAgent<E, Enc, Act, B, Buf> {
174 self.agent
175 }
176
177 pub fn env(&self) -> &E {
179 &self.env
180 }
181
182 pub fn train(&mut self, n_steps: usize) {
187 let start_steps = self.agent.total_steps();
188 let target_steps = start_steps + n_steps;
189
190 loop {
191 let metrics = self.step_once();
192 let total = metrics.total_steps;
193
194 if metrics.episode_done {
195 let record = EpisodeRecord::new(
196 metrics.episode_reward,
197 metrics.episode_step,
198 metrics.episode_status.clone(),
199 );
200 self.stats.update(&record);
201 if let Some(run) = &self.run {
202 let _ = run.log_train_episode(&record);
203 }
204 }
205
206 if let Some(run) = &mut self.run {
208 if total.is_multiple_of(self.checkpoint_freq) {
209 let path = run.checkpoint_path(total);
210 let path_no_ext = path.with_extension("");
212 let _ = self.agent.save(&path_no_ext);
213 let _ = self.agent.save(run.latest_checkpoint_path().with_extension(""));
214 let _ = run.prune_checkpoints(self.keep_checkpoints);
215 let _ = run.update_metadata(total, self.episode);
216 }
217 }
218
219 if total >= target_steps {
220 break;
221 }
222 }
223 }
224
225 pub fn eval(&mut self, n_episodes: usize) -> EvalReport {
231 let total_steps = self.agent.total_steps();
232 let mut eval_stats = StatsTracker::new();
233 let mut records = Vec::with_capacity(n_episodes);
234
235 for _ in 0..n_episodes {
236 let record = self.run_greedy_episode();
237 eval_stats.update(&record);
238 records.push(record);
239 }
240
241 let summary = eval_stats.summary();
242 let mean_reward = summary.get("episode_reward").copied().unwrap_or(f64::NAN);
243
244 if mean_reward > self.best_eval_reward {
246 self.best_eval_reward = mean_reward;
247 if let Some(run) = &self.run {
248 let _ = self.agent.save(run.best_checkpoint_path().with_extension(""));
249 }
250 }
251
252 if let Some(run) = &self.run {
254 for record in &records {
255 let _ = run.log_eval_episode(record, total_steps);
256 }
257 }
258
259 EvalReport::new(total_steps, n_episodes, summary)
260 }
261
262 fn step_once(&mut self) -> StepMetrics {
266 if self.current_obs.is_none() {
268 let (obs, _info) = self.env.reset(Some(0));
269 self.current_obs = Some(obs);
270 self.episode = 0;
271 self.episode_step = 0;
272 self.episode_reward = 0.0;
273 }
274
275 let obs = self.current_obs.clone().unwrap();
276
277 let action = self.agent.act_epsilon_greedy(&obs, &mut self.rng);
279 let epsilon = self.agent.epsilon();
280
281 let result = self.env.step(action.clone());
283 let reward = result.reward;
284 let done = result.is_done();
285
286 self.episode_reward += reward;
287 self.episode_step += 1;
288
289 let experience = Experience::new(
291 obs,
292 action,
293 reward,
294 result.observation.clone(),
295 result.status.clone(),
296 );
297 let did_update = self.agent.observe(experience);
298
299 let metrics = StepMetrics {
300 total_steps: self.agent.total_steps(),
301 episode: self.episode,
302 episode_step: self.episode_step,
303 reward,
304 episode_reward: self.episode_reward,
305 epsilon,
306 did_update,
307 episode_done: done,
308 episode_status: result.status.clone(),
309 };
310
311 if done {
313 let (next_obs, _info) = self.env.reset(None);
314 self.current_obs = Some(next_obs);
315 self.episode += 1;
316 self.episode_step = 0;
317 self.episode_reward = 0.0;
318 } else {
319 self.current_obs = Some(result.observation);
320 }
321
322 metrics
323 }
324
325 fn run_greedy_episode(&mut self) -> EpisodeRecord {
327 let (mut obs, _) = self.env.reset(None);
328 let mut total_reward = 0.0;
329 let mut length = 0;
330
331 loop {
332 let action = self.agent.act(&obs);
333 let result = self.env.step(action);
334 total_reward += result.reward;
335 length += 1;
336
337 if result.is_done() {
338 return EpisodeRecord::new(total_reward, length, result.status);
339 }
340 obs = result.observation;
341 }
342 }
343}
344
345pub struct TrainIter<'a, E, Enc, Act, B, Buf = CircularBuffer<
349 <E as Environment>::Observation,
350 <E as Environment>::Action,
351>>
352where
353 E: Environment,
354 B: AutodiffBackend,
355 Buf: ReplayBuffer<E::Observation, E::Action>,
356{
357 trainer: &'a mut DqnTrainer<E, Enc, Act, B, Buf>,
358}
359
360impl<'a, E, Enc, Act, B, Buf> Iterator for TrainIter<'a, E, Enc, Act, B, Buf>
361where
362 E: Environment,
363 E::Observation: Clone + Send + Sync + 'static,
364 E::Action: Clone + Send + Sync + 'static,
365 Enc: ObservationEncoder<E::Observation, B>
366 + ObservationEncoder<E::Observation, B::InnerBackend>,
367 Act: DiscreteActionMapper<E::Action>,
368 B: AutodiffBackend,
369 Buf: ReplayBuffer<E::Observation, E::Action>,
370{
371 type Item = StepMetrics;
372
373 fn next(&mut self) -> Option<StepMetrics> {
374 Some(self.trainer.step_once())
377 }
378}
379