burn_train/checkpoint/
file.rs1use std::path::{Path, PathBuf};
2
3use super::{Checkpointer, CheckpointerError};
4use burn_core::{
5 record::{FileRecorder, Record},
6 tensor::backend::Backend,
7};
8
9pub struct FileCheckpointer<FR> {
11 directory: PathBuf,
12 name: String,
13 recorder: FR,
14}
15
16impl<FR> FileCheckpointer<FR> {
17 pub fn new(recorder: FR, directory: impl AsRef<Path>, name: &str) -> Self {
25 let directory = directory.as_ref();
26 std::fs::create_dir_all(directory).ok();
27
28 Self {
29 directory: directory.to_path_buf(),
30 name: name.to_string(),
31 recorder,
32 }
33 }
34
35 fn path_for_epoch(&self, epoch: usize) -> PathBuf {
36 self.directory.join(format!("{}-{}", self.name, epoch))
37 }
38}
39
40impl<FR, R, B> Checkpointer<R, B> for FileCheckpointer<FR>
41where
42 R: Record<B>,
43 FR: FileRecorder<B>,
44 B: Backend,
45{
46 fn save(&self, epoch: usize, record: R) -> Result<(), CheckpointerError> {
47 let file_path = self.path_for_epoch(epoch);
48 log::info!("Saving checkpoint {} to {}", epoch, file_path.display());
49
50 self.recorder
51 .record(record, file_path)
52 .map_err(CheckpointerError::RecorderError)?;
53
54 Ok(())
55 }
56
57 fn restore(&self, epoch: usize, device: &B::Device) -> Result<R, CheckpointerError> {
58 let file_path = self.path_for_epoch(epoch);
59 log::info!(
60 "Restoring checkpoint {} from {}",
61 epoch,
62 file_path.display()
63 );
64 let record = self
65 .recorder
66 .load(file_path, device)
67 .map_err(CheckpointerError::RecorderError)?;
68
69 Ok(record)
70 }
71
72 fn delete(&self, epoch: usize) -> Result<(), CheckpointerError> {
73 let file_to_remove = format!(
74 "{}.{}",
75 self.path_for_epoch(epoch).display(),
76 FR::file_extension(),
77 );
78
79 if std::path::Path::new(&file_to_remove).exists() {
80 log::info!("Removing checkpoint {file_to_remove}");
81 std::fs::remove_file(file_to_remove).map_err(CheckpointerError::IOError)?;
82 }
83
84 Ok(())
85 }
86}