burn_train/learner/rl/
checkpointer.rs1use burn_core::tensor::Device;
2use burn_rl::{Policy, PolicyLearner, PolicyState};
3
4use crate::RLAgentRecord;
5use crate::{
6 RLComponentsTypes, RLPolicyRecord,
7 checkpoint::Checkpointer,
8 checkpoint::{AsyncCheckpointer, CheckpointingAction, CheckpointingStrategy},
9 metric::store::EventStoreClient,
10};
11
12#[derive(new)]
13pub struct RLCheckpointer<RLC: RLComponentsTypes> {
15 policy: AsyncCheckpointer<RLPolicyRecord<RLC>, RLC::Backend>,
16 learning_agent: AsyncCheckpointer<RLAgentRecord<RLC>, RLC::Backend>,
17 strategy: Box<dyn CheckpointingStrategy>,
18}
19
20impl<RLC: RLComponentsTypes> RLCheckpointer<RLC> {
21 pub fn checkpoint(
23 &mut self,
24 policy: &RLC::PolicyState,
25 learning_agent: &RLC::LearningAgent,
26 epoch: usize,
27 store: &EventStoreClient,
28 ) {
29 let actions = self.strategy.checkpointing(epoch, store);
30
31 for action in actions {
32 match action {
33 CheckpointingAction::Delete(epoch) => {
34 self.policy
35 .delete(epoch)
36 .expect("Can delete policy checkpoint.");
37 self.learning_agent
38 .delete(epoch)
39 .expect("Can delete learning agent checkpoint.")
40 }
41 CheckpointingAction::Save => {
42 self.policy
43 .save(epoch, policy.clone().into_record())
44 .expect("Can save policy checkpoint.");
45 self.learning_agent
46 .save(epoch, learning_agent.record())
47 .expect("Can save learning agent checkpoint.");
48 }
49 }
50 }
51 }
52
53 pub fn load_checkpoint(
55 &self,
56 learning_agent: RLC::LearningAgent,
57 device: &Device<RLC::Backend>,
58 epoch: usize,
59 ) -> RLC::LearningAgent {
60 let record = self
61 .policy
62 .restore(epoch, device)
63 .expect("Can load model checkpoint.");
64 let policy = learning_agent.policy().load_record(record);
65
66 let record = self
67 .learning_agent
68 .restore(epoch, device)
69 .expect("Can load learning agent checkpoint.");
70 let mut learning_agent = learning_agent.load_record(record);
71 learning_agent.update_policy(policy);
72
73 learning_agent
74 }
75}