gosh_database/
checkpoint.rs

1use super::*;
2
3use crate::schema::*;
4use crate::*;
5
6pub trait Checkpoint
7where
8    Self: Clone + serde::Serialize + serde::de::DeserializeOwned,
9{
10    // Return a key associated with a group of checkpoints.
11    // const CKPT_KEY: &'static str;
12
13    /// Return an unique name as the container for your data.
14    fn checkpoint_name() -> String {
15        format!("{}.ckpt", std::any::type_name::<Self>())
16    }
17
18    /// Load from the specified checkpoint `n` (ordered by create time).
19    fn from_checkpoint_n(db: &DbConnection, n: i32) -> Result<Self> {
20        use crate::schema::checkpoints::dsl::*;
21
22        let conn = db.get();
23        let ckpt_key = Self::checkpoint_name();
24        let ckpts: Vec<i32> = checkpoints
25            .filter(key.eq(&ckpt_key))
26            .select(id)
27            .order(ctime.asc())
28            .load(&*conn)?;
29        let nckpts = ckpts.len();
30        info!("Found {} checkpoints with key {}", nckpts, &ckpt_key);
31
32        // Allow negative index into the list.
33        let k = if n < 0 { nckpts as i32 + n } else { n } as usize;
34        // Avoid panic when n is invalid.
35        if k >= nckpts {
36            bail!("specified checkpoint {} is out of range.", n);
37        }
38
39        // Get encoded data.
40        let encoded: Vec<u8> = checkpoints.filter(id.eq(&ckpts[k])).select(data).first(&*conn)?;
41
42        let x = bincode::deserialize(&encoded)
43            .with_context(|| format!("Failed to deserialize from data for checkpoint: {}/{}", ckpt_key, n))?;
44        Ok(x)
45    }
46
47    /// Set a checkpoint
48    fn commit_checkpoint(&self, db: &DbConnection) -> Result<()> {
49        use crate::schema::checkpoints::dsl::*;
50
51        let ckpt_key = Self::checkpoint_name();
52        let conn = db.get();
53
54        let row = (key.eq(&ckpt_key), data.eq(bincode::serialize(&self).unwrap()));
55
56        diesel::insert_into(checkpoints)
57            .values(&row)
58            .execute(&*conn)
59            .with_context(|| {
60                format!(
61                    "Failed to save checkpoint\n chk key: {}\n db source: {}",
62                    ckpt_key,
63                    db.database_url()
64                )
65            })?;
66
67        Ok(())
68    }
69
70    /// Restore state from the latest checkpoint.
71    fn restore_from_checkpoint(&mut self, db: &DbConnection) -> Result<()> {
72        self.restore_from_checkpoint_n(db, -1)
73    }
74
75    /// List available checkpoints in `db`.
76    #[cfg(feature = "adhoc")]
77    fn list_checkpoints(db: &DbConnection) -> Result<()> {
78        use crate::schema::checkpoints::dsl::*;
79
80        let conn = db.get();
81        let ckpt_key = Self::checkpoint_name();
82        let ckpts: Vec<(i32, String, String)> = checkpoints
83            .filter(key.eq(&ckpt_key))
84            .select((id, key, ctime))
85            .order(ctime.asc())
86            .load(&*conn)?;
87        let nckpts = ckpts.len();
88        info!("Found {} checkpoints with key {}", nckpts, &ckpt_key);
89
90        println!("{:^5}\t{:^}", "slot", "create time");
91        for (i, (_, _, t)) in ckpts.iter().enumerate() {
92            println!("{:^5}\t{:^}", i, t);
93        }
94
95        Ok(())
96    }
97
98    /// Return the number of available checkpoints in database.
99    fn get_number_of_checkpoints(db: &DbConnection) -> Result<u64> {
100        use crate::schema::checkpoints::dsl::*;
101
102        let conn = db.get();
103        let ckpt_key = Self::checkpoint_name();
104        let count: i64 = checkpoints.filter(key.eq(&ckpt_key)).count().get_result(&*conn)?;
105        Ok(count as u64)
106    }
107
108    /// Restore state from the specified checkpoint `n` (ordered by create
109    /// time).
110    fn restore_from_checkpoint_n(&mut self, db: &DbConnection, n: i32) -> Result<()> {
111        let x = Self::from_checkpoint_n(db, n)?;
112        self.clone_from(&x);
113        Ok(())
114    }
115}
116
117impl<T> Checkpoint for T where T: Clone + serde::Serialize + serde::de::DeserializeOwned {}
118
119use gut::cli::*;
120use std::path::{Path, PathBuf};
121
122#[derive(Parser, Default, Clone, Debug)]
123pub struct CheckpointDb {
124    /// Path to a checkpoint file for resuming calculation later.
125    #[structopt(long)]
126    chk_file: Option<PathBuf>,
127
128    /// Index of checkpoint frame to restore (0-base). The default is to restore
129    /// from the lastest (--chk-slot=-1)
130    #[structopt(long)]
131    chk_slot: Option<i32>,
132
133    // internal: database connection
134    #[structopt(skip)]
135    db_connection: Option<DbConnection>,
136}
137
138impl CheckpointDb {
139    /// Construct Checkpoint from `path` to a file.
140    pub fn new<P: AsRef<Path>>(d: P) -> Self {
141        let mut chk = Self::default();
142        chk.chk_file = Some(d.as_ref().to_path_buf());
143        chk.create()
144    }
145
146    /// Construct with checkpoint slot `n`.
147    pub fn slot(mut self, n: i32) -> Self {
148        self.chk_slot = Some(n);
149        self
150    }
151
152    /// Create missing db_connection field if `chk_file` is not None. Mainly for cmdline uses.
153    pub fn create(&self) -> Self {
154        if let Some(dbfile) = &self.chk_file {
155            let url = format!("{}", dbfile.display());
156            let dbc = DbConnection::connect(&url).expect("failed to connect to db src");
157            let mut chk = self.clone();
158            chk.db_connection = Some(dbc);
159            chk
160        } else {
161            self.to_owned()
162        }
163    }
164}
165
166impl CheckpointDb {
167    /// Restore `chain` from checkpoint. Return true if restored successfuly,
168    /// false otherwise.
169    pub fn restore<T: Checkpoint>(&self, data: &mut T) -> Result<bool> {
170        // use resumed `data` from checkpoint if possible
171        if let Some(db) = &self.db_connection {
172            if let Some(n) = self.chk_slot {
173                if let Err(e) = data.restore_from_checkpoint_n(db, n) {
174                    warn!("failed to restore from checkpoint");
175                    dbg!(e);
176                }
177            } else {
178                if let Err(e) = data.restore_from_checkpoint(db) {
179                    warn!("failed to restore from checkpoint");
180                    dbg!(e);
181                    return Ok(false);
182                }
183            }
184            Ok(true)
185        } else {
186            Ok(false)
187        }
188    }
189
190    #[deprecated(note = "Please use load_from_latest instead")]
191    /// Return checkpointed `T`
192    pub fn restored<T: Checkpoint>(&self) -> Result<T> {
193        self.load_from_latest()
194    }
195
196    /// Load latest struct `T` from checkpoint
197    pub fn load_from_latest<T: Checkpoint>(&self) -> Result<T> {
198        let n = self.chk_slot.unwrap_or(-1);
199        self.load_from_slot_n(n)
200    }
201
202    /// Load struct `T` from checkpoint in `slot`
203    pub fn load_from_slot_n<T: Checkpoint>(&self, slot: i32) -> Result<T> {
204        let db = self.db_connection.as_ref().expect("no db connection");
205        Ok(T::from_checkpoint_n(db, slot)?)
206    }
207
208    /// Commit a checkpoint into database. Return true if committed, false
209    /// otherwise.
210    pub fn commit<T: Checkpoint>(&self, data: &T) -> Result<bool> {
211        if let Some(db) = &self.db_connection {
212            data.commit_checkpoint(db)?;
213            Ok(true)
214        } else {
215            Ok(false)
216        }
217    }
218
219    /// List available checkpoints in database.
220    #[cfg(feature = "adhoc")]
221    pub fn list<T: Checkpoint>(&self) -> Result<bool> {
222        if let Some(db) = &self.db_connection {
223            T::list_checkpoints(db)?;
224            Ok(true)
225        } else {
226            Ok(false)
227        }
228    }
229    /// Return the number of available checkpoints in database.
230    pub fn get_number_of_checkpoints<T: Checkpoint>(&self) -> Result<u64> {
231        if let Some(db) = &self.db_connection {
232            let n = T::get_number_of_checkpoints(db)?;
233            Ok(n)
234        } else {
235            Ok(0)
236        }
237    }
238}
239
240#[cfg(test)]
241mod test {
242    use super::*;
243
244    #[derive(Clone, Debug, Serialize, Deserialize)]
245    struct TestObject {
246        data: f64,
247    }
248
249    #[test]
250    fn test_checkpoint() -> Result<()> {
251        // setup database in a temp directory
252        let tdir = tempfile::tempdir()?;
253        let tmpdb = tdir.path().join("test.sqlite");
254        let url = format!("{}", tmpdb.display());
255        let db = DbConnection::connect(&url)?;
256
257        // commit checkpoint
258        let mut x = TestObject { data: -12.0 };
259        x.commit_checkpoint(&db)?;
260        // commit a new checkpoint
261        x.data = 1.0;
262        x.commit_checkpoint(&db)?;
263        // commit a new checkpoint again
264        x.data = 0.0;
265        x.commit_checkpoint(&db)?;
266        assert_eq!(x.data, 0.0);
267
268        // restore from checkpoint
269        #[cfg(feature = "adhoc")]
270        assert_eq!(TestObject::get_number_of_checkpoints(&db)?, 3);
271        x.restore_from_checkpoint(&db)?;
272        assert_eq!(x.data, 0.0);
273        x.restore_from_checkpoint_n(&db, 0)?;
274        assert_eq!(x.data, -12.0);
275        x.restore_from_checkpoint_n(&db, 1)?;
276        assert_eq!(x.data, 1.0);
277
278        Ok(())
279    }
280}