1use rand::prelude::SliceRandom;
2use std::{
3 sync::mpsc::{Receiver, Sender},
4 thread::spawn,
5};
6
7use burn_core::{Tensor, data::dataloader::Progress, prelude::Backend, tensor::Device};
8use burn_rl::EnvironmentInit;
9use burn_rl::Policy;
10use burn_rl::Transition;
11use burn_rl::{AsyncPolicy, Environment};
12
13use crate::{
14 AgentEnvLoop, AgentEvaluationEvent, EpisodeSummary, EvaluationItem, EventProcessorTraining,
15 Interrupter, RLComponentsTypes, RLEvent, RLEventProcessorType, RLTimeStep, RLTrajectory,
16 RlPolicy, TimeStep, Trajectory,
17};
18
19enum RequestMessage {
20 Step(),
21 Episode(),
22}
23
24pub struct AsyncAgentEnvLoopConfig {
26 pub eval: bool,
28 pub deterministic: bool,
30 pub id: usize,
32}
33
34pub struct AgentEnvAsyncLoop<BT: Backend, RLC: RLComponentsTypes> {
36 eval: bool,
37 agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
38 transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
39 trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
40 request_sender: Sender<RequestMessage>,
41}
42
43impl<BT: Backend, RLC: RLComponentsTypes> AgentEnvAsyncLoop<BT, RLC> {
44 pub fn new(
58 env_init: RLC::EnvInit,
59 agent: AsyncPolicy<RLC::Backend, RlPolicy<RLC>>,
60 config: AsyncAgentEnvLoopConfig,
61 transition_device: &Device<BT>,
62 transition_sender: Option<Sender<RLTimeStep<BT, RLC>>>,
63 trajectory_sender: Option<Sender<RLTrajectory<BT, RLC>>>,
64 ) -> Self {
65 let (loop_transition_sender, transition_receiver) = std::sync::mpsc::channel();
66 let (loop_trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
67 let (request_sender, request_receiver) = std::sync::mpsc::channel();
68 let loop_transition_sender = transition_sender.unwrap_or(loop_transition_sender);
69 let loop_trajectory_sender = trajectory_sender.unwrap_or(loop_trajectory_sender);
70
71 let device = transition_device.clone();
72 let mut loop_agent = agent.clone();
73 let eval = config.eval;
74
75 let mut current_steps = vec![];
76 let mut current_reward = 0.0;
77 let mut step_num = 0;
78 spawn(move || {
79 let mut env = env_init.init();
80 env.reset();
81
82 let mut request_episode = false;
83 loop {
84 let state = env.state();
85 let (action, context) =
86 loop_agent.action(state.clone().into(), config.deterministic);
87
88 let env_action = RLC::Action::from(action);
89 let step_result = env.step(env_action.clone());
90
91 current_reward += step_result.reward;
92 step_num += 1;
93
94 let transition = Transition::new(
95 state.clone(),
96 step_result.next_state,
97 env_action,
98 Tensor::from_data([step_result.reward], &device),
99 Tensor::from_data(
100 [(step_result.done || step_result.truncated) as i32 as f64],
101 &device,
102 ),
103 );
104
105 if !request_episode {
106 loop_agent.decrement_agents(1);
107 let request = match request_receiver.recv() {
108 Ok(req) => req,
109 Err(err) => {
110 log::error!("Error in env runner : {}", err);
111 break;
112 }
113 };
114 loop_agent.increment_agents(1);
115
116 match request {
117 RequestMessage::Step() => (),
118 RequestMessage::Episode() => request_episode = true,
119 }
120 }
121
122 let time_step = TimeStep {
123 env_id: config.id,
124 transition,
125 done: step_result.done,
126 ep_len: step_num,
127 cum_reward: current_reward,
128 action_context: context[0].clone(),
129 };
130 current_steps.push(time_step.clone());
131
132 if !request_episode && let Err(err) = loop_transition_sender.send(time_step) {
133 log::error!("Error in env runner : {}", err);
134 break;
135 }
136
137 if step_result.done || step_result.truncated {
138 if request_episode {
139 request_episode = false;
140 loop_trajectory_sender
141 .send(Trajectory {
142 timesteps: current_steps.clone(),
143 })
144 .expect("Can send trajectory to main thread.");
145 }
146 current_steps.clear();
147
148 env.reset();
149 current_reward = 0.;
150 step_num = 0;
151 }
152 }
153 });
154
155 Self {
156 eval,
157 agent,
158 transition_receiver,
159 trajectory_receiver,
160 request_sender,
161 }
162 }
163}
164
165impl<BT, RLC> AgentEnvLoop<BT, RLC> for AgentEnvAsyncLoop<BT, RLC>
166where
167 BT: Backend,
168 RLC: RLComponentsTypes,
169{
170 fn run_steps(
171 &mut self,
172 num_steps: usize,
173 processor: &mut RLEventProcessorType<RLC>,
174 interrupter: &Interrupter,
175 progress: &mut Progress,
176 ) -> Vec<RLTimeStep<BT, RLC>> {
177 let mut items = vec![];
178 for _ in 0..num_steps {
179 self.request_sender
180 .send(RequestMessage::Step())
181 .expect("Can request transitions.");
182 let transition = self
183 .transition_receiver
184 .recv()
185 .expect("Can receive transitions.");
186 items.push(transition.clone());
187
188 if !self.eval {
189 progress.items_processed += 1;
190 processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
191 transition.action_context,
192 progress.clone(),
193 None,
194 )));
195
196 if transition.done {
197 processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
198 EpisodeSummary {
199 episode_length: transition.ep_len,
200 cum_reward: transition.cum_reward,
201 },
202 progress.clone(),
203 None,
204 )));
205 }
206 }
207
208 if interrupter.should_stop() {
209 break;
210 }
211 }
212 items
213 }
214
215 fn run_episodes(
216 &mut self,
217 num_episodes: usize,
218 processor: &mut RLEventProcessorType<RLC>,
219 interrupter: &Interrupter,
220 _progress: &mut Progress,
221 ) -> Vec<RLTrajectory<BT, RLC>> {
222 let mut items = vec![];
223 self.agent.increment_agents(1);
224 for episode_num in 0..num_episodes {
225 self.request_sender
226 .send(RequestMessage::Episode())
227 .expect("Can request episodes.");
228 let trajectory = self
229 .trajectory_receiver
230 .recv()
231 .expect("Main thread can receive trajectory.");
232
233 for (i, step) in trajectory.timesteps.iter().enumerate() {
234 if self.eval {
236 processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
237 step.action_context.clone(),
238 Progress::new(i, i),
239 None,
240 )));
241
242 if step.done {
243 processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
244 EvaluationItem::new(
245 EpisodeSummary {
246 episode_length: step.ep_len,
247 cum_reward: step.cum_reward,
248 },
249 Progress::new(episode_num + 1, num_episodes),
250 None,
251 ),
252 ));
253 }
254 } else {
255 processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
256 step.action_context.clone(),
257 Progress::new(i, i),
258 None,
259 )));
260
261 if step.done {
262 processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
263 EpisodeSummary {
264 episode_length: step.ep_len,
265 cum_reward: step.cum_reward,
266 },
267 Progress::new(episode_num + 1, num_episodes),
268 None,
269 )));
270 }
271 }
272 }
273
274 items.push(trajectory);
275 if interrupter.should_stop() {
276 break;
277 }
278 }
279 self.agent.decrement_agents(1);
280 items
281 }
282
283 fn update_policy(&mut self, update: RLC::PolicyState) {
284 self.agent.update(update);
285 }
286
287 fn policy(&self) -> RLC::PolicyState {
288 self.agent.state()
289 }
290}
291
292pub struct MultiAgentEnvLoop<BT: Backend, RLC: RLComponentsTypes> {
294 num_envs: usize,
295 eval: bool,
296 agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
297 transition_receiver: Receiver<RLTimeStep<BT, RLC>>,
298 trajectory_receiver: Receiver<RLTrajectory<BT, RLC>>,
299 request_senders: Vec<Sender<RequestMessage>>,
300}
301
302impl<BT: Backend, RLC: RLComponentsTypes> MultiAgentEnvLoop<BT, RLC> {
303 pub fn new(
305 num_envs: usize,
306 env_init: RLC::EnvInit,
307 agent: AsyncPolicy<RLC::Backend, RLC::Policy>,
308 eval: bool,
309 deterministic: bool,
310 device: &Device<BT>,
311 ) -> Self {
312 let (transition_sender, transition_receiver) = std::sync::mpsc::channel();
313 let (trajectory_sender, trajectory_receiver) = std::sync::mpsc::channel();
314 let mut request_senders = vec![];
315
316 agent.increment_agents(num_envs);
318
319 for i in 0..num_envs {
320 let config = AsyncAgentEnvLoopConfig {
321 eval,
322 deterministic,
323 id: i,
324 };
325 let runner = AgentEnvAsyncLoop::<BT, RLC>::new(
326 env_init.clone(),
327 agent.clone(),
328 config,
329 &device.clone(),
330 Some(transition_sender.clone()),
331 Some(trajectory_sender.clone()),
332 );
333 request_senders.push(runner.request_sender.clone());
334 }
335
336 request_senders.iter().for_each(|s| {
338 s.send(RequestMessage::Step())
339 .expect("Main thread can send step requests.")
340 });
341
342 Self {
343 num_envs,
344 eval,
345 agent: agent.clone(),
346 transition_receiver,
347 trajectory_receiver,
348 request_senders,
349 }
350 }
351}
352
353impl<BT, RLC> AgentEnvLoop<BT, RLC> for MultiAgentEnvLoop<BT, RLC>
354where
355 BT: Backend,
356 RLC: RLComponentsTypes,
357{
358 fn run_steps(
359 &mut self,
360 num_steps: usize,
361 processor: &mut RLEventProcessorType<RLC>,
362 interrupter: &Interrupter,
363 progress: &mut Progress,
364 ) -> Vec<RLTimeStep<BT, RLC>> {
365 let mut items = vec![];
366 for _ in 0..num_steps {
367 let transition = self
368 .transition_receiver
369 .recv()
370 .expect("Can receive transitions.");
371 items.push(transition.clone());
372
373 self.request_senders[transition.env_id]
374 .send(RequestMessage::Step())
375 .expect("Main thread can request steps.");
376
377 if !self.eval {
378 progress.items_processed += 1;
379 processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
380 transition.action_context,
381 progress.clone(),
382 None,
383 )));
384
385 if transition.done {
386 processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
387 EpisodeSummary {
388 episode_length: transition.ep_len,
389 cum_reward: transition.cum_reward,
390 },
391 progress.clone(),
392 None,
393 )));
394 }
395 }
396
397 if interrupter.should_stop() {
398 break;
399 }
400 }
401 items
402 }
403
404 fn update_policy(&mut self, update: RLC::PolicyState) {
405 self.agent.update(update);
406 }
407
408 fn run_episodes(
409 &mut self,
410 num_episodes: usize,
411 processor: &mut RLEventProcessorType<RLC>,
412 interrupter: &Interrupter,
413 _progress: &mut Progress,
414 ) -> Vec<RLTrajectory<BT, RLC>> {
415 let mut idx = vec![];
417 if num_episodes < self.num_envs {
418 let mut rng = rand::rng();
419 let mut vec: Vec<usize> = (0..self.num_envs).collect();
420 vec.shuffle(&mut rng);
421 idx = vec.into_iter().take(num_episodes).collect();
422 } else {
423 idx = (0..self.num_envs).collect();
424 }
425 let num_requests = self.num_envs.min(num_episodes);
426 idx.into_iter().for_each(|i| {
427 self.request_senders[i]
428 .send(RequestMessage::Episode())
429 .expect("Main thread can request steps.");
430 });
431
432 let mut items = vec![];
433 for episode_num in 0..num_episodes {
434 let trajectory = self
435 .trajectory_receiver
436 .recv()
437 .expect("Can receive trajectory.");
438 items.push(trajectory.clone());
439 if items.len() + num_requests <= num_episodes {
440 self.request_senders[trajectory.timesteps[0].env_id]
441 .send(RequestMessage::Episode())
442 .expect("Main thread can request steps.");
443 }
444 for (i, step) in trajectory.timesteps.iter().enumerate() {
445 if self.eval {
446 processor.process_valid(AgentEvaluationEvent::TimeStep(EvaluationItem::new(
447 step.action_context.clone(),
448 Progress::new(i, i),
449 None,
450 )));
451
452 if step.done {
453 processor.process_valid(AgentEvaluationEvent::EpisodeEnd(
454 EvaluationItem::new(
455 EpisodeSummary {
456 episode_length: step.ep_len,
457 cum_reward: step.cum_reward,
458 },
459 Progress::new(episode_num + 1, num_episodes),
460 None,
461 ),
462 ));
463 }
464 } else {
465 processor.process_train(RLEvent::TimeStep(EvaluationItem::new(
466 step.action_context.clone(),
467 Progress::new(i, i),
468 None,
469 )));
470
471 if step.done {
472 processor.process_train(RLEvent::EpisodeEnd(EvaluationItem::new(
473 EpisodeSummary {
474 episode_length: step.ep_len,
475 cum_reward: step.cum_reward,
476 },
477 Progress::new(episode_num + 1, num_episodes),
478 None,
479 )));
480 }
481 }
482 }
483
484 if interrupter.should_stop() {
485 break;
486 }
487 }
488
489 items
490 }
491
492 fn policy(&self) -> RLC::PolicyState {
493 self.agent.state()
494 }
495}
496
497#[cfg(test)]
498#[allow(clippy::needless_range_loop)]
499mod tests {
500 use burn_core::data::dataloader::Progress;
501 use burn_rl::AsyncPolicy;
502
503 use crate::learner::rl::env_runner::async_runner::AsyncAgentEnvLoopConfig;
504 use crate::learner::rl::env_runner::base::AgentEnvLoop;
505 use crate::learner::tests::{MockPolicyState, MockProcessor};
506 use crate::{
507 AgentEnvAsyncLoop, TestBackend,
508 learner::tests::{MockEnvInit, MockPolicy, MockRLComponents},
509 };
510 use crate::{AsyncProcessorTraining, Interrupter, MultiAgentEnvLoop};
511
512 fn setup_async_loop(
513 state: usize,
514 eval: bool,
515 deterministic: bool,
516 ) -> AgentEnvAsyncLoop<TestBackend, MockRLComponents> {
517 let env_init = MockEnvInit;
518 let agent = MockPolicy(state);
519 let config = AsyncAgentEnvLoopConfig {
520 eval,
521 deterministic,
522 id: 0,
523 };
524 AgentEnvAsyncLoop::<TestBackend, MockRLComponents>::new(
525 env_init,
526 AsyncPolicy::new(1, agent),
527 config,
528 &Default::default(),
529 None,
530 None,
531 )
532 }
533
534 fn setup_multi_loop(
535 num_envs: usize,
536 autobatch_size: usize,
537 state: usize,
538 eval: bool,
539 deterministic: bool,
540 ) -> MultiAgentEnvLoop<TestBackend, MockRLComponents> {
541 let env_init = MockEnvInit;
542 let agent = MockPolicy(state);
543 MultiAgentEnvLoop::<TestBackend, MockRLComponents>::new(
544 num_envs,
545 env_init,
546 AsyncPolicy::new(autobatch_size, agent),
547 eval,
548 deterministic,
549 &Default::default(),
550 )
551 }
552
553 #[test]
554 fn test_policy_async_loop() {
555 let runner = setup_async_loop(1000, false, false);
556 let policy_state = runner.policy();
557 assert_eq!(policy_state.0, 1000);
558 }
559
560 #[test]
561 fn test_update_policy_async_loop() {
562 let mut runner = setup_async_loop(0, false, false);
563
564 runner.update_policy(MockPolicyState(1));
565 assert_eq!(runner.policy().0, 1);
566 }
567
568 #[test]
569 fn run_steps_returns_requested_number_async_loop() {
570 let mut runner = setup_async_loop(0, false, false);
571 let mut processor = AsyncProcessorTraining::new(MockProcessor);
572 let interrupter = Interrupter::new();
573 let mut progress = Progress {
574 items_processed: 0,
575 items_total: 1,
576 };
577
578 let steps = runner.run_steps(1, &mut processor, &interrupter, &mut progress);
579 assert_eq!(steps.len(), 1);
580 let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
581 assert_eq!(steps.len(), 8);
582 }
583
584 #[test]
585 fn run_episodes_returns_requested_number_async_loop() {
586 let mut runner = setup_async_loop(0, false, false);
587 let mut processor = AsyncProcessorTraining::new(MockProcessor);
588 let interrupter = Interrupter::new();
589 let mut progress = Progress {
590 items_processed: 0,
591 items_total: 1,
592 };
593
594 let trajectories = runner.run_episodes(1, &mut processor, &interrupter, &mut progress);
595 assert_eq!(trajectories.len(), 1);
596 assert_ne!(trajectories[0].timesteps.len(), 0);
597 let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
598 assert_eq!(trajectories.len(), 8);
599 for i in 0..8 {
600 assert_ne!(trajectories[i].timesteps.len(), 0);
601 }
602 }
603
604 #[test]
605 fn test_policy_multi_loop() {
606 let runner = setup_multi_loop(4, 4, 1000, false, false);
607 let policy_state = runner.policy();
608 assert_eq!(policy_state.0, 1000);
609 }
610
611 #[test]
612 fn test_update_policy_multi_loop() {
613 let mut runner = setup_multi_loop(4, 4, 0, false, false);
614
615 runner.update_policy(MockPolicyState(1));
616 assert_eq!(runner.policy().0, 1);
617 }
618
619 #[test]
620 fn run_steps_returns_requested_number_multi_loop() {
621 fn run_test(num_envs: usize, autobatch_size: usize) {
622 let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);
623 let mut processor = AsyncProcessorTraining::new(MockProcessor);
624 let interrupter = Interrupter::new();
625 let mut progress = Progress {
626 items_processed: 0,
627 items_total: 1,
628 };
629
630 let steps = runner.run_steps(8, &mut processor, &interrupter, &mut progress);
632 assert_eq!(steps.len(), 8);
633
634 for i in 0..16 {
635 let steps = runner.run_steps(i, &mut processor, &interrupter, &mut progress);
636 assert_eq!(steps.len(), i);
637 }
638 }
639
640 run_test(1, 1);
642 run_test(4, 4);
643 run_test(1, 2);
645 run_test(1, 3);
646 run_test(2, 3);
647 run_test(2, 4);
648 run_test(5, 19);
649 run_test(2, 1);
651 run_test(8, 1);
652 run_test(3, 2);
653 run_test(8, 2);
654 run_test(8, 3);
655 run_test(8, 7);
656 }
657
658 #[test]
659 fn run_episodes_returns_requested_number_multi_loop() {
660 fn run_test(num_envs: usize, autobatch_size: usize) {
661 let mut runner = setup_multi_loop(num_envs, autobatch_size, 0, false, false);
662 let mut processor = AsyncProcessorTraining::new(MockProcessor);
663 let interrupter = Interrupter::new();
664 let mut progress = Progress {
665 items_processed: 0,
666 items_total: 1,
667 };
668
669 let trajectories = runner.run_episodes(8, &mut processor, &interrupter, &mut progress);
671 assert_eq!(trajectories.len(), 8);
672 for j in 0..8 {
673 assert_ne!(trajectories[j].timesteps.len(), 0);
674 }
675
676 for i in 0..16 {
677 let trajectories =
678 runner.run_episodes(i, &mut processor, &interrupter, &mut progress);
679 assert_eq!(trajectories.len(), i);
680 for j in 0..i {
681 assert_ne!(trajectories[j].timesteps.len(), 0);
682 }
683 }
684 }
685
686 run_test(1, 1);
688 run_test(4, 4);
689 run_test(1, 2);
691 run_test(1, 3);
692 run_test(2, 3);
693 run_test(2, 4);
694 run_test(5, 19);
695 run_test(2, 1);
697 run_test(8, 1);
698 run_test(3, 2);
699 run_test(8, 2);
700 run_test(8, 3);
701 run_test(8, 7);
702 }
703}