burn_train/checkpoint/base.rs
1use burn_core::{
2 record::{Record, RecorderError},
3 tensor::backend::Backend,
4};
5
6/// The error type for checkpointer.
7#[derive(Debug)]
8pub enum CheckpointerError {
9 /// IO error.
10 IOError(std::io::Error),
11
12 /// Recorder error.
13 RecorderError(RecorderError),
14
15 /// Other errors.
16 Unknown(String),
17}
18
19/// The trait for checkpointer.
20pub trait Checkpointer<R, B>: Send + Sync
21where
22 R: Record<B>,
23 B: Backend,
24{
25 /// Save the record.
26 ///
27 /// # Arguments
28 ///
29 /// * `epoch` - The epoch.
30 /// * `record` - The record.
31 fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
32
33 /// Delete the record at the given epoch if present.
34 fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;
35
36 /// Restore the record.
37 ///
38 /// # Arguments
39 ///
40 /// * `epoch` - The epoch.
41 /// * `device` - The device used to restore the record.
42 ///
43 /// # Returns
44 ///
45 /// The record.
46 fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
47}