1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
// Copyright 2018-2024 argmin developers
//
// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
// http://apache.org/licenses/LICENSE-2.0> or the MIT license <LICENSE-MIT or
// http://opensource.org/licenses/MIT>, at your option. This file may not be
// copied, modified, or distributed except according to those terms.
//! # Checkpointing
//!
//! Checkpointing is a useful mechanism for mitigating the effects of crashes when software is run
//! in an unstable environment, particularly for long run times. Checkpoints are saved regularly
//! with a user-chosen frequency. Optimizations can then be resumed from a given checkpoint after a
//! crash.
//!
//! For saving checkpoints to disk, `FileCheckpoint` is provided in the `argmin-checkpointing-file`
//! crate.
//! Via the `Checkpoint` trait other checkpointing approaches can be implemented.
//!
//! The `CheckpointingFrequency` defines how often checkpoints are saved and can be chosen to be
//! either `Always` (every iteration), `Every(u64)` (every Nth iteration) or `Never`.
//!
//! The following example shows how the `checkpointing` method is used to activate checkpointing.
//! If no checkpoint is available on disk, an optimization will be started from scratch. If the run
//! crashes and a checkpoint is found on disk, then it will resume from the checkpoint.
//!
//! ## Example
//!
//! ```rust
//! # extern crate argmin;
//! # extern crate argmin_testfunctions;
//! # use argmin::core::{CostFunction, Error, Executor, Gradient, observers::ObserverMode};
//! # #[cfg(feature = "serde1")]
//! use argmin::core::checkpointing::CheckpointingFrequency;
//! # #[cfg(feature = "serde1")]
//! use argmin_checkpointing_file::FileCheckpoint;
//! # use argmin_observer_slog::SlogLogger;
//! # use argmin::solver::landweber::Landweber;
//! # use argmin_testfunctions::{rosenbrock, rosenbrock_derivative};
//! #
//! # #[derive(Default)]
//! # struct Rosenbrock {}
//! #
//! # /// Implement `CostFunction` for `Rosenbrock`
//! # impl CostFunction for Rosenbrock {
//! # /// Type of the parameter vector
//! # type Param = Vec<f64>;
//! # /// Type of the return value computed by the cost function
//! # type Output = f64;
//! #
//! # /// Apply the cost function to a parameter `p`
//! # fn cost(&self, p: &Self::Param) -> Result<Self::Output, Error> {
//! # Ok(rosenbrock(p))
//! # }
//! # }
//! #
//! # /// Implement `Gradient` for `Rosenbrock`
//! # impl Gradient for Rosenbrock {
//! # /// Type of the parameter vector
//! # type Param = Vec<f64>;
//! # /// Type of the return value computed by the cost function
//! # type Gradient = Vec<f64>;
//! #
//! # /// Compute the gradient at parameter `p`.
//! # fn gradient(&self, p: &Self::Param) -> Result<Self::Gradient, Error> {
//! # Ok(rosenbrock_derivative(p))
//! # }
//! # }
//! #
//! # fn run() -> Result<(), Error> {
//! # // define initial parameter vector
//! # let init_param: Vec<f64> = vec![1.2, 1.2];
//! # let my_optimization_problem = Rosenbrock {};
//! #
//! # let iters = 35;
//! # let solver = Landweber::new(0.001);
//!
//! // [...]
//!
//! # #[cfg(feature = "serde1")]
//! let checkpoint = FileCheckpoint::new(
//! ".checkpoints",
//! "optim",
//! CheckpointingFrequency::Every(20)
//! );
//!
//! #
//! # #[cfg(feature = "serde1")]
//! let res = Executor::new(my_optimization_problem, solver)
//! .configure(|config| config.param(init_param).max_iters(iters))
//! .checkpointing(checkpoint)
//! .run()?;
//!
//! // [...]
//! #
//! # Ok(())
//! # }
//! #
//! # fn main() {
//! # if let Err(ref e) = run() {
//! # println!("{}", e);
//! # }
//! # }
//! ```
use crateError;
use Default;
use Display;
/// An interface for checkpointing methods
///
/// Handles saving of a checkpoint. The methods [`save`](`Checkpoint::save`) (saving the
/// checkpoint), [`load`](`Checkpoint::load`) (loading a checkpoint) are mandatory to implement.
/// The method [`save_cond`](`Checkpoint::save_cond`) determines if the conditions for calling
/// `save` are met, and if yes, calls `save`. [`frequency`](`Checkpoint::frequency`) returns the
/// conditions in form of a [`CheckpointingFrequency`].
///
/// # Example
///
/// ```
/// use argmin::core::Error;
/// use argmin::core::checkpointing::{Checkpoint, CheckpointingFrequency};
/// # #[cfg(feature = "serde1")]
/// use serde::{Serialize, de::DeserializeOwned};
///
/// struct MyCheckpoint {
/// frequency: CheckpointingFrequency,
/// // ..
/// }
///
/// # #[cfg(feature = "serde1")]
/// impl<S, I> Checkpoint<S, I> for MyCheckpoint
/// where
/// // Both `solver` (`S`) and `state` (`I`) (probably) need to be (de)serializable
/// S: Serialize + DeserializeOwned,
/// I: Serialize + DeserializeOwned,
/// # S: Default,
/// # I: Default,
/// {
/// fn save(&self, solver: &S, state: &I) -> Result<(), Error> {
/// // Save `solver` and `state`
/// Ok(())
/// }
///
/// fn load(&self) -> Result<Option<(S, I)>, Error> {
/// // Load `solver` and `state` from checkpoint
/// // Return `Ok(None)` in case checkpoint is not found.
/// # let solver = S::default();
/// # let state = I::default();
/// Ok(Some((solver, state)))
/// }
///
/// fn frequency(&self) -> CheckpointingFrequency {
/// self.frequency
/// }
/// }
/// # fn main() {}
/// ```
/// Defines at which intervals a checkpoint is saved.
///
/// # Example
///
/// ```
/// use argmin::core::checkpointing::CheckpointingFrequency;
///
/// // A checkpoint every 10 iterations
/// let every_10 = CheckpointingFrequency::Every(10);
///
/// // A checkpoint in each iteration
/// let always = CheckpointingFrequency::Always;
///
/// // The default is `CheckpointingFrequency::Always`
/// assert_eq!(CheckpointingFrequency::default(), CheckpointingFrequency::Always);
/// ```