burn_train/checkpoint/
file.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
use std::path::{Path, PathBuf};

use super::{Checkpointer, CheckpointerError};
use burn_core::{
    record::{FileRecorder, Record},
    tensor::backend::Backend,
};

/// The file checkpointer.
pub struct FileCheckpointer<FR> {
    directory: PathBuf,
    name: String,
    recorder: FR,
}

impl<FR> FileCheckpointer<FR> {
    /// Creates a new file checkpointer.
    ///
    /// # Arguments
    ///
    /// * `recorder` - The file recorder.
    /// * `directory` - The directory to save the checkpoints.
    /// * `name` - The name of the checkpoint.
    pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {
        let directory = directory.as_ref();
        std::fs::create_dir_all(directory).ok();

        Self {
            directory: directory.to_path_buf(),
            name: name.to_string(),
            recorder,
        }
    }

    fn path_for_epoch(&self, epoch: usize) -> PathBuf {
        self.directory.join(format!("{}-{}", self.name, epoch))
    }
}

impl<FR, R, B> Checkpointer<R, B> for FileCheckpointer<FR>
where
    R: Record<B>,
    FR: FileRecorder<B>,
    B: Backend,
{
    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
        let file_path = self.path_for_epoch(epoch);
        log::info!("Saving checkpoint {} to {}", epoch, file_path.display());

        self.recorder
            .record(record, file_path)
            .map_err(CheckpointerError::RecorderError)?;

        Ok(())
    }

    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
        let file_path = self.path_for_epoch(epoch);
        log::info!(
            "Restoring checkpoint {} from {}",
            epoch,
            file_path.display()
        );
        let record = self
            .recorder
            .load(file_path, device)
            .map_err(CheckpointerError::RecorderError)?;

        Ok(record)
    }

    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
        let file_to_remove = format!(
            "{}.{}",
            self.path_for_epoch(epoch).display(),
            FR::file_extension(),
        );

        if std::path::Path::new(&file_to_remove).exists() {
            log::info!("Removing checkpoint {}", file_to_remove);
            std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?;
        }

        Ok(())
    }
}