1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
use crate::checkpoint::Checkpointer;
use crate::LearnerCallback;
use burn_core::module::ADModule;
use burn_core::optim::Optimizer;
use burn_core::tensor::backend::Backend;
pub struct Learner<M, O, TO, VO>
where
M: ADModule,
{
pub(super) model: M,
pub(super) optim: O,
pub(super) num_epochs: usize,
pub(super) callback: Box<dyn LearnerCallback<TO, VO>>,
pub(super) checkpoint: Option<usize>,
pub(super) checkpointer_model: Option<Box<dyn Checkpointer<<M::Backend as Backend>::Elem>>>,
pub(super) checkpointer_optimizer: Option<Box<dyn Checkpointer<<M::Backend as Backend>::Elem>>>,
pub(super) grad_accumulation: Option<usize>,
pub(super) devices: Vec<<M::Backend as Backend>::Device>,
}
impl<M, O, TO, VO> Learner<M, O, TO, VO>
where
VO: Send + Sync + 'static,
TO: Send + Sync + 'static,
M: ADModule,
O: Optimizer<Backend = M::Backend>,
{
pub(super) fn checkpoint(&self, epoch: usize) {
if let Some(checkpointer) = &self.checkpointer_model {
checkpointer.save(epoch, self.model.state()).unwrap();
}
if let Some(checkpointer) = &self.checkpointer_optimizer {
checkpointer
.save(epoch, self.optim.state(&self.model))
.unwrap();
}
}
pub(super) fn load_checkpoint(&mut self, epoch: usize) {
if let Some(checkpointer) = &self.checkpointer_model {
let state = checkpointer.restore(epoch).unwrap();
self.model.load(&state).unwrap();
}
if let Some(checkpointer) = &self.checkpointer_optimizer {
let state = checkpointer.restore(epoch).unwrap();
self.optim.load(&self.model, &state).unwrap();
}
}
}