Skip to main content

burn_train/learner/rl/
checkpointer.rs

1use burn_core::tensor::Device;
2use burn_rl::{Policy, PolicyLearner, PolicyState};
3
4use crate::RLAgentRecord;
5use crate::{
6    RLComponentsTypes, RLPolicyRecord,
7    checkpoint::Checkpointer,
8    checkpoint::{AsyncCheckpointer, CheckpointingAction, CheckpointingStrategy},
9    metric::store::EventStoreClient,
10};
11
12#[derive(new)]
13/// Used to create, delete, or load checkpoints of the training process.
14pub struct RLCheckpointer<RLC: RLComponentsTypes> {
15    policy: AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
16    learning_agent: AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
17    strategy: Box<dyn CheckpointingStrategy>,
18}
19
20impl<RLC: RLComponentsTypes> RLCheckpointer<RLC> {
21    /// Create checkpoint for the training process.
22    pub fn checkpoint(
23        &mut self,
24        policy: &RLC::PolicyState,
25        learning_agent: &RLC::LearningAgent,
26        epoch: usize,
27        store: &EventStoreClient,
28    ) {
29        let actions = self.strategy.checkpointing(epoch, store);
30
31        for action in actions {
32            match action {
33                CheckpointingAction::Delete(epoch) => {
34                    self.policy
35                        .delete(epoch)
36                        .expect("Can delete policy checkpoint.");
37                    self.learning_agent
38                        .delete(epoch)
39                        .expect("Can delete learning agent checkpoint.")
40                }
41                CheckpointingAction::Save => {
42                    self.policy
43                        .save(epoch, policy.clone().into_record())
44                        .expect("Can save policy checkpoint.");
45                    self.learning_agent
46                        .save(epoch, learning_agent.record())
47                        .expect("Can save learning agent checkpoint.");
48                }
49            }
50        }
51    }
52
53    /// Load a training checkpoint.
54    pub fn load_checkpoint(
55        &self,
56        learning_agent: RLC::LearningAgent,
57        device: &Device<RLC::Backend>,
58        epoch: usize,
59    ) -> RLC::LearningAgent {
60        let record = self
61            .policy
62            .restore(epoch, device)
63            .expect("Can load model checkpoint.");
64        let policy = learning_agent.policy().load_record(record);
65
66        let record = self
67            .learning_agent
68            .restore(epoch, device)
69            .expect("Can load learning agent checkpoint.");
70        let mut learning_agent = learning_agent.load_record(record);
71        learning_agent.update_policy(policy);
72
73        learning_agent
74    }
75}