use exonum::storage::{Database, Patch, Result as StorageResult, Snapshot};
use std::sync::{Arc, RwLock};
#[derive(Debug)]
pub struct CheckpointDb<T> {
inner: Arc<RwLock<CheckpointDbInner<T>>>,
}
impl<T: Database> CheckpointDb<T> {
pub fn new(db: T) -> Self {
CheckpointDb {
inner: Arc::new(RwLock::new(CheckpointDbInner::new(db))),
}
}
pub fn handler(&self) -> CheckpointDbHandler<T> {
CheckpointDbHandler {
inner: Arc::clone(&self.inner),
}
}
}
impl<T: Database> Database for CheckpointDb<T> {
fn snapshot(&self) -> Box<dyn Snapshot> {
self.inner
.read()
.expect("Cannot lock CheckpointDb for snapshot")
.snapshot()
}
fn merge(&self, patch: Patch) -> StorageResult<()> {
self.inner
.write()
.expect("Cannot lock CheckpointDb for merge")
.merge(patch)
}
fn merge_sync(&self, patch: Patch) -> StorageResult<()> {
self.merge(patch)
}
}
impl<T: Database> From<CheckpointDb<T>> for Arc<dyn Database> {
fn from(db: CheckpointDb<T>) -> Arc<dyn Database> {
Arc::from(Box::new(db) as Box<dyn Database>)
}
}
#[derive(Debug)]
pub struct CheckpointDbHandler<T> {
inner: Arc<RwLock<CheckpointDbInner<T>>>,
}
impl<T: Database> CheckpointDbHandler<T> {
pub fn checkpoint(&self) {
self.inner
.write()
.expect("Cannot lock checkpointDb for checkpoint")
.checkpoint();
}
pub fn rollback(&self) {
self.inner
.write()
.expect("Cannot lock CheckpointDb for rollback")
.rollback();
}
}
#[derive(Debug)]
struct CheckpointDbInner<T> {
db: T,
backup_stack: Vec<Vec<Patch>>,
}
impl<T: Database> CheckpointDbInner<T> {
fn new(db: T) -> Self {
CheckpointDbInner {
db,
backup_stack: Vec::new(),
}
}
fn snapshot(&self) -> Box<dyn Snapshot> {
self.db.snapshot()
}
fn merge(&mut self, patch: Patch) -> StorageResult<()> {
if self.backup_stack.is_empty() {
self.db.merge(patch)
} else {
self.merge_with_logging(patch)
}
}
fn merge_with_logging(&mut self, patch: Patch) -> StorageResult<()> {
let snapshot = self.db.snapshot();
self.db.merge(patch.clone())?;
let mut rev_fork = self.db.fork();
for (name, changes) in patch {
for (key, _) in changes {
match snapshot.get(&name, &key) {
Some(value) => {
rev_fork.put(&name, key, value);
}
None => {
rev_fork.remove(&name, key);
}
}
}
}
self.backup_stack
.last_mut()
.expect("`merge_with_logging` called before checkpoint has been set")
.push(rev_fork.into_patch());
Ok(())
}
fn checkpoint(&mut self) {
self.backup_stack.push(Vec::new())
}
fn rollback(&mut self) {
assert!(
!self.backup_stack.is_empty(),
"Checkpoint has not been set yet"
);
let changelog = self.backup_stack.pop().unwrap();
for patch in changelog.into_iter().rev() {
self.db.merge(patch).expect("Cannot merge roll-back patch");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use exonum::storage::{Change, MemoryDB};
#[derive(Debug, PartialOrd, Ord, PartialEq, Eq)]
enum OrdChange {
Put(Vec<u8>),
Delete,
}
impl From<Change> for OrdChange {
fn from(change: Change) -> Self {
match change {
Change::Put(value) => OrdChange::Put(value),
Change::Delete => OrdChange::Delete,
}
}
}
impl<'a> From<&'a Change> for OrdChange {
fn from(change: &'a Change) -> Self {
match *change {
Change::Put(ref value) => OrdChange::Put(value.clone()),
Change::Delete => OrdChange::Delete,
}
}
}
fn check_patch<'a, I>(patch: &Patch, changes: I)
where
I: IntoIterator<Item = (&'a str, Vec<u8>, Change)>,
{
use std::collections::BTreeSet;
use std::iter::FromIterator;
let mut patch_set: BTreeSet<(&str, _, _)> = BTreeSet::new();
for (name, changes) in patch.iter() {
for (key, value) in changes.iter() {
patch_set.insert((name, key.clone(), OrdChange::from(value)));
}
}
let expected_set = BTreeSet::from_iter(
changes
.into_iter()
.map(|(name, key, value)| (name, key, OrdChange::from(value))),
);
assert_eq!(patch_set, expected_set);
}
fn stack_len<T>(db: &CheckpointDb<T>) -> usize {
let inner = db.inner.read().unwrap();
inner.backup_stack.len()
}
#[test]
fn test_backup_stack() {
let db = CheckpointDb::new(MemoryDB::new());
let handler = db.handler();
assert_eq!(stack_len(&db), 0);
handler.checkpoint();
assert_eq!(stack_len(&db), 1);
handler.rollback();
assert_eq!(stack_len(&db), 0);
handler.checkpoint();
handler.checkpoint();
assert_eq!(stack_len(&db), 2);
handler.rollback();
assert_eq!(stack_len(&db), 1);
}
#[test]
fn test_backup() {
let db = CheckpointDb::new(MemoryDB::new());
let handler = db.handler();
handler.checkpoint();
let mut fork = db.fork();
fork.put("foo", vec![], vec![2]);
db.merge(fork.into_patch()).unwrap();
{
let inner = db.inner.read().unwrap();
let backup = &inner.backup_stack[0];
assert_eq!(backup.len(), 1);
check_patch(&backup[0], vec![("foo", vec![], Change::Delete)]);
}
let snapshot = db.snapshot();
assert_eq!(snapshot.get("foo", &[]), Some(vec![2]));
handler.checkpoint();
let mut fork = db.fork();
fork.put("foo", vec![], vec![3]);
fork.put("bar", vec![1], vec![4]);
fork.put("bar2", vec![5], vec![6]);
db.merge(fork.into_patch()).unwrap();
{
let inner = db.inner.read().unwrap();
let stack = &inner.backup_stack;
assert_eq!(stack.len(), 2);
let recent_backup = &stack[1];
let older_backup = &stack[0];
check_patch(&older_backup[0], vec![("foo", vec![], Change::Delete)]);
check_patch(
&recent_backup[0],
vec![
("bar2", vec![5], Change::Delete),
("bar", vec![1], Change::Delete),
("foo", vec![], Change::Put(vec![2])),
],
);
}
assert_eq!(snapshot.get("foo", &[]), Some(vec![2]));
let snapshot = db.snapshot();
assert_eq!(snapshot.get("foo", &[]), Some(vec![3]));
}
#[test]
fn test_rollback() {
let db = CheckpointDb::new(MemoryDB::new());
let handler = db.handler();
let mut fork = db.fork();
fork.put("foo", vec![], vec![2]);
db.merge(fork.into_patch()).unwrap();
handler.checkpoint();
handler.checkpoint();
let mut fork = db.fork();
fork.put("foo", vec![], vec![3]);
fork.put("bar", vec![1], vec![4]);
db.merge(fork.into_patch()).unwrap();
{
let inner = db.inner.read().unwrap();
let stack = &inner.backup_stack;
assert_eq!(stack.len(), 2);
assert_eq!(stack[1].len(), 1);
assert_eq!(stack[0].len(), 0);
}
let snapshot = db.snapshot();
assert_eq!(snapshot.get("foo", &[]), Some(vec![3]));
handler.rollback();
let snapshot = db.snapshot();
assert_eq!(snapshot.get("foo", &[]), Some(vec![2]));
assert_eq!(snapshot.get("bar", &[1]), None);
{
let inner = db.inner.read().unwrap();
let stack = &inner.backup_stack;
assert_eq!(stack.len(), 1);
assert_eq!(stack[0].len(), 0);
}
handler.checkpoint();
let mut fork = db.fork();
fork.put("foo", vec![], vec![4]);
fork.put("foo", vec![0, 0], vec![255]);
db.merge(fork.into_patch()).unwrap();
{
let inner = db.inner.read().unwrap();
let stack = &inner.backup_stack;
assert_eq!(stack.len(), 2);
assert_eq!(stack[1].len(), 1);
assert_eq!(stack[0].len(), 0);
}
let snapshot = db.snapshot();
assert_eq!(snapshot.get("foo", &[]), Some(vec![4]));
assert_eq!(snapshot.get("foo", &[0, 0]), Some(vec![255]));
let mut fork = db.fork();
fork.put("bar", vec![1], vec![254]);
db.merge(fork.into_patch()).unwrap();
{
let inner = db.inner.read().unwrap();
let stack = &inner.backup_stack;
assert_eq!(stack.len(), 2);
assert_eq!(stack[1].len(), 2);
assert_eq!(stack[0].len(), 0);
}
let new_snapshot = db.snapshot();
assert_eq!(new_snapshot.get("foo", &[]), Some(vec![4]));
assert_eq!(new_snapshot.get("foo", &[0, 0]), Some(vec![255]));
assert_eq!(new_snapshot.get("bar", &[1]), Some(vec![254]));
handler.rollback();
{
let inner = db.inner.read().unwrap();
let stack = &inner.backup_stack;
assert_eq!(stack.len(), 1);
assert_eq!(stack[0].len(), 0);
}
let snapshot = db.snapshot();
assert_eq!(snapshot.get("foo", &[]), Some(vec![2]));
assert_eq!(snapshot.get("foo", &[0, 0]), None);
assert_eq!(snapshot.get("bar", &[1]), None);
assert_eq!(new_snapshot.get("foo", &[]), Some(vec![4]));
assert_eq!(new_snapshot.get("foo", &[0, 0]), Some(vec![255]));
assert_eq!(new_snapshot.get("bar", &[1]), Some(vec![254]));
handler.rollback();
{
let inner = db.inner.read().unwrap();
let stack = &inner.backup_stack;
assert_eq!(stack.len(), 0);
}
}
#[test]
fn test_handler() {
let db = CheckpointDb::new(MemoryDB::new());
let handler = db.handler();
handler.checkpoint();
let mut fork = db.fork();
fork.put("foo", vec![], vec![2]);
db.merge(fork.into_patch()).unwrap();
let snapshot = db.snapshot();
assert_eq!(snapshot.get("foo", &[]), Some(vec![2]));
handler.rollback();
let snapshot = db.snapshot();
assert_eq!(snapshot.get("foo", &[]), None);
}
#[test]
#[should_panic]
fn test_extra_rollback() {
let db = CheckpointDb::new(MemoryDB::new());
let handler = db.handler();
handler.checkpoint();
handler.checkpoint();
handler.rollback();
handler.rollback();
handler.rollback();
}
}