Skip to main content

rl_traits/
episode.rs

1/// Whether an episode is ongoing, has naturally ended, or was cut short.
2///
3/// This distinction is critical for bootstrapping in RL algorithms.
4///
5/// # Why this matters
6///
7/// When computing value targets (e.g. TD targets, GAE), the treatment of
8/// the terminal state depends on *why* the episode ended:
9///
10/// - `Terminated`: the agent reached a natural terminal state. The value of
11///   the next state is zero — there is no future reward to bootstrap.
12///
13/// - `Truncated`: the episode was cut short by something external (e.g. a
14///   time limit, the agent going out of bounds). The environment has not
15///   actually terminated — the agent simply stopped. The value of the next
16///   state is *non-zero* and must be bootstrapped from the value function.
17///
18/// Confusing these two is one of the most common bugs in policy gradient
19/// implementations. Gymnasium introduced this distinction in v0.26; we
20/// encode it correctly from the start.
21#[derive(Debug, Clone, PartialEq, Eq, Hash)]
22pub enum EpisodeStatus {
23    /// The episode is ongoing.
24    Continuing,
25
26    /// The episode reached a natural terminal state (MDP termination).
27    ///
28    /// Bootstrap target: `r + gamma * 0` — no future value.
29    Terminated,
30
31    /// The episode was cut short by an external condition (e.g. time limit).
32    ///
33    /// Bootstrap target: `r + gamma * V(s')` — future value is non-zero.
34    Truncated,
35}
36
37impl EpisodeStatus {
38    /// Returns `true` if the episode is over for any reason.
39    #[inline]
40    pub fn is_done(&self) -> bool {
41        matches!(self, Self::Terminated | Self::Truncated)
42    }
43
44    /// Returns `true` only for natural MDP termination.
45    /// Use this to decide whether to bootstrap the next-state value.
46    #[inline]
47    pub fn is_terminal(&self) -> bool {
48        matches!(self, Self::Terminated)
49    }
50
51    /// Returns `true` if the episode was cut short externally.
52    #[inline]
53    pub fn is_truncated(&self) -> bool {
54        matches!(self, Self::Truncated)
55    }
56}
57
58/// The output of a single environment step.
59///
60/// Returned by [`crate::Environment::step`]. Contains everything an agent needs
61/// to learn: the next observation, the reward signal, whether the episode
62/// is done, and any auxiliary info.
63#[derive(Debug, Clone)]
64pub struct StepResult<O, I> {
65    /// The observation after taking the action.
66    pub observation: O,
67
68    /// The scalar reward signal.
69    pub reward: f64,
70
71    /// Whether the episode continues, terminated, or was truncated.
72    pub status: EpisodeStatus,
73
74    /// Auxiliary information (e.g. diagnostics, hidden state, sub-rewards).
75    /// Typed — no `HashMap<String, Any>` here.
76    pub info: I,
77}
78
79impl<O, I> StepResult<O, I> {
80    pub fn new(observation: O, reward: f64, status: EpisodeStatus, info: I) -> Self {
81        Self {
82            observation,
83            reward,
84            status,
85            info,
86        }
87    }
88
89    /// Convenience: is the episode over for any reason?
90    #[inline]
91    pub fn is_done(&self) -> bool {
92        self.status.is_done()
93    }
94
95    /// Map the observation to a different type (useful for wrapper implementations).
96    pub fn map_obs<O2>(self, f: impl FnOnce(O) -> O2) -> StepResult<O2, I> {
97        StepResult {
98            observation: f(self.observation),
99            reward: self.reward,
100            status: self.status,
101            info: self.info,
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    // ── EpisodeStatus ────────────────────────────────────────────────────────
111
112    #[test]
113    fn continuing_is_not_done() {
114        assert!(!EpisodeStatus::Continuing.is_done());
115        assert!(!EpisodeStatus::Continuing.is_terminal());
116        assert!(!EpisodeStatus::Continuing.is_truncated());
117    }
118
119    #[test]
120    fn terminated_is_done_and_terminal() {
121        assert!(EpisodeStatus::Terminated.is_done());
122        assert!(EpisodeStatus::Terminated.is_terminal());
123        assert!(!EpisodeStatus::Terminated.is_truncated());
124    }
125
126    #[test]
127    fn truncated_is_done_but_not_terminal() {
128        // This is the critical distinction: Truncated ends the episode but the
129        // next-state value is non-zero, so algorithms must NOT zero the bootstrap.
130        assert!(EpisodeStatus::Truncated.is_done());
131        assert!(!EpisodeStatus::Truncated.is_terminal());
132        assert!(EpisodeStatus::Truncated.is_truncated());
133    }
134
135    // ── StepResult ───────────────────────────────────────────────────────────
136
137    #[test]
138    fn map_obs_transforms_observation_preserves_rest() {
139        let result = StepResult::new(2_i32, 1.5, EpisodeStatus::Continuing, "info");
140        let mapped = result.map_obs(|o| o * 10);
141        assert_eq!(mapped.observation, 20);
142        assert_eq!(mapped.reward, 1.5);
143        assert_eq!(mapped.status, EpisodeStatus::Continuing);
144        assert_eq!(mapped.info, "info");
145    }
146}