gosh_database/
checkpoint.rs1use super::*;
2
3use crate::schema::*;
4use crate::*;
5
6pub trait Checkpoint
7where
8 Self: Clone + serde::Serialize + serde::de::DeserializeOwned,
9{
10 fn checkpoint_name() -> String {
15 format!("{}.ckpt", std::any::type_name::<Self>())
16 }
17
18 fn from_checkpoint_n(db: &DbConnection, n: i32) -> Result<Self> {
20 use crate::schema::checkpoints::dsl::*;
21
22 let conn = db.get();
23 let ckpt_key = Self::checkpoint_name();
24 let ckpts: Vec<i32> = checkpoints
25 .filter(key.eq(&ckpt_key))
26 .select(id)
27 .order(ctime.asc())
28 .load(&*conn)?;
29 let nckpts = ckpts.len();
30 info!("Found {} checkpoints with key {}", nckpts, &ckpt_key);
31
32 let k = if n < 0 { nckpts as i32 + n } else { n } as usize;
34 if k >= nckpts {
36 bail!("specified checkpoint {} is out of range.", n);
37 }
38
39 let encoded: Vec<u8> = checkpoints.filter(id.eq(&ckpts[k])).select(data).first(&*conn)?;
41
42 let x = bincode::deserialize(&encoded)
43 .with_context(|| format!("Failed to deserialize from data for checkpoint: {}/{}", ckpt_key, n))?;
44 Ok(x)
45 }
46
47 fn commit_checkpoint(&self, db: &DbConnection) -> Result<()> {
49 use crate::schema::checkpoints::dsl::*;
50
51 let ckpt_key = Self::checkpoint_name();
52 let conn = db.get();
53
54 let row = (key.eq(&ckpt_key), data.eq(bincode::serialize(&self).unwrap()));
55
56 diesel::insert_into(checkpoints)
57 .values(&row)
58 .execute(&*conn)
59 .with_context(|| {
60 format!(
61 "Failed to save checkpoint\n chk key: {}\n db source: {}",
62 ckpt_key,
63 db.database_url()
64 )
65 })?;
66
67 Ok(())
68 }
69
70 fn restore_from_checkpoint(&mut self, db: &DbConnection) -> Result<()> {
72 self.restore_from_checkpoint_n(db, -1)
73 }
74
75 #[cfg(feature = "adhoc")]
77 fn list_checkpoints(db: &DbConnection) -> Result<()> {
78 use crate::schema::checkpoints::dsl::*;
79
80 let conn = db.get();
81 let ckpt_key = Self::checkpoint_name();
82 let ckpts: Vec<(i32, String, String)> = checkpoints
83 .filter(key.eq(&ckpt_key))
84 .select((id, key, ctime))
85 .order(ctime.asc())
86 .load(&*conn)?;
87 let nckpts = ckpts.len();
88 info!("Found {} checkpoints with key {}", nckpts, &ckpt_key);
89
90 println!("{:^5}\t{:^}", "slot", "create time");
91 for (i, (_, _, t)) in ckpts.iter().enumerate() {
92 println!("{:^5}\t{:^}", i, t);
93 }
94
95 Ok(())
96 }
97
98 fn get_number_of_checkpoints(db: &DbConnection) -> Result<u64> {
100 use crate::schema::checkpoints::dsl::*;
101
102 let conn = db.get();
103 let ckpt_key = Self::checkpoint_name();
104 let count: i64 = checkpoints.filter(key.eq(&ckpt_key)).count().get_result(&*conn)?;
105 Ok(count as u64)
106 }
107
108 fn restore_from_checkpoint_n(&mut self, db: &DbConnection, n: i32) -> Result<()> {
111 let x = Self::from_checkpoint_n(db, n)?;
112 self.clone_from(&x);
113 Ok(())
114 }
115}
116
117impl<T> Checkpoint for T where T: Clone + serde::Serialize + serde::de::DeserializeOwned {}
118
119use gut::cli::*;
120use std::path::{Path, PathBuf};
121
122#[derive(Parser, Default, Clone, Debug)]
123pub struct CheckpointDb {
124 #[structopt(long)]
126 chk_file: Option<PathBuf>,
127
128 #[structopt(long)]
131 chk_slot: Option<i32>,
132
133 #[structopt(skip)]
135 db_connection: Option<DbConnection>,
136}
137
138impl CheckpointDb {
139 pub fn new<P: AsRef<Path>>(d: P) -> Self {
141 let mut chk = Self::default();
142 chk.chk_file = Some(d.as_ref().to_path_buf());
143 chk.create()
144 }
145
146 pub fn slot(mut self, n: i32) -> Self {
148 self.chk_slot = Some(n);
149 self
150 }
151
152 pub fn create(&self) -> Self {
154 if let Some(dbfile) = &self.chk_file {
155 let url = format!("{}", dbfile.display());
156 let dbc = DbConnection::connect(&url).expect("failed to connect to db src");
157 let mut chk = self.clone();
158 chk.db_connection = Some(dbc);
159 chk
160 } else {
161 self.to_owned()
162 }
163 }
164}
165
166impl CheckpointDb {
167 pub fn restore<T: Checkpoint>(&self, data: &mut T) -> Result<bool> {
170 if let Some(db) = &self.db_connection {
172 if let Some(n) = self.chk_slot {
173 if let Err(e) = data.restore_from_checkpoint_n(db, n) {
174 warn!("failed to restore from checkpoint");
175 dbg!(e);
176 }
177 } else {
178 if let Err(e) = data.restore_from_checkpoint(db) {
179 warn!("failed to restore from checkpoint");
180 dbg!(e);
181 return Ok(false);
182 }
183 }
184 Ok(true)
185 } else {
186 Ok(false)
187 }
188 }
189
190 #[deprecated(note = "Please use load_from_latest instead")]
191 pub fn restored<T: Checkpoint>(&self) -> Result<T> {
193 self.load_from_latest()
194 }
195
196 pub fn load_from_latest<T: Checkpoint>(&self) -> Result<T> {
198 let n = self.chk_slot.unwrap_or(-1);
199 self.load_from_slot_n(n)
200 }
201
202 pub fn load_from_slot_n<T: Checkpoint>(&self, slot: i32) -> Result<T> {
204 let db = self.db_connection.as_ref().expect("no db connection");
205 Ok(T::from_checkpoint_n(db, slot)?)
206 }
207
208 pub fn commit<T: Checkpoint>(&self, data: &T) -> Result<bool> {
211 if let Some(db) = &self.db_connection {
212 data.commit_checkpoint(db)?;
213 Ok(true)
214 } else {
215 Ok(false)
216 }
217 }
218
219 #[cfg(feature = "adhoc")]
221 pub fn list<T: Checkpoint>(&self) -> Result<bool> {
222 if let Some(db) = &self.db_connection {
223 T::list_checkpoints(db)?;
224 Ok(true)
225 } else {
226 Ok(false)
227 }
228 }
229 pub fn get_number_of_checkpoints<T: Checkpoint>(&self) -> Result<u64> {
231 if let Some(db) = &self.db_connection {
232 let n = T::get_number_of_checkpoints(db)?;
233 Ok(n)
234 } else {
235 Ok(0)
236 }
237 }
238}
239
240#[cfg(test)]
241mod test {
242 use super::*;
243
244 #[derive(Clone, Debug, Serialize, Deserialize)]
245 struct TestObject {
246 data: f64,
247 }
248
249 #[test]
250 fn test_checkpoint() -> Result<()> {
251 let tdir = tempfile::tempdir()?;
253 let tmpdb = tdir.path().join("test.sqlite");
254 let url = format!("{}", tmpdb.display());
255 let db = DbConnection::connect(&url)?;
256
257 let mut x = TestObject { data: -12.0 };
259 x.commit_checkpoint(&db)?;
260 x.data = 1.0;
262 x.commit_checkpoint(&db)?;
263 x.data = 0.0;
265 x.commit_checkpoint(&db)?;
266 assert_eq!(x.data, 0.0);
267
268 #[cfg(feature = "adhoc")]
270 assert_eq!(TestObject::get_number_of_checkpoints(&db)?, 3);
271 x.restore_from_checkpoint(&db)?;
272 assert_eq!(x.data, 0.0);
273 x.restore_from_checkpoint_n(&db, 0)?;
274 assert_eq!(x.data, -12.0);
275 x.restore_from_checkpoint_n(&db, 1)?;
276 assert_eq!(x.data, 1.0);
277
278 Ok(())
279 }
280}