argmin_checkpointing_file/lib.rs
1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This crate creates checkpoints on disk for an optimization run.
9//!
10//! Saves a checkpoint on disk from which an interrupted optimization run can be resumed.
11//! For details on the usage please see the documentation of [`FileCheckpoint`] or have a look at
12//! the [example](https://github.com/argmin-rs/argmin/tree/main/examples/checkpoint).
13//!
14//! # Usage
15//!
16//! Add the following line to your dependencies list:
17//!
18//! ```toml
19//! [dependencies]
20#![doc = concat!("argmin-checkpointing-file = \"", env!("CARGO_PKG_VERSION"), "\"")]
21//! ```
22//!
23//! # License
24//!
25//! Licensed under either of
26//!
27//! * Apache License, Version 2.0,
28//! ([LICENSE-APACHE](https://github.com/argmin-rs/argmin/blob/main/LICENSE-APACHE) or
29//! <http://www.apache.org/licenses/LICENSE-2.0>)
30//! * MIT License ([LICENSE-MIT](https://github.com/argmin-rs/argmin/blob/main/LICENSE-MIT) or
31//! <http://opensource.org/licenses/MIT>)
32//!
33//! at your option.
34//!
35//! ## Contribution
36//!
37//! Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion
38//! in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above,
39//! without any additional terms or conditions.
40
41pub use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
42use argmin::core::Error;
43use serde::{de::DeserializeOwned, Serialize};
44use std::default::Default;
45use std::fs::File;
46use std::io::{BufReader, BufWriter};
47use std::path::PathBuf;
48
49/// Handles saving a checkpoint to disk as a binary file.
50#[derive(Clone, Eq, PartialEq, Debug, Hash)]
51pub struct FileCheckpoint {
52 /// Indicates how often a checkpoint is created
53 pub frequency: CheckpointingFrequency,
54 /// Directory where the checkpoints are saved to
55 pub directory: PathBuf,
56 /// Name of the checkpoint files
57 pub filename: PathBuf,
58}
59
60impl Default for FileCheckpoint {
61 /// Create a default `FileCheckpoint` instance.
62 ///
63 /// This will save the checkpoint in the file `.checkpoints/checkpoint.arg`.
64 ///
65 /// # Example
66 ///
67 /// ```
68 /// use argmin_checkpointing_file::FileCheckpoint;
69 /// # use argmin::core::checkpointing::CheckpointingFrequency;
70 /// # use std::path::PathBuf;
71 ///
72 /// let checkpoint = FileCheckpoint::default();
73 /// # assert_eq!(checkpoint.frequency, CheckpointingFrequency::default());
74 /// # assert_eq!(checkpoint.directory, PathBuf::from(".checkpoints"));
75 /// # assert_eq!(checkpoint.filename, PathBuf::from("checkpoint.arg"));
76 /// ```
77 fn default() -> FileCheckpoint {
78 FileCheckpoint {
79 frequency: CheckpointingFrequency::default(),
80 directory: PathBuf::from(".checkpoints"),
81 filename: PathBuf::from("checkpoint.arg"),
82 }
83 }
84}
85
86impl FileCheckpoint {
87 /// Create a new `FileCheckpoint` instance
88 ///
89 /// # Example
90 ///
91 /// ```
92 /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency};
93 /// # use std::path::PathBuf;
94 ///
95 /// let directory = "checkpoints";
96 /// let filename = "optimization";
97 ///
98 /// // When passed to an `Executor`, this will save a checkpoint in the file
99 /// // `checkpoints/optimization.arg` in every iteration.
100 /// let checkpoint = FileCheckpoint::new(directory, filename, CheckpointingFrequency::Always);
101 /// # assert_eq!(checkpoint.frequency, CheckpointingFrequency::Always);
102 /// # assert_eq!(checkpoint.directory, PathBuf::from("checkpoints"));
103 /// # assert_eq!(checkpoint.filename, PathBuf::from("optimization.arg"));
104 /// ```
105 pub fn new<N: AsRef<str>>(directory: N, name: N, frequency: CheckpointingFrequency) -> Self {
106 FileCheckpoint {
107 frequency,
108 directory: PathBuf::from(directory.as_ref()),
109 filename: PathBuf::from(format!("{}.arg", name.as_ref())),
110 }
111 }
112}
113
114impl<S, I> Checkpoint<S, I> for FileCheckpoint
115where
116 S: Serialize + DeserializeOwned,
117 I: Serialize + DeserializeOwned,
118{
119 /// Writes checkpoint to disk.
120 ///
121 /// If the directory does not exist already, it will be created. It uses `bincode` to serialize
122 /// the data.
123 /// It will return an error if creating the directory or file or serialization failed.
124 ///
125 /// # Example
126 ///
127 /// ```
128 /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency, Checkpoint};
129 ///
130 /// # use std::fs::File;
131 /// # use std::io::BufReader;
132 /// # let checkpoint = FileCheckpoint::new(".checkpoints", "save_test" , CheckpointingFrequency::Always);
133 /// # let solver: u64 = 12;
134 /// # let state: u64 = 21;
135 /// # let _ = std::fs::remove_file(".checkpoints/save_test.arg");
136 /// checkpoint.save(&solver, &state);
137 /// # let (f_solver, f_state): (u64, u64) = bincode::deserialize_from(
138 /// # BufReader::new(File::open(".checkpoints/save_test.arg").unwrap())
139 /// # ).unwrap();
140 /// # assert_eq!(solver, f_solver);
141 /// # assert_eq!(state, f_state);
142 /// # let _ = std::fs::remove_file(".checkpoints/save_test.arg");
143 /// ```
144 fn save(&self, solver: &S, state: &I) -> Result<(), Error> {
145 if !self.directory.exists() {
146 std::fs::create_dir_all(&self.directory)?
147 }
148 let fname = self.directory.join(&self.filename);
149 let f = BufWriter::new(File::create(fname)?);
150 bincode::serialize_into(f, &(solver, state))?;
151 Ok(())
152 }
153
154 /// Load a checkpoint from disk.
155 ///
156 ///
157 /// If there is no checkpoint on disk, it will return `Ok(None)`.
158 /// Returns an error if opening the file or deserialization failed.
159 ///
160 /// # Example
161 ///
162 /// ```
163 /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency, Checkpoint};
164 /// # use argmin::core::Error;
165 ///
166 /// # use std::fs::File;
167 /// # use std::io::BufWriter;
168 /// # fn main() -> Result<(), Error> {
169 /// # std::fs::DirBuilder::new().recursive(true).create(".checkpoints").unwrap();
170 /// # let f = BufWriter::new(File::create(".checkpoints/load_test.arg")?);
171 /// # let f_solver: u64 = 12;
172 /// # let f_state: u64 = 21;
173 /// # bincode::serialize_into(f, &(f_solver, f_state))?;
174 /// # let checkpoint = FileCheckpoint::new(".checkpoints", "load_test" , CheckpointingFrequency::Always);
175 /// let (solver, state) = checkpoint.load()?.unwrap();
176 /// # // Let the compiler know which types to expect.
177 /// # let blah1: u64 = solver;
178 /// # let blah2: u64 = state;
179 /// # assert_eq!(solver, f_solver);
180 /// # assert_eq!(state, f_state);
181 /// # let _ = std::fs::remove_file(".checkpoints/load_test.arg");
182 /// #
183 /// # // Return none if File does not exist
184 /// # let checkpoint = FileCheckpoint::new(".checkpoints", "certainly_does_not_exist" , CheckpointingFrequency::Always);
185 /// # let loaded: Option<(u64, u64)> = checkpoint.load()?;
186 /// # assert!(loaded.is_none());
187 /// # Ok(())
188 /// # }
189 /// ```
190 fn load(&self) -> Result<Option<(S, I)>, Error> {
191 let path = &self.directory.join(&self.filename);
192 if !path.exists() {
193 return Ok(None);
194 }
195 let file = File::open(path)?;
196 let reader = BufReader::new(file);
197 Ok(Some(bincode::deserialize_from(reader)?))
198 }
199
200 /// Returns the how often a checkpoint is to be saved.
201 ///
202 /// Used internally by [`save_cond`](`argmin::core::checkpointing::Checkpoint::save_cond`).
203 fn frequency(&self) -> CheckpointingFrequency {
204 self.frequency
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211 use argmin::core::test_utils::TestSolver;
212 use argmin::core::{IterState, State};
213
214 #[test]
215 #[allow(clippy::type_complexity)]
216 fn test_save() {
217 let solver = TestSolver::new();
218 let state: IterState<Vec<f64>, (), (), (), (), f64> =
219 IterState::new().param(vec![1.0f64, 0.0]);
220 let check = FileCheckpoint::new("checkpoints", "solver", CheckpointingFrequency::Always);
221 check.save_cond(&solver, &state, 20).unwrap();
222
223 let _loaded: Option<(TestSolver, IterState<Vec<f64>, (), (), (), (), f64>)> =
224 check.load().unwrap();
225 }
226}