burn_train/checkpoint/
base.rs1use burn_core::{
2 record::{Record, RecorderError},
3 tensor::backend::Backend,
4};
5use thiserror::Error;
6
7#[derive(Error, Debug)]
9pub enum CheckpointerError {
10 #[error("I/O Error: `{0}`")]
12 IOError(std::io::Error),
13
14 #[error("Recorder error: `{0}`")]
16 RecorderError(RecorderError),
17
18 #[error("Unknown error: `{0}`")]
20 Unknown(String),
21}
22
23pub trait Checkpointer<R, B>: Send + Sync
25where
26 R: Record<B>,
27 B: Backend,
28{
29 fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
36
37 fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;
39
40 fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
51}