burn_train/checkpoint/
base.rs

1use burn_core::{
2    record::{Record, RecorderError},
3    tensor::backend::Backend,
4};
5use thiserror::Error;
6
7/// The error type for checkpointer.
8#[derive(Error, Debug)]
9pub enum CheckpointerError {
10    /// IO error.
11    #[error("I/O Error: `{0}`")]
12    IOError(std::io::Error),
13
14    /// Recorder error.
15    #[error("Recorder error: `{0}`")]
16    RecorderError(RecorderError),
17
18    /// Other errors.
19    #[error("Unknown error: `{0}`")]
20    Unknown(String),
21}
22
23/// The trait for checkpointer.
24pub trait Checkpointer<R, B>: Send + Sync
25where
26    R: Record<B>,
27    B: Backend,
28{
29    /// Save the record.
30    ///
31    /// # Arguments
32    ///
33    /// * `epoch` - The epoch.
34    /// * `record` - The record.
35    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError>;
36
37    /// Delete the record at the given epoch if present.
38    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError>;
39
40    /// Restore the record.
41    ///
42    /// # Arguments
43    ///
44    /// * `epoch` - The epoch.
45    /// * `device` - The device used to restore the record.
46    ///
47    /// # Returns
48    ///
49    /// The record.
50    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError>;
51}