Skip to main content

rl_traits/
experience.rs

1use crate::episode::EpisodeStatus;
2
3/// A single transition: `(s, a, r, s', status)`.
4///
5/// The fundamental unit of experience stored in replay buffers and used
6/// for agent updates. Corresponds to one (s, a, r, s', done) tuple in
7/// classical RL literature, but with a richer `status` field that
8/// distinguishes natural termination from truncation.
9#[derive(Debug, Clone)]
10pub struct Experience<O, A> {
11    /// The observation at the start of this transition.
12    pub observation: O,
13
14    /// The action taken.
15    pub action: A,
16
17    /// The scalar reward received.
18    pub reward: f64,
19
20    /// The observation after taking the action.
21    pub next_observation: O,
22
23    /// Whether the episode ended and why.
24    ///
25    /// Algorithms that bootstrap value estimates (DQN, PPO, SAC) must
26    /// inspect this to handle terminal states correctly:
27    /// - `Terminated`: bootstrap with zero value
28    /// - `Truncated`: bootstrap with V(next_observation)
29    /// - `Continuing`: bootstrap with V(next_observation)
30    pub status: EpisodeStatus,
31}
32
33impl<O, A> Experience<O, A> {
34    pub fn new(
35        observation: O,
36        action: A,
37        reward: f64,
38        next_observation: O,
39        status: EpisodeStatus,
40    ) -> Self {
41        Self {
42            observation,
43            action,
44            reward,
45            next_observation,
46            status,
47        }
48    }
49
50    /// Returns `true` if this transition ends an episode.
51    #[inline]
52    pub fn is_done(&self) -> bool {
53        self.status.is_done()
54    }
55
56    /// Returns the bootstrap mask: `1.0` if the episode continues or was
57    /// truncated (i.e. the next state has non-zero value), `0.0` if terminated.
58    ///
59    /// Multiply value estimates by this mask when computing TD targets:
60    /// `target = reward + gamma * bootstrap_mask() * V(next_obs)`
61    #[inline]
62    pub fn bootstrap_mask(&self) -> f64 {
63        match self.status {
64            EpisodeStatus::Terminated => 0.0,
65            EpisodeStatus::Continuing | EpisodeStatus::Truncated => 1.0,
66        }
67    }
68
69    /// Map the observation to a different type.
70    ///
71    /// Useful for observation-wrapping layers that preprocess before storage.
72    pub fn map_obs<O2>(self, f: impl Fn(O) -> O2) -> Experience<O2, A> {
73        Experience {
74            observation: f(self.observation),
75            action: self.action,
76            reward: self.reward,
77            next_observation: f(self.next_observation),
78            status: self.status,
79        }
80    }
81
82    /// Map the action to a different type.
83    pub fn map_action<A2>(self, f: impl Fn(A) -> A2) -> Experience<O, A2> {
84        Experience {
85            observation: self.observation,
86            action: f(self.action),
87            reward: self.reward,
88            next_observation: self.next_observation,
89            status: self.status,
90        }
91    }
92}
93
94#[cfg(test)]
95mod tests {
96    use super::*;
97
98    fn exp(status: EpisodeStatus) -> Experience<i32, i32> {
99        Experience::new(0, 0, 1.0, 1, status)
100    }
101
102    // ── bootstrap_mask ───────────────────────────────────────────────────────
103
104    #[test]
105    fn bootstrap_mask_is_zero_on_termination() {
106        // Terminated: natural end of MDP — next-state value is zero.
107        assert_eq!(exp(EpisodeStatus::Terminated).bootstrap_mask(), 0.0);
108    }
109
110    #[test]
111    fn bootstrap_mask_is_one_when_continuing() {
112        assert_eq!(exp(EpisodeStatus::Continuing).bootstrap_mask(), 1.0);
113    }
114
115    #[test]
116    fn bootstrap_mask_is_one_when_truncated() {
117        // Truncated: episode cut short externally — next state still has value.
118        // Zeroing this would underestimate returns; algorithms must bootstrap.
119        assert_eq!(exp(EpisodeStatus::Truncated).bootstrap_mask(), 1.0);
120    }
121
122    // ── map_obs / map_action ─────────────────────────────────────────────────
123
124    #[test]
125    fn map_obs_transforms_both_observations() {
126        let e = Experience::new(1_i32, 99_i32, 2.0, 3_i32, EpisodeStatus::Continuing);
127        let mapped = e.map_obs(|o| o * 10);
128        assert_eq!(mapped.observation, 10);
129        assert_eq!(mapped.next_observation, 30);
130        assert_eq!(mapped.action, 99);
131        assert_eq!(mapped.reward, 2.0);
132    }
133
134    #[test]
135    fn map_action_transforms_action_preserves_observations() {
136        let e = Experience::new(5_i32, 2_i32, 0.5, 6_i32, EpisodeStatus::Truncated);
137        let mapped = e.map_action(|a| a as f64 * 0.5);
138        assert_eq!(mapped.action, 1.0_f64);
139        assert_eq!(mapped.observation, 5);
140        assert_eq!(mapped.next_observation, 6);
141        assert_eq!(mapped.status, EpisodeStatus::Truncated);
142    }
143}