Skip to main content

ember_rl/
stats.rs

1use std::collections::HashMap;
2
3use rl_traits::EpisodeStatus;
4
5// ── EpisodeStatus serde ───────────────────────────────────────────────────────
6// EpisodeStatus lives in rl-traits which doesn't derive serde.
7// We use a string-based representation via #[serde(with)].
8
9mod episode_status_serde {
10    use rl_traits::EpisodeStatus;
11    use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13    pub fn serialize<S: Serializer>(status: &EpisodeStatus, s: S) -> Result<S::Ok, S::Error> {
14        let tag = match status {
15            EpisodeStatus::Continuing => "Continuing",
16            EpisodeStatus::Terminated => "Terminated",
17            EpisodeStatus::Truncated => "Truncated",
18        };
19        tag.serialize(s)
20    }
21
22    pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<EpisodeStatus, D::Error> {
23        let tag = String::deserialize(d)?;
24        match tag.as_str() {
25            "Continuing" => Ok(EpisodeStatus::Continuing),
26            "Terminated" => Ok(EpisodeStatus::Terminated),
27            "Truncated" => Ok(EpisodeStatus::Truncated),
28            other => Err(serde::de::Error::unknown_variant(other, &["Continuing", "Terminated", "Truncated"])),
29        }
30    }
31}
32
33// ── Records ──────────────────────────────────────────────────────────────────
34
35/// Stats recorded for a single completed episode.
36///
37/// Produced by the runner at every episode boundary and fed to `StatsTracker`.
38/// The `extras` map holds any additional per-episode scalars the user wants
39/// to track (e.g. custom environment metrics).
40#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct EpisodeRecord {
42    /// Total undiscounted reward for the episode.
43    pub total_reward: f64,
44
45    /// Number of steps in the episode.
46    pub length: usize,
47
48    /// How the episode ended.
49    #[serde(with = "episode_status_serde")]
50    pub status: EpisodeStatus,
51
52    /// Arbitrary scalar extras provided by the user or environment.
53    pub extras: HashMap<String, f64>,
54}
55
56impl EpisodeRecord {
57    pub fn new(total_reward: f64, length: usize, status: EpisodeStatus) -> Self {
58        Self {
59            total_reward,
60            length,
61            status,
62            extras: HashMap::new(),
63        }
64    }
65
66    pub fn with_extra(mut self, key: impl Into<String>, value: f64) -> Self {
67        self.extras.insert(key.into(), value);
68        self
69    }
70}
71
72// ── Aggregator trait ─────────────────────────────────────────────────────────
73
74/// Accumulates a stream of `f64` values into a single summary statistic.
75///
76/// Implement this to add custom aggregators. Built-ins: [`Mean`], [`Max`],
77/// [`Min`], [`Last`], [`RollingMean`].
78pub trait Aggregator: Send + Sync {
79    /// Record a new value.
80    fn update(&mut self, value: f64);
81
82    /// Return the current aggregate. `f64::NAN` if no values have been seen.
83    fn value(&self) -> f64;
84
85    /// Clear all accumulated values, as if freshly constructed.
86    fn reset(&mut self);
87}
88
89// ── Built-in aggregators ──────────────────────────────────────────────────────
90
91/// Running mean over all values seen since the last reset.
92#[derive(Debug, Clone, Default)]
93pub struct Mean {
94    sum: f64,
95    count: usize,
96}
97
98impl Aggregator for Mean {
99    fn update(&mut self, value: f64) {
100        self.sum += value;
101        self.count += 1;
102    }
103
104    fn value(&self) -> f64 {
105        if self.count == 0 { f64::NAN } else { self.sum / self.count as f64 }
106    }
107
108    fn reset(&mut self) {
109        self.sum = 0.0;
110        self.count = 0;
111    }
112}
113
114/// Maximum value seen since the last reset.
115#[derive(Debug, Clone, Default)]
116pub struct Max {
117    max: Option<f64>,
118}
119
120impl Aggregator for Max {
121    fn update(&mut self, value: f64) {
122        self.max = Some(self.max.map_or(value, |m| m.max(value)));
123    }
124
125    fn value(&self) -> f64 {
126        self.max.unwrap_or(f64::NAN)
127    }
128
129    fn reset(&mut self) {
130        self.max = None;
131    }
132}
133
134/// Minimum value seen since the last reset.
135#[derive(Debug, Clone, Default)]
136pub struct Min {
137    min: Option<f64>,
138}
139
140impl Aggregator for Min {
141    fn update(&mut self, value: f64) {
142        self.min = Some(self.min.map_or(value, |m| m.min(value)));
143    }
144
145    fn value(&self) -> f64 {
146        self.min.unwrap_or(f64::NAN)
147    }
148
149    fn reset(&mut self) {
150        self.min = None;
151    }
152}
153
154/// The most recent value, ignoring history.
155#[derive(Debug, Clone, Default)]
156pub struct Last {
157    last: Option<f64>,
158}
159
160impl Aggregator for Last {
161    fn update(&mut self, value: f64) {
162        self.last = Some(value);
163    }
164
165    fn value(&self) -> f64 {
166        self.last.unwrap_or(f64::NAN)
167    }
168
169    fn reset(&mut self) {
170        self.last = None;
171    }
172}
173
174/// Mean over the last `window` values (sliding window).
175#[derive(Debug, Clone)]
176pub struct RollingMean {
177    window: usize,
178    buf: std::collections::VecDeque<f64>,
179}
180
181impl RollingMean {
182    pub fn new(window: usize) -> Self {
183        assert!(window > 0, "window must be > 0");
184        Self { window, buf: std::collections::VecDeque::with_capacity(window) }
185    }
186}
187
188impl Aggregator for RollingMean {
189    fn update(&mut self, value: f64) {
190        if self.buf.len() == self.window {
191            self.buf.pop_front();
192        }
193        self.buf.push_back(value);
194    }
195
196    fn value(&self) -> f64 {
197        if self.buf.is_empty() {
198            f64::NAN
199        } else {
200            self.buf.iter().sum::<f64>() / self.buf.len() as f64
201        }
202    }
203
204    fn reset(&mut self) {
205        self.buf.clear();
206    }
207}
208
209// ── StatsTracker ──────────────────────────────────────────────────────────────
210
211struct TrackedStat {
212    name: String,
213    extractor: Box<dyn Fn(&EpisodeRecord) -> f64 + Send + Sync>,
214    aggregator: Box<dyn Aggregator>,
215}
216
217/// Accumulates per-episode stats and reports summary aggregates.
218///
219/// By default tracks `episode_reward` (mean) and `episode_length` (mean).
220/// Use the builder methods to add custom stats or change aggregators.
221///
222/// # Usage
223///
224/// ```rust,ignore
225/// let mut tracker = StatsTracker::new()
226///     .with("episode_reward_max", StatSource::TotalReward, Max::default())
227///     .with_custom("ep_len_last10", |r| r.length as f64, RollingMean::new(10));
228///
229/// // Feed an episode record:
230/// tracker.update(&record);
231///
232/// // Print summary:
233/// let summary = tracker.summary();
234/// println!("mean reward: {:.1}", summary["episode_reward"]);
235/// ```
236pub struct StatsTracker {
237    stats: Vec<TrackedStat>,
238}
239
240/// Predefined sources for `StatsTracker::with`.
241pub enum StatSource {
242    /// `EpisodeRecord::total_reward`
243    TotalReward,
244    /// `EpisodeRecord::length` cast to `f64`
245    Length,
246    /// A key from `EpisodeRecord::extras`
247    Extra(String),
248}
249
250impl StatSource {
251    fn into_extractor(self) -> Box<dyn Fn(&EpisodeRecord) -> f64 + Send + Sync> {
252        match self {
253            StatSource::TotalReward => Box::new(|r: &EpisodeRecord| r.total_reward),
254            StatSource::Length => Box::new(|r: &EpisodeRecord| r.length as f64),
255            StatSource::Extra(key) => Box::new(move |r: &EpisodeRecord| {
256                r.extras.get(&key).copied().unwrap_or(f64::NAN)
257            }),
258        }
259    }
260}
261
262impl StatsTracker {
263    /// Create a tracker with the default stats: episode_reward (mean) and episode_length (mean).
264    pub fn new() -> Self {
265        let mut tracker = Self { stats: Vec::new() };
266        tracker = tracker.with("episode_reward", StatSource::TotalReward, Mean::default());
267        tracker = tracker.with("episode_length", StatSource::Length, Mean::default());
268        tracker
269    }
270
271    /// Create a tracker with no default stats.
272    pub fn empty() -> Self {
273        Self { stats: Vec::new() }
274    }
275
276    /// Track a predefined stat field with the given aggregator.
277    pub fn with(mut self, name: impl Into<String>, source: StatSource, aggregator: impl Aggregator + 'static) -> Self {
278        self.stats.push(TrackedStat {
279            name: name.into(),
280            extractor: source.into_extractor(),
281            aggregator: Box::new(aggregator),
282        });
283        self
284    }
285
286    /// Track an arbitrary derived value from each episode record.
287    pub fn with_custom(
288        mut self,
289        name: impl Into<String>,
290        f: impl Fn(&EpisodeRecord) -> f64 + Send + Sync + 'static,
291        aggregator: impl Aggregator + 'static,
292    ) -> Self {
293        self.stats.push(TrackedStat {
294            name: name.into(),
295            extractor: Box::new(f),
296            aggregator: Box::new(aggregator),
297        });
298        self
299    }
300
301    /// Feed a completed episode into all tracked stats.
302    pub fn update(&mut self, record: &EpisodeRecord) {
303        for stat in &mut self.stats {
304            let value = (stat.extractor)(record);
305            stat.aggregator.update(value);
306        }
307    }
308
309    /// Snapshot of all current aggregate values.
310    pub fn summary(&self) -> HashMap<String, f64> {
311        self.stats.iter()
312            .map(|s| (s.name.clone(), s.aggregator.value()))
313            .collect()
314    }
315
316    /// Reset all aggregators (e.g. between eval runs).
317    pub fn reset(&mut self) {
318        for stat in &mut self.stats {
319            stat.aggregator.reset();
320        }
321    }
322}
323
324impl Default for StatsTracker {
325    fn default() -> Self {
326        Self::new()
327    }
328}
329
330// ── EvalReport ────────────────────────────────────────────────────────────────
331
332/// Summary statistics from a single evaluation run.
333///
334/// Returned by `DqnTrainer::eval()`. Contains the aggregated stats from
335/// all eval episodes plus the step count at which eval was performed.
336#[derive(Debug, Clone)]
337pub struct EvalReport {
338    /// Total agent steps at the time of evaluation.
339    pub total_steps: usize,
340
341    /// Number of episodes evaluated.
342    pub n_episodes: usize,
343
344    /// Aggregated statistics (same keys as the `StatsTracker` that produced this).
345    pub stats: HashMap<String, f64>,
346}
347
348impl EvalReport {
349    pub fn new(total_steps: usize, n_episodes: usize, stats: HashMap<String, f64>) -> Self {
350        Self { total_steps, n_episodes, stats }
351    }
352
353    /// Pretty-print all stats to stdout.
354    pub fn print(&self) {
355        println!("=== Eval @ step {} ({} episodes) ===", self.total_steps, self.n_episodes);
356        let mut keys: Vec<_> = self.stats.keys().collect();
357        keys.sort();
358        for key in keys {
359            println!("  {}: {:.3}", key, self.stats[key]);
360        }
361    }
362}
363