ember_rl/training/
session.rs1use std::collections::HashMap;
2
3use rl_traits::{Environment, Experience};
4
5use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
6use crate::traits::{ActMode, LearningAgent};
7use crate::training::run::TrainingRun;
8
9
10#[derive(Debug, Clone)]
12pub struct SessionConfig {
13 pub max_steps: usize,
15
16 pub checkpoint_freq: usize,
18
19 pub keep_checkpoints: usize,
21}
22
23impl Default for SessionConfig {
24 fn default() -> Self {
25 Self {
26 max_steps: usize::MAX,
27 checkpoint_freq: 10_000,
28 keep_checkpoints: 3,
29 }
30 }
31}
32
33pub struct TrainingSession<E: Environment, A> {
58 agent: A,
59 run: Option<TrainingRun>,
60 stats: StatsTracker,
61 config: SessionConfig,
62 best_eval_reward: f64,
63 _env: std::marker::PhantomData<E>,
64}
65
66impl<E, A> TrainingSession<E, A>
67where
68 E: Environment,
69 E::Observation: Clone + Send + Sync + 'static,
70 E::Action: Clone + Send + Sync + 'static,
71 A: LearningAgent<E>,
72{
73 pub fn new(agent: A) -> Self {
80 Self {
81 agent,
82 run: None,
83 stats: StatsTracker::new(),
84 config: SessionConfig::default(),
85 best_eval_reward: f64::NEG_INFINITY,
86 _env: std::marker::PhantomData,
87 }
88 }
89
90 pub fn with_run(mut self, run: TrainingRun) -> Self {
92 self.run = Some(run);
93 self
94 }
95
96 pub fn with_max_steps(mut self, n: usize) -> Self {
98 self.config.max_steps = n;
99 self
100 }
101
102 pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
104 self.config.checkpoint_freq = freq;
105 self
106 }
107
108 pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
110 self.config.keep_checkpoints = keep;
111 self
112 }
113
114 pub fn with_stats(mut self, stats: StatsTracker) -> Self {
116 self.stats = stats;
117 self
118 }
119
120 pub fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
124 self.agent.act(obs, mode)
125 }
126
127 pub fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
129 self.agent.observe(experience);
130
131 let total = self.agent.total_steps();
132 if total > 0 && total.is_multiple_of(self.config.checkpoint_freq) {
133 if let Some(run) = &mut self.run {
134 let path = run.checkpoint_path(total).with_extension("");
135 let _ = self.agent.save(&path);
136 let latest = run.latest_checkpoint_path().with_extension("");
137 let _ = self.agent.save(&latest);
138 let _ = run.prune_checkpoints(self.config.keep_checkpoints);
139 }
140 }
141 }
142
143 pub fn on_episode(
151 &mut self,
152 total_reward: f64,
153 steps: usize,
154 status: rl_traits::EpisodeStatus,
155 env_extras: HashMap<String, f64>,
156 ) {
157 let agent_extras = self.agent.episode_extras();
158 let record = EpisodeRecord::new(total_reward, steps, status)
159 .with_extras(env_extras)
160 .with_extras(agent_extras);
161
162 self.stats.update(&record);
163
164 if let Some(run) = &mut self.run {
165 let _ = run.log_train_episode(&record);
166 let _ = run.update_metadata(self.agent.total_steps(), 0);
167 }
168
169 self.agent.on_episode_start();
170 }
171
172 pub fn on_episode_start(&mut self) {
174 self.agent.on_episode_start();
175 }
176
177 pub fn total_steps(&self) -> usize {
179 self.agent.total_steps()
180 }
181
182 pub fn is_done(&self) -> bool {
184 self.config.max_steps != usize::MAX
185 && self.agent.total_steps() >= self.config.max_steps
186 }
187
188 pub fn on_eval_episode(&self, record: &EpisodeRecord) {
192 if let Some(run) = &self.run {
193 let _ = run.log_eval_episode(record, self.agent.total_steps());
194 }
195 }
196
197 pub fn maybe_save_best(&mut self, mean_reward: f64) {
199 if mean_reward > self.best_eval_reward {
200 self.best_eval_reward = mean_reward;
201 if let Some(run) = &self.run {
202 let best = run.best_checkpoint_path().with_extension("");
203 let _ = self.agent.save(&best);
204 }
205 }
206 }
207
208 pub fn agent(&self) -> &A {
212 &self.agent
213 }
214
215 pub fn agent_mut(&mut self) -> &mut A {
217 &mut self.agent
218 }
219
220 pub fn stats_summary(&self) -> HashMap<String, f64> {
222 self.stats.summary()
223 }
224
225 pub fn run(&self) -> Option<&TrainingRun> {
227 self.run.as_ref()
228 }
229
230 pub fn into_agent(self) -> A {
232 self.agent
233 }
234
235 pub fn eval_report(&self, n_episodes: usize) -> EvalReport {
237 EvalReport::new(self.agent.total_steps(), n_episodes, self.stats.summary())
238 }
239}