burn_train/learner/rl/
off_policy.rs1use std::marker::PhantomData;
2
3use crate::{
4 AgentEnvAsyncLoop, AgentEnvLoop, AsyncAgentEnvLoopConfig, EvaluationItem,
5 EventProcessorTraining, MultiAgentEnvLoop, RLComponents, RLComponentsTypes, RLEvent,
6 RLEventProcessorType, RLStrategy,
7};
8use burn_core::{self as burn};
9use burn_core::{config::Config, data::dataloader::Progress};
10use burn_flex::Flex;
11use burn_rl::{AsyncPolicy, Policy, PolicyLearner, SliceAccess, TransitionBuffer};
12
13#[derive(Config, Debug)]
15pub struct OffPolicyConfig {
16 #[config(default = 1)]
18 pub num_envs: usize,
19 #[config(default = 1)]
22 pub autobatch_size: usize,
23 #[config(default = 1024)]
25 pub replay_buffer_size: usize,
26 #[config(default = 1)]
28 pub train_interval: usize,
29 #[config(default = 1)]
31 pub train_steps: usize,
32 #[config(default = 10_000)]
34 pub eval_interval: usize,
35 #[config(default = 1)]
37 pub eval_episodes: usize,
38 #[config(default = 32)]
40 pub train_batch_size: usize,
41 #[config(default = 0)]
43 pub warmup_steps: usize,
44}
45
46pub struct OffPolicyStrategy<RLC: RLComponentsTypes> {
48 config: OffPolicyConfig,
49 _components: PhantomData<RLC>,
50}
51impl<RLC: RLComponentsTypes> OffPolicyStrategy<RLC> {
52 pub fn new(config: OffPolicyConfig) -> Self {
54 Self {
55 config,
56 _components: PhantomData,
57 }
58 }
59}
60
61impl<RLC> RLStrategy<RLC> for OffPolicyStrategy<RLC>
62where
63 RLC: RLComponentsTypes,
64 RLC::PolicyObs: SliceAccess<RLC::Backend>,
65 RLC::PolicyAction: SliceAccess<RLC::Backend>,
66{
67 fn train_loop(
68 &self,
69 training_components: RLComponents<RLC>,
70 learner_agent: &mut RLC::LearningAgent,
71 starting_epoch: usize,
72 env_init: RLC::EnvInit,
73 ) -> (RLC::Policy, RLEventProcessorType<RLC>) {
74 let mut event_processor = training_components.event_processor;
75 let mut checkpointer = training_components.checkpointer;
76 let num_steps_total = training_components.num_steps;
77
78 let mut env_runner = MultiAgentEnvLoop::<Flex, RLC>::new(
79 self.config.num_envs,
80 env_init.clone(),
81 AsyncPolicy::new(
82 self.config.num_envs.min(self.config.autobatch_size),
83 learner_agent.policy(),
84 ),
85 false,
86 false,
87 &Default::default(),
88 );
89 let runner_config = AsyncAgentEnvLoopConfig {
90 eval: true,
91 deterministic: true,
92 id: 0,
93 };
94 let mut env_runner_valid = AgentEnvAsyncLoop::<Flex, RLC>::new(
95 env_init,
96 AsyncPolicy::new(1, learner_agent.policy()),
97 runner_config,
98 &Default::default(),
99 None,
100 None,
101 );
102
103 let device = Default::default();
104 let mut transition_buffer = TransitionBuffer::<
105 RLC::Backend,
106 RLC::PolicyObs,
107 RLC::PolicyAction,
108 >::new(self.config.replay_buffer_size, &device);
109
110 let mut valid_next = self.config.eval_interval + starting_epoch - 1;
111 let mut progress = Progress {
112 items_processed: starting_epoch,
113 items_total: num_steps_total,
114 };
115
116 let mut intermediary_update: Option<<RLC::Policy as Policy<RLC::Backend>>::PolicyState> =
117 None;
118 while progress.items_processed < num_steps_total {
119 if training_components.interrupter.should_stop() {
120 let reason = training_components
121 .interrupter
122 .get_message()
123 .unwrap_or(String::from("Reason unknown"));
124 log::info!("Training interrupted: {reason}");
125 break;
126 }
127
128 let previous_steps = progress.items_processed;
129 let items = env_runner.run_steps(
130 self.config.train_interval,
131 &mut event_processor,
132 &training_components.interrupter,
133 &mut progress,
134 );
135
136 for item in &items {
137 let t = &item.transition;
138 let state: RLC::PolicyObs = t.state.clone().into();
139 let next_state: RLC::PolicyObs = t.next_state.clone().into();
140 let action: RLC::PolicyAction = t.action.clone().into();
141 let reward = t.reward.to_data().to_vec::<f32>().unwrap()[0];
142 let done = t.done.to_data().to_vec::<f32>().unwrap()[0] > 0.5;
143 transition_buffer.push(state, next_state, action, reward, done);
144 }
145
146 if transition_buffer.len() >= self.config.train_batch_size
147 && progress.items_processed >= self.config.warmup_steps
148 {
149 if let Some(ref u) = intermediary_update {
150 env_runner.update_policy(u.clone());
151 }
152 for _ in 0..self.config.train_steps {
153 let batch = transition_buffer.sample(self.config.train_batch_size);
154 let train_item = learner_agent.train(batch);
155 intermediary_update = Some(train_item.policy);
156
157 event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new(
158 train_item.item,
159 progress.clone(),
160 None,
161 )));
162 }
163 }
164
165 if valid_next > previous_steps && valid_next <= progress.items_processed {
166 env_runner_valid.update_policy(learner_agent.policy().state());
167 env_runner_valid.run_episodes(
168 self.config.eval_episodes,
169 &mut event_processor,
170 &training_components.interrupter,
171 &mut progress,
172 );
173
174 if let Some(checkpointer) = &mut checkpointer {
175 checkpointer.checkpoint(
176 &env_runner.policy(),
177 learner_agent,
178 valid_next,
179 &training_components.event_store,
180 );
181 }
182
183 valid_next += self.config.eval_interval;
184 }
185 }
186
187 (learner_agent.policy(), event_processor)
188 }
189}