1use crate::episode::EpisodeStatus;
2
3#[derive(Debug, Clone)]
10pub struct Experience<O, A> {
11 pub observation: O,
13
14 pub action: A,
16
17 pub reward: f64,
19
20 pub next_observation: O,
22
23 pub status: EpisodeStatus,
31}
32
33impl<O, A> Experience<O, A> {
34 pub fn new(
35 observation: O,
36 action: A,
37 reward: f64,
38 next_observation: O,
39 status: EpisodeStatus,
40 ) -> Self {
41 Self {
42 observation,
43 action,
44 reward,
45 next_observation,
46 status,
47 }
48 }
49
50 #[inline]
52 pub fn is_done(&self) -> bool {
53 self.status.is_done()
54 }
55
56 #[inline]
62 pub fn bootstrap_mask(&self) -> f64 {
63 match self.status {
64 EpisodeStatus::Terminated => 0.0,
65 EpisodeStatus::Continuing | EpisodeStatus::Truncated => 1.0,
66 }
67 }
68
69 pub fn map_obs<O2>(self, f: impl Fn(O) -> O2) -> Experience<O2, A> {
73 Experience {
74 observation: f(self.observation),
75 action: self.action,
76 reward: self.reward,
77 next_observation: f(self.next_observation),
78 status: self.status,
79 }
80 }
81
82 pub fn map_action<A2>(self, f: impl Fn(A) -> A2) -> Experience<O, A2> {
84 Experience {
85 observation: self.observation,
86 action: f(self.action),
87 reward: self.reward,
88 next_observation: self.next_observation,
89 status: self.status,
90 }
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 fn exp(status: EpisodeStatus) -> Experience<i32, i32> {
99 Experience::new(0, 0, 1.0, 1, status)
100 }
101
102 #[test]
105 fn bootstrap_mask_is_zero_on_termination() {
106 assert_eq!(exp(EpisodeStatus::Terminated).bootstrap_mask(), 0.0);
108 }
109
110 #[test]
111 fn bootstrap_mask_is_one_when_continuing() {
112 assert_eq!(exp(EpisodeStatus::Continuing).bootstrap_mask(), 1.0);
113 }
114
115 #[test]
116 fn bootstrap_mask_is_one_when_truncated() {
117 assert_eq!(exp(EpisodeStatus::Truncated).bootstrap_mask(), 1.0);
120 }
121
122 #[test]
125 fn map_obs_transforms_both_observations() {
126 let e = Experience::new(1_i32, 99_i32, 2.0, 3_i32, EpisodeStatus::Continuing);
127 let mapped = e.map_obs(|o| o * 10);
128 assert_eq!(mapped.observation, 10);
129 assert_eq!(mapped.next_observation, 30);
130 assert_eq!(mapped.action, 99);
131 assert_eq!(mapped.reward, 2.0);
132 }
133
134 #[test]
135 fn map_action_transforms_action_preserves_observations() {
136 let e = Experience::new(5_i32, 2_i32, 0.5, 6_i32, EpisodeStatus::Truncated);
137 let mapped = e.map_action(|a| a as f64 * 0.5);
138 assert_eq!(mapped.action, 1.0_f64);
139 assert_eq!(mapped.observation, 5);
140 assert_eq!(mapped.next_observation, 6);
141 assert_eq!(mapped.status, EpisodeStatus::Truncated);
142 }
143}