Skip to main content

ember_rl/training/
run.rs

1use std::fs::{self, OpenOptions};
2use std::io::Write;
3use std::path::{Path, PathBuf};
4
5use chrono::Local;
6use serde::{Deserialize, Serialize};
7
8use crate::stats::EpisodeRecord;
9
10// ── Metadata ──────────────────────────────────────────────────────────────────
11
12/// Persisted metadata for a training run.
13///
14/// Written to `metadata.json` in the run directory. Updated after each
15/// checkpoint via `TrainingRun::update_metadata`.
16#[derive(Debug, Clone, Serialize, Deserialize)]
17pub struct RunMetadata {
18    /// Human-readable name (e.g. `"cartpole"`).
19    pub name: String,
20
21    /// Version string (e.g. `"v1"`, `"baseline"`).
22    pub version: String,
23
24    /// Timestamp string used as the run directory name (`YYYYMMDD_HHMMSS`).
25    pub run_id: String,
26
27    /// Total environment steps at last update.
28    pub total_steps: usize,
29
30    /// Total episodes completed at last update.
31    pub total_episodes: usize,
32
33    /// ISO-8601 datetime this run was created.
34    pub started_at: String,
35
36    /// ISO-8601 datetime of the last metadata update.
37    pub last_updated: String,
38}
39
40// ── EvalEntry (JSONL row for eval episodes) ────────────────────────────────────
41
42#[derive(Serialize)]
43struct EvalEntry<'a> {
44    total_steps_at_eval: usize,
45    #[serde(flatten)]
46    record: &'a EpisodeRecord,
47}
48
49// ── TrainingRun ───────────────────────────────────────────────────────────────
50
51/// Manages the on-disk artefacts for a single training run.
52///
53/// Directory layout:
54/// ```text
55/// runs/<name>/<version>/<YYYYMMDD_HHMMSS>/
56///     metadata.json          ← name, version, step counts, timestamps
57///     config.json            ← serialized hyperparams (written by caller)
58///     checkpoints/
59///         step_<N>.mpk       ← periodic checkpoints
60///         latest.mpk         ← symlink-equivalent: overwritten each checkpoint
61///         best.mpk           ← best eval-reward checkpoint
62///     train_episodes.jsonl   ← one EpisodeRecord per line (training)
63///     eval_episodes.jsonl    ← one tagged EpisodeRecord per line (eval)
64/// ```
65///
66/// `TrainingRun` is **not** generic over the neural network backend. It manages
67/// directories and JSON; the caller (e.g. `DqnTrainer`) handles actual
68/// network serialization by using the paths returned by the checkpoint methods.
69///
70/// # Usage
71///
72/// ```rust,ignore
73/// // Start a new run
74/// let run = TrainingRun::create("cartpole", "v1")?;
75/// run.write_config(&(&config, &encoder, &mapper))?;
76///
77/// // During training
78/// run.log_train_episode(&episode_record)?;
79/// run.update_metadata(total_steps, total_episodes)?;
80/// // (save network to run.checkpoint_path(step) yourself)
81///
82/// // Resume
83/// let run = TrainingRun::resume("runs/cartpole/v1")?; // picks latest
84/// ```
85pub struct TrainingRun {
86    /// Root directory for this run (`.../runs/<name>/<version>/<run_id>/`).
87    dir: PathBuf,
88
89    /// Loaded/created metadata.
90    pub metadata: RunMetadata,
91}
92
93impl TrainingRun {
94    // ── Constructors ──────────────────────────────────────────────────────────
95
96    /// Create a brand-new run directory under `runs/<name>/<version>/<timestamp>/`.
97    ///
98    /// Returns an error if the directory cannot be created or metadata cannot
99    /// be written.
100    pub fn create(name: impl Into<String>, version: impl Into<String>) -> std::io::Result<Self> {
101        let name = name.into();
102        let version = version.into();
103        let run_id = Local::now().format("%Y%m%d_%H%M%S").to_string();
104        let now = Local::now().to_rfc3339();
105
106        let dir = PathBuf::from("runs")
107            .join(&name)
108            .join(&version)
109            .join(&run_id);
110
111        fs::create_dir_all(dir.join("checkpoints"))?;
112
113        let metadata = RunMetadata {
114            name,
115            version,
116            run_id,
117            total_steps: 0,
118            total_episodes: 0,
119            started_at: now.clone(),
120            last_updated: now,
121        };
122
123        let run = Self { dir, metadata };
124        run.write_metadata()?;
125        Ok(run)
126    }
127
128    /// Resume the most recent run found under `base_path`.
129    ///
130    /// `base_path` can be:
131    /// - An exact run directory (`runs/cartpole/v1/20260322_120000`) -- used directly.
132    /// - A name/version directory (`runs/cartpole/v1`) -- picks the lexicographically
133    ///   latest subdirectory (timestamps sort correctly).
134    /// - A name directory (`runs/cartpole`) -- picks latest version, then latest run.
135    ///
136    /// Returns an error if no run is found or `metadata.json` is missing/corrupt.
137    pub fn resume(base_path: impl AsRef<Path>) -> std::io::Result<Self> {
138        let dir = Self::resolve_latest(base_path.as_ref())?;
139        let metadata_path = dir.join("metadata.json");
140        let raw = fs::read_to_string(&metadata_path)?;
141        let metadata: RunMetadata = serde_json::from_str(&raw)
142            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
143        Ok(Self { dir, metadata })
144    }
145
146    /// The root directory of this run.
147    pub fn dir(&self) -> &Path {
148        &self.dir
149    }
150
151    // ── Config ────────────────────────────────────────────────────────────────
152
153    /// Write an arbitrary serialisable value to `config.json`.
154    ///
155    /// Typically called once after `create` with a tuple of
156    /// `(&config, &encoder, &action_mapper)`.
157    pub fn write_config<T: Serialize>(&self, config: &T) -> std::io::Result<()> {
158        let json = serde_json::to_string_pretty(config)
159            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
160        fs::write(self.dir.join("config.json"), json)
161    }
162
163    // ── Checkpoint paths ──────────────────────────────────────────────────────
164
165    /// Path for a numbered checkpoint: `checkpoints/step_<N>.mpk`.
166    ///
167    /// Pass this to `DqnAgent::save` (or `network.save_file`).
168    pub fn checkpoint_path(&self, step: usize) -> PathBuf {
169        self.dir.join("checkpoints").join(format!("step_{}.mpk", step))
170    }
171
172    /// Path for the rolling "latest" checkpoint: `checkpoints/latest.mpk`.
173    ///
174    /// Overwrite this on every checkpoint save so users can always resume
175    /// from the most recent state without knowing the step number.
176    pub fn latest_checkpoint_path(&self) -> PathBuf {
177        self.dir.join("checkpoints").join("latest.mpk")
178    }
179
180    /// Path for the best-eval-reward checkpoint: `checkpoints/best.mpk`.
181    pub fn best_checkpoint_path(&self) -> PathBuf {
182        self.dir.join("checkpoints").join("best.mpk")
183    }
184
185    /// Delete old numbered checkpoints, keeping the `keep` most recent.
186    ///
187    /// `latest.mpk` and `best.mpk` are never deleted.
188    pub fn prune_checkpoints(&self, keep: usize) -> std::io::Result<()> {
189        let ckpt_dir = self.dir.join("checkpoints");
190        let mut numbered: Vec<PathBuf> = fs::read_dir(&ckpt_dir)?
191            .filter_map(|e| e.ok())
192            .map(|e| e.path())
193            .filter(|p| {
194                p.file_name()
195                    .and_then(|n| n.to_str())
196                    .map(|n| n.starts_with("step_") && n.ends_with(".mpk"))
197                    .unwrap_or(false)
198            })
199            .collect();
200
201        // Sort lexicographically -- step_ prefix + zero-padded or not: sort by step number
202        numbered.sort_by_key(|p| {
203            p.file_stem()
204                .and_then(|s| s.to_str())
205                .and_then(|s| s.strip_prefix("step_"))
206                .and_then(|s| s.parse::<usize>().ok())
207                .unwrap_or(0)
208        });
209
210        let to_delete = numbered.len().saturating_sub(keep);
211        for path in numbered.into_iter().take(to_delete) {
212            fs::remove_file(path)?;
213        }
214        Ok(())
215    }
216
217    // ── Stats logging ─────────────────────────────────────────────────────────
218
219    /// Append an episode record to `train_episodes.jsonl`.
220    pub fn log_train_episode(&self, record: &EpisodeRecord) -> std::io::Result<()> {
221        self.append_jsonl("train_episodes.jsonl", record)
222    }
223
224    /// Append an episode record (tagged with `total_steps_at_eval`) to `eval_episodes.jsonl`.
225    pub fn log_eval_episode(&self, record: &EpisodeRecord, total_steps: usize) -> std::io::Result<()> {
226        let entry = EvalEntry { total_steps_at_eval: total_steps, record };
227        self.append_jsonl("eval_episodes.jsonl", &entry)
228    }
229
230    // ── Metadata ──────────────────────────────────────────────────────────────
231
232    /// Update step/episode counts and `last_updated` timestamp, then flush to disk.
233    pub fn update_metadata(&mut self, total_steps: usize, total_episodes: usize) -> std::io::Result<()> {
234        self.metadata.total_steps = total_steps;
235        self.metadata.total_episodes = total_episodes;
236        self.metadata.last_updated = Local::now().to_rfc3339();
237        self.write_metadata()
238    }
239
240    // ── Private helpers ───────────────────────────────────────────────────────
241
242    fn write_metadata(&self) -> std::io::Result<()> {
243        let json = serde_json::to_string_pretty(&self.metadata)
244            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
245        fs::write(self.dir.join("metadata.json"), json)
246    }
247
248    fn append_jsonl<T: Serialize>(&self, filename: &str, value: &T) -> std::io::Result<()> {
249        let line = serde_json::to_string(value)
250            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
251        let mut file = OpenOptions::new()
252            .create(true)
253            .append(true)
254            .open(self.dir.join(filename))?;
255        writeln!(file, "{}", line)
256    }
257
258    /// Walk `path` downward, always picking the lexicographically latest child
259    /// directory until we find one that contains `metadata.json`.
260    fn resolve_latest(path: &Path) -> std::io::Result<PathBuf> {
261        if path.join("metadata.json").exists() {
262            return Ok(path.to_path_buf());
263        }
264
265        let latest = Self::latest_subdir(path)?;
266        Self::resolve_latest(&latest)
267    }
268
269    fn latest_subdir(dir: &Path) -> std::io::Result<PathBuf> {
270        let mut subdirs: Vec<PathBuf> = fs::read_dir(dir)?
271            .filter_map(|e| e.ok())
272            .map(|e| e.path())
273            .filter(|p| p.is_dir())
274            .collect();
275
276        if subdirs.is_empty() {
277            return Err(std::io::Error::new(
278                std::io::ErrorKind::NotFound,
279                format!("no subdirectories in {}", dir.display()),
280            ));
281        }
282
283        subdirs.sort();
284        Ok(subdirs.pop().unwrap())
285    }
286}