Skip to main content

ember_rl/training/
runner.rs

1use burn::tensor::backend::AutodiffBackend;
2use rl_traits::{Environment, Experience};
3
4use crate::algorithms::dqn::{CircularBuffer, DqnAgent};
5use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
6use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
7use crate::traits::ActMode;
8use crate::training::run::TrainingRun;
9use crate::training::session::TrainingSession;
10use rl_traits::ReplayBuffer;
11
12/// Metrics emitted after every environment step.
13#[derive(Debug, Clone)]
14pub struct StepMetrics {
15    /// Total environment steps taken so far.
16    pub total_steps: usize,
17
18    /// Which episode this step belongs to.
19    pub episode: usize,
20
21    /// Step index within the current episode.
22    pub episode_step: usize,
23
24    /// Reward received this step.
25    pub reward: f64,
26
27    /// Cumulative reward in the current episode so far.
28    pub episode_reward: f64,
29
30    /// Current ε (exploration rate).
31    pub epsilon: f64,
32
33    /// Whether this step ended the episode.
34    pub episode_done: bool,
35
36    /// How the episode ended (only meaningful when `episode_done` is `true`).
37    pub episode_status: rl_traits::EpisodeStatus,
38}
39
40/// The imperative training driver for DQN.
41///
42/// A thin convenience wrapper around [`TrainingSession`] that adds an
43/// environment and drives the training loop. Use [`TrainingSession`] directly
44/// when your loop is owned externally (e.g. Bevy's ECS).
45///
46/// # Usage — iterator style
47///
48/// ```rust,ignore
49/// let mut trainer = DqnTrainer::new(env, agent);
50///
51/// for step in trainer.steps().take(50_000) {
52///     if step.episode_done {
53///         println!("Episode {} reward: {}", step.episode, step.episode_reward);
54///     }
55/// }
56/// ```
57///
58/// # Usage — imperative style with run tracking
59///
60/// ```rust,ignore
61/// let mut trainer = DqnTrainer::new(env, agent)
62///     .with_run(TrainingRun::create("cartpole", "v1")?)
63///     .with_max_steps(200_000);
64///
65/// trainer.train();
66/// let report = trainer.eval(20);
67/// report.print();
68/// ```
69pub struct DqnTrainer<E, Enc, Act, B, Buf = CircularBuffer<
70    <E as Environment>::Observation,
71    <E as Environment>::Action,
72>>
73where
74    E: Environment,
75    B: AutodiffBackend,
76    Buf: ReplayBuffer<E::Observation, E::Action>,
77{
78    env: E,
79    session: TrainingSession<E, DqnAgent<E, Enc, Act, B, Buf>>,
80
81    // Episode state
82    current_obs: Option<E::Observation>,
83    episode: usize,
84    episode_step: usize,
85    episode_reward: f64,
86}
87
88impl<E, Enc, Act, B, Buf> DqnTrainer<E, Enc, Act, B, Buf>
89where
90    E: Environment,
91    E::Observation: Clone + Send + Sync + 'static,
92    E::Action: Clone + Send + Sync + 'static,
93    Enc: ObservationEncoder<E::Observation, B>
94        + ObservationEncoder<E::Observation, B::InnerBackend>,
95    Act: DiscreteActionMapper<E::Action>,
96    B: AutodiffBackend,
97    Buf: ReplayBuffer<E::Observation, E::Action>,
98{
99    pub fn new(env: E, agent: DqnAgent<E, Enc, Act, B, Buf>) -> Self {
100        Self {
101            env,
102            session: TrainingSession::new(agent),
103            current_obs: None,
104            episode: 0,
105            episode_step: 0,
106            episode_reward: 0.0,
107        }
108    }
109
110    /// Attach a `TrainingRun` for checkpoint saving and JSONL logging.
111    pub fn with_run(mut self, run: TrainingRun) -> Self {
112        self.session = self.session.with_run(run);
113        self
114    }
115
116    /// Maximum training steps. `train()` stops when this is reached.
117    pub fn with_max_steps(mut self, n: usize) -> Self {
118        self.session = self.session.with_max_steps(n);
119        self
120    }
121
122    /// Checkpoint frequency in steps. Default: 10_000.
123    pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
124        self.session = self.session.with_checkpoint_freq(freq);
125        self
126    }
127
128    /// Number of numbered checkpoints to keep on disk. Default: 3.
129    pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
130        self.session = self.session.with_keep_checkpoints(keep);
131        self
132    }
133
134    /// Replace the default stats tracker.
135    pub fn with_stats(mut self, stats: StatsTracker) -> Self {
136        self.session = self.session.with_stats(stats);
137        self
138    }
139
140    /// Returns an iterator that yields `StepMetrics` after each environment step.
141    pub fn steps(&mut self) -> TrainIter<'_, E, Enc, Act, B, Buf> {
142        TrainIter { trainer: self }
143    }
144
145    /// Run until `max_steps` is reached (or forever if not set).
146    pub fn train(&mut self) {
147        loop {
148            self.step_once();
149            if self.session.is_done() {
150                break;
151            }
152        }
153    }
154
155    /// Run `n_episodes` of greedy evaluation and return an `EvalReport`.
156    pub fn eval(&mut self, n_episodes: usize) -> EvalReport {
157        let mut eval_stats = StatsTracker::new();
158        let mut records = Vec::with_capacity(n_episodes);
159
160        for _ in 0..n_episodes {
161            let record = self.run_greedy_episode();
162            eval_stats.update(&record);
163            self.session.on_eval_episode(&record);
164            records.push(record);
165        }
166
167        let summary = eval_stats.summary();
168        let mean_reward = summary.get("episode_reward").copied().unwrap_or(f64::NAN);
169        self.session.maybe_save_best(mean_reward);
170
171        let total_steps = self.session.total_steps();
172        EvalReport::new(total_steps, n_episodes, summary)
173    }
174
175    /// Read-only access to the underlying session.
176    pub fn session(&self) -> &TrainingSession<E, DqnAgent<E, Enc, Act, B, Buf>> {
177        &self.session
178    }
179
180    /// Consume the trainer and return the inner agent.
181    pub fn into_agent(self) -> DqnAgent<E, Enc, Act, B, Buf> {
182        self.session.into_agent()
183    }
184
185    /// Access the environment directly.
186    pub fn env(&self) -> &E {
187        &self.env
188    }
189
190    // ── Private helpers ───────────────────────────────────────────────────────
191
192    fn step_once(&mut self) -> StepMetrics {
193        if self.current_obs.is_none() {
194            let (obs, _) = self.env.reset(Some(0));
195            self.current_obs = Some(obs);
196            self.episode = 0;
197            self.episode_step = 0;
198            self.episode_reward = 0.0;
199            self.session.on_episode_start();
200        }
201
202        let obs = self.current_obs.clone().unwrap();
203        let epsilon = self.session.agent().epsilon();
204        let action = self.session.act(&obs, ActMode::Explore);
205
206        let result = self.env.step(action.clone());
207        let reward = result.reward;
208        let done = result.is_done();
209
210        self.episode_reward += reward;
211        self.episode_step += 1;
212
213        self.session.observe(Experience::new(
214            obs,
215            action,
216            reward,
217            result.observation.clone(),
218            result.status.clone(),
219        ));
220
221        let metrics = StepMetrics {
222            total_steps: self.session.total_steps(),
223            episode: self.episode,
224            episode_step: self.episode_step,
225            reward,
226            episode_reward: self.episode_reward,
227            epsilon,
228            episode_done: done,
229            episode_status: result.status.clone(),
230        };
231
232        if done {
233            self.session.on_episode(
234                self.episode_reward,
235                self.episode_step,
236                result.status,
237                self.env.episode_extras(),
238            );
239            let (next_obs, _) = self.env.reset(None);
240            self.current_obs = Some(next_obs);
241            self.episode += 1;
242            self.episode_step = 0;
243            self.episode_reward = 0.0;
244        } else {
245            self.current_obs = Some(result.observation);
246        }
247
248        metrics
249    }
250
251    fn run_greedy_episode(&mut self) -> EpisodeRecord {
252        let (mut obs, _) = self.env.reset(None);
253        let mut total_reward = 0.0;
254        let mut length = 0;
255
256        loop {
257            let action = self.session.act(&obs, ActMode::Exploit);
258            let result = self.env.step(action);
259            total_reward += result.reward;
260            length += 1;
261
262            if result.is_done() {
263                return EpisodeRecord::new(total_reward, length, result.status);
264            }
265            obs = result.observation;
266        }
267    }
268}
269
270// ── TrainIter ─────────────────────────────────────────────────────────────────
271
272pub struct TrainIter<'a, E, Enc, Act, B, Buf = CircularBuffer<
273    <E as Environment>::Observation,
274    <E as Environment>::Action,
275>>
276where
277    E: Environment,
278    B: AutodiffBackend,
279    Buf: ReplayBuffer<E::Observation, E::Action>,
280{
281    trainer: &'a mut DqnTrainer<E, Enc, Act, B, Buf>,
282}
283
284impl<'a, E, Enc, Act, B, Buf> Iterator for TrainIter<'a, E, Enc, Act, B, Buf>
285where
286    E: Environment,
287    E::Observation: Clone + Send + Sync + 'static,
288    E::Action: Clone + Send + Sync + 'static,
289    Enc: ObservationEncoder<E::Observation, B>
290        + ObservationEncoder<E::Observation, B::InnerBackend>,
291    Act: DiscreteActionMapper<E::Action>,
292    B: AutodiffBackend,
293    Buf: ReplayBuffer<E::Observation, E::Action>,
294{
295    type Item = StepMetrics;
296
297    fn next(&mut self) -> Option<StepMetrics> {
298        Some(self.trainer.step_once())
299    }
300}