argmin_checkpointing_file/
lib.rs

1// Copyright 2018-2024 argmin developers
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
5// http://opensource.org/licenses/MIT>, at your option. This file may not be
6// copied, modified, or distributed except according to those terms.
7
8//! This crate creates checkpoints on disk for an optimization run.
9//!
10//! Saves a checkpoint on disk from which an interrupted optimization run can be resumed.
11//! For details on the usage please see the documentation of [`FileCheckpoint`] or have a look at
12//! the [example](https://github.com/argmin-rs/argmin/tree/main/examples/checkpoint).
13//!
14//! # Usage
15//!
16//! Add the following line to your dependencies list:
17//!
18//! ```toml
19//! [dependencies]
20#![doc = concat!("argmin-checkpointing-file = \"", env!("CARGO_PKG_VERSION"), "\"")]
21//! ```
22//!
23//! # License
24//!
25//! Licensed under either of
26//!
27//!   * Apache License, Version 2.0,
28//!     ([LICENSE-APACHE](https://github.com/argmin-rs/argmin/blob/main/LICENSE-APACHE) or
29//!     <http://www.apache.org/licenses/LICENSE-2.0>)
30//!   * MIT License ([LICENSE-MIT](https://github.com/argmin-rs/argmin/blob/main/LICENSE-MIT) or
31//!     <http://opensource.org/licenses/MIT>)
32//!
33//! at your option.
34//!
35//! ## Contribution
36//!
37//! Unless you explicitly state otherwise, any contribution intentionally submitted for inclusion
38//! in the work by you, as defined in the Apache-2.0 license, shall be dual licensed as above,
39//! without any additional terms or conditions.
40
41pub use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
42use argmin::core::Error;
43use serde::{de::DeserializeOwned, Serialize};
44use std::default::Default;
45use std::fs::File;
46use std::io::{BufReader, BufWriter};
47use std::path::PathBuf;
48
49/// Handles saving a checkpoint to disk as a binary file.
50#[derive(Clone, Eq, PartialEq, Debug, Hash)]
51pub struct FileCheckpoint {
52    /// Indicates how often a checkpoint is created
53    pub frequency: CheckpointingFrequency,
54    /// Directory where the checkpoints are saved to
55    pub directory: PathBuf,
56    /// Name of the checkpoint files
57    pub filename: PathBuf,
58}
59
60impl Default for FileCheckpoint {
61    /// Create a default `FileCheckpoint` instance.
62    ///
63    /// This will save the checkpoint in the file `.checkpoints/checkpoint.arg`.
64    ///
65    /// # Example
66    ///
67    /// ```
68    /// use argmin_checkpointing_file::FileCheckpoint;
69    /// # use argmin::core::checkpointing::CheckpointingFrequency;
70    /// # use std::path::PathBuf;
71    ///
72    /// let checkpoint = FileCheckpoint::default();
73    /// # assert_eq!(checkpoint.frequency, CheckpointingFrequency::default());
74    /// # assert_eq!(checkpoint.directory, PathBuf::from(".checkpoints"));
75    /// # assert_eq!(checkpoint.filename, PathBuf::from("checkpoint.arg"));
76    /// ```
77    fn default() -> FileCheckpoint {
78        FileCheckpoint {
79            frequency: CheckpointingFrequency::default(),
80            directory: PathBuf::from(".checkpoints"),
81            filename: PathBuf::from("checkpoint.arg"),
82        }
83    }
84}
85
86impl FileCheckpoint {
87    /// Create a new `FileCheckpoint` instance
88    ///
89    /// # Example
90    ///
91    /// ```
92    /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency};
93    /// # use std::path::PathBuf;
94    ///
95    /// let directory = "checkpoints";
96    /// let filename = "optimization";
97    ///
98    /// // When passed to an `Executor`, this will save a checkpoint in the file
99    /// // `checkpoints/optimization.arg` in every iteration.
100    /// let checkpoint = FileCheckpoint::new(directory, filename, CheckpointingFrequency::Always);
101    /// # assert_eq!(checkpoint.frequency, CheckpointingFrequency::Always);
102    /// # assert_eq!(checkpoint.directory, PathBuf::from("checkpoints"));
103    /// # assert_eq!(checkpoint.filename, PathBuf::from("optimization.arg"));
104    /// ```
105    pub fn new<N: AsRef<str>>(directory: N, name: N, frequency: CheckpointingFrequency) -> Self {
106        FileCheckpoint {
107            frequency,
108            directory: PathBuf::from(directory.as_ref()),
109            filename: PathBuf::from(format!("{}.arg", name.as_ref())),
110        }
111    }
112}
113
114impl<S, I> Checkpoint<S, I> for FileCheckpoint
115where
116    S: Serialize + DeserializeOwned,
117    I: Serialize + DeserializeOwned,
118{
119    /// Writes checkpoint to disk.
120    ///
121    /// If the directory does not exist already, it will be created. It uses `bincode` to serialize
122    /// the data.
123    /// It will return an error if creating the directory or file or serialization failed.
124    ///
125    /// # Example
126    ///
127    /// ```
128    /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency, Checkpoint};
129    ///
130    /// # use std::fs::File;
131    /// # use std::io::BufReader;
132    /// # let checkpoint = FileCheckpoint::new(".checkpoints", "save_test" , CheckpointingFrequency::Always);
133    /// # let solver: u64 = 12;
134    /// # let state: u64 = 21;
135    /// # let _ = std::fs::remove_file(".checkpoints/save_test.arg");
136    /// checkpoint.save(&solver, &state);
137    /// # let (f_solver, f_state): (u64, u64) = bincode::deserialize_from(
138    /// #     BufReader::new(File::open(".checkpoints/save_test.arg").unwrap())
139    /// # ).unwrap();
140    /// # assert_eq!(solver, f_solver);
141    /// # assert_eq!(state, f_state);
142    /// # let _ = std::fs::remove_file(".checkpoints/save_test.arg");
143    /// ```
144    fn save(&self, solver: &S, state: &I) -> Result<(), Error> {
145        if !self.directory.exists() {
146            std::fs::create_dir_all(&self.directory)?
147        }
148        let fname = self.directory.join(&self.filename);
149        let f = BufWriter::new(File::create(fname)?);
150        bincode::serialize_into(f, &(solver, state))?;
151        Ok(())
152    }
153
154    /// Load a checkpoint from disk.
155    ///
156    ///
157    /// If there is no checkpoint on disk, it will return `Ok(None)`.
158    /// Returns an error if opening the file or deserialization failed.
159    ///
160    /// # Example
161    ///
162    /// ```
163    /// use argmin_checkpointing_file::{FileCheckpoint, CheckpointingFrequency, Checkpoint};
164    /// # use argmin::core::Error;
165    ///
166    /// # use std::fs::File;
167    /// # use std::io::BufWriter;
168    /// # fn main() -> Result<(), Error> {
169    /// # std::fs::DirBuilder::new().recursive(true).create(".checkpoints").unwrap();
170    /// # let f = BufWriter::new(File::create(".checkpoints/load_test.arg")?);
171    /// # let f_solver: u64 = 12;
172    /// # let f_state: u64 = 21;
173    /// # bincode::serialize_into(f, &(f_solver, f_state))?;
174    /// # let checkpoint = FileCheckpoint::new(".checkpoints", "load_test" , CheckpointingFrequency::Always);
175    /// let (solver, state) = checkpoint.load()?.unwrap();
176    /// # // Let the compiler know which types to expect.
177    /// # let blah1: u64 = solver;
178    /// # let blah2: u64 = state;
179    /// # assert_eq!(solver, f_solver);
180    /// # assert_eq!(state, f_state);
181    /// # let _ = std::fs::remove_file(".checkpoints/load_test.arg");
182    /// #
183    /// # // Return none if File does not exist
184    /// # let checkpoint = FileCheckpoint::new(".checkpoints", "certainly_does_not_exist" , CheckpointingFrequency::Always);
185    /// # let loaded: Option<(u64, u64)> = checkpoint.load()?;
186    /// # assert!(loaded.is_none());
187    /// # Ok(())
188    /// # }
189    /// ```
190    fn load(&self) -> Result<Option<(S, I)>, Error> {
191        let path = &self.directory.join(&self.filename);
192        if !path.exists() {
193            return Ok(None);
194        }
195        let file = File::open(path)?;
196        let reader = BufReader::new(file);
197        Ok(Some(bincode::deserialize_from(reader)?))
198    }
199
200    /// Returns the how often a checkpoint is to be saved.
201    ///
202    /// Used internally by [`save_cond`](`argmin::core::checkpointing::Checkpoint::save_cond`).
203    fn frequency(&self) -> CheckpointingFrequency {
204        self.frequency
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211    use argmin::core::test_utils::TestSolver;
212    use argmin::core::{IterState, State};
213
214    #[test]
215    #[allow(clippy::type_complexity)]
216    fn test_save() {
217        let solver = TestSolver::new();
218        let state: IterState<Vec<f64>, (), (), (), (), f64> =
219            IterState::new().param(vec![1.0f64, 0.0]);
220        let check = FileCheckpoint::new("checkpoints", "solver", CheckpointingFrequency::Always);
221        check.save_cond(&solver, &state, 20).unwrap();
222
223        let _loaded: Option<(TestSolver, IterState<Vec<f64>, (), (), (), (), f64>)> =
224            check.load().unwrap();
225    }
226}