use super::*;
use crate::schema::*;
use crate::*;
pub trait Checkpoint
where
Self: Clone + serde::Serialize + serde::de::DeserializeOwned,
{
fn checkpoint_name() -> String {
format!("{}.ckpt", std::any::type_name::<Self>())
}
fn from_checkpoint_n(db: &DbConnection, n: i32) -> Result<Self> {
use crate::schema::checkpoints::dsl::*;
let conn = db.get();
let ckpt_key = Self::checkpoint_name();
let ckpts: Vec<i32> = checkpoints
.filter(key.eq(&ckpt_key))
.select(id)
.order(ctime.asc())
.load(&*conn)?;
let nckpts = ckpts.len();
info!("Found {} checkpoints with key {}", nckpts, &ckpt_key);
let k = if n < 0 { nckpts as i32 + n } else { n } as usize;
if k >= nckpts {
bail!("specified checkpoint {} is out of range.", n);
}
let encoded: Vec<u8> = checkpoints.filter(id.eq(&ckpts[k])).select(data).first(&*conn)?;
let x = bincode::deserialize(&encoded)
.with_context(|| format!("Failed to deserialize from data for checkpoint: {}/{}", ckpt_key, n))?;
Ok(x)
}
fn commit_checkpoint(&self, db: &DbConnection) -> Result<()> {
use crate::schema::checkpoints::dsl::*;
let ckpt_key = Self::checkpoint_name();
let conn = db.get();
let row = (key.eq(&ckpt_key), data.eq(bincode::serialize(&self).unwrap()));
diesel::insert_into(checkpoints)
.values(&row)
.execute(&*conn)
.with_context(|| {
format!(
"Failed to save checkpoint\n chk key: {}\n db source: {}",
ckpt_key,
db.database_url()
)
})?;
Ok(())
}
fn restore_from_checkpoint(&mut self, db: &DbConnection) -> Result<()> {
self.restore_from_checkpoint_n(db, -1)
}
#[cfg(feature = "adhoc")]
fn list_checkpoints(db: &DbConnection) -> Result<()> {
use crate::schema::checkpoints::dsl::*;
let conn = db.get();
let ckpt_key = Self::checkpoint_name();
let ckpts: Vec<(i32, String, String)> = checkpoints
.filter(key.eq(&ckpt_key))
.select((id, key, ctime))
.order(ctime.asc())
.load(&*conn)?;
let nckpts = ckpts.len();
info!("Found {} checkpoints with key {}", nckpts, &ckpt_key);
println!("{:^5}\t{:^}", "slot", "create time");
for (i, (_, _, t)) in ckpts.iter().enumerate() {
println!("{:^5}\t{:^}", i, t);
}
Ok(())
}
fn get_number_of_checkpoints(db: &DbConnection) -> Result<u64> {
use crate::schema::checkpoints::dsl::*;
let conn = db.get();
let ckpt_key = Self::checkpoint_name();
let count: i64 = checkpoints.filter(key.eq(&ckpt_key)).count().get_result(&*conn)?;
Ok(count as u64)
}
fn restore_from_checkpoint_n(&mut self, db: &DbConnection, n: i32) -> Result<()> {
let x = Self::from_checkpoint_n(db, n)?;
self.clone_from(&x);
Ok(())
}
}
impl<T> Checkpoint for T where T: Clone + serde::Serialize + serde::de::DeserializeOwned {}
use gut::cli::*;
use std::path::{Path, PathBuf};
#[derive(Parser, Default, Clone, Debug)]
pub struct CheckpointDb {
#[structopt(long)]
chk_file: Option<PathBuf>,
#[structopt(long)]
chk_slot: Option<i32>,
#[structopt(skip)]
db_connection: Option<DbConnection>,
}
impl CheckpointDb {
pub fn new<P: AsRef<Path>>(d: P) -> Self {
let mut chk = Self::default();
chk.chk_file = Some(d.as_ref().to_path_buf());
chk.create()
}
pub fn slot(mut self, n: i32) -> Self {
self.chk_slot = Some(n);
self
}
pub fn create(&self) -> Self {
if let Some(dbfile) = &self.chk_file {
let url = format!("{}", dbfile.display());
let dbc = DbConnection::connect(&url).expect("failed to connect to db src");
let mut chk = self.clone();
chk.db_connection = Some(dbc);
chk
} else {
self.to_owned()
}
}
}
impl CheckpointDb {
pub fn restore<T: Checkpoint>(&self, data: &mut T) -> Result<bool> {
if let Some(db) = &self.db_connection {
if let Some(n) = self.chk_slot {
if let Err(e) = data.restore_from_checkpoint_n(db, n) {
warn!("failed to restore from checkpoint");
dbg!(e);
}
} else {
if let Err(e) = data.restore_from_checkpoint(db) {
warn!("failed to restore from checkpoint");
dbg!(e);
return Ok(false);
}
}
Ok(true)
} else {
Ok(false)
}
}
#[deprecated(note = "Please use load_from_latest instead")]
pub fn restored<T: Checkpoint>(&self) -> Result<T> {
self.load_from_latest()
}
pub fn load_from_latest<T: Checkpoint>(&self) -> Result<T> {
let n = self.chk_slot.unwrap_or(-1);
self.load_from_slot_n(n)
}
pub fn load_from_slot_n<T: Checkpoint>(&self, slot: i32) -> Result<T> {
let db = self.db_connection.as_ref().expect("no db connection");
Ok(T::from_checkpoint_n(db, slot)?)
}
pub fn commit<T: Checkpoint>(&self, data: &T) -> Result<bool> {
if let Some(db) = &self.db_connection {
data.commit_checkpoint(db)?;
Ok(true)
} else {
Ok(false)
}
}
#[cfg(feature = "adhoc")]
pub fn list<T: Checkpoint>(&self) -> Result<bool> {
if let Some(db) = &self.db_connection {
T::list_checkpoints(db)?;
Ok(true)
} else {
Ok(false)
}
}
pub fn get_number_of_checkpoints<T: Checkpoint>(&self) -> Result<u64> {
if let Some(db) = &self.db_connection {
let n = T::get_number_of_checkpoints(db)?;
Ok(n)
} else {
Ok(0)
}
}
}
#[cfg(test)]
mod test {
use super::*;
#[derive(Clone, Debug, Serialize, Deserialize)]
struct TestObject {
data: f64,
}
#[test]
fn test_checkpoint() -> Result<()> {
let tdir = tempfile::tempdir()?;
let tmpdb = tdir.path().join("test.sqlite");
let url = format!("{}", tmpdb.display());
let db = DbConnection::connect(&url)?;
let mut x = TestObject { data: -12.0 };
x.commit_checkpoint(&db)?;
x.data = 1.0;
x.commit_checkpoint(&db)?;
x.data = 0.0;
x.commit_checkpoint(&db)?;
assert_eq!(x.data, 0.0);
#[cfg(feature = "adhoc")]
assert_eq!(TestObject::get_number_of_checkpoints(&db)?, 3);
x.restore_from_checkpoint(&db)?;
assert_eq!(x.data, 0.0);
x.restore_from_checkpoint_n(&db, 0)?;
assert_eq!(x.data, -12.0);
x.restore_from_checkpoint_n(&db, 1)?;
assert_eq!(x.data, 1.0);
Ok(())
}
}