ember_rl/traits.rs
1//! Core training traits for ember-rl.
2//!
3//! These traits define the composable building blocks that all algorithms
4//! and environments plug into. They live in ember-rl (not rl-traits) because
5//! they are training machinery, not environment contract.
6
7use std::collections::HashMap;
8use std::path::Path;
9
10use rl_traits::{Environment, Experience};
11
12// ── ActMode ───────────────────────────────────────────────────────────────────
13
14/// Controls whether an agent acts to explore or exploit.
15///
16/// Passed to [`LearningAgent::act`] to select the agent's action strategy.
17/// Algorithms interpret this mode internally — DQN uses epsilon-greedy for
18/// `Explore` and greedy argmax for `Exploit`.
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
20pub enum ActMode {
21 /// Act to explore: use the algorithm's exploration strategy (e.g. ε-greedy).
22 Explore,
23 /// Act to exploit: select the greedy/best-known action.
24 Exploit,
25}
26
27// ── Checkpointable ────────────────────────────────────────────────────────────
28
29/// An agent whose weights can be saved to and loaded from disk.
30pub trait Checkpointable: Sized {
31 /// Save weights to `path` (without extension — implementations add their own).
32 fn save(&self, path: &Path) -> anyhow::Result<()>;
33
34 /// Load weights from `path`, consuming and returning `self`.
35 fn load(self, path: &Path) -> anyhow::Result<Self>;
36}
37
38// ── LearningAgent ─────────────────────────────────────────────────────────────
39
40/// An agent that can act, learn from experience, and report training stats.
41///
42/// Implemented by all algorithm agents (`DqnAgent`, future `PpoAgent`, etc.).
43/// The agent owns its exploration RNG internally — no external RNG is needed
44/// at call sites.
45///
46/// # Episode extras
47///
48/// Algorithms should maintain internal aggregators (e.g. `Mean`, `Std`, `Max`
49/// from [`crate::stats`]) over per-step values during each episode, reset them
50/// at episode start, and report summaries via [`episode_extras`]. These are
51/// merged into [`crate::stats::EpisodeRecord::extras`] automatically by
52/// [`crate::training::TrainingSession`].
53///
54/// Example extras a DQN agent might report:
55/// ```text
56/// { "epsilon": 0.12, "loss_mean": 0.043, "loss_std": 0.012, "loss_max": 0.21 }
57/// ```
58///
59/// [`episode_extras`]: LearningAgent::episode_extras
60pub trait LearningAgent<E: Environment>: Checkpointable {
61 /// Select an action for `obs` according to `mode`.
62 fn act(&mut self, obs: &E::Observation, mode: ActMode) -> E::Action;
63
64 /// Record a transition and update the agent's internal state.
65 fn observe(&mut self, experience: Experience<E::Observation, E::Action>);
66
67 /// Total number of `observe` calls since construction.
68 fn total_steps(&self) -> usize;
69
70 /// Per-episode aggregates of step-level values, reported at episode end.
71 ///
72 /// The default implementation returns an empty map. Algorithms override
73 /// this to expose training dynamics (loss statistics, epsilon, etc.).
74 fn episode_extras(&self) -> HashMap<String, f64> {
75 HashMap::new()
76 }
77
78 /// Called by [`crate::training::TrainingSession`] at the start of each
79 /// episode so the agent can reset its per-episode aggregators.
80 fn on_episode_start(&mut self) {}
81}
82