ember_rl/training/run.rs
1use std::fs::{self, OpenOptions};
2use std::io::Write;
3use std::path::{Path, PathBuf};
4
5use chrono::Local;
6use serde::{Deserialize, Serialize};
7
8use crate::stats::EpisodeRecord;
9
10// ── Metadata ──────────────────────────────────────────────────────────────────
11
12/// Persisted metadata for a training run.
13///
14/// Written to `metadata.json` in the run directory. Updated after each
15/// checkpoint via `TrainingRun::update_metadata`.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RunMetadata {
18 /// Human-readable name (e.g. `"cartpole"`).
19 pub name: String,
20
21 /// Version string (e.g. `"v1"`, `"baseline"`).
22 pub version: String,
23
24 /// Timestamp string used as the run directory name (`YYYYMMDD_HHMMSS`).
25 pub run_id: String,
26
27 /// Total environment steps at last update.
28 pub total_steps: usize,
29
30 /// Total episodes completed at last update.
31 pub total_episodes: usize,
32
33 /// ISO-8601 datetime this run was created.
34 pub started_at: String,
35
36 /// ISO-8601 datetime of the last metadata update.
37 pub last_updated: String,
38}
39
40// ── EvalEntry (JSONL row for eval episodes) ────────────────────────────────────
41
42#[derive(Serialize)]
43struct EvalEntry<'a> {
44 total_steps_at_eval: usize,
45 #[serde(flatten)]
46 record: &'a EpisodeRecord,
47}
48
49// ── TrainingRun ───────────────────────────────────────────────────────────────
50
51/// Manages the on-disk artefacts for a single training run.
52///
53/// Directory layout:
54/// ```text
55/// runs/<name>/<version>/<YYYYMMDD_HHMMSS>/
56/// metadata.json ← name, version, step counts, timestamps
57/// config.json ← serialized hyperparams (written by caller)
58/// checkpoints/
59/// step_<N>.mpk ← periodic checkpoints
60/// latest.mpk ← symlink-equivalent: overwritten each checkpoint
61/// best.mpk ← best eval-reward checkpoint
62/// train_episodes.jsonl ← one EpisodeRecord per line (training)
63/// eval_episodes.jsonl ← one tagged EpisodeRecord per line (eval)
64/// ```
65///
66/// `TrainingRun` is **not** generic over the neural network backend. It manages
67/// directories and JSON; the caller (e.g. `DqnTrainer`) handles actual
68/// network serialization by using the paths returned by the checkpoint methods.
69///
70/// # Usage
71///
72/// ```rust,ignore
73/// // Start a new run
74/// let run = TrainingRun::create("cartpole", "v1")?;
75/// run.write_config(&(&config, &encoder, &mapper))?;
76///
77/// // During training
78/// run.log_train_episode(&episode_record)?;
79/// run.update_metadata(total_steps, total_episodes)?;
80/// // (save network to run.checkpoint_path(step) yourself)
81///
82/// // Resume
83/// let run = TrainingRun::resume("runs/cartpole/v1")?; // picks latest
84/// ```
85pub struct TrainingRun {
86 /// Root directory for this run (`.../runs/<name>/<version>/<run_id>/`).
87 dir: PathBuf,
88
89 /// Loaded/created metadata.
90 pub metadata: RunMetadata,
91}
92
93impl TrainingRun {
94 // ── Constructors ──────────────────────────────────────────────────────────
95
96 /// Create a brand-new run directory under `runs/<name>/<version>/<timestamp>/`.
97 ///
98 /// Returns an error if the directory cannot be created or metadata cannot
99 /// be written.
100 pub fn create(name: impl Into<String>, version: impl Into<String>) -> std::io::Result<Self> {
101 let name = name.into();
102 let version = version.into();
103 let run_id = Local::now().format("%Y%m%d_%H%M%S").to_string();
104 let now = Local::now().to_rfc3339();
105
106 let dir = PathBuf::from("runs")
107 .join(&name)
108 .join(&version)
109 .join(&run_id);
110
111 fs::create_dir_all(dir.join("checkpoints"))?;
112
113 let metadata = RunMetadata {
114 name,
115 version,
116 run_id,
117 total_steps: 0,
118 total_episodes: 0,
119 started_at: now.clone(),
120 last_updated: now,
121 };
122
123 let run = Self { dir, metadata };
124 run.write_metadata()?;
125 Ok(run)
126 }
127
128 /// Resume the most recent run found under `base_path`.
129 ///
130 /// `base_path` can be:
131 /// - An exact run directory (`runs/cartpole/v1/20260322_120000`) -- used directly.
132 /// - A name/version directory (`runs/cartpole/v1`) -- picks the lexicographically
133 /// latest subdirectory (timestamps sort correctly).
134 /// - A name directory (`runs/cartpole`) -- picks latest version, then latest run.
135 ///
136 /// Returns an error if no run is found or `metadata.json` is missing/corrupt.
137 pub fn resume(base_path: impl AsRef<Path>) -> std::io::Result<Self> {
138 let dir = Self::resolve_latest(base_path.as_ref())?;
139 let metadata_path = dir.join("metadata.json");
140 let raw = fs::read_to_string(&metadata_path)?;
141 let metadata: RunMetadata = serde_json::from_str(&raw)
142 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
143 Ok(Self { dir, metadata })
144 }
145
146 /// The root directory of this run.
147 pub fn dir(&self) -> &Path {
148 &self.dir
149 }
150
151 // ── Config ────────────────────────────────────────────────────────────────
152
153 /// Write an arbitrary serialisable value to `config.json`.
154 ///
155 /// Typically called once after `create` with a tuple of
156 /// `(&config, &encoder, &action_mapper)`.
157 pub fn write_config<T: Serialize>(&self, config: &T) -> std::io::Result<()> {
158 let json = serde_json::to_string_pretty(config)
159 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
160 fs::write(self.dir.join("config.json"), json)
161 }
162
163 // ── Checkpoint paths ──────────────────────────────────────────────────────
164
165 /// Path for a numbered checkpoint: `checkpoints/step_<N>.mpk`.
166 ///
167 /// Pass this to `DqnAgent::save` (or `network.save_file`).
168 pub fn checkpoint_path(&self, step: usize) -> PathBuf {
169 self.dir.join("checkpoints").join(format!("step_{}.mpk", step))
170 }
171
172 /// Path for the rolling "latest" checkpoint: `checkpoints/latest.mpk`.
173 ///
174 /// Overwrite this on every checkpoint save so users can always resume
175 /// from the most recent state without knowing the step number.
176 pub fn latest_checkpoint_path(&self) -> PathBuf {
177 self.dir.join("checkpoints").join("latest.mpk")
178 }
179
180 /// Path for the best-eval-reward checkpoint: `checkpoints/best.mpk`.
181 pub fn best_checkpoint_path(&self) -> PathBuf {
182 self.dir.join("checkpoints").join("best.mpk")
183 }
184
185 /// Delete old numbered checkpoints, keeping the `keep` most recent.
186 ///
187 /// `latest.mpk` and `best.mpk` are never deleted.
188 pub fn prune_checkpoints(&self, keep: usize) -> std::io::Result<()> {
189 let ckpt_dir = self.dir.join("checkpoints");
190 let mut numbered: Vec<PathBuf> = fs::read_dir(&ckpt_dir)?
191 .filter_map(|e| e.ok())
192 .map(|e| e.path())
193 .filter(|p| {
194 p.file_name()
195 .and_then(|n| n.to_str())
196 .map(|n| n.starts_with("step_") && n.ends_with(".mpk"))
197 .unwrap_or(false)
198 })
199 .collect();
200
201 // Sort lexicographically -- step_ prefix + zero-padded or not: sort by step number
202 numbered.sort_by_key(|p| {
203 p.file_stem()
204 .and_then(|s| s.to_str())
205 .and_then(|s| s.strip_prefix("step_"))
206 .and_then(|s| s.parse::<usize>().ok())
207 .unwrap_or(0)
208 });
209
210 let to_delete = numbered.len().saturating_sub(keep);
211 for path in numbered.into_iter().take(to_delete) {
212 fs::remove_file(path)?;
213 }
214 Ok(())
215 }
216
217 // ── Stats logging ─────────────────────────────────────────────────────────
218
219 /// Append an episode record to `train_episodes.jsonl`.
220 pub fn log_train_episode(&self, record: &EpisodeRecord) -> std::io::Result<()> {
221 self.append_jsonl("train_episodes.jsonl", record)
222 }
223
224 /// Append an episode record (tagged with `total_steps_at_eval`) to `eval_episodes.jsonl`.
225 pub fn log_eval_episode(&self, record: &EpisodeRecord, total_steps: usize) -> std::io::Result<()> {
226 let entry = EvalEntry { total_steps_at_eval: total_steps, record };
227 self.append_jsonl("eval_episodes.jsonl", &entry)
228 }
229
230 // ── Metadata ──────────────────────────────────────────────────────────────
231
232 /// Update step/episode counts and `last_updated` timestamp, then flush to disk.
233 pub fn update_metadata(&mut self, total_steps: usize, total_episodes: usize) -> std::io::Result<()> {
234 self.metadata.total_steps = total_steps;
235 self.metadata.total_episodes = total_episodes;
236 self.metadata.last_updated = Local::now().to_rfc3339();
237 self.write_metadata()
238 }
239
240 // ── Private helpers ───────────────────────────────────────────────────────
241
242 fn write_metadata(&self) -> std::io::Result<()> {
243 let json = serde_json::to_string_pretty(&self.metadata)
244 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
245 fs::write(self.dir.join("metadata.json"), json)
246 }
247
248 fn append_jsonl<T: Serialize>(&self, filename: &str, value: &T) -> std::io::Result<()> {
249 let line = serde_json::to_string(value)
250 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
251 let mut file = OpenOptions::new()
252 .create(true)
253 .append(true)
254 .open(self.dir.join(filename))?;
255 writeln!(file, "{}", line)
256 }
257
258 /// Walk `path` downward, always picking the lexicographically latest child
259 /// directory until we find one that contains `metadata.json`.
260 fn resolve_latest(path: &Path) -> std::io::Result<PathBuf> {
261 if path.join("metadata.json").exists() {
262 return Ok(path.to_path_buf());
263 }
264
265 let latest = Self::latest_subdir(path)?;
266 Self::resolve_latest(&latest)
267 }
268
269 fn latest_subdir(dir: &Path) -> std::io::Result<PathBuf> {
270 let mut subdirs: Vec<PathBuf> = fs::read_dir(dir)?
271 .filter_map(|e| e.ok())
272 .map(|e| e.path())
273 .filter(|p| p.is_dir())
274 .collect();
275
276 if subdirs.is_empty() {
277 return Err(std::io::Error::new(
278 std::io::ErrorKind::NotFound,
279 format!("no subdirectories in {}", dir.display()),
280 ));
281 }
282
283 subdirs.sort();
284 Ok(subdirs.pop().unwrap())
285 }
286}