Skip to main content

ember_rl/training/
session.rs

1use std::collections::HashMap;
2
3use rl_traits::{Environment, Experience};
4
5use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
6use crate::traits::{ActMode, LearningAgent};
7use crate::training::run::TrainingRun;
8
9
10/// Configuration for a `TrainingSession`.
11#[derive(Debug, Clone)]
12pub struct SessionConfig {
13    /// Stop when `total_steps >= max_steps`. Default: `usize::MAX` (no limit).
14    pub max_steps: usize,
15
16    /// Save a numbered checkpoint every this many steps. Default: 10_000.
17    pub checkpoint_freq: usize,
18
19    /// Number of recent numbered checkpoints to keep on disk. Default: 3.
20    pub keep_checkpoints: usize,
21}
22
23impl Default for SessionConfig {
24    fn default() -> Self {
25        Self {
26            max_steps: usize::MAX,
27            checkpoint_freq: 10_000,
28            keep_checkpoints: 3,
29        }
30    }
31}
32
33/// A self-contained, loop-agnostic training coordinator.
34///
35/// `TrainingSession` wires together a [`LearningAgent`], an optional
36/// [`TrainingRun`], and a [`StatsTracker`]. It is driven purely by incoming
37/// data -- it does not own a training loop. Feed it experiences and episode
38/// boundaries from wherever your loop lives: a plain `for` loop, Bevy's ECS,
39/// or anything else.
40///
41/// # Usage
42///
43/// ```rust,ignore
44/// let session = TrainingSession::new(agent)
45///     .with_run(TrainingRun::create("cartpole", "v1")?)
46///     .with_max_steps(200_000)
47///     .with_checkpoint_freq(10_000);
48///
49/// // Each environment step:
50/// session.observe(experience);
51///
52/// // Each episode end:
53/// session.on_episode(total_reward, steps, status, env_extras);
54///
55/// if session.is_done() { break; }
56/// ```
57pub struct TrainingSession<E: Environment, A> {
58    agent: A,
59    run: Option<TrainingRun>,
60    stats: StatsTracker,
61    config: SessionConfig,
62    best_eval_reward: f64,
63    _env: std::marker::PhantomData<E>,
64}
65
66impl<E, A> TrainingSession<E, A>
67where
68    E: Environment,
69    E::Observation: Clone + Send + Sync + 'static,
70    E::Action: Clone + Send + Sync + 'static,
71    A: LearningAgent<E>,
72{
73    /// Create a session with no run attached.
74    ///
75    /// Stats are tracked in memory but nothing is persisted. Attach a run with
76    /// [`with_run`] to enable checkpointing and JSONL logging.
77    ///
78    /// [`with_run`]: TrainingSession::with_run
79    pub fn new(agent: A) -> Self {
80        Self {
81            agent,
82            run: None,
83            stats: StatsTracker::new(),
84            config: SessionConfig::default(),
85            best_eval_reward: f64::NEG_INFINITY,
86            _env: std::marker::PhantomData,
87        }
88    }
89
90    /// Attach a `TrainingRun` for checkpointing and JSONL episode logging.
91    pub fn with_run(mut self, run: TrainingRun) -> Self {
92        self.run = Some(run);
93        self
94    }
95
96    /// Maximum number of steps before `is_done()` returns `true`. Default: no limit.
97    pub fn with_max_steps(mut self, n: usize) -> Self {
98        self.config.max_steps = n;
99        self
100    }
101
102    /// Checkpoint frequency in steps. Default: 10_000.
103    pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
104        self.config.checkpoint_freq = freq;
105        self
106    }
107
108    /// Number of numbered checkpoints to retain on disk. Default: 3.
109    pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
110        self.config.keep_checkpoints = keep;
111        self
112    }
113
114    /// Replace the default `StatsTracker` with a custom one.
115    pub fn with_stats(mut self, stats: StatsTracker) -> Self {
116        self.stats = stats;
117        self
118    }
119
120    // ── Data ingestion ────────────────────────────────────────────────────────
121
122    /// Select an action for `obs` according to `mode`.
123    pub fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
124        self.agent.act(obs, mode)
125    }
126
127    /// Record a transition. Checkpoints + prunes if a step milestone is hit.
128    pub fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
129        self.agent.observe(experience);
130
131        let total = self.agent.total_steps();
132        if total > 0 && total.is_multiple_of(self.config.checkpoint_freq) {
133            if let Some(run) = &mut self.run {
134                let path = run.checkpoint_path(total).with_extension("");
135                let _ = self.agent.save(&path);
136                let latest = run.latest_checkpoint_path().with_extension("");
137                let _ = self.agent.save(&latest);
138                let _ = run.prune_checkpoints(self.config.keep_checkpoints);
139            }
140        }
141    }
142
143    /// Record an episode boundary.
144    ///
145    /// Merges agent and environment extras into the record, updates stats,
146    /// and appends to the training JSONL log (if a run is attached).
147    ///
148    /// `env_extras` should come from [`crate::traits::EpisodeStats::episode_extras`]
149    /// if the environment implements it, or an empty map otherwise.
150    pub fn on_episode(
151        &mut self,
152        total_reward: f64,
153        steps: usize,
154        status: rl_traits::EpisodeStatus,
155        env_extras: HashMap<String, f64>,
156    ) {
157        let agent_extras = self.agent.episode_extras();
158        let record = EpisodeRecord::new(total_reward, steps, status)
159            .with_extras(env_extras)
160            .with_extras(agent_extras);
161
162        self.stats.update(&record);
163
164        if let Some(run) = &mut self.run {
165            let _ = run.log_train_episode(&record);
166            let _ = run.update_metadata(self.agent.total_steps(), 0);
167        }
168
169        self.agent.on_episode_start();
170    }
171
172    /// Signal the start of a new episode (resets per-episode agent aggregators).
173    pub fn on_episode_start(&mut self) {
174        self.agent.on_episode_start();
175    }
176
177    /// Total environment steps observed so far.
178    pub fn total_steps(&self) -> usize {
179        self.agent.total_steps()
180    }
181
182    /// Returns `true` when `total_steps >= max_steps`.
183    pub fn is_done(&self) -> bool {
184        self.config.max_steps != usize::MAX
185            && self.agent.total_steps() >= self.config.max_steps
186    }
187
188    // ── Eval ──────────────────────────────────────────────────────────────────
189
190    /// Log an eval episode to the run (if attached).
191    pub fn on_eval_episode(&self, record: &EpisodeRecord) {
192        if let Some(run) = &self.run {
193            let _ = run.log_eval_episode(record, self.agent.total_steps());
194        }
195    }
196
197    /// Save `best.mpk` if `mean_reward` exceeds the best seen so far.
198    pub fn maybe_save_best(&mut self, mean_reward: f64) {
199        if mean_reward > self.best_eval_reward {
200            self.best_eval_reward = mean_reward;
201            if let Some(run) = &self.run {
202                let best = run.best_checkpoint_path().with_extension("");
203                let _ = self.agent.save(&best);
204            }
205        }
206    }
207
208    // ── Access ────────────────────────────────────────────────────────────────
209
210    /// Read-only access to the agent.
211    pub fn agent(&self) -> &A {
212        &self.agent
213    }
214
215    /// Mutable access to the agent.
216    pub fn agent_mut(&mut self) -> &mut A {
217        &mut self.agent
218    }
219
220    /// Current stats summary.
221    pub fn stats_summary(&self) -> HashMap<String, f64> {
222        self.stats.summary()
223    }
224
225    /// Read-only access to the run (if attached).
226    pub fn run(&self) -> Option<&TrainingRun> {
227        self.run.as_ref()
228    }
229
230    /// Consume the session and return the inner agent.
231    pub fn into_agent(self) -> A {
232        self.agent
233    }
234
235    /// Snapshot the current stats as an `EvalReport`.
236    pub fn eval_report(&self, n_episodes: usize) -> EvalReport {
237        EvalReport::new(self.agent.total_steps(), n_episodes, self.stats.summary())
238    }
239}