Skip to main content

burn_train/learner/rl/env_runner/
base.rs

1use std::marker::PhantomData;
2
3use burn_core::data::dataloader::Progress;
4use burn_core::{Tensor, prelude::Backend};
5use burn_rl::Policy;
6use burn_rl::Transition;
7use burn_rl::{Environment, EnvironmentInit};
8
9use crate::RLEvent;
10use crate::{
11    AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
12    RLEventProcessorType,
13};
14use crate::{Interrupter, RLComponentsTypes};
15
16/// A trajectory, i.e. a list of ordered [TimeStep](TimeStep).
17#[derive(Clone, new)]
18pub struct Trajectory<B: Backend, S, A, C> {
19    /// A list of ordered [TimeStep](TimeStep)s.
20    pub timesteps: Vec<TimeStep<B, S, A, C>>,
21}
22
23/// A timestep debscribing an iteration of the state/decision process.
24#[derive(Clone)]
25pub struct TimeStep<B: Backend, S, A, C> {
26    /// The environment id.
27    pub env_id: usize,
28    /// The [burn_rl::Transition](burn_rl::Transition).
29    pub transition: Transition<B, S, A>,
30    /// True if the environment reaches a terminal state.
31    pub done: bool,
32    /// The running length of the current episode.
33    pub ep_len: usize,
34    /// The running cumulative reward.
35    pub cum_reward: f64,
36    /// The action's context for this timestep.
37    pub action_context: C,
38}
39
40pub(crate) type RLTimeStep<B, RLC> = TimeStep<
41    B,
42    <RLC as RLComponentsTypes>::State,
43    <RLC as RLComponentsTypes>::Action,
44    <RLC as RLComponentsTypes>::ActionContext,
45>;
46
47pub(crate) type RLTrajectory<B, RLC> = Trajectory<
48    B,
49    <RLC as RLComponentsTypes>::State,
50    <RLC as RLComponentsTypes>::Action,
51    <RLC as RLComponentsTypes>::ActionContext,
52>;
53
54/// Trait for a structure that implements an agent/environement interface.
55pub trait AgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
56    /// Run a certain number of timesteps.
57    ///
58    /// # Arguments
59    ///
60    /// * `num_steps` - The number of time_steps to run.
61    /// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
62    /// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
63    /// * `num_steps` - The number of time_steps to run.
64    /// * `progress` - A mutable reference to the learning progress.
65    ///
66    /// # Returns
67    ///
68    /// A list of ordered timesteps.
69    fn run_steps(
70        &mut self,
71        num_steps: usize,
72        processor: &mut RLEventProcessorType<RLC>,
73        interrupter: &Interrupter,
74        progress: &mut Progress,
75    ) -> Vec<RLTimeStep<BT, RLC>>;
76    /// Run a certain number of episodes.
77    ///
78    /// # Arguments
79    ///
80    /// * `num_episodes` - The number of episodes to run.
81    /// * `processor` - An [crate::EventProcessorTraining](crate::EventProcessorTraining).
82    /// * `interrupter` - An [crate::Interrupter](crate::Interrupter).
83    /// * `progress` - A mutable reference to the learning progress.
84    ///
85    /// # Returns
86    ///
87    /// A list of ordered timesteps.
88    fn run_episodes(
89        &mut self,
90        num_episodes: usize,
91        processor: &mut RLEventProcessorType<RLC>,
92        interrupter: &Interrupter,
93        progress: &mut Progress,
94    ) -> Vec<RLTrajectory<BT, RLC>>;
95    /// Update the runner's agent.
96    fn update_policy(&mut self, update: RLC::PolicyState);
97    /// Get the state of the runner's agent.
98    fn policy(&self) -> RLC::PolicyState;
99}
100
101/// A simple, synchronized agent/environement interface.
102pub struct AgentEnvBaseLoop<B: Backend, RLC: RLComponentsTypes> {
103    env: RLC::Env,
104    eval: bool,
105    agent: RLC::Policy,
106    deterministic: bool,
107    current_reward: f64,
108    run_num: usize,
109    step_num: usize,
110    _backend: PhantomData<B>,
111}
112
113impl<B: Backend, RLC: RLComponentsTypes> AgentEnvBaseLoop<B, RLC> {
114    /// Create a new base runner.
115    pub fn new(
116        env_init: RLC::EnvInit,
117        agent: RLC::Policy,
118        eval: bool,
119        deterministic: bool,
120    ) -> Self {
121        let mut env = env_init.init();
122        env.reset();
123
124        Self {
125            env,
126            eval,
127            agent: agent.clone(),
128            deterministic,
129            current_reward: 0.0,
130            run_num: 0,
131            step_num: 0,
132            _backend: PhantomData,
133        }
134    }
135}
136
137impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvBaseLoop<BT, RLC>
138where
139    BT: Backend,
140    RLC: RLComponentsTypes,
141{
142    fn run_steps(
143        &mut self,
144        num_steps: usize,
145        processor: &mut RLEventProcessorType<RLC>,
146        interrupter: &Interrupter,
147        progress: &mut Progress,
148    ) -> Vec<RLTimeStep<BT, RLC>> {
149        let mut items = vec![];
150        let device = Default::default();
151        for _ in 0..num_steps {
152            let state = self.env.state();
153            let (action, context) = self.agent.action(state.clone().into(), self.deterministic);
154
155            let step_result = self.env.step(RLC::Action::from(action.clone()));
156
157            self.current_reward += step_result.reward;
158            self.step_num += 1;
159
160            let transition = Transition::new(
161                state.clone(),
162                step_result.next_state,
163                RLC::Action::from(action),
164                Tensor::from_data([step_result.reward], &device),
165                Tensor::from_data(
166                    [(step_result.done || step_result.truncated) as i32 as f64],
167                    &device,
168                ),
169            );
170            items.push(TimeStep {
171                env_id: 0,
172                transition,
173                done: step_result.done,
174                ep_len: self.step_num,
175                cum_reward: self.current_reward,
176                action_context: context[0].clone(),
177            });
178
179            if !self.eval {
180                progress.items_processed += 1;
181                processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
182                    context[0].clone(),
183                    progress.clone(),
184                    None,
185                )));
186
187                if step_result.done {
188                    processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
189                        EpisodeSummary {
190                            episode_length: self.step_num,
191                            cum_reward: self.current_reward,
192                        },
193                        progress.clone(),
194                        None,
195                    )));
196                }
197            }
198
199            if interrupter.should_stop() {
200                break;
201            }
202
203            if step_result.done || step_result.truncated {
204                self.env.reset();
205                self.current_reward = 0.;
206                self.step_num = 0;
207                self.run_num += 1;
208            }
209        }
210        items
211    }
212
213    fn update_policy(&mut self, update: RLC::PolicyState) {
214        self.agent.update(update);
215    }
216
217    fn run_episodes(
218        &mut self,
219        num_episodes: usize,
220        processor: &mut RLEventProcessorType<RLC>,
221        interrupter: &Interrupter,
222        progress: &mut Progress,
223    ) -> Vec<RLTrajectory<BT, RLC>> {
224        self.env.reset();
225
226        let mut items = vec![];
227        for ep in 0..num_episodes {
228            let mut steps = vec![];
229            loop {
230                let step = self.run_steps(1, processor, interrupter, progress)[0].clone();
231                steps.push(step.clone());
232
233                if self.eval {
234                    processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
235                        step.action_context.clone(),
236                        Progress::new(steps.len() + 1, steps.len() + 1),
237                        None,
238                    )));
239
240                    if step.done {
241                        processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
242                            EvaluationItem::new(
243                                EpisodeSummary {
244                                    episode_length: step.ep_len,
245                                    cum_reward: step.cum_reward,
246                                },
247                                Progress::new(ep + 1, num_episodes),
248                                None,
249                            ),
250                        ));
251                    }
252                }
253
254                if interrupter.should_stop() || step.done {
255                    break;
256                }
257            }
258            items.push(Trajectory::new(steps));
259
260            if interrupter.should_stop() {
261                break;
262            }
263        }
264        items
265    }
266
267    fn policy(&self) -> RLC::PolicyState {
268        self.agent.state()
269    }
270}
271
272#[cfg(test)]
273#[allow(clippy::needless_range_loop)]
274mod tests {
275    use crate::{AsyncProcessorTraining, TestBackend};
276
277    use crate::learner::tests::{
278        MockEnvInit, MockPolicy, MockPolicyState, MockProcessor, MockRLComponents,
279    };
280
281    use super::*;
282
283    fn setup(
284        state: usize,
285        eval: bool,
286        deterministic: bool,
287    ) -> AgentEnvBaseLoop<TestBackend, MockRLComponents> {
288        let env_init = MockEnvInit;
289        let agent = MockPolicy(state);
290        AgentEnvBaseLoop::<TestBackend, MockRLComponents>::new(env_init, agent, eval, deterministic)
291    }
292
293    #[test]
294    fn test_policy_returns_agent_state() {
295        let runner = setup(1000, false, false);
296        let policy_state = runner.policy();
297        assert_eq!(policy_state.0, 1000);
298    }
299
300    #[test]
301    fn test_update_policy() {
302        let mut runner = setup(0, false, false);
303
304        runner.update_policy(MockPolicyState(1));
305        assert_eq!(runner.policy().0, 1);
306    }
307
308    #[test]
309    fn run_steps_returns_requested_number() {
310        let mut runner = setup(0, false, false);
311        let mut processor = AsyncProcessorTraining::new(MockProcessor);
312        let interrupter = Interrupter::new();
313        let mut progress = Progress {
314            items_processed: 0,
315            items_total: 1,
316        };
317
318        let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);
319        assert_eq!(steps.len(), 1);
320        let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
321        assert_eq!(steps.len(), 8);
322    }
323
324    #[test]
325    fn run_episodes_returns_requested_number() {
326        let mut runner = setup(0, false, false);
327        let mut processor = AsyncProcessorTraining::new(MockProcessor);
328        let interrupter = Interrupter::new();
329        let mut progress = Progress {
330            items_processed: 0,
331            items_total: 1,
332        };
333
334        let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);
335        assert_eq!(trajectories.len(), 1);
336        assert_ne!(trajectories[0].timesteps.len(), 0);
337        let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
338        assert_eq!(trajectories.len(), 8);
339        for i in 0..8 {
340            assert_ne!(trajectories[i].timesteps.len(), 0);
341        }
342    }
343}