Skip to main content

ember_rl/training/
session.rs

1use std::collections::HashMap;
2
3use rl_traits::{Environment, Experience};
4
5use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
6use crate::traits::{ActMode, LearningAgent};
7use crate::training::run::TrainingRun;
8
9/// Configuration for a `TrainingSession`.
10#[derive(Debug, Clone)]
11pub struct SessionConfig {
12    /// Stop when `total_steps >= max_steps`. Default: `usize::MAX` (no limit).
13    pub max_steps: usize,
14
15    /// Save a numbered checkpoint every this many steps. Default: 10_000.
16    pub checkpoint_freq: usize,
17
18    /// Number of recent numbered checkpoints to keep on disk. Default: 3.
19    pub keep_checkpoints: usize,
20}
21
22impl Default for SessionConfig {
23    fn default() -> Self {
24        Self {
25            max_steps: usize::MAX,
26            checkpoint_freq: 10_000,
27            keep_checkpoints: 3,
28        }
29    }
30}
31
32/// A self-contained, loop-agnostic training coordinator.
33///
34/// `TrainingSession` wires together a [`LearningAgent`], an optional
35/// [`TrainingRun`], and a [`StatsTracker`]. It is driven purely by incoming
36/// data — it does not own a training loop. Feed it experiences and episode
37/// boundaries from wherever your loop lives: a plain `for` loop, Bevy's ECS,
38/// or anything else.
39///
40/// # Usage
41///
42/// ```rust,ignore
43/// let session = TrainingSession::new(agent)
44///     .with_run(TrainingRun::create("cartpole", "v1")?)
45///     .with_max_steps(200_000)
46///     .with_checkpoint_freq(10_000);
47///
48/// // Each environment step:
49/// session.observe(experience);
50///
51/// // Each episode end:
52/// session.on_episode(total_reward, steps, status, env_extras);
53///
54/// if session.is_done() { break; }
55/// ```
56pub struct TrainingSession<E: Environment, A> {
57    agent: A,
58    run: Option<TrainingRun>,
59    stats: StatsTracker,
60    config: SessionConfig,
61    best_eval_reward: f64,
62    _env: std::marker::PhantomData<E>,
63}
64
65impl<E, A> TrainingSession<E, A>
66where
67    E: Environment,
68    E::Observation: Clone + Send + Sync + 'static,
69    E::Action: Clone + Send + Sync + 'static,
70    A: LearningAgent<E>,
71{
72    /// Create a session with no run attached.
73    ///
74    /// Stats are tracked in memory but nothing is persisted. Attach a run with
75    /// [`with_run`] to enable checkpointing and JSONL logging.
76    ///
77    /// [`with_run`]: TrainingSession::with_run
78    pub fn new(agent: A) -> Self {
79        Self {
80            agent,
81            run: None,
82            stats: StatsTracker::new(),
83            config: SessionConfig::default(),
84            best_eval_reward: f64::NEG_INFINITY,
85            _env: std::marker::PhantomData,
86        }
87    }
88
89    /// Attach a `TrainingRun` for checkpointing and JSONL episode logging.
90    pub fn with_run(mut self, run: TrainingRun) -> Self {
91        self.run = Some(run);
92        self
93    }
94
95    /// Maximum number of steps before `is_done()` returns `true`. Default: no limit.
96    pub fn with_max_steps(mut self, n: usize) -> Self {
97        self.config.max_steps = n;
98        self
99    }
100
101    /// Checkpoint frequency in steps. Default: 10_000.
102    pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
103        self.config.checkpoint_freq = freq;
104        self
105    }
106
107    /// Number of numbered checkpoints to retain on disk. Default: 3.
108    pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
109        self.config.keep_checkpoints = keep;
110        self
111    }
112
113    /// Replace the default `StatsTracker` with a custom one.
114    pub fn with_stats(mut self, stats: StatsTracker) -> Self {
115        self.stats = stats;
116        self
117    }
118
119    // ── Data ingestion ────────────────────────────────────────────────────────
120
121    /// Select an action for `obs` according to `mode`.
122    pub fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
123        self.agent.act(obs, mode)
124    }
125
126    /// Record a transition. Checkpoints + prunes if a step milestone is hit.
127    pub fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
128        self.agent.observe(experience);
129
130        let total = self.agent.total_steps();
131        if total > 0 && total.is_multiple_of(self.config.checkpoint_freq) {
132            if let Some(run) = &mut self.run {
133                let path = run.checkpoint_path(total).with_extension("");
134                let _ = self.agent.save(&path);
135                let latest = run.latest_checkpoint_path().with_extension("");
136                let _ = self.agent.save(&latest);
137                let _ = run.prune_checkpoints(self.config.keep_checkpoints);
138            }
139        }
140    }
141
142    /// Record an episode boundary.
143    ///
144    /// Merges agent and environment extras into the record, updates stats,
145    /// and appends to the training JSONL log (if a run is attached).
146    ///
147    /// `env_extras` should come from [`crate::traits::EpisodeStats::episode_extras`]
148    /// if the environment implements it, or an empty map otherwise.
149    pub fn on_episode(
150        &mut self,
151        total_reward: f64,
152        steps: usize,
153        status: rl_traits::EpisodeStatus,
154        env_extras: HashMap<String, f64>,
155    ) {
156        let agent_extras = self.agent.episode_extras();
157        let record = EpisodeRecord::new(total_reward, steps, status)
158            .with_extras(env_extras)
159            .with_extras(agent_extras);
160
161        self.stats.update(&record);
162
163        if let Some(run) = &mut self.run {
164            let _ = run.log_train_episode(&record);
165            let _ = run.update_metadata(self.agent.total_steps(), 0);
166        }
167
168        self.agent.on_episode_start();
169    }
170
171    /// Signal the start of a new episode (resets per-episode agent aggregators).
172    pub fn on_episode_start(&mut self) {
173        self.agent.on_episode_start();
174    }
175
176    /// Total environment steps observed so far.
177    pub fn total_steps(&self) -> usize {
178        self.agent.total_steps()
179    }
180
181    /// Returns `true` when `total_steps >= max_steps`.
182    pub fn is_done(&self) -> bool {
183        self.config.max_steps != usize::MAX
184            && self.agent.total_steps() >= self.config.max_steps
185    }
186
187    // ── Eval ──────────────────────────────────────────────────────────────────
188
189    /// Log an eval episode to the run (if attached).
190    pub fn on_eval_episode(&self, record: &EpisodeRecord) {
191        if let Some(run) = &self.run {
192            let _ = run.log_eval_episode(record, self.agent.total_steps());
193        }
194    }
195
196    /// Save `best.mpk` if `mean_reward` exceeds the best seen so far.
197    pub fn maybe_save_best(&mut self, mean_reward: f64) {
198        if mean_reward > self.best_eval_reward {
199            self.best_eval_reward = mean_reward;
200            if let Some(run) = &self.run {
201                let best = run.best_checkpoint_path().with_extension("");
202                let _ = self.agent.save(&best);
203            }
204        }
205    }
206
207    // ── Access ────────────────────────────────────────────────────────────────
208
209    /// Read-only access to the agent.
210    pub fn agent(&self) -> &A {
211        &self.agent
212    }
213
214    /// Mutable access to the agent.
215    pub fn agent_mut(&mut self) -> &mut A {
216        &mut self.agent
217    }
218
219    /// Current stats summary.
220    pub fn stats_summary(&self) -> HashMap<String, f64> {
221        self.stats.summary()
222    }
223
224    /// Read-only access to the run (if attached).
225    pub fn run(&self) -> Option<&TrainingRun> {
226        self.run.as_ref()
227    }
228
229    /// Consume the session and return the inner agent.
230    pub fn into_agent(self) -> A {
231        self.agent
232    }
233
234    /// Snapshot the current stats as an `EvalReport`.
235    pub fn eval_report(&self, n_episodes: usize) -> EvalReport {
236        EvalReport::new(self.agent.total_steps(), n_episodes, self.stats.summary())
237    }
238}