Skip to main content

burn_train/learner/rl/
paradigm.rs

1use crate::checkpoint::{
2    AsyncCheckpointer, CheckpointingStrategy, ComposedCheckpointingStrategy, FileCheckpointer,
3    KeepLastNCheckpoints, MetricCheckpointingStrategy,
4};
5use crate::learner::base::Interrupter;
6use crate::logger::{FileMetricLogger, MetricLogger};
7use crate::metric::store::{Aggregate, Direction, EventStoreClient, LogEventStore, Split};
8use crate::metric::{Adaptor, EpisodeLengthMetric, Metric, Numeric};
9use crate::renderer::{MetricsRenderer, default_renderer};
10use crate::{
11    ApplicationLoggerInstaller, AsyncProcessorTraining, FileApplicationLoggerInstaller, ItemLazy,
12    LearnerSummaryConfig, OffPolicyConfig, OffPolicyStrategy, RLAgentRecord, RLCheckpointer,
13    RLComponents, RLComponentsMarker, RLComponentsTypes, RLEventProcessor, RLMetrics,
14    RLPolicyRecord, RLStrategy,
15};
16use crate::{EpisodeSummary, RLStrategies};
17use burn_core::record::FileRecorder;
18use burn_core::tensor::backend::AutodiffBackend;
19use burn_rl::{Batchable, Environment, EnvironmentInit, Policy, PolicyLearner, SliceAccess};
20use std::collections::BTreeSet;
21use std::path::{Path, PathBuf};
22use std::sync::Arc;
23
24/// Structure to configure and launch reinforcement learning trainings.
25pub struct RLTraining<RLC: RLComponentsTypes> {
26    // Not that complex. Extracting into yet another type would only make it more confusing.
27    #[allow(clippy::type_complexity)]
28    checkpointers: Option<(
29        AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
30        AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
31    )>,
32    num_steps: usize,
33    checkpoint: Option<usize>,
34    directory: PathBuf,
35    grad_accumulation: Option<usize>,
36    renderer: Option<Box<dyn MetricsRenderer + 'static>>,
37    metrics: RLMetrics<RLC::TrainingOutput, RLC::ActionContext>,
38    event_store: LogEventStore,
39    interrupter: Interrupter,
40    tracing_logger: Option<Box<dyn ApplicationLoggerInstaller>>,
41    checkpointer_strategy: Box<dyn CheckpointingStrategy>,
42    learning_strategy: RLStrategies<RLC>,
43    // Use BTreeSet instead of HashSet for consistent (alphabetical) iteration order
44    summary_metrics: BTreeSet<String>,
45    summary: bool,
46    env_initializer: RLC::EnvInit,
47}
48
49impl<B, E, EI, A> RLTraining<RLComponentsMarker<B, E, EI, A>>
50where
51    B: AutodiffBackend,
52    E: Environment + 'static,
53    EI: EnvironmentInit<E> + Send + 'static,
54    A: PolicyLearner<B> + Send + 'static,
55    A::TrainContext: ItemLazy + Clone + Send,
56    A::InnerPolicy: Policy<B> + Send,
57    <A::InnerPolicy as Policy<B>>::Observation: Batchable + Clone + Send,
58    <A::InnerPolicy as Policy<B>>::ActionDistribution: Batchable + Clone + Send,
59    <A::InnerPolicy as Policy<B>>::Action: Batchable + Clone + Send,
60    <A::InnerPolicy as Policy<B>>::ActionContext: ItemLazy + Clone + Send + 'static,
61    <A::InnerPolicy as Policy<B>>::PolicyState: Clone + Send,
62    E::State: Into<<A::InnerPolicy as Policy<B>>::Observation> + Clone + Send + 'static,
63    E::Action: From<<A::InnerPolicy as Policy<B>>::Action>
64        + Into<<A::InnerPolicy as Policy<B>>::Action>
65        + Clone
66        + Send
67        + 'static,
68{
69    /// Creates a new runner for reinforcement learning.
70    ///
71    /// # Arguments
72    ///
73    /// * `directory` - The directory to save the checkpoints.
74    /// * `env_init` - Specifies how to initialize the environment.
75    pub fn new(directory: impl AsRef<Path>, env_initializer: EI) -> Self {
76        let directory = directory.as_ref().to_path_buf();
77        let experiment_log_file = directory.join("experiment.log");
78        Self {
79            num_steps: 1,
80            checkpoint: None,
81            checkpointers: None,
82            directory,
83            grad_accumulation: None,
84            metrics: RLMetrics::default(),
85            event_store: LogEventStore::default(),
86            renderer: None,
87            interrupter: Interrupter::new(),
88            tracing_logger: Some(Box::new(FileApplicationLoggerInstaller::new(
89                experiment_log_file,
90            ))),
91            checkpointer_strategy: Box::new(
92                ComposedCheckpointingStrategy::builder()
93                    .add(KeepLastNCheckpoints::new(2))
94                    .add(MetricCheckpointingStrategy::new(
95                        &EpisodeLengthMetric::new(), // default to evaluations' cumulative reward.
96                        Aggregate::Mean,
97                        Direction::Lowest,
98                        Split::Valid,
99                    ))
100                    .build(),
101            ),
102            learning_strategy: RLStrategies::OffPolicyStrategy(OffPolicyConfig::new()),
103            summary_metrics: BTreeSet::new(),
104            summary: false,
105            env_initializer,
106        }
107    }
108}
109
110impl<RLC: RLComponentsTypes + 'static> RLTraining<RLC> {
111    /// Replace the default learning strategy (Off Policy learning) with the provided one.
112    ///
113    /// # Arguments
114    ///
115    /// * `training_strategy` - The training strategy.
116    pub fn with_learning_strategy(mut self, learning_strategy: RLStrategies<RLC>) -> Self {
117        self.learning_strategy = learning_strategy;
118        self
119    }
120
121    /// Replace the default metric loggers with the provided ones.
122    ///
123    /// # Arguments
124    ///
125    /// * `logger` - The training logger.
126    pub fn with_metric_logger<ML>(mut self, logger: ML) -> Self
127    where
128        ML: MetricLogger + 'static,
129    {
130        self.event_store.register_logger(logger);
131        self
132    }
133
134    /// Update the checkpointing_strategy.
135    pub fn with_checkpointing_strategy<CS: CheckpointingStrategy + 'static>(
136        mut self,
137        strategy: CS,
138    ) -> Self {
139        self.checkpointer_strategy = Box::new(strategy);
140        self
141    }
142
143    /// Replace the default CLI renderer with a custom one.
144    ///
145    /// # Arguments
146    ///
147    /// * `renderer` - The custom renderer.
148    pub fn renderer<MR>(mut self, renderer: MR) -> Self
149    where
150        MR: MetricsRenderer + 'static,
151    {
152        self.renderer = Some(Box::new(renderer));
153        self
154    }
155
156    /// Register numerical metrics for a training step of the agent.
157    pub fn metrics_train<Me: TrainMetricRegistration<RLC>>(self, metrics: Me) -> Self {
158        metrics.register(self)
159    }
160
161    /// Register textual metrics for a training step of the agent.
162    pub fn text_metrics_train<Me: TrainTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
163        metrics.register(self)
164    }
165
166    /// Register numerical metrics for each action of the agent.
167    pub fn metrics_agent<Me: AgentMetricRegistration<RLC>>(self, metrics: Me) -> Self {
168        metrics.register(self)
169    }
170
171    /// Register textual metrics for each action of the agent.
172    pub fn text_metrics_agent<Me: AgentTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
173        metrics.register(self)
174    }
175
176    /// Register numerical metrics for a completed episode.
177    pub fn metrics_episode<Me: EpisodeMetricRegistration<RLC>>(self, metrics: Me) -> Self {
178        metrics.register(self)
179    }
180
181    /// Register textual metrics for a completed episode.
182    pub fn text_metrics_episode<Me: EpisodeTextMetricRegistration<RLC>>(self, metrics: Me) -> Self {
183        metrics.register(self)
184    }
185
186    /// Register a textual metric for a training step.
187    pub fn text_metric_train<Me: Metric + 'static>(mut self, metric: Me) -> Self
188    where
189        <RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
190    {
191        self.metrics.register_text_metric_train(metric);
192        self
193    }
194
195    /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a training step.
196    pub fn metric_train<Me>(mut self, metric: Me) -> Self
197    where
198        Me: Metric + Numeric + 'static,
199        <RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<Me::Input>,
200    {
201        self.summary_metrics.insert(metric.name().to_string());
202        self.metrics.register_metric_train(metric);
203        self
204    }
205
206    /// Register a textual metric for each action taken by the agent.
207    pub fn text_metric_agent<Me: Metric + 'static>(mut self, metric: Me) -> Self
208    where
209        <RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
210    {
211        self.metrics.register_text_metric_agent(metric.clone());
212        self.metrics.register_text_metric_agent_valid(metric);
213        self
214    }
215
216    /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for each action taken by the agent.
217    pub fn metric_agent<Me>(mut self, metric: Me) -> Self
218    where
219        Me: Metric + Numeric + 'static,
220        <RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<Me::Input>,
221    {
222        self.summary_metrics.insert(metric.name().to_string());
223        self.metrics.register_agent_metric(metric.clone());
224        self.metrics.register_agent_metric_valid(metric);
225        self
226    }
227
228    /// Register a textual metric for a completed episode.
229    pub fn text_metric_episode<Me: Metric + 'static>(mut self, metric: Me) -> Self
230    where
231        EpisodeSummary: Adaptor<Me::Input> + 'static,
232    {
233        self.metrics.register_text_metric_episode(metric.clone());
234        self.metrics.register_text_metric_episode_valid(metric);
235        self
236    }
237
238    /// Register a [numeric](crate::metric::Numeric) [metric](Metric) for a completed episode.
239    pub fn metric_episode<Me>(mut self, metric: Me) -> Self
240    where
241        Me: Metric + Numeric + 'static,
242        EpisodeSummary: Adaptor<Me::Input> + 'static,
243    {
244        self.summary_metrics.insert(metric.name().to_string());
245        self.metrics.register_episode_metric(metric.clone());
246        self.metrics.register_episode_metric_valid(metric);
247        self
248    }
249
250    /// The number of environment steps to train for.
251    pub fn num_steps(mut self, num_steps: usize) -> Self {
252        self.num_steps = num_steps;
253        self
254    }
255
256    /// The step from which the training must resume.
257    pub fn checkpoint(mut self, checkpoint: usize) -> Self {
258        self.checkpoint = Some(checkpoint);
259        self
260    }
261
262    /// Provides a handle that can be used to interrupt training.
263    pub fn interrupter(&self) -> Interrupter {
264        self.interrupter.clone()
265    }
266
267    /// Override the handle for stopping training with an externally provided handle
268    pub fn with_interrupter(mut self, interrupter: Interrupter) -> Self {
269        self.interrupter = interrupter;
270        self
271    }
272
273    /// By default, Rust logs are captured and written into
274    /// `experiment.log`. If disabled, standard Rust log handling
275    /// will apply.
276    pub fn with_application_logger(
277        mut self,
278        logger: Option<Box<dyn ApplicationLoggerInstaller>>,
279    ) -> Self {
280        self.tracing_logger = logger;
281        self
282    }
283
284    /// Register a checkpointer that will save the environment runner's [policy](Policy)
285    /// and the [PolicyLearner](PolicyLearner) state to different files.
286    pub fn with_file_checkpointer<FR>(mut self, recorder: FR) -> Self
287    where
288        FR: FileRecorder<RLC::Backend> + 'static,
289        FR: FileRecorder<<RLC::Backend as AutodiffBackend>::InnerBackend> + 'static,
290    {
291        let checkpoint_dir = self.directory.join("checkpoint");
292        let checkpointer_policy =
293            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "policy");
294        let checkpointer_learning =
295            FileCheckpointer::new(recorder.clone(), &checkpoint_dir, "learning-agent");
296
297        self.checkpointers = Some((
298            AsyncCheckpointer::new(checkpointer_policy),
299            AsyncCheckpointer::new(checkpointer_learning),
300        ));
301
302        self
303    }
304
305    /// Enable the training summary report.
306    ///
307    /// The summary will be displayed after `.launch()`, when the renderer is dropped.
308    pub fn summary(mut self) -> Self {
309        self.summary = true;
310        self
311    }
312
313    /// Launch the training with the specified [PolicyLearner](PolicyLearner) on the specified environment.
314    pub fn launch(mut self, learner_agent: RLC::LearningAgent) -> RLResult<RLC::Policy>
315    where
316        RLC::PolicyObs: SliceAccess<RLC::Backend>,
317        RLC::PolicyAction: SliceAccess<RLC::Backend>,
318    {
319        if self.tracing_logger.is_some()
320            && let Err(e) = self.tracing_logger.as_ref().unwrap().install()
321        {
322            log::warn!("Failed to install the experiment logger: {e}");
323        }
324        let renderer = self
325            .renderer
326            .unwrap_or_else(|| default_renderer(self.interrupter.clone(), self.checkpoint));
327
328        if !self.event_store.has_loggers() {
329            self.event_store
330                .register_logger(FileMetricLogger::new(self.directory.clone()));
331        }
332
333        let event_store = Arc::new(EventStoreClient::new(self.event_store));
334        let event_processor = AsyncProcessorTraining::new(RLEventProcessor::new(
335            self.metrics,
336            renderer,
337            event_store.clone(),
338        ));
339
340        let checkpointer = self.checkpointers.map(|(policy, learning_agent)| {
341            RLCheckpointer::new(policy, learning_agent, self.checkpointer_strategy)
342        });
343
344        let summary = if self.summary {
345            Some(LearnerSummaryConfig {
346                directory: self.directory,
347                metrics: self.summary_metrics.into_iter().collect::<Vec<_>>(),
348            })
349        } else {
350            None
351        };
352
353        let components = RLComponents::<RLC> {
354            checkpoint: self.checkpoint,
355            checkpointer,
356            interrupter: self.interrupter,
357            event_processor,
358            event_store,
359            num_steps: self.num_steps,
360            grad_accumulation: self.grad_accumulation,
361            summary,
362        };
363
364        match self.learning_strategy {
365            RLStrategies::OffPolicyStrategy(config) => {
366                let strategy = OffPolicyStrategy::new(config);
367                strategy.train(learner_agent, components, self.env_initializer)
368            }
369            RLStrategies::Custom(strategy) => {
370                strategy.train(learner_agent, components, self.env_initializer)
371            }
372        }
373    }
374}
375
376/// The result of reinforcement learning, containing the final policy along with the [renderer](MetricsRenderer).
377pub struct RLResult<P> {
378    /// The learned policy.
379    pub policy: P,
380    /// The renderer that can be used for follow up training and evaluation.
381    pub renderer: Box<dyn MetricsRenderer>,
382}
383
384/// Trait to fake variadic generics for train step metrics.
385pub trait AgentMetricRegistration<RLC: RLComponentsTypes>: Sized {
386    /// Register the metrics.
387    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
388}
389
390/// Trait to fake variadic generics for train step text metrics.
391pub trait AgentTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
392    /// Register the metrics.
393    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
394}
395
396/// Trait to fake variadic generics for env step metrics.
397pub trait TrainMetricRegistration<RLC: RLComponentsTypes>: Sized {
398    /// Register the metrics.
399    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
400}
401
402/// Trait to fake variadic generics for env step text metrics.
403pub trait TrainTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
404    /// Register the metrics.
405    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
406}
407
408/// Trait to fake variadic generics for episode metrics.
409pub trait EpisodeMetricRegistration<RLC: RLComponentsTypes>: Sized {
410    /// Register the metrics.
411    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
412}
413
414/// Trait to fake variadic generics for episode text metrics.
415pub trait EpisodeTextMetricRegistration<RLC: RLComponentsTypes>: Sized {
416    /// Register the metrics.
417    fn register(self, builder: RLTraining<RLC>) -> RLTraining<RLC>;
418}
419
420macro_rules! gen_tuple {
421    ($($M:ident),*) => {
422        impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainTextMetricRegistration<RLC> for ($($M,)*)
423        where
424            $(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
425            $($M: Metric + 'static,)*
426        {
427            #[allow(non_snake_case)]
428            fn register(
429                self,
430                builder: RLTraining<RLC>,
431            ) -> RLTraining<RLC> {
432                let ($($M,)*) = self;
433                $(let builder = builder.text_metric_train($M.clone());)*
434                builder
435            }
436        }
437
438        impl<$($M,)* RLC: RLComponentsTypes + 'static> TrainMetricRegistration<RLC> for ($($M,)*)
439        where
440            $(<RLC::TrainingOutput as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
441            $($M: Metric + Numeric + 'static,)*
442        {
443            #[allow(non_snake_case)]
444            fn register(
445                self,
446                builder: RLTraining<RLC>,
447            ) -> RLTraining<RLC> {
448                let ($($M,)*) = self;
449                $(let builder = builder.metric_train($M.clone());)*
450                builder
451            }
452        }
453
454        impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentTextMetricRegistration<RLC> for ($($M,)*)
455        where
456            $(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
457            $($M: Metric + 'static,)*
458        {
459            #[allow(non_snake_case)]
460            fn register(
461                self,
462                builder: RLTraining<RLC>,
463            ) -> RLTraining<RLC> {
464                let ($($M,)*) = self;
465                $(let builder = builder.text_metric_agent($M.clone());)*
466                builder
467            }
468        }
469
470        impl<$($M,)* RLC: RLComponentsTypes + 'static> AgentMetricRegistration<RLC> for ($($M,)*)
471        where
472            $(<RLC::ActionContext as ItemLazy>::ItemSync: Adaptor<$M::Input>,)*
473            $($M: Metric + Numeric + 'static,)*
474        {
475            #[allow(non_snake_case)]
476            fn register(
477                self,
478                builder: RLTraining<RLC>,
479            ) -> RLTraining<RLC> {
480                let ($($M,)*) = self;
481                $(let builder = builder.metric_agent($M.clone());)*
482                builder
483            }
484        }
485
486        impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeTextMetricRegistration<RLC> for ($($M,)*)
487        where
488            $(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
489            $($M: Metric + 'static,)*
490        {
491            #[allow(non_snake_case)]
492            fn register(
493                self,
494                builder: RLTraining<RLC>,
495            ) -> RLTraining<RLC> {
496                let ($($M,)*) = self;
497                $(let builder = builder.text_metric_episode($M.clone());)*
498                builder
499            }
500        }
501
502        impl<$($M,)* RLC: RLComponentsTypes + 'static> EpisodeMetricRegistration<RLC> for ($($M,)*)
503        where
504            $(EpisodeSummary: Adaptor<$M::Input> + 'static,)*
505            $($M: Metric + Numeric + 'static,)*
506        {
507            #[allow(non_snake_case)]
508            fn register(
509                self,
510                builder: RLTraining<RLC>,
511            ) -> RLTraining<RLC> {
512                let ($($M,)*) = self;
513                $(let builder = builder.metric_episode($M.clone());)*
514                builder
515            }
516        }
517    };
518}
519
520gen_tuple!(M1);
521gen_tuple!(M1, M2);
522gen_tuple!(M1, M2, M3);
523gen_tuple!(M1, M2, M3, M4);
524gen_tuple!(M1, M2, M3, M4, M5);
525gen_tuple!(M1, M2, M3, M4, M5, M6);