1use burn::tensor::backend::AutodiffBackend;
2use rand::rngs::SmallRng;
3use rand::SeedableRng;
4use rl_traits::{Environment, Experience};
5
6use crate::algorithms::dqn::DqnAgent;
7use crate::encoding::{DiscreteActionMapper, ObservationEncoder};
8
9#[derive(Debug, Clone)]
14pub struct StepMetrics {
15 pub total_steps: usize,
17
18 pub episode: usize,
20
21 pub episode_step: usize,
23
24 pub reward: f64,
26
27 pub episode_reward: f64,
29
30 pub epsilon: f64,
32
33 pub did_update: bool,
35
36 pub episode_done: bool,
38}
39
40pub struct DqnRunner<E, Enc, Act, B>
63where
64 E: Environment,
65 B: AutodiffBackend,
66{
67 env: E,
68 agent: DqnAgent<E, Enc, Act, B>,
69 rng: SmallRng,
70
71 current_obs: Option<E::Observation>,
73 episode: usize,
74 episode_step: usize,
75 episode_reward: f64,
76}
77
78impl<E, Enc, Act, B> DqnRunner<E, Enc, Act, B>
79where
80 E: Environment,
81 E::Observation: Clone + Send + Sync + 'static,
82 E::Action: Clone + Send + Sync + 'static,
83 Enc: ObservationEncoder<E::Observation, B>
84 + ObservationEncoder<E::Observation, B::InnerBackend>,
85 Act: DiscreteActionMapper<E::Action>,
86 B: AutodiffBackend,
87{
88 pub fn new(env: E, agent: DqnAgent<E, Enc, Act, B>, seed: u64) -> Self {
89 Self {
90 env,
91 agent,
92 rng: SmallRng::seed_from_u64(seed),
93 current_obs: None,
94 episode: 0,
95 episode_step: 0,
96 episode_reward: 0.0,
97 }
98 }
99
100 pub fn steps(&mut self) -> StepIter<'_, E, Enc, Act, B> {
102 StepIter { runner: self }
103 }
104
105 pub fn agent(&self) -> &DqnAgent<E, Enc, Act, B> {
107 &self.agent
108 }
109
110 pub fn env(&self) -> &E {
112 &self.env
113 }
114
115 fn step_once(&mut self) -> StepMetrics {
117 if self.current_obs.is_none() {
119 let (obs, _info) = self.env.reset(Some(0));
120 self.current_obs = Some(obs);
121 self.episode = 0;
122 self.episode_step = 0;
123 self.episode_reward = 0.0;
124 }
125
126 let obs = self.current_obs.clone().unwrap();
127
128 let action = self.agent.act_epsilon_greedy(&obs, &mut self.rng);
130 let epsilon = self.agent.epsilon();
131
132 let result = self.env.step(action.clone());
134 let reward = result.reward;
135 let done = result.is_done();
136
137 self.episode_reward += reward;
138 self.episode_step += 1;
139
140 let experience = Experience::new(
142 obs,
143 action,
144 reward,
145 result.observation.clone(),
146 result.status.clone(),
147 );
148 let did_update = self.agent.observe(experience);
149
150 let metrics = StepMetrics {
151 total_steps: self.agent.total_steps(),
152 episode: self.episode,
153 episode_step: self.episode_step,
154 reward,
155 episode_reward: self.episode_reward,
156 epsilon,
157 did_update,
158 episode_done: done,
159 };
160
161 if done {
163 let (next_obs, _info) = self.env.reset(None);
164 self.current_obs = Some(next_obs);
165 self.episode += 1;
166 self.episode_step = 0;
167 self.episode_reward = 0.0;
168 } else {
169 self.current_obs = Some(result.observation);
170 }
171
172 metrics
173 }
174}
175
176pub struct StepIter<'a, E, Enc, Act, B>
178where
179 E: Environment,
180 B: AutodiffBackend,
181{
182 runner: &'a mut DqnRunner<E, Enc, Act, B>,
183}
184
185impl<'a, E, Enc, Act, B> Iterator for StepIter<'a, E, Enc, Act, B>
186where
187 E: Environment,
188 E::Observation: Clone + Send + Sync + 'static,
189 E::Action: Clone + Send + Sync + 'static,
190 Enc: ObservationEncoder<E::Observation, B>
191 + ObservationEncoder<E::Observation, B::InnerBackend>,
192 Act: DiscreteActionMapper<E::Action>,
193 B: AutodiffBackend,
194{
195 type Item = StepMetrics;
196
197 fn next(&mut self) -> Option<StepMetrics> {
198 Some(self.runner.step_once())
201 }
202}