1use std::marker::PhantomData;
2
3use burn_core::data::dataloader::Progress;
4use burn_core::{Tensor, prelude::Backend};
5use burn_rl::Policy;
6use burn_rl::Transition;
7use burn_rl::{Environment, EnvironmentInit};
8
9use crate::RLEvent;
10use crate::{
11 AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
12 RLEventProcessorType,
13};
14use crate::{Interrupter, RLComponentsTypes};
15
16#[derive(Clone, new)]
18pub struct Trajectory<B: Backend, S, A, C> {
19 pub timesteps: Vec<TimeStep<B, S, A, C>>,
21}
22
23#[derive(Clone)]
25pub struct TimeStep<B: Backend, S, A, C> {
26 pub env_id: usize,
28 pub transition: Transition<B, S, A>,
30 pub done: bool,
32 pub ep_len: usize,
34 pub cum_reward: f64,
36 pub action_context: C,
38}
39
40pub(crate) type RLTimeStep<B, RLC> = TimeStep<
41 B,
42 <RLC as RLComponentsTypes>::State,
43 <RLC as RLComponentsTypes>::Action,
44 <RLC as RLComponentsTypes>::ActionContext,
45>;
46
47pub(crate) type RLTrajectory<B, RLC> = Trajectory<
48 B,
49 <RLC as RLComponentsTypes>::State,
50 <RLC as RLComponentsTypes>::Action,
51 <RLC as RLComponentsTypes>::ActionContext,
52>;
53
54pub trait AgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
56 fn run_steps(
70 &mut self,
71 num_steps: usize,
72 processor: &mut RLEventProcessorType<RLC>,
73 interrupter: &Interrupter,
74 progress: &mut Progress,
75 ) -> Vec<RLTimeStep<BT, RLC>>;
76 fn run_episodes(
89 &mut self,
90 num_episodes: usize,
91 processor: &mut RLEventProcessorType<RLC>,
92 interrupter: &Interrupter,
93 progress: &mut Progress,
94 ) -> Vec<RLTrajectory<BT, RLC>>;
95 fn update_policy(&mut self, update: RLC::PolicyState);
97 fn policy(&self) -> RLC::PolicyState;
99}
100
101pub struct AgentEnvBaseLoop<B: Backend, RLC: RLComponentsTypes> {
103 env: RLC::Env,
104 eval: bool,
105 agent: RLC::Policy,
106 deterministic: bool,
107 current_reward: f64,
108 run_num: usize,
109 step_num: usize,
110 _backend: PhantomData<B>,
111}
112
113impl<B: Backend, RLC: RLComponentsTypes> AgentEnvBaseLoop<B, RLC> {
114 pub fn new(
116 env_init: RLC::EnvInit,
117 agent: RLC::Policy,
118 eval: bool,
119 deterministic: bool,
120 ) -> Self {
121 let mut env = env_init.init();
122 env.reset();
123
124 Self {
125 env,
126 eval,
127 agent: agent.clone(),
128 deterministic,
129 current_reward: 0.0,
130 run_num: 0,
131 step_num: 0,
132 _backend: PhantomData,
133 }
134 }
135}
136
137impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvBaseLoop<BT, RLC>
138where
139 BT: Backend,
140 RLC: RLComponentsTypes,
141{
142 fn run_steps(
143 &mut self,
144 num_steps: usize,
145 processor: &mut RLEventProcessorType<RLC>,
146 interrupter: &Interrupter,
147 progress: &mut Progress,
148 ) -> Vec<RLTimeStep<BT, RLC>> {
149 let mut items = vec![];
150 let device = Default::default();
151 for _ in 0..num_steps {
152 let state = self.env.state();
153 let (action, context) = self.agent.action(state.clone().into(), self.deterministic);
154
155 let step_result = self.env.step(RLC::Action::from(action.clone()));
156
157 self.current_reward += step_result.reward;
158 self.step_num += 1;
159
160 let transition = Transition::new(
161 state.clone(),
162 step_result.next_state,
163 RLC::Action::from(action),
164 Tensor::from_data([step_result.reward], &device),
165 Tensor::from_data(
166 [(step_result.done || step_result.truncated) as i32 as f64],
167 &device,
168 ),
169 );
170 items.push(TimeStep {
171 env_id: 0,
172 transition,
173 done: step_result.done,
174 ep_len: self.step_num,
175 cum_reward: self.current_reward,
176 action_context: context[0].clone(),
177 });
178
179 if !self.eval {
180 progress.items_processed += 1;
181 processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
182 context[0].clone(),
183 progress.clone(),
184 None,
185 )));
186
187 if step_result.done {
188 processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
189 EpisodeSummary {
190 episode_length: self.step_num,
191 cum_reward: self.current_reward,
192 },
193 progress.clone(),
194 None,
195 )));
196 }
197 }
198
199 if interrupter.should_stop() {
200 break;
201 }
202
203 if step_result.done || step_result.truncated {
204 self.env.reset();
205 self.current_reward = 0.;
206 self.step_num = 0;
207 self.run_num += 1;
208 }
209 }
210 items
211 }
212
213 fn update_policy(&mut self, update: RLC::PolicyState) {
214 self.agent.update(update);
215 }
216
217 fn run_episodes(
218 &mut self,
219 num_episodes: usize,
220 processor: &mut RLEventProcessorType<RLC>,
221 interrupter: &Interrupter,
222 progress: &mut Progress,
223 ) -> Vec<RLTrajectory<BT, RLC>> {
224 self.env.reset();
225
226 let mut items = vec![];
227 for ep in 0..num_episodes {
228 let mut steps = vec![];
229 loop {
230 let step = self.run_steps(1, processor, interrupter, progress)[0].clone();
231 steps.push(step.clone());
232
233 if self.eval {
234 processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
235 step.action_context.clone(),
236 Progress::new(steps.len() + 1, steps.len() + 1),
237 None,
238 )));
239
240 if step.done {
241 processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
242 EvaluationItem::new(
243 EpisodeSummary {
244 episode_length: step.ep_len,
245 cum_reward: step.cum_reward,
246 },
247 Progress::new(ep + 1, num_episodes),
248 None,
249 ),
250 ));
251 }
252 }
253
254 if interrupter.should_stop() || step.done {
255 break;
256 }
257 }
258 items.push(Trajectory::new(steps));
259
260 if interrupter.should_stop() {
261 break;
262 }
263 }
264 items
265 }
266
267 fn policy(&self) -> RLC::PolicyState {
268 self.agent.state()
269 }
270}
271
272#[cfg(test)]
273#[allow(clippy::needless_range_loop)]
274mod tests {
275 use crate::{AsyncProcessorTraining, TestBackend};
276
277 use crate::learner::tests::{
278 MockEnvInit, MockPolicy, MockPolicyState, MockProcessor, MockRLComponents,
279 };
280
281 use super::*;
282
283 fn setup(
284 state: usize,
285 eval: bool,
286 deterministic: bool,
287 ) -> AgentEnvBaseLoop<TestBackend, MockRLComponents> {
288 let env_init = MockEnvInit;
289 let agent = MockPolicy(state);
290 AgentEnvBaseLoop::<TestBackend, MockRLComponents>::new(env_init, agent, eval, deterministic)
291 }
292
293 #[test]
294 fn test_policy_returns_agent_state() {
295 let runner = setup(1000, false, false);
296 let policy_state = runner.policy();
297 assert_eq!(policy_state.0, 1000);
298 }
299
300 #[test]
301 fn test_update_policy() {
302 let mut runner = setup(0, false, false);
303
304 runner.update_policy(MockPolicyState(1));
305 assert_eq!(runner.policy().0, 1);
306 }
307
308 #[test]
309 fn run_steps_returns_requested_number() {
310 let mut runner = setup(0, false, false);
311 let mut processor = AsyncProcessorTraining::new(MockProcessor);
312 let interrupter = Interrupter::new();
313 let mut progress = Progress {
314 items_processed: 0,
315 items_total: 1,
316 };
317
318 let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);
319 assert_eq!(steps.len(), 1);
320 let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
321 assert_eq!(steps.len(), 8);
322 }
323
324 #[test]
325 fn run_episodes_returns_requested_number() {
326 let mut runner = setup(0, false, false);
327 let mut processor = AsyncProcessorTraining::new(MockProcessor);
328 let interrupter = Interrupter::new();
329 let mut progress = Progress {
330 items_processed: 0,
331 items_total: 1,
332 };
333
334 let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);
335 assert_eq!(trajectories.len(), 1);
336 assert_ne!(trajectories[0].timesteps.len(), 0);
337 let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
338 assert_eq!(trajectories.len(), 8);
339 for i in 0..8 {
340 assert_ne!(trajectories[i].timesteps.len(), 0);
341 }
342 }
343}