burn_train/learner/rl/
strategy.rs1use std::sync::Arc;
2
3use crate::{
4 Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent,
5 RLEventProcessorType, RLResult,
6 metric::{processor::EventProcessorTraining, store::EventStoreClient},
7};
8
9pub struct RLComponents<RLC: RLComponentsTypes> {
11 pub num_steps: usize,
13 pub checkpoint: Option<usize>,
15 pub checkpointer: Option<RLCheckpointer<RLC>>,
17 pub grad_accumulation: Option<usize>,
19 pub interrupter: Interrupter,
21 pub event_processor: RLEventProcessorType<RLC>,
23 pub event_store: Arc<EventStoreClient>,
25 pub summary: Option<LearnerSummaryConfig>,
27}
28
29#[derive(Clone)]
31pub enum RLStrategies<RLC: RLComponentsTypes> {
32 OffPolicyStrategy(OffPolicyConfig),
34 Custom(CustomRLStrategy<RLC>),
36}
37
38pub type CustomRLStrategy<LC> = Arc<dyn RLStrategy<LC>>;
40
41pub trait RLStrategy<RLC: RLComponentsTypes> {
43 fn train(
45 &self,
46 mut learner_agent: RLC::LearningAgent,
47 mut training_components: RLComponents<RLC>,
48 env_init: RLC::EnvInit,
49 ) -> RLResult<RLC::Policy> {
50 let starting_epoch = match training_components.checkpoint {
51 Some(checkpoint) => {
52 if let Some(checkpointer) = &mut training_components.checkpointer {
53 learner_agent = checkpointer.load_checkpoint(
54 learner_agent,
55 &Default::default(),
56 checkpoint,
57 );
58 }
59 checkpoint + 1
60 }
61 None => 1,
62 };
63
64 let summary_config = training_components.summary.clone();
65
66 training_components
68 .event_processor
69 .process_train(RLEvent::Start);
70
71 let (policy, mut event_processor) = self.train_loop(
73 training_components,
74 &mut learner_agent,
75 starting_epoch,
76 env_init,
77 );
78
79 let summary = summary_config.and_then(|summary| summary.init().ok());
80
81 event_processor.process_train(RLEvent::End(summary));
84
85 let renderer = event_processor.renderer();
87
88 RLResult { policy, renderer }
89 }
90
91 fn train_loop(
93 &self,
94 training_components: RLComponents<RLC>,
95 learner_agent: &mut RLC::LearningAgent,
96 starting_epoch: usize,
97 env_init: RLC::EnvInit,
98 ) -> (RLC::Policy, RLEventProcessorType<RLC>);
99}