Skip to main content

ember_rl/training/
session.rs

1use std::collections::HashMap;
2use std::time::Instant;
3
4use rl_traits::{Environment, Experience};
5
6use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
7use crate::traits::{ActMode, LearningAgent};
8use crate::training::run::TrainingRun;
9
10
11/// Configuration for a `TrainingSession`.
12#[derive(Debug, Clone)]
13pub struct SessionConfig {
14    /// Stop when `total_steps >= max_steps`. Default: `usize::MAX` (no limit).
15    pub max_steps: usize,
16
17    /// Save a numbered checkpoint every this many steps. Default: 10_000.
18    pub checkpoint_freq: usize,
19
20    /// Number of recent numbered checkpoints to keep on disk. Default: 3.
21    pub keep_checkpoints: usize,
22}
23
24impl Default for SessionConfig {
25    fn default() -> Self {
26        Self {
27            max_steps: usize::MAX,
28            checkpoint_freq: 10_000,
29            keep_checkpoints: 3,
30        }
31    }
32}
33
34/// A self-contained, loop-agnostic training coordinator.
35///
36/// `TrainingSession` wires together a [`LearningAgent`], an optional
37/// [`TrainingRun`], and a [`StatsTracker`]. It is driven purely by incoming
38/// data -- it does not own a training loop. Feed it experiences and episode
39/// boundaries from wherever your loop lives: a plain `for` loop, Bevy's ECS,
40/// or anything else.
41///
42/// # Usage
43///
44/// ```rust,ignore
45/// let session = TrainingSession::new(agent)
46///     .with_run(TrainingRun::create("cartpole", "v1")?)
47///     .with_max_steps(200_000)
48///     .with_checkpoint_freq(10_000);
49///
50/// // Each environment step:
51/// session.observe(experience);
52///
53/// // Each episode end:
54/// session.on_episode(total_reward, steps, status, env_extras);
55///
56/// if session.is_done() { break; }
57/// ```
58pub struct TrainingSession<E: Environment, A> {
59    agent: A,
60    run: Option<TrainingRun>,
61    stats: StatsTracker,
62    config: SessionConfig,
63    best_eval_reward: f64,
64    start_time: Instant,
65    _env: std::marker::PhantomData<E>,
66}
67
68impl<E, A> TrainingSession<E, A>
69where
70    E: Environment,
71    E::Observation: Clone + Send + Sync + 'static,
72    E::Action: Clone + Send + Sync + 'static,
73    A: LearningAgent<E>,
74{
75    /// Create a session with no run attached.
76    ///
77    /// Stats are tracked in memory but nothing is persisted. Attach a run with
78    /// [`with_run`] to enable checkpointing and JSONL logging.
79    ///
80    /// [`with_run`]: TrainingSession::with_run
81    pub fn new(agent: A) -> Self {
82        Self {
83            agent,
84            run: None,
85            stats: StatsTracker::new(),
86            config: SessionConfig::default(),
87            best_eval_reward: f64::NEG_INFINITY,
88            start_time: Instant::now(),
89            _env: std::marker::PhantomData,
90        }
91    }
92
93    /// Attach a `TrainingRun` for checkpointing and JSONL episode logging.
94    pub fn with_run(mut self, run: TrainingRun) -> Self {
95        self.run = Some(run);
96        self
97    }
98
99    /// Maximum number of steps before `is_done()` returns `true`. Default: no limit.
100    pub fn with_max_steps(mut self, n: usize) -> Self {
101        self.config.max_steps = n;
102        self
103    }
104
105    /// Checkpoint frequency in steps. Default: 10_000.
106    pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
107        self.config.checkpoint_freq = freq;
108        self
109    }
110
111    /// Number of numbered checkpoints to retain on disk. Default: 3.
112    pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
113        self.config.keep_checkpoints = keep;
114        self
115    }
116
117    /// Replace the default `StatsTracker` with a custom one.
118    pub fn with_stats(mut self, stats: StatsTracker) -> Self {
119        self.stats = stats;
120        self
121    }
122
123    // ── Data ingestion ────────────────────────────────────────────────────────
124
125    /// Select an action for `obs` according to `mode`.
126    pub fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
127        self.agent.act(obs, mode)
128    }
129
130    /// Record a transition. Checkpoints + prunes if a step milestone is hit.
131    pub fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
132        self.agent.observe(experience);
133
134        let total = self.agent.total_steps();
135        if total > 0 && total.is_multiple_of(self.config.checkpoint_freq) {
136            if let Some(run) = &mut self.run {
137                let path = run.checkpoint_path(total).with_extension("");
138                let _ = self.agent.save(&path);
139                let latest = run.latest_checkpoint_path().with_extension("");
140                let _ = self.agent.save(&latest);
141                let _ = run.prune_checkpoints(self.config.keep_checkpoints);
142            }
143        }
144    }
145
146    /// Record an episode boundary.
147    ///
148    /// Merges agent and environment extras into the record, updates stats,
149    /// and appends to the training JSONL log (if a run is attached).
150    ///
151    /// `env_extras` should come from [`crate::traits::EpisodeStats::episode_extras`]
152    /// if the environment implements it, or an empty map otherwise.
153    pub fn on_episode(
154        &mut self,
155        total_reward: f64,
156        steps: usize,
157        status: rl_traits::EpisodeStatus,
158        env_extras: HashMap<String, f64>,
159    ) {
160        let agent_extras = self.agent.episode_extras();
161        let record = EpisodeRecord::new(total_reward, steps, status)
162            .with_extras(env_extras)
163            .with_extras(agent_extras);
164
165        self.stats.update(&record);
166
167        if let Some(run) = &mut self.run {
168            let _ = run.log_train_episode(&record);
169            let _ = run.update_metadata(self.agent.total_steps(), 0);
170        }
171
172        self.agent.on_episode_start();
173    }
174
175    /// Signal the start of a new episode (resets per-episode agent aggregators).
176    pub fn on_episode_start(&mut self) {
177        self.agent.on_episode_start();
178    }
179
180    /// Total environment steps observed so far.
181    pub fn total_steps(&self) -> usize {
182        self.agent.total_steps()
183    }
184
185    /// Average environment steps per wall-clock second since the session was created.
186    pub fn steps_per_sec(&self) -> f64 {
187        let elapsed = self.start_time.elapsed().as_secs_f64();
188        if elapsed < 1e-6 { return 0.0; }
189        self.agent.total_steps() as f64 / elapsed
190    }
191
192    /// Returns `true` when `total_steps >= max_steps`.
193    pub fn is_done(&self) -> bool {
194        self.config.max_steps != usize::MAX
195            && self.agent.total_steps() >= self.config.max_steps
196    }
197
198    // ── Eval ──────────────────────────────────────────────────────────────────
199
200    /// Log an eval episode to the run (if attached).
201    pub fn on_eval_episode(&self, record: &EpisodeRecord) {
202        if let Some(run) = &self.run {
203            let _ = run.log_eval_episode(record, self.agent.total_steps());
204        }
205    }
206
207    /// Save `best.mpk` if `mean_reward` exceeds the best seen so far.
208    pub fn maybe_save_best(&mut self, mean_reward: f64) {
209        if mean_reward > self.best_eval_reward {
210            self.best_eval_reward = mean_reward;
211            if let Some(run) = &self.run {
212                let best = run.best_checkpoint_path().with_extension("");
213                let _ = self.agent.save(&best);
214            }
215        }
216    }
217
218    // ── Access ────────────────────────────────────────────────────────────────
219
220    /// Read-only access to the agent.
221    pub fn agent(&self) -> &A {
222        &self.agent
223    }
224
225    /// Mutable access to the agent.
226    pub fn agent_mut(&mut self) -> &mut A {
227        &mut self.agent
228    }
229
230    /// Current stats summary.
231    pub fn stats_summary(&self) -> HashMap<String, f64> {
232        self.stats.summary()
233    }
234
235    /// Read-only access to the run (if attached).
236    pub fn run(&self) -> Option<&TrainingRun> {
237        self.run.as_ref()
238    }
239
240    /// Consume the session and return the inner agent.
241    pub fn into_agent(self) -> A {
242        self.agent
243    }
244
245    /// Snapshot the current stats as an `EvalReport`.
246    pub fn eval_report(&self, n_episodes: usize) -> EvalReport {
247        EvalReport::new(self.agent.total_steps(), n_episodes, self.stats.summary())
248    }
249}