Skip to main content

burn_train/learner/rl/
off_policy.rs

1use std::marker::PhantomData;
2
3use crate::{
4    AgentEnvAsyncLoop, AgentEnvLoop, AsyncAgentEnvLoopConfig, EvaluationItem,
5    EventProcessorTraining, MultiAgentEnvLoop, RLComponents, RLComponentsTypes, RLEvent,
6    RLEventProcessorType, RLStrategy,
7};
8use burn_core::{self as burn};
9use burn_core::{config::Config, data::dataloader::Progress};
10use burn_flex::Flex;
11use burn_rl::{AsyncPolicy, Policy, PolicyLearner, SliceAccess, TransitionBuffer};
12
13/// Parameters of an on policy training with multi environments and double-batching.
14#[derive(Config, Debug)]
15pub struct OffPolicyConfig {
16    /// The number of environments to run simultaneously for experience collection.
17    #[config(default = 1)]
18    pub num_envs: usize,
19    /// Number of environment state to accumulate before running one step of inference with the policy.
20    /// Must be equal or less than the number of simultaneous environments.
21    #[config(default = 1)]
22    pub autobatch_size: usize,
23    /// Max number of transitions stored in the replay buffer.
24    #[config(default = 1024)]
25    pub replay_buffer_size: usize,
26    /// The number of steps to collect between each step of training.
27    #[config(default = 1)]
28    pub train_interval: usize,
29    /// Number of optimization steps done each `train_interval`.
30    #[config(default = 1)]
31    pub train_steps: usize,
32    /// The number of steps to collect between each evaluation.
33    #[config(default = 10_000)]
34    pub eval_interval: usize,
35    /// The number of episodes to run for each evaluation.
36    #[config(default = 1)]
37    pub eval_episodes: usize,
38    /// The number of transition to train on.
39    #[config(default = 32)]
40    pub train_batch_size: usize,
41    /// Number of steps to collect before starting to train.
42    #[config(default = 0)]
43    pub warmup_steps: usize,
44}
45
46/// Off-policy reinforcement learning strategy with multi-env experience collection and double-batching.
47pub struct OffPolicyStrategy<RLC: RLComponentsTypes> {
48    config: OffPolicyConfig,
49    _components: PhantomData<RLC>,
50}
51impl<RLC: RLComponentsTypes> OffPolicyStrategy<RLC> {
52    /// Create a new off-policy base strategy.
53    pub fn new(config: OffPolicyConfig) -> Self {
54        Self {
55            config,
56            _components: PhantomData,
57        }
58    }
59}
60
61impl<RLC> RLStrategy<RLC> for OffPolicyStrategy<RLC>
62where
63    RLC: RLComponentsTypes,
64    RLC::PolicyObs: SliceAccess<RLC::Backend>,
65    RLC::PolicyAction: SliceAccess<RLC::Backend>,
66{
67    fn train_loop(
68        &self,
69        training_components: RLComponents<RLC>,
70        learner_agent: &mut RLC::LearningAgent,
71        starting_epoch: usize,
72        env_init: RLC::EnvInit,
73    ) -> (RLC::Policy, RLEventProcessorType<RLC>) {
74        let mut event_processor = training_components.event_processor;
75        let mut checkpointer = training_components.checkpointer;
76        let num_steps_total = training_components.num_steps;
77
78        let mut env_runner = MultiAgentEnvLoop::<Flex, RLC>::new(
79            self.config.num_envs,
80            env_init.clone(),
81            AsyncPolicy::new(
82                self.config.num_envs.min(self.config.autobatch_size),
83                learner_agent.policy(),
84            ),
85            false,
86            false,
87            &Default::default(),
88        );
89        let runner_config = AsyncAgentEnvLoopConfig {
90            eval: true,
91            deterministic: true,
92            id: 0,
93        };
94        let mut env_runner_valid = AgentEnvAsyncLoop::<Flex, RLC>::new(
95            env_init,
96            AsyncPolicy::new(1, learner_agent.policy()),
97            runner_config,
98            &Default::default(),
99            None,
100            None,
101        );
102
103        let device = Default::default();
104        let mut transition_buffer = TransitionBuffer::<
105            RLC::Backend,
106            RLC::PolicyObs,
107            RLC::PolicyAction,
108        >::new(self.config.replay_buffer_size, &device);
109
110        let mut valid_next = self.config.eval_interval + starting_epoch - 1;
111        let mut progress = Progress {
112            items_processed: starting_epoch,
113            items_total: num_steps_total,
114        };
115
116        let mut intermediary_update: Option<<RLC::Policy as Policy<RLC::Backend>>::PolicyState> =
117            None;
118        while progress.items_processed < num_steps_total {
119            if training_components.interrupter.should_stop() {
120                let reason = training_components
121                    .interrupter
122                    .get_message()
123                    .unwrap_or(String::from("Reason unknown"));
124                log::info!("Training interrupted: {reason}");
125                break;
126            }
127
128            let previous_steps = progress.items_processed;
129            let items = env_runner.run_steps(
130                self.config.train_interval,
131                &mut event_processor,
132                &training_components.interrupter,
133                &mut progress,
134            );
135
136            for item in &items {
137                let t = &item.transition;
138                let state: RLC::PolicyObs = t.state.clone().into();
139                let next_state: RLC::PolicyObs = t.next_state.clone().into();
140                let action: RLC::PolicyAction = t.action.clone().into();
141                let reward = t.reward.to_data().to_vec::<f32>().unwrap()[0];
142                let done = t.done.to_data().to_vec::<f32>().unwrap()[0] > 0.5;
143                transition_buffer.push(state, next_state, action, reward, done);
144            }
145
146            if transition_buffer.len() >= self.config.train_batch_size
147                && progress.items_processed >= self.config.warmup_steps
148            {
149                if let Some(ref u) = intermediary_update {
150                    env_runner.update_policy(u.clone());
151                }
152                for _ in 0..self.config.train_steps {
153                    let batch = transition_buffer.sample(self.config.train_batch_size);
154                    let train_item = learner_agent.train(batch);
155                    intermediary_update = Some(train_item.policy);
156
157                    event_processor.process_train(RLEvent::TrainStep(EvaluationItem::new(
158                        train_item.item,
159                        progress.clone(),
160                        None,
161                    )));
162                }
163            }
164
165            if valid_next > previous_steps && valid_next <= progress.items_processed {
166                env_runner_valid.update_policy(learner_agent.policy().state());
167                env_runner_valid.run_episodes(
168                    self.config.eval_episodes,
169                    &mut event_processor,
170                    &training_components.interrupter,
171                    &mut progress,
172                );
173
174                if let Some(checkpointer) = &mut checkpointer {
175                    checkpointer.checkpoint(
176                        &env_runner.policy(),
177                        learner_agent,
178                        valid_next,
179                        &training_components.event_store,
180                    );
181                }
182
183                valid_next += self.config.eval_interval;
184            }
185        }
186
187        (learner_agent.policy(), event_processor)
188    }
189}