use burn_core::{
record::{Record, RecorderError},
tensor::backend::Backend,
};
#[derive(Debug)]
pub enum CheckpointerError {
IOError(std::io::Error),
RecorderError(RecorderError),
Unknown(String),
}
pub trait Checkpointer<R, B>: Send + Sync
where
R: Record<B>,
B: Backend,
{
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
}