Skip to main content

burn_train/learner/rl/
strategy.rs

1use std::sync::Arc;
2
3use crate::{
4    Interrupter, LearnerSummaryConfig, OffPolicyConfig, RLCheckpointer, RLComponentsTypes, RLEvent,
5    RLEventProcessorType, RLResult,
6    metric::{processor::EventProcessorTraining, store::EventStoreClient},
7};
8
9/// Struct to minimise parameters passed to [RLStrategy::train].
10pub struct RLComponents<RLC: RLComponentsTypes> {
11    /// The total number of environment steps.
12    pub num_steps: usize,
13    /// The step number from which to continue the training.
14    pub checkpoint: Option<usize>,
15    /// A checkpointer used to load and save learning checkpoints.
16    pub checkpointer: Option<RLCheckpointer<RLC>>,
17    /// Enables gradients accumulation.
18    pub grad_accumulation: Option<usize>,
19    /// An [Interupter](Interrupter) that allows aborting the training/evaluation process early.
20    pub interrupter: Interrupter,
21    /// An [EventProcessor](crate::EventProcessorTraining) that processes events happening during training and evaluation.
22    pub event_processor: RLEventProcessorType<RLC>,
23    /// A reference to an [EventStoreClient](EventStoreClient).
24    pub event_store: Arc<EventStoreClient>,
25    /// Config for creating a summary of the learning
26    pub summary: Option<LearnerSummaryConfig>,
27}
28
29/// The strategy for reinforcement learning.
30#[derive(Clone)]
31pub enum RLStrategies<RLC: RLComponentsTypes> {
32    /// Training on one device
33    OffPolicyStrategy(OffPolicyConfig),
34    /// Training using a custom learning strategy
35    Custom(CustomRLStrategy<RLC>),
36}
37
38/// A reference to an implementation of [RLStrategy].
39pub type CustomRLStrategy<LC> = Arc<dyn RLStrategy<LC>>;
40
41/// Provides the `fit` function for any learning strategy
42pub trait RLStrategy<RLC: RLComponentsTypes> {
43    /// Train the learner agent with this strategy.
44    fn train(
45        &self,
46        mut learner_agent: RLC::LearningAgent,
47        mut training_components: RLComponents<RLC>,
48        env_init: RLC::EnvInit,
49    ) -> RLResult<RLC::Policy> {
50        let starting_epoch = match training_components.checkpoint {
51            Some(checkpoint) => {
52                if let Some(checkpointer) = &mut training_components.checkpointer {
53                    learner_agent = checkpointer.load_checkpoint(
54                        learner_agent,
55                        &Default::default(),
56                        checkpoint,
57                    );
58                }
59                checkpoint + 1
60            }
61            None => 1,
62        };
63
64        let summary_config = training_components.summary.clone();
65
66        // Event processor start training
67        training_components
68            .event_processor
69            .process_train(RLEvent::Start);
70
71        // Training loop
72        let (policy, mut event_processor) = self.train_loop(
73            training_components,
74            &mut learner_agent,
75            starting_epoch,
76            env_init,
77        );
78
79        let summary = summary_config.and_then(|summary| summary.init().ok());
80
81        // Signal training end. For the TUI renderer, this handles the exit & return to main screen.
82        // TODO: summary makes sense for RL?
83        event_processor.process_train(RLEvent::End(summary));
84
85        // let model = model.valid();
86        let renderer = event_processor.renderer();
87
88        RLResult { policy, renderer }
89    }
90
91    /// Training loop for this strategy
92    fn train_loop(
93        &self,
94        training_components: RLComponents<RLC>,
95        learner_agent: &mut RLC::LearningAgent,
96        starting_epoch: usize,
97        env_init: RLC::EnvInit,
98    ) -> (RLC::Policy, RLEventProcessorType<RLC>);
99}