use burn_core::{
record::{Record, RecorderError},
tensor::backend::Backend,
};
use thiserror::Error;
#[derive(Error, Debug)]
pub enum CheckpointerError {
#[error("I/O Error: `{0}`")]
IOError(std::io::Error),
#[error("Recorder error: `{0}`")]
RecorderError(RecorderError),
#[error("Unknown error: `{0}`")]
Unknown(String),
}
pub trait Checkpointer<R, B>: Send + Sync
where
R: Record<B>,
B: Backend,
{
fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;
fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
}