1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
use super::{Checkpointer, CheckpointerError};
use burn_core::{
    record::{FileRecorder, Record},
    tensor::backend::Backend,
};

/// The file checkpointer.
pub struct FileCheckpointer<FR> {
    directory: String,
    name: String,
    recorder: FR,
}

impl<FR> FileCheckpointer<FR> {
    /// Creates a new file checkpointer.
    ///
    /// # Arguments
    ///
    /// * `recorder` - The file recorder.
    /// * `directory` - The directory to save the checkpoints.
    /// * `name` - The name of the checkpoint.
    pub fn new(recorder: FR, directory: &str, name: &str) -> Self {
        std::fs::create_dir_all(directory).ok();

        Self {
            directory: directory.to_string(),
            name: name.to_string(),
            recorder,
        }
    }
    fn path_for_epoch(&self, epoch: usize) -> String {
        format!("{}/{}-{}", self.directory, self.name, epoch)
    }
}

impl<FR, R, B> Checkpointer<R, B> for FileCheckpointer<FR>
where
    R: Record<B>,
    FR: FileRecorder<B>,
    B: Backend,
{
    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
        let file_path = self.path_for_epoch(epoch);
        log::info!("Saving checkpoint {} to {}", epoch, file_path);

        self.recorder
            .record(record, file_path.into())
            .map_err(CheckpointerError::RecorderError)?;

        Ok(())
    }

    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
        let file_path = self.path_for_epoch(epoch);
        log::info!("Restoring checkpoint {} from {}", epoch, file_path);
        let record = self
            .recorder
            .load(file_path.into(), device)
            .map_err(CheckpointerError::RecorderError)?;

        Ok(record)
    }

    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
        let file_to_remove = format!("{}.{}", self.path_for_epoch(epoch), FR::file_extension(),);

        if std::path::Path::new(&file_to_remove).exists() {
            log::info!("Removing checkpoint {}", file_to_remove);
            std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?;
        }

        Ok(())
    }
}