rl_traits/episode.rs
1/// Whether an episode is ongoing, has naturally ended, or was cut short.
2///
3/// This distinction is critical for bootstrapping in RL algorithms.
4///
5/// # Why this matters
6///
7/// When computing value targets (e.g. TD targets, GAE), the treatment of
8/// the terminal state depends on *why* the episode ended:
9///
10/// - `Terminated`: the agent reached a natural terminal state. The value of
11/// the next state is zero — there is no future reward to bootstrap.
12///
13/// - `Truncated`: the episode was cut short by something external (e.g. a
14/// time limit, the agent going out of bounds). The environment has not
15/// actually terminated — the agent simply stopped. The value of the next
16/// state is *non-zero* and must be bootstrapped from the value function.
17///
18/// Confusing these two is one of the most common bugs in policy gradient
19/// implementations. Gymnasium introduced this distinction in v0.26; we
20/// encode it correctly from the start.
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub enum EpisodeStatus {
23 /// The episode is ongoing.
24 Continuing,
25
26 /// The episode reached a natural terminal state (MDP termination).
27 ///
28 /// Bootstrap target: `r + gamma * 0` — no future value.
29 Terminated,
30
31 /// The episode was cut short by an external condition (e.g. time limit).
32 ///
33 /// Bootstrap target: `r + gamma * V(s')` — future value is non-zero.
34 Truncated,
35}
36
37impl EpisodeStatus {
38 /// Returns `true` if the episode is over for any reason.
39 #[inline]
40 pub fn is_done(&self) -> bool {
41 matches!(self, Self::Terminated | Self::Truncated)
42 }
43
44 /// Returns `true` only for natural MDP termination.
45 /// Use this to decide whether to bootstrap the next-state value.
46 #[inline]
47 pub fn is_terminal(&self) -> bool {
48 matches!(self, Self::Terminated)
49 }
50
51 /// Returns `true` if the episode was cut short externally.
52 #[inline]
53 pub fn is_truncated(&self) -> bool {
54 matches!(self, Self::Truncated)
55 }
56}
57
58/// The output of a single environment step.
59///
60/// Returned by [`crate::Environment::step`]. Contains everything an agent needs
61/// to learn: the next observation, the reward signal, whether the episode
62/// is done, and any auxiliary info.
63#[derive(Debug, Clone)]
64pub struct StepResult<O, I> {
65 /// The observation after taking the action.
66 pub observation: O,
67
68 /// The scalar reward signal.
69 pub reward: f64,
70
71 /// Whether the episode continues, terminated, or was truncated.
72 pub status: EpisodeStatus,
73
74 /// Auxiliary information (e.g. diagnostics, hidden state, sub-rewards).
75 /// Typed — no `HashMap<String, Any>` here.
76 pub info: I,
77}
78
79impl<O, I> StepResult<O, I> {
80 pub fn new(observation: O, reward: f64, status: EpisodeStatus, info: I) -> Self {
81 Self {
82 observation,
83 reward,
84 status,
85 info,
86 }
87 }
88
89 /// Convenience: is the episode over for any reason?
90 #[inline]
91 pub fn is_done(&self) -> bool {
92 self.status.is_done()
93 }
94
95 /// Map the observation to a different type (useful for wrapper implementations).
96 pub fn map_obs<O2>(self, f: impl FnOnce(O) -> O2) -> StepResult<O2, I> {
97 StepResult {
98 observation: f(self.observation),
99 reward: self.reward,
100 status: self.status,
101 info: self.info,
102 }
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109
110 // ── EpisodeStatus ────────────────────────────────────────────────────────
111
112 #[test]
113 fn continuing_is_not_done() {
114 assert!(!EpisodeStatus::Continuing.is_done());
115 assert!(!EpisodeStatus::Continuing.is_terminal());
116 assert!(!EpisodeStatus::Continuing.is_truncated());
117 }
118
119 #[test]
120 fn terminated_is_done_and_terminal() {
121 assert!(EpisodeStatus::Terminated.is_done());
122 assert!(EpisodeStatus::Terminated.is_terminal());
123 assert!(!EpisodeStatus::Terminated.is_truncated());
124 }
125
126 #[test]
127 fn truncated_is_done_but_not_terminal() {
128 // This is the critical distinction: Truncated ends the episode but the
129 // next-state value is non-zero, so algorithms must NOT zero the bootstrap.
130 assert!(EpisodeStatus::Truncated.is_done());
131 assert!(!EpisodeStatus::Truncated.is_terminal());
132 assert!(EpisodeStatus::Truncated.is_truncated());
133 }
134
135 // ── StepResult ───────────────────────────────────────────────────────────
136
137 #[test]
138 fn map_obs_transforms_observation_preserves_rest() {
139 let result = StepResult::new(2_i32, 1.5, EpisodeStatus::Continuing, "info");
140 let mapped = result.map_obs(|o| o * 10);
141 assert_eq!(mapped.observation, 20);
142 assert_eq!(mapped.reward, 1.5);
143 assert_eq!(mapped.status, EpisodeStatus::Continuing);
144 assert_eq!(mapped.info, "info");
145 }
146}