ember_rl/training/
session.rs1use std::collections::HashMap;
2
3use rl_traits::{Environment, Experience};
4
5use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
6use crate::traits::{ActMode, LearningAgent};
7use crate::training::run::TrainingRun;
8
9#[derive(Debug, Clone)]
11pub struct SessionConfig {
12 pub max_steps: usize,
14
15 pub checkpoint_freq: usize,
17
18 pub keep_checkpoints: usize,
20}
21
22impl Default for SessionConfig {
23 fn default() -> Self {
24 Self {
25 max_steps: usize::MAX,
26 checkpoint_freq: 10_000,
27 keep_checkpoints: 3,
28 }
29 }
30}
31
32pub struct TrainingSession<E: Environment, A> {
57 agent: A,
58 run: Option<TrainingRun>,
59 stats: StatsTracker,
60 config: SessionConfig,
61 best_eval_reward: f64,
62 _env: std::marker::PhantomData<E>,
63}
64
65impl<E, A> TrainingSession<E, A>
66where
67 E: Environment,
68 E::Observation: Clone + Send + Sync + 'static,
69 E::Action: Clone + Send + Sync + 'static,
70 A: LearningAgent<E>,
71{
72 pub fn new(agent: A) -> Self {
79 Self {
80 agent,
81 run: None,
82 stats: StatsTracker::new(),
83 config: SessionConfig::default(),
84 best_eval_reward: f64::NEG_INFINITY,
85 _env: std::marker::PhantomData,
86 }
87 }
88
89 pub fn with_run(mut self, run: TrainingRun) -> Self {
91 self.run = Some(run);
92 self
93 }
94
95 pub fn with_max_steps(mut self, n: usize) -> Self {
97 self.config.max_steps = n;
98 self
99 }
100
101 pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
103 self.config.checkpoint_freq = freq;
104 self
105 }
106
107 pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
109 self.config.keep_checkpoints = keep;
110 self
111 }
112
113 pub fn with_stats(mut self, stats: StatsTracker) -> Self {
115 self.stats = stats;
116 self
117 }
118
119 pub fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
123 self.agent.act(obs, mode)
124 }
125
126 pub fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
128 self.agent.observe(experience);
129
130 let total = self.agent.total_steps();
131 if total > 0 && total.is_multiple_of(self.config.checkpoint_freq) {
132 if let Some(run) = &mut self.run {
133 let path = run.checkpoint_path(total).with_extension("");
134 let _ = self.agent.save(&path);
135 let latest = run.latest_checkpoint_path().with_extension("");
136 let _ = self.agent.save(&latest);
137 let _ = run.prune_checkpoints(self.config.keep_checkpoints);
138 }
139 }
140 }
141
142 pub fn on_episode(
150 &mut self,
151 total_reward: f64,
152 steps: usize,
153 status: rl_traits::EpisodeStatus,
154 env_extras: HashMap<String, f64>,
155 ) {
156 let agent_extras = self.agent.episode_extras();
157 let record = EpisodeRecord::new(total_reward, steps, status)
158 .with_extras(env_extras)
159 .with_extras(agent_extras);
160
161 self.stats.update(&record);
162
163 if let Some(run) = &mut self.run {
164 let _ = run.log_train_episode(&record);
165 let _ = run.update_metadata(self.agent.total_steps(), 0);
166 }
167
168 self.agent.on_episode_start();
169 }
170
171 pub fn on_episode_start(&mut self) {
173 self.agent.on_episode_start();
174 }
175
176 pub fn total_steps(&self) -> usize {
178 self.agent.total_steps()
179 }
180
181 pub fn is_done(&self) -> bool {
183 self.config.max_steps != usize::MAX
184 && self.agent.total_steps() >= self.config.max_steps
185 }
186
187 pub fn on_eval_episode(&self, record: &EpisodeRecord) {
191 if let Some(run) = &self.run {
192 let _ = run.log_eval_episode(record, self.agent.total_steps());
193 }
194 }
195
196 pub fn maybe_save_best(&mut self, mean_reward: f64) {
198 if mean_reward > self.best_eval_reward {
199 self.best_eval_reward = mean_reward;
200 if let Some(run) = &self.run {
201 let best = run.best_checkpoint_path().with_extension("");
202 let _ = self.agent.save(&best);
203 }
204 }
205 }
206
207 pub fn agent(&self) -> &A {
211 &self.agent
212 }
213
214 pub fn agent_mut(&mut self) -> &mut A {
216 &mut self.agent
217 }
218
219 pub fn stats_summary(&self) -> HashMap<String, f64> {
221 self.stats.summary()
222 }
223
224 pub fn run(&self) -> Option<&TrainingRun> {
226 self.run.as_ref()
227 }
228
229 pub fn into_agent(self) -> A {
231 self.agent
232 }
233
234 pub fn eval_report(&self, n_episodes: usize) -> EvalReport {
236 EvalReport::new(self.agent.total_steps(), n_episodes, self.stats.summary())
237 }
238}