1use burn::tensor::backend::AutodiffBackend;
2use rl_traits::{Environment, Experience};
3
4use crate::algorithms::dqn::{CircularBuffer, DqnAgent};
5use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
6use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
7use crate::traits::ActMode;
8use crate::training::run::TrainingRun;
9use crate::training::session::TrainingSession;
10use rl_traits::ReplayBuffer;
11
12#[derive(Debug, Clone)]
14pub struct StepMetrics {
15 pub total_steps: usize,
17
18 pub episode: usize,
20
21 pub episode_step: usize,
23
24 pub reward: f64,
26
27 pub episode_reward: f64,
29
30 pub epsilon: f64,
32
33 pub episode_done: bool,
35
36 pub episode_status: rl_traits::EpisodeStatus,
38}
39
40pub struct DqnTrainer<E, Enc, Act, B, Buf = CircularBuffer<
70 <E as Environment>::Observation,
71 <E as Environment>::Action,
72>>
73where
74 E: Environment,
75 B: AutodiffBackend,
76 Buf: ReplayBuffer<E::Observation, E::Action>,
77{
78 env: E,
79 session: TrainingSession<E, DqnAgent<E, Enc, Act, B, Buf>>,
80
81 current_obs: Option<E::Observation>,
83 episode: usize,
84 episode_step: usize,
85 episode_reward: f64,
86}
87
88impl<E, Enc, Act, B, Buf> DqnTrainer<E, Enc, Act, B, Buf>
89where
90 E: Environment,
91 E::Observation: Clone + Send + Sync + 'static,
92 E::Action: Clone + Send + Sync + 'static,
93 Enc: ObservationEncoder<E::Observation, B>
94 + ObservationEncoder<E::Observation, B::InnerBackend>,
95 Act: DiscreteActionMapper<E::Action>,
96 B: AutodiffBackend,
97 Buf: ReplayBuffer<E::Observation, E::Action>,
98{
99 pub fn new(env: E, agent: DqnAgent<E, Enc, Act, B, Buf>) -> Self {
100 Self {
101 env,
102 session: TrainingSession::new(agent),
103 current_obs: None,
104 episode: 0,
105 episode_step: 0,
106 episode_reward: 0.0,
107 }
108 }
109
110 pub fn with_run(mut self, run: TrainingRun) -> Self {
112 self.session = self.session.with_run(run);
113 self
114 }
115
116 pub fn with_max_steps(mut self, n: usize) -> Self {
118 self.session = self.session.with_max_steps(n);
119 self
120 }
121
122 pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
124 self.session = self.session.with_checkpoint_freq(freq);
125 self
126 }
127
128 pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
130 self.session = self.session.with_keep_checkpoints(keep);
131 self
132 }
133
134 pub fn with_stats(mut self, stats: StatsTracker) -> Self {
136 self.session = self.session.with_stats(stats);
137 self
138 }
139
140 pub fn steps(&mut self) -> TrainIter<'_, E, Enc, Act, B, Buf> {
142 TrainIter { trainer: self }
143 }
144
145 pub fn train(&mut self) {
147 loop {
148 self.step_once();
149 if self.session.is_done() {
150 break;
151 }
152 }
153 }
154
155 pub fn eval(&mut self, n_episodes: usize) -> EvalReport {
157 let mut eval_stats = StatsTracker::new();
158 let mut records = Vec::with_capacity(n_episodes);
159
160 for _ in 0..n_episodes {
161 let record = self.run_greedy_episode();
162 eval_stats.update(&record);
163 self.session.on_eval_episode(&record);
164 records.push(record);
165 }
166
167 let summary = eval_stats.summary();
168 let mean_reward = summary.get("episode_reward").copied().unwrap_or(f64::NAN);
169 self.session.maybe_save_best(mean_reward);
170
171 let total_steps = self.session.total_steps();
172 EvalReport::new(total_steps, n_episodes, summary)
173 }
174
175 pub fn session(&self) -> &TrainingSession<E, DqnAgent<E, Enc, Act, B, Buf>> {
177 &self.session
178 }
179
180 pub fn into_agent(self) -> DqnAgent<E, Enc, Act, B, Buf> {
182 self.session.into_agent()
183 }
184
185 pub fn env(&self) -> &E {
187 &self.env
188 }
189
190 fn step_once(&mut self) -> StepMetrics {
193 if self.current_obs.is_none() {
194 let (obs, _) = self.env.reset(Some(0));
195 self.current_obs = Some(obs);
196 self.episode = 0;
197 self.episode_step = 0;
198 self.episode_reward = 0.0;
199 self.session.on_episode_start();
200 }
201
202 let obs = self.current_obs.clone().unwrap();
203 let epsilon = self.session.agent().epsilon();
204 let action = self.session.act(&obs, ActMode::Explore);
205
206 let result = self.env.step(action.clone());
207 let reward = result.reward;
208 let done = result.is_done();
209
210 self.episode_reward += reward;
211 self.episode_step += 1;
212
213 self.session.observe(Experience::new(
214 obs,
215 action,
216 reward,
217 result.observation.clone(),
218 result.status.clone(),
219 ));
220
221 let metrics = StepMetrics {
222 total_steps: self.session.total_steps(),
223 episode: self.episode,
224 episode_step: self.episode_step,
225 reward,
226 episode_reward: self.episode_reward,
227 epsilon,
228 episode_done: done,
229 episode_status: result.status.clone(),
230 };
231
232 if done {
233 self.session.on_episode(
234 self.episode_reward,
235 self.episode_step,
236 result.status,
237 self.env.episode_extras(),
238 );
239 let (next_obs, _) = self.env.reset(None);
240 self.current_obs = Some(next_obs);
241 self.episode += 1;
242 self.episode_step = 0;
243 self.episode_reward = 0.0;
244 } else {
245 self.current_obs = Some(result.observation);
246 }
247
248 metrics
249 }
250
251 fn run_greedy_episode(&mut self) -> EpisodeRecord {
252 let (mut obs, _) = self.env.reset(None);
253 let mut total_reward = 0.0;
254 let mut length = 0;
255
256 loop {
257 let action = self.session.act(&obs, ActMode::Exploit);
258 let result = self.env.step(action);
259 total_reward += result.reward;
260 length += 1;
261
262 if result.is_done() {
263 return EpisodeRecord::new(total_reward, length, result.status);
264 }
265 obs = result.observation;
266 }
267 }
268}
269
270pub struct TrainIter<'a, E, Enc, Act, B, Buf = CircularBuffer<
273 <E as Environment>::Observation,
274 <E as Environment>::Action,
275>>
276where
277 E: Environment,
278 B: AutodiffBackend,
279 Buf: ReplayBuffer<E::Observation, E::Action>,
280{
281 trainer: &'a mut DqnTrainer<E, Enc, Act, B, Buf>,
282}
283
284impl<'a, E, Enc, Act, B, Buf> Iterator for TrainIter<'a, E, Enc, Act, B, Buf>
285where
286 E: Environment,
287 E::Observation: Clone + Send + Sync + 'static,
288 E::Action: Clone + Send + Sync + 'static,
289 Enc: ObservationEncoder<E::Observation, B>
290 + ObservationEncoder<E::Observation, B::InnerBackend>,
291 Act: DiscreteActionMapper<E::Action>,
292 B: AutodiffBackend,
293 Buf: ReplayBuffer<E::Observation, E::Action>,
294{
295 type Item = StepMetrics;
296
297 fn next(&mut self) -> Option<StepMetrics> {
298 Some(self.trainer.step_once())
299 }
300}