burn_train/checkpoint/
base.rs

1use burn_core::{
2    record::{Record, RecorderError},
3    tensor::backend::Backend,
4};
5
6/// The error type for checkpointer.
7#[derive(Debug)]
8pub enum CheckpointerError {
9    /// IO error.
10    IOError(std::io::Error),
11
12    /// Recorder error.
13    RecorderError(RecorderError),
14
15    /// Other errors.
16    Unknown(String),
17}
18
19/// The trait for checkpointer.
20pub trait Checkpointer<R, B>: Send + Sync
21where
22    R: Record<B>,
23    B: Backend,
24{
25    /// Save the record.
26    ///
27    /// # Arguments
28    ///
29    /// * `epoch` - The epoch.
30    /// * `record` - The record.
31    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
32
33    /// Delete the record at the given epoch if present.
34    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;
35
36    /// Restore the record.
37    ///
38    /// # Arguments
39    ///
40    /// * `epoch` - The epoch.
41    /// * `device` - The device used to restore the record.
42    ///
43    /// # Returns
44    ///
45    /// The record.
46    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
47}