Skip to main content

ember_rl/training/
runner.rs

1use burn::tensor::backend::AutodiffBackend;
2use rand::rngs::SmallRng;
3use rand::SeedableRng;
4use rl_traits::{Environment, Experience};
5
6use crate::algorithms::dqn::DqnAgent;
7use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
8
9/// Metrics emitted after every environment step.
10///
11/// The runner yields one of these per call to `Iterator::next()`.
12/// Use these to log progress, plot curves, or decide when to stop.
13#[derive(Debug, Clone)]
14pub struct StepMetrics {
15    /// Total environment steps taken so far.
16    pub total_steps: usize,
17
18    /// Which episode this step belongs to.
19    pub episode: usize,
20
21    /// Step index within the current episode.
22    pub episode_step: usize,
23
24    /// Reward received this step.
25    pub reward: f64,
26
27    /// Cumulative reward in the current episode so far.
28    pub episode_reward: f64,
29
30    /// Current ε (exploration rate). `None` if in warm-up.
31    pub epsilon: f64,
32
33    /// Whether a gradient update was performed this step.
34    pub did_update: bool,
35
36    /// Whether this step ended the episode.
37    pub episode_done: bool,
38}
39
40/// The imperative training runner.
41///
42/// Drives the interaction between an environment and a DQN agent,
43/// exposing it as an iterator that yields `StepMetrics` after every step.
44///
45/// # Usage
46///
47/// ```rust,ignore
48/// let mut runner = DqnRunner::new(env, agent, seed);
49///
50/// for step in runner.steps().take(50_000) {
51///     if step.episode_done {
52///         println!("Episode {} reward: {}", step.episode, step.episode_reward);
53///     }
54/// }
55/// ```
56///
57/// # Why an iterator?
58///
59/// - You control the loop: add early stopping, custom logging, checkpointing
60/// - bevy-gym can drive the same runner one step per ECS tick
61/// - No callbacks, no closures, no inversion of control
62pub struct DqnRunner<E, Enc, Act, B>
63where
64    E: Environment,
65    B: AutodiffBackend,
66{
67    env: E,
68    agent: DqnAgent<E, Enc, Act, B>,
69    rng: SmallRng,
70
71    // Episode state
72    current_obs: Option<E::Observation>,
73    episode: usize,
74    episode_step: usize,
75    episode_reward: f64,
76}
77
78impl<E, Enc, Act, B> DqnRunner<E, Enc, Act, B>
79where
80    E: Environment,
81    E::Observation: Clone + Send + Sync + 'static,
82    E::Action: Clone + Send + Sync + 'static,
83    Enc: ObservationEncoder<E::Observation, B>
84        + ObservationEncoder<E::Observation, B::InnerBackend>,
85    Act: DiscreteActionMapper<E::Action>,
86    B: AutodiffBackend,
87{
88    pub fn new(env: E, agent: DqnAgent<E, Enc, Act, B>, seed: u64) -> Self {
89        Self {
90            env,
91            agent,
92            rng: SmallRng::seed_from_u64(seed),
93            current_obs: None,
94            episode: 0,
95            episode_step: 0,
96            episode_reward: 0.0,
97        }
98    }
99
100    /// Returns an iterator that yields `StepMetrics` after each environment step.
101    pub fn steps(&mut self) -> StepIter<'_, E, Enc, Act, B> {
102        StepIter { runner: self }
103    }
104
105    /// Access the agent for evaluation or inspection.
106    pub fn agent(&self) -> &DqnAgent<E, Enc, Act, B> {
107        &self.agent
108    }
109
110    /// Access the environment directly.
111    pub fn env(&self) -> &E {
112        &self.env
113    }
114
115    /// Perform one step. Called by `StepIter::next()`.
116    fn step_once(&mut self) -> StepMetrics {
117        // Initialise the first episode
118        if self.current_obs.is_none() {
119            let (obs, _info) = self.env.reset(Some(0));
120            self.current_obs = Some(obs);
121            self.episode = 0;
122            self.episode_step = 0;
123            self.episode_reward = 0.0;
124        }
125
126        let obs = self.current_obs.clone().unwrap();
127
128        // ε-greedy action selection
129        let action = self.agent.act_epsilon_greedy(&obs, &mut self.rng);
130        let epsilon = self.agent.epsilon();
131
132        // Step environment
133        let result = self.env.step(action.clone());
134        let reward = result.reward;
135        let done = result.is_done();
136
137        self.episode_reward += reward;
138        self.episode_step += 1;
139
140        // Store experience
141        let experience = Experience::new(
142            obs,
143            action,
144            reward,
145            result.observation.clone(),
146            result.status.clone(),
147        );
148        let did_update = self.agent.observe(experience);
149
150        let metrics = StepMetrics {
151            total_steps: self.agent.total_steps(),
152            episode: self.episode,
153            episode_step: self.episode_step,
154            reward,
155            episode_reward: self.episode_reward,
156            epsilon,
157            did_update,
158            episode_done: done,
159        };
160
161        // Handle episode boundary
162        if done {
163            let (next_obs, _info) = self.env.reset(None);
164            self.current_obs = Some(next_obs);
165            self.episode += 1;
166            self.episode_step = 0;
167            self.episode_reward = 0.0;
168        } else {
169            self.current_obs = Some(result.observation);
170        }
171
172        metrics
173    }
174}
175
176/// The iterator returned by `DqnRunner::steps()`.
177pub struct StepIter<'a, E, Enc, Act, B>
178where
179    E: Environment,
180    B: AutodiffBackend,
181{
182    runner: &'a mut DqnRunner<E, Enc, Act, B>,
183}
184
185impl<'a, E, Enc, Act, B> Iterator for StepIter<'a, E, Enc, Act, B>
186where
187    E: Environment,
188    E::Observation: Clone + Send + Sync + 'static,
189    E::Action: Clone + Send + Sync + 'static,
190    Enc: ObservationEncoder<E::Observation, B>
191        + ObservationEncoder<E::Observation, B::InnerBackend>,
192    Act: DiscreteActionMapper<E::Action>,
193    B: AutodiffBackend,
194{
195    type Item = StepMetrics;
196
197    fn next(&mut self) -> Option<StepMetrics> {
198        // The iterator is infinite — training stops when the caller stops
199        // consuming it (e.g. via `.take(n)` or a manual break).
200        Some(self.runner.step_once())
201    }
202}