ember_rl/training/
session.rs1use std::collections::HashMap;
2use std::time::Instant;
3
4use rl_traits::{Environment, Experience};
5
6use crate::stats::{EpisodeRecord, EvalReport, StatsTracker};
7use crate::traits::{ActMode, LearningAgent};
8use crate::training::run::TrainingRun;
9
10
11#[derive(Debug, Clone)]
13pub struct SessionConfig {
14 pub max_steps: usize,
16
17 pub checkpoint_freq: usize,
19
20 pub keep_checkpoints: usize,
22}
23
24impl Default for SessionConfig {
25 fn default() -> Self {
26 Self {
27 max_steps: usize::MAX,
28 checkpoint_freq: 10_000,
29 keep_checkpoints: 3,
30 }
31 }
32}
33
34pub struct TrainingSession<E: Environment, A> {
59 agent: A,
60 run: Option<TrainingRun>,
61 stats: StatsTracker,
62 config: SessionConfig,
63 best_eval_reward: f64,
64 start_time: Instant,
65 _env: std::marker::PhantomData<E>,
66}
67
68impl<E, A> TrainingSession<E, A>
69where
70 E: Environment,
71 E::Observation: Clone + Send + Sync + 'static,
72 E::Action: Clone + Send + Sync + 'static,
73 A: LearningAgent<E>,
74{
75 pub fn new(agent: A) -> Self {
82 Self {
83 agent,
84 run: None,
85 stats: StatsTracker::new(),
86 config: SessionConfig::default(),
87 best_eval_reward: f64::NEG_INFINITY,
88 start_time: Instant::now(),
89 _env: std::marker::PhantomData,
90 }
91 }
92
93 pub fn with_run(mut self, run: TrainingRun) -> Self {
95 self.run = Some(run);
96 self
97 }
98
99 pub fn with_max_steps(mut self, n: usize) -> Self {
101 self.config.max_steps = n;
102 self
103 }
104
105 pub fn with_checkpoint_freq(mut self, freq: usize) -> Self {
107 self.config.checkpoint_freq = freq;
108 self
109 }
110
111 pub fn with_keep_checkpoints(mut self, keep: usize) -> Self {
113 self.config.keep_checkpoints = keep;
114 self
115 }
116
117 pub fn with_stats(mut self, stats: StatsTracker) -> Self {
119 self.stats = stats;
120 self
121 }
122
123 pub fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action {
127 self.agent.act(obs, mode)
128 }
129
130 pub fn observe(&mut self, experience: Experience<E::Observation, E::Action>) {
132 self.agent.observe(experience);
133
134 let total = self.agent.total_steps();
135 if total > 0 && total.is_multiple_of(self.config.checkpoint_freq) {
136 if let Some(run) = &mut self.run {
137 let path = run.checkpoint_path(total).with_extension("");
138 let _ = self.agent.save(&path);
139 let latest = run.latest_checkpoint_path().with_extension("");
140 let _ = self.agent.save(&latest);
141 let _ = run.prune_checkpoints(self.config.keep_checkpoints);
142 }
143 }
144 }
145
146 pub fn on_episode(
154 &mut self,
155 total_reward: f64,
156 steps: usize,
157 status: rl_traits::EpisodeStatus,
158 env_extras: HashMap<String, f64>,
159 ) {
160 let agent_extras = self.agent.episode_extras();
161 let record = EpisodeRecord::new(total_reward, steps, status)
162 .with_extras(env_extras)
163 .with_extras(agent_extras);
164
165 self.stats.update(&record);
166
167 if let Some(run) = &mut self.run {
168 let _ = run.log_train_episode(&record);
169 let _ = run.update_metadata(self.agent.total_steps(), 0);
170 }
171
172 self.agent.on_episode_start();
173 }
174
175 pub fn on_episode_start(&mut self) {
177 self.agent.on_episode_start();
178 }
179
180 pub fn total_steps(&self) -> usize {
182 self.agent.total_steps()
183 }
184
185 pub fn steps_per_sec(&self) -> f64 {
187 let elapsed = self.start_time.elapsed().as_secs_f64();
188 if elapsed < 1e-6 { return 0.0; }
189 self.agent.total_steps() as f64 / elapsed
190 }
191
192 pub fn is_done(&self) -> bool {
194 self.config.max_steps != usize::MAX
195 && self.agent.total_steps() >= self.config.max_steps
196 }
197
198 pub fn on_eval_episode(&self, record: &EpisodeRecord) {
202 if let Some(run) = &self.run {
203 let _ = run.log_eval_episode(record, self.agent.total_steps());
204 }
205 }
206
207 pub fn maybe_save_best(&mut self, mean_reward: f64) {
209 if mean_reward > self.best_eval_reward {
210 self.best_eval_reward = mean_reward;
211 if let Some(run) = &self.run {
212 let best = run.best_checkpoint_path().with_extension("");
213 let _ = self.agent.save(&best);
214 }
215 }
216 }
217
218 pub fn agent(&self) -> &A {
222 &self.agent
223 }
224
225 pub fn agent_mut(&mut self) -> &mut A {
227 &mut self.agent
228 }
229
230 pub fn stats_summary(&self) -> HashMap<String, f64> {
232 self.stats.summary()
233 }
234
235 pub fn run(&self) -> Option<&TrainingRun> {
237 self.run.as_ref()
238 }
239
240 pub fn into_agent(self) -> A {
242 self.agent
243 }
244
245 pub fn eval_report(&self, n_episodes: usize) -> EvalReport {
247 EvalReport::new(self.agent.total_steps(), n_episodes, self.stats.summary())
248 }
249}