Skip to main content

ember_rl/
stats.rs

1use std::collections::HashMap;
2
3use rl_traits::EpisodeStatus;
4
5// ── EpisodeStatus serde ───────────────────────────────────────────────────────
6// EpisodeStatus lives in rl-traits which doesn't derive serde.
7// We use a string-based representation via #[serde(with)].
8
9mod episode_status_serde {
10    use rl_traits::EpisodeStatus;
11    use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13    pub fn serialize<S: Serializer>(status: &EpisodeStatus, s: S) -> Result<S::Ok, S::Error> {
14        let tag = match status {
15            EpisodeStatus::Continuing => "Continuing",
16            EpisodeStatus::Terminated => "Terminated",
17            EpisodeStatus::Truncated => "Truncated",
18        };
19        tag.serialize(s)
20    }
21
22    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<EpisodeStatus, D::Error> {
23        let tag = String::deserialize(d)?;
24        match tag.as_str() {
25            "Continuing" => Ok(EpisodeStatus::Continuing),
26            "Terminated" => Ok(EpisodeStatus::Terminated),
27            "Truncated" => Ok(EpisodeStatus::Truncated),
28            other => Err(serde::de::Error::unknown_variant(other, &["Continuing", "Terminated", "Truncated"])),
29        }
30    }
31}
32
33// ── Records ──────────────────────────────────────────────────────────────────
34
35/// Stats recorded for a single completed episode.
36///
37/// Produced by the runner at every episode boundary and fed to `StatsTracker`.
38/// The `extras` map holds any additional per-episode scalars the user wants
39/// to track (e.g. custom environment metrics).
40#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct EpisodeRecord {
42    /// Total undiscounted reward for the episode.
43    pub total_reward: f64,
44
45    /// Number of steps in the episode.
46    pub length: usize,
47
48    /// How the episode ended.
49    #[serde(with = "episode_status_serde")]
50    pub status: EpisodeStatus,
51
52    /// Arbitrary scalar extras provided by the user or environment.
53    pub extras: HashMap<String, f64>,
54}
55
56impl EpisodeRecord {
57    pub fn new(total_reward: f64, length: usize, status: EpisodeStatus) -> Self {
58        Self {
59            total_reward,
60            length,
61            status,
62            extras: HashMap::new(),
63        }
64    }
65
66    pub fn with_extra(mut self, key: impl Into<String>, value: f64) -> Self {
67        self.extras.insert(key.into(), value);
68        self
69    }
70
71    pub fn with_extras(mut self, extras: HashMap<String, f64>) -> Self {
72        self.extras.extend(extras);
73        self
74    }
75}
76
77// ── Aggregator trait ─────────────────────────────────────────────────────────
78
79/// Accumulates a stream of `f64` values into a single summary statistic.
80///
81/// Implement this to add custom aggregators. Built-ins: [`Mean`], [`Max`],
82/// [`Min`], [`Last`], [`RollingMean`].
83pub trait Aggregator: Send + Sync {
84    /// Record a new value.
85    fn update(&mut self, value: f64);
86
87    /// Return the current aggregate. `f64::NAN` if no values have been seen.
88    fn value(&self) -> f64;
89
90    /// Clear all accumulated values, as if freshly constructed.
91    fn reset(&mut self);
92}
93
94// ── Built-in aggregators ──────────────────────────────────────────────────────
95
96/// Running mean over all values seen since the last reset.
97#[derive(Debug, Clone, Default)]
98pub struct Mean {
99    sum: f64,
100    count: usize,
101}
102
103impl Aggregator for Mean {
104    fn update(&mut self, value: f64) {
105        self.sum += value;
106        self.count += 1;
107    }
108
109    fn value(&self) -> f64 {
110        if self.count == 0 { f64::NAN } else { self.sum / self.count as f64 }
111    }
112
113    fn reset(&mut self) {
114        self.sum = 0.0;
115        self.count = 0;
116    }
117}
118
119/// Maximum value seen since the last reset.
120#[derive(Debug, Clone, Default)]
121pub struct Max {
122    max: Option<f64>,
123}
124
125impl Aggregator for Max {
126    fn update(&mut self, value: f64) {
127        self.max = Some(self.max.map_or(value, |m| m.max(value)));
128    }
129
130    fn value(&self) -> f64 {
131        self.max.unwrap_or(f64::NAN)
132    }
133
134    fn reset(&mut self) {
135        self.max = None;
136    }
137}
138
139/// Minimum value seen since the last reset.
140#[derive(Debug, Clone, Default)]
141pub struct Min {
142    min: Option<f64>,
143}
144
145impl Aggregator for Min {
146    fn update(&mut self, value: f64) {
147        self.min = Some(self.min.map_or(value, |m| m.min(value)));
148    }
149
150    fn value(&self) -> f64 {
151        self.min.unwrap_or(f64::NAN)
152    }
153
154    fn reset(&mut self) {
155        self.min = None;
156    }
157}
158
159/// The most recent value, ignoring history.
160#[derive(Debug, Clone, Default)]
161pub struct Last {
162    last: Option<f64>,
163}
164
165impl Aggregator for Last {
166    fn update(&mut self, value: f64) {
167        self.last = Some(value);
168    }
169
170    fn value(&self) -> f64 {
171        self.last.unwrap_or(f64::NAN)
172    }
173
174    fn reset(&mut self) {
175        self.last = None;
176    }
177}
178
179/// Mean over the last `window` values (sliding window).
180#[derive(Debug, Clone)]
181pub struct RollingMean {
182    window: usize,
183    buf: std::collections::VecDeque<f64>,
184}
185
186impl RollingMean {
187    pub fn new(window: usize) -> Self {
188        assert!(window > 0, "window must be > 0");
189        Self { window, buf: std::collections::VecDeque::with_capacity(window) }
190    }
191}
192
193impl Aggregator for RollingMean {
194    fn update(&mut self, value: f64) {
195        if self.buf.len() == self.window {
196            self.buf.pop_front();
197        }
198        self.buf.push_back(value);
199    }
200
201    fn value(&self) -> f64 {
202        if self.buf.is_empty() {
203            f64::NAN
204        } else {
205            self.buf.iter().sum::<f64>() / self.buf.len() as f64
206        }
207    }
208
209    fn reset(&mut self) {
210        self.buf.clear();
211    }
212}
213
214/// Population standard deviation over all values seen since the last reset.
215///
216/// Uses Welford's online algorithm for numerical stability.
217#[derive(Debug, Clone, Default)]
218pub struct Std {
219    count: usize,
220    mean: f64,
221    m2: f64,
222}
223
224impl Aggregator for Std {
225    fn update(&mut self, value: f64) {
226        self.count += 1;
227        let delta = value - self.mean;
228        self.mean += delta / self.count as f64;
229        let delta2 = value - self.mean;
230        self.m2 += delta * delta2;
231    }
232
233    fn value(&self) -> f64 {
234        if self.count < 2 { f64::NAN } else { (self.m2 / self.count as f64).sqrt() }
235    }
236
237    fn reset(&mut self) {
238        self.count = 0;
239        self.mean = 0.0;
240        self.m2 = 0.0;
241    }
242}
243
244// ── StatsTracker ──────────────────────────────────────────────────────────────
245
246struct TrackedStat {
247    name: String,
248    extractor: Box<dyn Fn(&EpisodeRecord) -> f64 + Send + Sync>,
249    aggregator: Box<dyn Aggregator>,
250}
251
252/// Accumulates per-episode stats and reports summary aggregates.
253///
254/// By default tracks `episode_reward` (mean) and `episode_length` (mean).
255/// Use the builder methods to add custom stats or change aggregators.
256///
257/// # Usage
258///
259/// ```rust,ignore
260/// let mut tracker = StatsTracker::new()
261///     .with("episode_reward_max", StatSource::TotalReward, Max::default())
262///     .with_custom("ep_len_last10", |r| r.length as f64, RollingMean::new(10));
263///
264/// // Feed an episode record:
265/// tracker.update(&record);
266///
267/// // Print summary:
268/// let summary = tracker.summary();
269/// println!("mean reward: {:.1}", summary["episode_reward"]);
270/// ```
271pub struct StatsTracker {
272    stats: Vec<TrackedStat>,
273}
274
275/// Predefined sources for `StatsTracker::with`.
276pub enum StatSource {
277    /// `EpisodeRecord::total_reward`
278    TotalReward,
279    /// `EpisodeRecord::length` cast to `f64`
280    Length,
281    /// A key from `EpisodeRecord::extras`
282    Extra(String),
283}
284
285impl StatSource {
286    fn into_extractor(self) -> Box<dyn Fn(&EpisodeRecord) -> f64 + Send + Sync> {
287        match self {
288            StatSource::TotalReward => Box::new(|r: &EpisodeRecord| r.total_reward),
289            StatSource::Length => Box::new(|r: &EpisodeRecord| r.length as f64),
290            StatSource::Extra(key) => Box::new(move |r: &EpisodeRecord| {
291                r.extras.get(&key).copied().unwrap_or(f64::NAN)
292            }),
293        }
294    }
295}
296
297impl StatsTracker {
298    /// Create a tracker with the default stats: episode_reward (mean) and episode_length (mean).
299    pub fn new() -> Self {
300        let mut tracker = Self { stats: Vec::new() };
301        tracker = tracker.with("episode_reward", StatSource::TotalReward, Mean::default());
302        tracker = tracker.with("episode_length", StatSource::Length, Mean::default());
303        tracker
304    }
305
306    /// Create a tracker with no default stats.
307    pub fn empty() -> Self {
308        Self { stats: Vec::new() }
309    }
310
311    /// Track a predefined stat field with the given aggregator.
312    pub fn with(mut self, name: impl Into<String>, source: StatSource, aggregator: impl Aggregator + 'static) -> Self {
313        self.stats.push(TrackedStat {
314            name: name.into(),
315            extractor: source.into_extractor(),
316            aggregator: Box::new(aggregator),
317        });
318        self
319    }
320
321    /// Track an arbitrary derived value from each episode record.
322    pub fn with_custom(
323        mut self,
324        name: impl Into<String>,
325        f: impl Fn(&EpisodeRecord) -> f64 + Send + Sync + 'static,
326        aggregator: impl Aggregator + 'static,
327    ) -> Self {
328        self.stats.push(TrackedStat {
329            name: name.into(),
330            extractor: Box::new(f),
331            aggregator: Box::new(aggregator),
332        });
333        self
334    }
335
336    /// Feed a completed episode into all tracked stats.
337    pub fn update(&mut self, record: &EpisodeRecord) {
338        for stat in &mut self.stats {
339            let value = (stat.extractor)(record);
340            stat.aggregator.update(value);
341        }
342    }
343
344    /// Snapshot of all current aggregate values.
345    pub fn summary(&self) -> HashMap<String, f64> {
346        self.stats.iter()
347            .map(|s| (s.name.clone(), s.aggregator.value()))
348            .collect()
349    }
350
351    /// Reset all aggregators (e.g. between eval runs).
352    pub fn reset(&mut self) {
353        for stat in &mut self.stats {
354            stat.aggregator.reset();
355        }
356    }
357}
358
359impl Default for StatsTracker {
360    fn default() -> Self {
361        Self::new()
362    }
363}
364
365// ── EvalReport ────────────────────────────────────────────────────────────────
366
367/// Summary statistics from a single evaluation run.
368///
369/// Returned by `DqnTrainer::eval()`. Contains the aggregated stats from
370/// all eval episodes plus the step count at which eval was performed.
371#[derive(Debug, Clone)]
372pub struct EvalReport {
373    /// Total agent steps at the time of evaluation.
374    pub total_steps: usize,
375
376    /// Number of episodes evaluated.
377    pub n_episodes: usize,
378
379    /// Aggregated statistics (same keys as the `StatsTracker` that produced this).
380    pub stats: HashMap<String, f64>,
381}
382
383impl EvalReport {
384    pub fn new(total_steps: usize, n_episodes: usize, stats: HashMap<String, f64>) -> Self {
385        Self { total_steps, n_episodes, stats }
386    }
387
388    /// Pretty-print all stats to stdout.
389    pub fn print(&self) {
390        println!("=== Eval @ step {} ({} episodes) ===", self.total_steps, self.n_episodes);
391        let mut keys: Vec<_> = self.stats.keys().collect();
392        keys.sort();
393        for key in keys {
394            println!("  {}: {:.3}", key, self.stats[key]);
395        }
396    }
397}
398