1use std::collections::HashMap;
2
3use rl_traits::EpisodeStatus;
4
5mod episode_status_serde {
10 use rl_traits::EpisodeStatus;
11 use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13 pub fn serialize<S: Serializer>(status: &EpisodeStatus, s: S) -> Result<S::Ok, S::Error> {
14 let tag = match status {
15 EpisodeStatus::Continuing => "Continuing",
16 EpisodeStatus::Terminated => "Terminated",
17 EpisodeStatus::Truncated => "Truncated",
18 };
19 tag.serialize(s)
20 }
21
22 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<EpisodeStatus, D::Error> {
23 let tag = String::deserialize(d)?;
24 match tag.as_str() {
25 "Continuing" => Ok(EpisodeStatus::Continuing),
26 "Terminated" => Ok(EpisodeStatus::Terminated),
27 "Truncated" => Ok(EpisodeStatus::Truncated),
28 other => Err(serde::de::Error::unknown_variant(other, &["Continuing", "Terminated", "Truncated"])),
29 }
30 }
31}
32
33#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct EpisodeRecord {
42 pub total_reward: f64,
44
45 pub length: usize,
47
48 #[serde(with = "episode_status_serde")]
50 pub status: EpisodeStatus,
51
52 pub extras: HashMap<String, f64>,
54}
55
56impl EpisodeRecord {
57 pub fn new(total_reward: f64, length: usize, status: EpisodeStatus) -> Self {
58 Self {
59 total_reward,
60 length,
61 status,
62 extras: HashMap::new(),
63 }
64 }
65
66 pub fn with_extra(mut self, key: impl Into<String>, value: f64) -> Self {
67 self.extras.insert(key.into(), value);
68 self
69 }
70}
71
72pub trait Aggregator: Send + Sync {
79 fn update(&mut self, value: f64);
81
82 fn value(&self) -> f64;
84
85 fn reset(&mut self);
87}
88
89#[derive(Debug, Clone, Default)]
93pub struct Mean {
94 sum: f64,
95 count: usize,
96}
97
98impl Aggregator for Mean {
99 fn update(&mut self, value: f64) {
100 self.sum += value;
101 self.count += 1;
102 }
103
104 fn value(&self) -> f64 {
105 if self.count == 0 { f64::NAN } else { self.sum / self.count as f64 }
106 }
107
108 fn reset(&mut self) {
109 self.sum = 0.0;
110 self.count = 0;
111 }
112}
113
114#[derive(Debug, Clone, Default)]
116pub struct Max {
117 max: Option<f64>,
118}
119
120impl Aggregator for Max {
121 fn update(&mut self, value: f64) {
122 self.max = Some(self.max.map_or(value, |m| m.max(value)));
123 }
124
125 fn value(&self) -> f64 {
126 self.max.unwrap_or(f64::NAN)
127 }
128
129 fn reset(&mut self) {
130 self.max = None;
131 }
132}
133
134#[derive(Debug, Clone, Default)]
136pub struct Min {
137 min: Option<f64>,
138}
139
140impl Aggregator for Min {
141 fn update(&mut self, value: f64) {
142 self.min = Some(self.min.map_or(value, |m| m.min(value)));
143 }
144
145 fn value(&self) -> f64 {
146 self.min.unwrap_or(f64::NAN)
147 }
148
149 fn reset(&mut self) {
150 self.min = None;
151 }
152}
153
154#[derive(Debug, Clone, Default)]
156pub struct Last {
157 last: Option<f64>,
158}
159
160impl Aggregator for Last {
161 fn update(&mut self, value: f64) {
162 self.last = Some(value);
163 }
164
165 fn value(&self) -> f64 {
166 self.last.unwrap_or(f64::NAN)
167 }
168
169 fn reset(&mut self) {
170 self.last = None;
171 }
172}
173
174#[derive(Debug, Clone)]
176pub struct RollingMean {
177 window: usize,
178 buf: std::collections::VecDeque<f64>,
179}
180
181impl RollingMean {
182 pub fn new(window: usize) -> Self {
183 assert!(window > 0, "window must be > 0");
184 Self { window, buf: std::collections::VecDeque::with_capacity(window) }
185 }
186}
187
188impl Aggregator for RollingMean {
189 fn update(&mut self, value: f64) {
190 if self.buf.len() == self.window {
191 self.buf.pop_front();
192 }
193 self.buf.push_back(value);
194 }
195
196 fn value(&self) -> f64 {
197 if self.buf.is_empty() {
198 f64::NAN
199 } else {
200 self.buf.iter().sum::<f64>() / self.buf.len() as f64
201 }
202 }
203
204 fn reset(&mut self) {
205 self.buf.clear();
206 }
207}
208
209struct TrackedStat {
212 name: String,
213 extractor: Box<dyn Fn(&EpisodeRecord) -> f64 + Send + Sync>,
214 aggregator: Box<dyn Aggregator>,
215}
216
217pub struct StatsTracker {
237 stats: Vec<TrackedStat>,
238}
239
240pub enum StatSource {
242 TotalReward,
244 Length,
246 Extra(String),
248}
249
250impl StatSource {
251 fn into_extractor(self) -> Box<dyn Fn(&EpisodeRecord) -> f64 + Send + Sync> {
252 match self {
253 StatSource::TotalReward => Box::new(|r: &EpisodeRecord| r.total_reward),
254 StatSource::Length => Box::new(|r: &EpisodeRecord| r.length as f64),
255 StatSource::Extra(key) => Box::new(move |r: &EpisodeRecord| {
256 r.extras.get(&key).copied().unwrap_or(f64::NAN)
257 }),
258 }
259 }
260}
261
262impl StatsTracker {
263 pub fn new() -> Self {
265 let mut tracker = Self { stats: Vec::new() };
266 tracker = tracker.with("episode_reward", StatSource::TotalReward, Mean::default());
267 tracker = tracker.with("episode_length", StatSource::Length, Mean::default());
268 tracker
269 }
270
271 pub fn empty() -> Self {
273 Self { stats: Vec::new() }
274 }
275
276 pub fn with(mut self, name: impl Into<String>, source: StatSource, aggregator: impl Aggregator + 'static) -> Self {
278 self.stats.push(TrackedStat {
279 name: name.into(),
280 extractor: source.into_extractor(),
281 aggregator: Box::new(aggregator),
282 });
283 self
284 }
285
286 pub fn with_custom(
288 mut self,
289 name: impl Into<String>,
290 f: impl Fn(&EpisodeRecord) -> f64 + Send + Sync + 'static,
291 aggregator: impl Aggregator + 'static,
292 ) -> Self {
293 self.stats.push(TrackedStat {
294 name: name.into(),
295 extractor: Box::new(f),
296 aggregator: Box::new(aggregator),
297 });
298 self
299 }
300
301 pub fn update(&mut self, record: &EpisodeRecord) {
303 for stat in &mut self.stats {
304 let value = (stat.extractor)(record);
305 stat.aggregator.update(value);
306 }
307 }
308
309 pub fn summary(&self) -> HashMap<String, f64> {
311 self.stats.iter()
312 .map(|s| (s.name.clone(), s.aggregator.value()))
313 .collect()
314 }
315
316 pub fn reset(&mut self) {
318 for stat in &mut self.stats {
319 stat.aggregator.reset();
320 }
321 }
322}
323
324impl Default for StatsTracker {
325 fn default() -> Self {
326 Self::new()
327 }
328}
329
330#[derive(Debug, Clone)]
337pub struct EvalReport {
338 pub total_steps: usize,
340
341 pub n_episodes: usize,
343
344 pub stats: HashMap<String, f64>,
346}
347
348impl EvalReport {
349 pub fn new(total_steps: usize, n_episodes: usize, stats: HashMap<String, f64>) -> Self {
350 Self { total_steps, n_episodes, stats }
351 }
352
353 pub fn print(&self) {
355 println!("=== Eval @ step {} ({} episodes) ===", self.total_steps, self.n_episodes);
356 let mut keys: Vec<_> = self.stats.keys().collect();
357 keys.sort();
358 for key in keys {
359 println!(" {}: {:.3}", key, self.stats[key]);
360 }
361 }
362}
363