Skip to main content

ember_rl/training/
runner.rs

1use burn::tensor::backend::AutodiffBackend;
2use rand::rngs::SmallRng;
3use rand::SeedableRng;
4use rl_traits::{Environment, Experience, Policy, ReplayBuffer};
5
6use crate::algorithms::dqn::{CircularBuffer, DqnAgent};
7use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
8use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
9use crate::training::run::TrainingRun;
10
11/// Metrics emitted after every environment step.
12///
13/// The trainer yields one of these per call to `Iterator::next()`.
14/// Use these to log progress, plot curves, or decide when to stop.
15#[derive(Debug, Clone)]
16pub struct StepMetrics {
17    /// Total environment steps taken so far.
18    pub total_steps: usize,
19
20    /// Which episode this step belongs to.
21    pub episode: usize,
22
23    /// Step index within the current episode.
24    pub episode_step: usize,
25
26    /// Reward received this step.
27    pub reward: f64,
28
29    /// Cumulative reward in the current episode so far.
30    pub episode_reward: f64,
31
32    /// Current ε (exploration rate). `None` if in warm-up.
33    pub epsilon: f64,
34
35    /// Whether a gradient update was performed this step.
36    pub did_update: bool,
37
38    /// Whether this step ended the episode.
39    pub episode_done: bool,
40
41    /// How the episode ended (only meaningful when `episode_done` is `true`).
42    pub episode_status: rl_traits::EpisodeStatus,
43}
44
45/// The imperative training driver.
46///
47/// Drives the interaction between an environment and a DQN agent,
48/// exposing it as an iterator that yields `StepMetrics` after every step,
49/// as well as higher-level `train()` and `eval()` methods.
50///
51/// # Usage — iterator style (manual control)
52///
53/// ```rust,ignore
54/// let mut trainer = DqnTrainer::new(env, agent, seed);
55///
56/// for step in trainer.steps().take(50_000) {
57///     if step.episode_done {
58///         println!("Episode {} reward: {}", step.episode, step.episode_reward);
59///     }
60/// }
61/// ```
62///
63/// # Usage — imperative style (with `TrainingRun`)
64///
65/// ```rust,ignore
66/// let run = TrainingRun::create("cartpole", "v1")?;
67/// let mut trainer = DqnTrainer::new(env, agent, seed).with_run(run);
68/// trainer.train(200_000);
69/// let report = trainer.eval(20);
70/// report.print();
71/// ```
72pub struct DqnTrainer<E, Enc, Act, B, Buf = CircularBuffer<
73    <E as Environment>::Observation,
74    <E as Environment>::Action,
75>>
76where
77    E: Environment,
78    B: AutodiffBackend,
79    Buf: ReplayBuffer<E::Observation, E::Action>,
80{
81    env: E,
82    agent: DqnAgent<E, Enc, Act, B, Buf>,
83    rng: SmallRng,
84
85    // Episode state
86    current_obs: Option<E::Observation>,
87    episode: usize,
88    episode_step: usize,
89    episode_reward: f64,
90
91    // Optional run tracking
92    run: Option<TrainingRun>,
93    stats: StatsTracker,
94
95    // Checkpoint policy
96    checkpoint_freq: usize,
97    keep_checkpoints: usize,
98
99    // For "save best" during eval
100    best_eval_reward: f64,
101}
102
103impl<E, Enc, Act, B, Buf> DqnTrainer<E, Enc, Act, B, Buf>
104where
105    E: Environment,
106    E::Observation: Clone + Send + Sync + 'static,
107    E::Action: Clone + Send + Sync + 'static,
108    Enc: ObservationEncoder<E::Observation, B>
109        + ObservationEncoder<E::Observation, B::InnerBackend>,
110    Act: DiscreteActionMapper<E::Action>,
111    B: AutodiffBackend,
112    Buf: ReplayBuffer<E::Observation, E::Action>,
113{
114    pub fn new(env: E, agent: DqnAgent<E, Enc, Act, B, Buf>, seed: u64) -> Self {
115        Self {
116            env,
117            agent,
118            rng: SmallRng::seed_from_u64(seed),
119            current_obs: None,
120            episode: 0,
121            episode_step: 0,
122            episode_reward: 0.0,
123            run: None,
124            stats: StatsTracker::new(),
125            checkpoint_freq: 10_000,
126            keep_checkpoints: 5,
127            best_eval_reward: f64::NEG_INFINITY,
128        }
129    }
130
131    /// Attach a `TrainingRun` for checkpoint saving and stats persistence.
132    pub fn with_run(mut self, run: TrainingRun) -> Self {
133        self.run = Some(run);
134        self
135    }
136
137    /// How often (in steps) to save a numbered checkpoint. Default: 10_000.
138    pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
139        self.checkpoint_freq = freq;
140        self
141    }
142
143    /// How many numbered checkpoints to keep on disk. Default: 5.
144    pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
145        self.keep_checkpoints = keep;
146        self
147    }
148
149    /// Replace the default stats tracker with a custom one.
150    pub fn with_stats(mut self, stats: StatsTracker) -> Self {
151        self.stats = stats;
152        self
153    }
154
155    /// Returns an iterator that yields `StepMetrics` after each environment step.
156    ///
157    /// The iterator is infinite — stop it with `.take(n)` or `break`.
158    pub fn steps(&mut self) -> TrainIter<'_, E, Enc, Act, B, Buf> {
159        TrainIter { trainer: self }
160    }
161
162    /// Access the agent for evaluation or inspection.
163    pub fn agent(&self) -> &DqnAgent<E, Enc, Act, B, Buf> {
164        &self.agent
165    }
166
167    /// Consume the trainer and return the inner agent.
168    ///
169    /// Useful for converting to a `DqnPolicy` after training:
170    /// ```rust,ignore
171    /// let policy = trainer.into_agent().into_policy();
172    /// ```
173    pub fn into_agent(self) -> DqnAgent<E, Enc, Act, B, Buf> {
174        self.agent
175    }
176
177    /// Access the environment directly.
178    pub fn env(&self) -> &E {
179        &self.env
180    }
181
182    /// Run `n_steps` of training.
183    ///
184    /// If a `TrainingRun` is attached, saves checkpoints at `checkpoint_freq`
185    /// intervals and writes episode records to `train_episodes.jsonl`.
186    pub fn train(&mut self, n_steps: usize) {
187        let start_steps = self.agent.total_steps();
188        let target_steps = start_steps + n_steps;
189
190        loop {
191            let metrics = self.step_once();
192            let total = metrics.total_steps;
193
194            if metrics.episode_done {
195                let record = EpisodeRecord::new(
196                    metrics.episode_reward,
197                    metrics.episode_step,
198                    metrics.episode_status.clone(),
199                );
200                self.stats.update(&record);
201                if let Some(run) = &self.run {
202                    let _ = run.log_train_episode(&record);
203                }
204            }
205
206            // Periodic checkpoint
207            if let Some(run) = &mut self.run {
208                if total.is_multiple_of(self.checkpoint_freq) {
209                    let path = run.checkpoint_path(total);
210                    // Strip the .mpk extension — DqnAgent::save appends it
211                    let path_no_ext = path.with_extension("");
212                    let _ = self.agent.save(&path_no_ext);
213                    let _ = self.agent.save(run.latest_checkpoint_path().with_extension(""));
214                    let _ = run.prune_checkpoints(self.keep_checkpoints);
215                    let _ = run.update_metadata(total, self.episode);
216                }
217            }
218
219            if total >= target_steps {
220                break;
221            }
222        }
223    }
224
225    /// Run `n_episodes` of greedy evaluation and return an `EvalReport`.
226    ///
227    /// Exploration is disabled (ε = 0). If a `TrainingRun` is attached,
228    /// each episode record is written to `eval_episodes.jsonl`.
229    /// If the mean reward improves, saves a `best.mpk` checkpoint.
230    pub fn eval(&mut self, n_episodes: usize) -> EvalReport {
231        let total_steps = self.agent.total_steps();
232        let mut eval_stats = StatsTracker::new();
233        let mut records = Vec::with_capacity(n_episodes);
234
235        for _ in 0..n_episodes {
236            let record = self.run_greedy_episode();
237            eval_stats.update(&record);
238            records.push(record);
239        }
240
241        let summary = eval_stats.summary();
242        let mean_reward = summary.get("episode_reward").copied().unwrap_or(f64::NAN);
243
244        // Save best checkpoint
245        if mean_reward > self.best_eval_reward {
246            self.best_eval_reward = mean_reward;
247            if let Some(run) = &self.run {
248                let _ = self.agent.save(run.best_checkpoint_path().with_extension(""));
249            }
250        }
251
252        // Log eval episodes
253        if let Some(run) = &self.run {
254            for record in &records {
255                let _ = run.log_eval_episode(record, total_steps);
256            }
257        }
258
259        EvalReport::new(total_steps, n_episodes, summary)
260    }
261
262    // ── Private helpers ───────────────────────────────────────────────────────
263
264    /// Perform one step. Called by `TrainIter::next()`.
265    fn step_once(&mut self) -> StepMetrics {
266        // Initialise the first episode
267        if self.current_obs.is_none() {
268            let (obs, _info) = self.env.reset(Some(0));
269            self.current_obs = Some(obs);
270            self.episode = 0;
271            self.episode_step = 0;
272            self.episode_reward = 0.0;
273        }
274
275        let obs = self.current_obs.clone().unwrap();
276
277        // ε-greedy action selection
278        let action = self.agent.act_epsilon_greedy(&obs, &mut self.rng);
279        let epsilon = self.agent.epsilon();
280
281        // Step environment
282        let result = self.env.step(action.clone());
283        let reward = result.reward;
284        let done = result.is_done();
285
286        self.episode_reward += reward;
287        self.episode_step += 1;
288
289        // Store experience
290        let experience = Experience::new(
291            obs,
292            action,
293            reward,
294            result.observation.clone(),
295            result.status.clone(),
296        );
297        let did_update = self.agent.observe(experience);
298
299        let metrics = StepMetrics {
300            total_steps: self.agent.total_steps(),
301            episode: self.episode,
302            episode_step: self.episode_step,
303            reward,
304            episode_reward: self.episode_reward,
305            epsilon,
306            did_update,
307            episode_done: done,
308            episode_status: result.status.clone(),
309        };
310
311        // Handle episode boundary
312        if done {
313            let (next_obs, _info) = self.env.reset(None);
314            self.current_obs = Some(next_obs);
315            self.episode += 1;
316            self.episode_step = 0;
317            self.episode_reward = 0.0;
318        } else {
319            self.current_obs = Some(result.observation);
320        }
321
322        metrics
323    }
324
325    /// Run one full episode greedily (no exploration). Returns the episode record.
326    fn run_greedy_episode(&mut self) -> EpisodeRecord {
327        let (mut obs, _) = self.env.reset(None);
328        let mut total_reward = 0.0;
329        let mut length = 0;
330
331        loop {
332            let action = self.agent.act(&obs);
333            let result = self.env.step(action);
334            total_reward += result.reward;
335            length += 1;
336
337            if result.is_done() {
338                return EpisodeRecord::new(total_reward, length, result.status);
339            }
340            obs = result.observation;
341        }
342    }
343}
344
345// ── TrainIter ─────────────────────────────────────────────────────────────────
346
347/// The iterator returned by `DqnTrainer::steps()`.
348pub struct TrainIter<'a, E, Enc, Act, B, Buf = CircularBuffer<
349    <E as Environment>::Observation,
350    <E as Environment>::Action,
351>>
352where
353    E: Environment,
354    B: AutodiffBackend,
355    Buf: ReplayBuffer<E::Observation, E::Action>,
356{
357    trainer: &'a mut DqnTrainer<E, Enc, Act, B, Buf>,
358}
359
360impl<'a, E, Enc, Act, B, Buf> Iterator for TrainIter<'a, E, Enc, Act, B, Buf>
361where
362    E: Environment,
363    E::Observation: Clone + Send + Sync + 'static,
364    E::Action: Clone + Send + Sync + 'static,
365    Enc: ObservationEncoder<E::Observation, B>
366        + ObservationEncoder<E::Observation, B::InnerBackend>,
367    Act: DiscreteActionMapper<E::Action>,
368    B: AutodiffBackend,
369    Buf: ReplayBuffer<E::Observation, E::Action>,
370{
371    type Item = StepMetrics;
372
373    fn next(&mut self) -> Option<StepMetrics> {
374        // The iterator is infinite — training stops when the caller stops
375        // consuming it (e.g. via `.take(n)` or a manual break).
376        Some(self.trainer.step_once())
377    }
378}
379