Skip to main content

burn_train/learner/rl/env_runner/
async_runner.rs

1use rand::prelude::SliceRandom;
2use std::{
3    sync::mpsc::{Receiver, Sender},
4    thread::spawn,
5};
6
7use burn_core::{Tensor, data::dataloader::Progress, prelude::Backend, tensor::Device};
8use burn_rl::EnvironmentInit;
9use burn_rl::Policy;
10use burn_rl::Transition;
11use burn_rl::{AsyncPolicy, Environment};
12
13use crate::{
14    AgentEnvLoop, AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
15    Interrupter, RLComponentsTypes, RLEvent, RLEventProcessorType, RLTimeStep, RLTrajectory,
16    RlPolicy, TimeStep, Trajectory,
17};
18
19enum RequestMessage {
20    Step(),
21    Episode(),
22}
23
24/// Configuration for an async agent/environment loop.
25pub struct AsyncAgentEnvLoopConfig {
26    /// If the loop is used for evaluation (as opposed to training).
27    pub eval: bool,
28    /// If the agent should take action deterministically.
29    pub deterministic: bool,
30    /// An arbitrary ID for the loop.
31    pub id: usize,
32}
33
34/// An asynchronous agent/environement interface.
35pub struct AgentEnvAsyncLoop<BT: Backend, RLC: RLComponentsTypes> {
36    eval: bool,
37    agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
38    transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
39    trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
40    request_sender: Sender<RequestMessage>,
41}
42
43impl<BT: Backend, RLC: RLComponentsTypes> AgentEnvAsyncLoop<BT, RLC> {
44    /// Create a new asynchronous runner.
45    ///
46    /// # Arguments
47    ///
48    /// * `env_init` - A function returning an environment instance.
49    /// * `agent` - An [AsyncPolicy](AsyncPolicy) taking actions in the loop.
50    /// * `config` - An [AsyncAgentEnvLoopConfig](AsyncAgentEnvLoopConfig).
51    /// * `transition_sender` - Optional sender for transitions if you want to drive the requests from outside of the loop instance.
52    /// * `trajectory_sender` - Optional sender for trajectories if you want to drive the requests from outside of the loop instance.
53    ///
54    /// # Returns
55    ///
56    /// An async Agent/Environement loop.
57    pub fn new(
58        env_init: RLC::EnvInit,
59        agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
60        config: AsyncAgentEnvLoopConfig,
61        transition_device: &Device<BT>,
62        transition_sender: Option<Sender<RLTimeStep<BT, RLC>>>,
63        trajectory_sender: Option<Sender<RLTrajectory<BT, RLC>>>,
64    ) -> Self {
65        let (loop_transition_sender, transition_receiver) = std::sync::mpsc::channel();
66        let (loop_trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
67        let (request_sender, request_receiver) = std::sync::mpsc::channel();
68        let loop_transition_sender = transition_sender.unwrap_or(loop_transition_sender);
69        let loop_trajectory_sender = trajectory_sender.unwrap_or(loop_trajectory_sender);
70
71        let device = transition_device.clone();
72        let mut loop_agent = agent.clone();
73        let eval = config.eval;
74
75        let mut current_steps = vec![];
76        let mut current_reward = 0.0;
77        let mut step_num = 0;
78        spawn(move || {
79            let mut env = env_init.init();
80            env.reset();
81
82            let mut request_episode = false;
83            loop {
84                let state = env.state();
85                let (action, context) =
86                    loop_agent.action(state.clone().into(), config.deterministic);
87
88                let env_action = RLC::Action::from(action);
89                let step_result = env.step(env_action.clone());
90
91                current_reward += step_result.reward;
92                step_num += 1;
93
94                let transition = Transition::new(
95                    state.clone(),
96                    step_result.next_state,
97                    env_action,
98                    Tensor::from_data([step_result.reward], &device),
99                    Tensor::from_data(
100                        [(step_result.done || step_result.truncated) as i32 as f64],
101                        &device,
102                    ),
103                );
104
105                if !request_episode {
106                    loop_agent.decrement_agents(1);
107                    let request = match request_receiver.recv() {
108                        Ok(req) => req,
109                        Err(err) => {
110                            log::error!("Error in env runner : {}", err);
111                            break;
112                        }
113                    };
114                    loop_agent.increment_agents(1);
115
116                    match request {
117                        RequestMessage::Step() => (),
118                        RequestMessage::Episode() => request_episode = true,
119                    }
120                }
121
122                let time_step = TimeStep {
123                    env_id: config.id,
124                    transition,
125                    done: step_result.done,
126                    ep_len: step_num,
127                    cum_reward: current_reward,
128                    action_context: context[0].clone(),
129                };
130                current_steps.push(time_step.clone());
131
132                if !request_episode && let Err(err) = loop_transition_sender.send(time_step) {
133                    log::error!("Error in env runner : {}", err);
134                    break;
135                }
136
137                if step_result.done || step_result.truncated {
138                    if request_episode {
139                        request_episode = false;
140                        loop_trajectory_sender
141                            .send(Trajectory {
142                                timesteps: current_steps.clone(),
143                            })
144                            .expect("Can send trajectory to main thread.");
145                    }
146                    current_steps.clear();
147
148                    env.reset();
149                    current_reward = 0.;
150                    step_num = 0;
151                }
152            }
153        });
154
155        Self {
156            eval,
157            agent,
158            transition_receiver,
159            trajectory_receiver,
160            request_sender,
161        }
162    }
163}
164
165impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvAsyncLoop<BT, RLC>
166where
167    BT: Backend,
168    RLC: RLComponentsTypes,
169{
170    fn run_steps(
171        &mut self,
172        num_steps: usize,
173        processor: &mut RLEventProcessorType<RLC>,
174        interrupter: &Interrupter,
175        progress: &mut Progress,
176    ) -> Vec<RLTimeStep<BT, RLC>> {
177        let mut items = vec![];
178        for _ in 0..num_steps {
179            self.request_sender
180                .send(RequestMessage::Step())
181                .expect("Can request transitions.");
182            let transition = self
183                .transition_receiver
184                .recv()
185                .expect("Can receive transitions.");
186            items.push(transition.clone());
187
188            if !self.eval {
189                progress.items_processed += 1;
190                processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
191                    transition.action_context,
192                    progress.clone(),
193                    None,
194                )));
195
196                if transition.done {
197                    processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
198                        EpisodeSummary {
199                            episode_length: transition.ep_len,
200                            cum_reward: transition.cum_reward,
201                        },
202                        progress.clone(),
203                        None,
204                    )));
205                }
206            }
207
208            if interrupter.should_stop() {
209                break;
210            }
211        }
212        items
213    }
214
215    fn run_episodes(
216        &mut self,
217        num_episodes: usize,
218        processor: &mut RLEventProcessorType<RLC>,
219        interrupter: &Interrupter,
220        _progress: &mut Progress,
221    ) -> Vec<RLTrajectory<BT, RLC>> {
222        let mut items = vec![];
223        self.agent.increment_agents(1);
224        for episode_num in 0..num_episodes {
225            self.request_sender
226                .send(RequestMessage::Episode())
227                .expect("Can request episodes.");
228            let trajectory = self
229                .trajectory_receiver
230                .recv()
231                .expect("Main thread can receive trajectory.");
232
233            for (i, step) in trajectory.timesteps.iter().enumerate() {
234                // TODO : clean this.
235                if self.eval {
236                    processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
237                        step.action_context.clone(),
238                        Progress::new(i, i),
239                        None,
240                    )));
241
242                    if step.done {
243                        processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
244                            EvaluationItem::new(
245                                EpisodeSummary {
246                                    episode_length: step.ep_len,
247                                    cum_reward: step.cum_reward,
248                                },
249                                Progress::new(episode_num + 1, num_episodes),
250                                None,
251                            ),
252                        ));
253                    }
254                } else {
255                    processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
256                        step.action_context.clone(),
257                        Progress::new(i, i),
258                        None,
259                    )));
260
261                    if step.done {
262                        processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
263                            EpisodeSummary {
264                                episode_length: step.ep_len,
265                                cum_reward: step.cum_reward,
266                            },
267                            Progress::new(episode_num + 1, num_episodes),
268                            None,
269                        )));
270                    }
271                }
272            }
273
274            items.push(trajectory);
275            if interrupter.should_stop() {
276                break;
277            }
278        }
279        self.agent.decrement_agents(1);
280        items
281    }
282
283    fn update_policy(&mut self, update: RLC::PolicyState) {
284        self.agent.update(update);
285    }
286
287    fn policy(&self) -> RLC::PolicyState {
288        self.agent.state()
289    }
290}
291
292/// An asynchronous runner for multiple agent/environement interfaces.
293pub struct MultiAgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
294    num_envs: usize,
295    eval: bool,
296    agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
297    transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
298    trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
299    request_senders: Vec<Sender<RequestMessage>>,
300}
301
302impl<BT: Backend, RLC: RLComponentsTypes> MultiAgentEnvLoop<BT, RLC> {
303    /// Create a new asynchronous runner for multiple agent/environement interfaces.
304    pub fn new(
305        num_envs: usize,
306        env_init: RLC::EnvInit,
307        agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
308        eval: bool,
309        deterministic: bool,
310        device: &Device<BT>,
311    ) -> Self {
312        let (transition_sender, transition_receiver) = std::sync::mpsc::channel();
313        let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
314        let mut request_senders = vec![];
315
316        // Double batching : The environments are always one step ahead of requests. This allows inference for the first batch of steps.
317        agent.increment_agents(num_envs);
318
319        for i in 0..num_envs {
320            let config = AsyncAgentEnvLoopConfig {
321                eval,
322                deterministic,
323                id: i,
324            };
325            let runner = AgentEnvAsyncLoop::<BT, RLC>::new(
326                env_init.clone(),
327                agent.clone(),
328                config,
329                &device.clone(),
330                Some(transition_sender.clone()),
331                Some(trajectory_sender.clone()),
332            );
333            request_senders.push(runner.request_sender.clone());
334        }
335
336        // Double batching : The environments are always one step ahead.
337        request_senders.iter().for_each(|s| {
338            s.send(RequestMessage::Step())
339                .expect("Main thread can send step requests.")
340        });
341
342        Self {
343            num_envs,
344            eval,
345            agent: agent.clone(),
346            transition_receiver,
347            trajectory_receiver,
348            request_senders,
349        }
350    }
351}
352
353impl<BT, RLC> AgentEnvLoop<BT, RLC> for MultiAgentEnvLoop<BT, RLC>
354where
355    BT: Backend,
356    RLC: RLComponentsTypes,
357{
358    fn run_steps(
359        &mut self,
360        num_steps: usize,
361        processor: &mut RLEventProcessorType<RLC>,
362        interrupter: &Interrupter,
363        progress: &mut Progress,
364    ) -> Vec<RLTimeStep<BT, RLC>> {
365        let mut items = vec![];
366        for _ in 0..num_steps {
367            let transition = self
368                .transition_receiver
369                .recv()
370                .expect("Can receive transitions.");
371            items.push(transition.clone());
372
373            self.request_senders[transition.env_id]
374                .send(RequestMessage::Step())
375                .expect("Main thread can request steps.");
376
377            if !self.eval {
378                progress.items_processed += 1;
379                processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
380                    transition.action_context,
381                    progress.clone(),
382                    None,
383                )));
384
385                if transition.done {
386                    processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
387                        EpisodeSummary {
388                            episode_length: transition.ep_len,
389                            cum_reward: transition.cum_reward,
390                        },
391                        progress.clone(),
392                        None,
393                    )));
394                }
395            }
396
397            if interrupter.should_stop() {
398                break;
399            }
400        }
401        items
402    }
403
404    fn update_policy(&mut self, update: RLC::PolicyState) {
405        self.agent.update(update);
406    }
407
408    fn run_episodes(
409        &mut self,
410        num_episodes: usize,
411        processor: &mut RLEventProcessorType<RLC>,
412        interrupter: &Interrupter,
413        _progress: &mut Progress,
414    ) -> Vec<RLTrajectory<BT, RLC>> {
415        // Send `num_episodes` initial requests.
416        let mut idx = vec![];
417        if num_episodes < self.num_envs {
418            let mut rng = rand::rng();
419            let mut vec: Vec<usize> = (0..self.num_envs).collect();
420            vec.shuffle(&mut rng);
421            idx = vec.into_iter().take(num_episodes).collect();
422        } else {
423            idx = (0..self.num_envs).collect();
424        }
425        let num_requests = self.num_envs.min(num_episodes);
426        idx.into_iter().for_each(|i| {
427            self.request_senders[i]
428                .send(RequestMessage::Episode())
429                .expect("Main thread can request steps.");
430        });
431
432        let mut items = vec![];
433        for episode_num in 0..num_episodes {
434            let trajectory = self
435                .trajectory_receiver
436                .recv()
437                .expect("Can receive trajectory.");
438            items.push(trajectory.clone());
439            if items.len() + num_requests <= num_episodes {
440                self.request_senders[trajectory.timesteps[0].env_id]
441                    .send(RequestMessage::Episode())
442                    .expect("Main thread can request steps.");
443            }
444            for (i, step) in trajectory.timesteps.iter().enumerate() {
445                if self.eval {
446                    processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
447                        step.action_context.clone(),
448                        Progress::new(i, i),
449                        None,
450                    )));
451
452                    if step.done {
453                        processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
454                            EvaluationItem::new(
455                                EpisodeSummary {
456                                    episode_length: step.ep_len,
457                                    cum_reward: step.cum_reward,
458                                },
459                                Progress::new(episode_num + 1, num_episodes),
460                                None,
461                            ),
462                        ));
463                    }
464                } else {
465                    processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
466                        step.action_context.clone(),
467                        Progress::new(i, i),
468                        None,
469                    )));
470
471                    if step.done {
472                        processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
473                            EpisodeSummary {
474                                episode_length: step.ep_len,
475                                cum_reward: step.cum_reward,
476                            },
477                            Progress::new(episode_num + 1, num_episodes),
478                            None,
479                        )));
480                    }
481                }
482            }
483
484            if interrupter.should_stop() {
485                break;
486            }
487        }
488
489        items
490    }
491
492    fn policy(&self) -> RLC::PolicyState {
493        self.agent.state()
494    }
495}
496
497#[cfg(test)]
498#[allow(clippy::needless_range_loop)]
499mod tests {
500    use burn_core::data::dataloader::Progress;
501    use burn_rl::AsyncPolicy;
502
503    use crate::learner::rl::env_runner::async_runner::AsyncAgentEnvLoopConfig;
504    use crate::learner::rl::env_runner::base::AgentEnvLoop;
505    use crate::learner::tests::{MockPolicyState, MockProcessor};
506    use crate::{
507        AgentEnvAsyncLoop, TestBackend,
508        learner::tests::{MockEnvInit, MockPolicy, MockRLComponents},
509    };
510    use crate::{AsyncProcessorTraining, Interrupter, MultiAgentEnvLoop};
511
512    fn setup_async_loop(
513        state: usize,
514        eval: bool,
515        deterministic: bool,
516    ) -> AgentEnvAsyncLoop<TestBackend, MockRLComponents> {
517        let env_init = MockEnvInit;
518        let agent = MockPolicy(state);
519        let config = AsyncAgentEnvLoopConfig {
520            eval,
521            deterministic,
522            id: 0,
523        };
524        AgentEnvAsyncLoop::<TestBackend, MockRLComponents>::new(
525            env_init,
526            AsyncPolicy::new(1, agent),
527            config,
528            &Default::default(),
529            None,
530            None,
531        )
532    }
533
534    fn setup_multi_loop(
535        num_envs: usize,
536        autobatch_size: usize,
537        state: usize,
538        eval: bool,
539        deterministic: bool,
540    ) -> MultiAgentEnvLoop<TestBackend, MockRLComponents> {
541        let env_init = MockEnvInit;
542        let agent = MockPolicy(state);
543        MultiAgentEnvLoop::<TestBackend, MockRLComponents>::new(
544            num_envs,
545            env_init,
546            AsyncPolicy::new(autobatch_size, agent),
547            eval,
548            deterministic,
549            &Default::default(),
550        )
551    }
552
553    #[test]
554    fn test_policy_async_loop() {
555        let runner = setup_async_loop(1000, false, false);
556        let policy_state = runner.policy();
557        assert_eq!(policy_state.0, 1000);
558    }
559
560    #[test]
561    fn test_update_policy_async_loop() {
562        let mut runner = setup_async_loop(0, false, false);
563
564        runner.update_policy(MockPolicyState(1));
565        assert_eq!(runner.policy().0, 1);
566    }
567
568    #[test]
569    fn run_steps_returns_requested_number_async_loop() {
570        let mut runner = setup_async_loop(0, false, false);
571        let mut processor = AsyncProcessorTraining::new(MockProcessor);
572        let interrupter = Interrupter::new();
573        let mut progress = Progress {
574            items_processed: 0,
575            items_total: 1,
576        };
577
578        let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);
579        assert_eq!(steps.len(), 1);
580        let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
581        assert_eq!(steps.len(), 8);
582    }
583
584    #[test]
585    fn run_episodes_returns_requested_number_async_loop() {
586        let mut runner = setup_async_loop(0, false, false);
587        let mut processor = AsyncProcessorTraining::new(MockProcessor);
588        let interrupter = Interrupter::new();
589        let mut progress = Progress {
590            items_processed: 0,
591            items_total: 1,
592        };
593
594        let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);
595        assert_eq!(trajectories.len(), 1);
596        assert_ne!(trajectories[0].timesteps.len(), 0);
597        let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
598        assert_eq!(trajectories.len(), 8);
599        for i in 0..8 {
600            assert_ne!(trajectories[i].timesteps.len(), 0);
601        }
602    }
603
604    #[test]
605    fn test_policy_multi_loop() {
606        let runner = setup_multi_loop(4, 4, 1000, false, false);
607        let policy_state = runner.policy();
608        assert_eq!(policy_state.0, 1000);
609    }
610
611    #[test]
612    fn test_update_policy_multi_loop() {
613        let mut runner = setup_multi_loop(4, 4, 0, false, false);
614
615        runner.update_policy(MockPolicyState(1));
616        assert_eq!(runner.policy().0, 1);
617    }
618
619    #[test]
620    fn run_steps_returns_requested_number_multi_loop() {
621        fn run_test(num_envs: usize, autobatch_size: usize) {
622            let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);
623            let mut processor = AsyncProcessorTraining::new(MockProcessor);
624            let interrupter = Interrupter::new();
625            let mut progress = Progress {
626                items_processed: 0,
627                items_total: 1,
628            };
629
630            // Kickstart tests by running some steps to make sure it's not a double batching edge case success.
631            let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
632            assert_eq!(steps.len(), 8);
633
634            for i in 0..16 {
635                let steps = runner.run_steps(i, &mut processor, &interrupter, &mut progress);
636                assert_eq!(steps.len(), i);
637            }
638        }
639
640        // num_envs == autobatch_size
641        run_test(1, 1);
642        run_test(4, 4);
643        // num_envs < autobatch_size
644        run_test(1, 2);
645        run_test(1, 3);
646        run_test(2, 3);
647        run_test(2, 4);
648        run_test(5, 19);
649        // num_envs > autobatch_size
650        run_test(2, 1);
651        run_test(8, 1);
652        run_test(3, 2);
653        run_test(8, 2);
654        run_test(8, 3);
655        run_test(8, 7);
656    }
657
658    #[test]
659    fn run_episodes_returns_requested_number_multi_loop() {
660        fn run_test(num_envs: usize, autobatch_size: usize) {
661            let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);
662            let mut processor = AsyncProcessorTraining::new(MockProcessor);
663            let interrupter = Interrupter::new();
664            let mut progress = Progress {
665                items_processed: 0,
666                items_total: 1,
667            };
668
669            // Kickstart tests by running some episodes to make sure it's not a double batching edge case success.
670            let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
671            assert_eq!(trajectories.len(), 8);
672            for j in 0..8 {
673                assert_ne!(trajectories[j].timesteps.len(), 0);
674            }
675
676            for i in 0..16 {
677                let trajectories =
678                    runner.run_episodes(i, &mut processor, &interrupter, &mut progress);
679                assert_eq!(trajectories.len(), i);
680                for j in 0..i {
681                    assert_ne!(trajectories[j].timesteps.len(), 0);
682                }
683            }
684        }
685
686        // num_envs == autobatch_size
687        run_test(1, 1);
688        run_test(4, 4);
689        // num_envs < autobatch_size
690        run_test(1, 2);
691        run_test(1, 3);
692        run_test(2, 3);
693        run_test(2, 4);
694        run_test(5, 19);
695        // num_envs > autobatch_size
696        run_test(2, 1);
697        run_test(8, 1);
698        run_test(3, 2);
699        run_test(8, 2);
700        run_test(8, 3);
701        run_test(8, 7);
702    }
703}