1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47
use burn_core::{
record::{Record, RecorderError},
tensor::backend::Backend,
};
/// The error type for checkpointer.
#[derive(Debug)]
pub enum CheckpointerError {
/// IO error.
IOError(std::io::Error),
/// Recorder error.
RecorderError(RecorderError),
/// Other errors.
Unknown(String),
}
/// The trait for checkpointer.
pub trait Checkpointer<R, B>
where
R: Record<B>,
B: Backend,
{
/// Save the record.
///
/// # Arguments
///
/// * `epoch` - The epoch.
/// * `record` - The record.
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
/// Delete the record at the given epoch if present.
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;
/// Restore the record.
///
/// # Arguments
///
/// * `epoch` - The epoch.
/// * `device` - The device used to restore the record.
///
/// # Returns
///
/// The record.
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
}