burn_train/checkpoint/
file.rs

1use std::path::{Path, PathBuf};
2
3use super::{Checkpointer, CheckpointerError};
4use burn_core::{
5    record::{FileRecorder, Record},
6    tensor::backend::Backend,
7};
8
9/// The file checkpointer.
10pub struct FileCheckpointer<FR> {
11    directory: PathBuf,
12    name: String,
13    recorder: FR,
14}
15
16impl<FR> FileCheckpointer<FR> {
17    /// Creates a new file checkpointer.
18    ///
19    /// # Arguments
20    ///
21    /// * `recorder` - The file recorder.
22    /// * `directory` - The directory to save the checkpoints.
23    /// * `name` - The name of the checkpoint.
24    pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {
25        let directory = directory.as_ref();
26        std::fs::create_dir_all(directory).ok();
27
28        Self {
29            directory: directory.to_path_buf(),
30            name: name.to_string(),
31            recorder,
32        }
33    }
34
35    fn path_for_epoch(&self, epoch: usize) -> PathBuf {
36        self.directory.join(format!("{}-{}", self.name, epoch))
37    }
38}
39
40impl<FR, R, B> Checkpointer<R, B> for FileCheckpointer<FR>
41where
42    R: Record<B>,
43    FR: FileRecorder<B>,
44    B: Backend,
45{
46    fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
47        let file_path = self.path_for_epoch(epoch);
48        log::info!("Saving checkpoint {} to {}", epoch, file_path.display());
49
50        self.recorder
51            .record(record, file_path)
52            .map_err(CheckpointerError::RecorderError)?;
53
54        Ok(())
55    }
56
57    fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
58        let file_path = self.path_for_epoch(epoch);
59        log::info!(
60            "Restoring checkpoint {} from {}",
61            epoch,
62            file_path.display()
63        );
64        let record = self
65            .recorder
66            .load(file_path, device)
67            .map_err(CheckpointerError::RecorderError)?;
68
69        Ok(record)
70    }
71
72    fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
73        let file_to_remove = format!(
74            "{}.{}",
75            self.path_for_epoch(epoch).display(),
76            FR::file_extension(),
77        );
78
79        if std::path::Path::new(&file_to_remove).exists() {
80            log::info!("Removing checkpoint {file_to_remove}");
81            std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?;
82        }
83
84        Ok(())
85    }
86}