1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
use burn_core::{
    record::{Record, RecorderError},
    tensor::backend::Backend,
};

/// The error type for checkpointer.
#[derive(Debug)]
pub enum CheckpointerError {
    /// IO error.
    IOError(std::io::Error),

    /// Recorder error.
    RecorderError(RecorderError),

    /// Other errors.
    Unknown(String),
}

/// The trait for checkpointer.
pub trait Checkpointer<R, B>
where
    R: Record<B>,
    B: Backend,
{
    /// Save the record.
    ///
    /// # Arguments
    ///
    /// * `epoch` - The epoch.
    /// * `record` - The record.
    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;

    /// Delete the record at the given epoch if present.
    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;

    /// Restore the record.
    ///
    /// # Arguments
    ///
    /// * `epoch` - The epoch.
    /// * `device` - The device used to restore the record.
    ///
    /// # Returns
    ///
    /// The record.
    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
}