1use std::collections::HashMap;
2
3use rl_traits::EpisodeStatus;
4
5mod episode_status_serde {
10 use rl_traits::EpisodeStatus;
11 use serde::{Deserialize, Deserializer, Serialize, Serializer};
12
13 pub fn serialize<S: Serializer>(status: &EpisodeStatus, s: S) -> Result<S::Ok, S::Error> {
14 let tag = match status {
15 EpisodeStatus::Continuing => "Continuing",
16 EpisodeStatus::Terminated => "Terminated",
17 EpisodeStatus::Truncated => "Truncated",
18 };
19 tag.serialize(s)
20 }
21
22 pub fn deserialize<'de, D: Deserializer<'de>>(d: D) -> Result<EpisodeStatus, D::Error> {
23 let tag = String::deserialize(d)?;
24 match tag.as_str() {
25 "Continuing" => Ok(EpisodeStatus::Continuing),
26 "Terminated" => Ok(EpisodeStatus::Terminated),
27 "Truncated" => Ok(EpisodeStatus::Truncated),
28 other => Err(serde::de::Error::unknown_variant(other, &["Continuing", "Terminated", "Truncated"])),
29 }
30 }
31}
32
33#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
41pub struct EpisodeRecord {
42 pub total_reward: f64,
44
45 pub length: usize,
47
48 #[serde(with = "episode_status_serde")]
50 pub status: EpisodeStatus,
51
52 pub extras: HashMap<String, f64>,
54}
55
56impl EpisodeRecord {
57 pub fn new(total_reward: f64, length: usize, status: EpisodeStatus) -> Self {
58 Self {
59 total_reward,
60 length,
61 status,
62 extras: HashMap::new(),
63 }
64 }
65
66 pub fn with_extra(mut self, key: impl Into<String>, value: f64) -> Self {
67 self.extras.insert(key.into(), value);
68 self
69 }
70
71 pub fn with_extras(mut self, extras: HashMap<String, f64>) -> Self {
72 self.extras.extend(extras);
73 self
74 }
75}
76
77pub trait Aggregator: Send + Sync {
84 fn update(&mut self, value: f64);
86
87 fn value(&self) -> f64;
89
90 fn reset(&mut self);
92}
93
94#[derive(Debug, Clone, Default)]
98pub struct Mean {
99 sum: f64,
100 count: usize,
101}
102
103impl Aggregator for Mean {
104 fn update(&mut self, value: f64) {
105 self.sum += value;
106 self.count += 1;
107 }
108
109 fn value(&self) -> f64 {
110 if self.count == 0 { f64::NAN } else { self.sum / self.count as f64 }
111 }
112
113 fn reset(&mut self) {
114 self.sum = 0.0;
115 self.count = 0;
116 }
117}
118
119#[derive(Debug, Clone, Default)]
121pub struct Max {
122 max: Option<f64>,
123}
124
125impl Aggregator for Max {
126 fn update(&mut self, value: f64) {
127 self.max = Some(self.max.map_or(value, |m| m.max(value)));
128 }
129
130 fn value(&self) -> f64 {
131 self.max.unwrap_or(f64::NAN)
132 }
133
134 fn reset(&mut self) {
135 self.max = None;
136 }
137}
138
139#[derive(Debug, Clone, Default)]
141pub struct Min {
142 min: Option<f64>,
143}
144
145impl Aggregator for Min {
146 fn update(&mut self, value: f64) {
147 self.min = Some(self.min.map_or(value, |m| m.min(value)));
148 }
149
150 fn value(&self) -> f64 {
151 self.min.unwrap_or(f64::NAN)
152 }
153
154 fn reset(&mut self) {
155 self.min = None;
156 }
157}
158
159#[derive(Debug, Clone, Default)]
161pub struct Last {
162 last: Option<f64>,
163}
164
165impl Aggregator for Last {
166 fn update(&mut self, value: f64) {
167 self.last = Some(value);
168 }
169
170 fn value(&self) -> f64 {
171 self.last.unwrap_or(f64::NAN)
172 }
173
174 fn reset(&mut self) {
175 self.last = None;
176 }
177}
178
179#[derive(Debug, Clone)]
181pub struct RollingMean {
182 window: usize,
183 buf: std::collections::VecDeque<f64>,
184}
185
186impl RollingMean {
187 pub fn new(window: usize) -> Self {
188 assert!(window > 0, "window must be > 0");
189 Self { window, buf: std::collections::VecDeque::with_capacity(window) }
190 }
191}
192
193impl Aggregator for RollingMean {
194 fn update(&mut self, value: f64) {
195 if self.buf.len() == self.window {
196 self.buf.pop_front();
197 }
198 self.buf.push_back(value);
199 }
200
201 fn value(&self) -> f64 {
202 if self.buf.is_empty() {
203 f64::NAN
204 } else {
205 self.buf.iter().sum::<f64>() / self.buf.len() as f64
206 }
207 }
208
209 fn reset(&mut self) {
210 self.buf.clear();
211 }
212}
213
214#[derive(Debug, Clone, Default)]
218pub struct Std {
219 count: usize,
220 mean: f64,
221 m2: f64,
222}
223
224impl Aggregator for Std {
225 fn update(&mut self, value: f64) {
226 self.count += 1;
227 let delta = value - self.mean;
228 self.mean += delta / self.count as f64;
229 let delta2 = value - self.mean;
230 self.m2 += delta * delta2;
231 }
232
233 fn value(&self) -> f64 {
234 if self.count < 2 { f64::NAN } else { (self.m2 / self.count as f64).sqrt() }
235 }
236
237 fn reset(&mut self) {
238 self.count = 0;
239 self.mean = 0.0;
240 self.m2 = 0.0;
241 }
242}
243
244struct TrackedStat {
247 name: String,
248 extractor: Box<dyn Fn(&EpisodeRecord) -> f64 + Send + Sync>,
249 aggregator: Box<dyn Aggregator>,
250}
251
252pub struct StatsTracker {
272 stats: Vec<TrackedStat>,
273}
274
275pub enum StatSource {
277 TotalReward,
279 Length,
281 Extra(String),
283}
284
285impl StatSource {
286 fn into_extractor(self) -> Box<dyn Fn(&EpisodeRecord) -> f64 + Send + Sync> {
287 match self {
288 StatSource::TotalReward => Box::new(|r: &EpisodeRecord| r.total_reward),
289 StatSource::Length => Box::new(|r: &EpisodeRecord| r.length as f64),
290 StatSource::Extra(key) => Box::new(move |r: &EpisodeRecord| {
291 r.extras.get(&key).copied().unwrap_or(f64::NAN)
292 }),
293 }
294 }
295}
296
297impl StatsTracker {
298 pub fn new() -> Self {
300 let mut tracker = Self { stats: Vec::new() };
301 tracker = tracker.with("episode_reward", StatSource::TotalReward, Mean::default());
302 tracker = tracker.with("episode_length", StatSource::Length, Mean::default());
303 tracker
304 }
305
306 pub fn empty() -> Self {
308 Self { stats: Vec::new() }
309 }
310
311 pub fn with(mut self, name: impl Into<String>, source: StatSource, aggregator: impl Aggregator + 'static) -> Self {
313 self.stats.push(TrackedStat {
314 name: name.into(),
315 extractor: source.into_extractor(),
316 aggregator: Box::new(aggregator),
317 });
318 self
319 }
320
321 pub fn with_custom(
323 mut self,
324 name: impl Into<String>,
325 f: impl Fn(&EpisodeRecord) -> f64 + Send + Sync + 'static,
326 aggregator: impl Aggregator + 'static,
327 ) -> Self {
328 self.stats.push(TrackedStat {
329 name: name.into(),
330 extractor: Box::new(f),
331 aggregator: Box::new(aggregator),
332 });
333 self
334 }
335
336 pub fn update(&mut self, record: &EpisodeRecord) {
338 for stat in &mut self.stats {
339 let value = (stat.extractor)(record);
340 stat.aggregator.update(value);
341 }
342 }
343
344 pub fn summary(&self) -> HashMap<String, f64> {
346 self.stats.iter()
347 .map(|s| (s.name.clone(), s.aggregator.value()))
348 .collect()
349 }
350
351 pub fn reset(&mut self) {
353 for stat in &mut self.stats {
354 stat.aggregator.reset();
355 }
356 }
357}
358
359impl Default for StatsTracker {
360 fn default() -> Self {
361 Self::new()
362 }
363}
364
365#[derive(Debug, Clone)]
372pub struct EvalReport {
373 pub total_steps: usize,
375
376 pub n_episodes: usize,
378
379 pub stats: HashMap<String, f64>,
381}
382
383impl EvalReport {
384 pub fn new(total_steps: usize, n_episodes: usize, stats: HashMap<String, f64>) -> Self {
385 Self { total_steps, n_episodes, stats }
386 }
387
388 pub fn print(&self) {
390 println!("=== Eval @ step {} ({} episodes) ===", self.total_steps, self.n_episodes);
391 let mut keys: Vec<_> = self.stats.keys().collect();
392 keys.sort();
393 for key in keys {
394 println!(" {}: {:.3}", key, self.stats[key]);
395 }
396 }
397}
398