#![deny(
missing_docs,
non_camel_case_types,
non_snake_case,
path_statements,
trivial_casts,
trivial_numeric_casts,
unsafe_code,
unstable_features,
unused_allocation,
unused_import_braces,
unused_imports,
unused_must_use,
unused_mut,
unused_qualifications,
while_true,
)]
extern crate serde;
#[macro_use] extern crate quick_error;
extern crate fs2;
extern crate bincode;
#[cfg(test)] extern crate tempfile;
mod error;
use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use std::sync::{RwLock, RwLockWriteGuard, Mutex};
use std::hash::Hash;
use std::borrow::Borrow;
use serde::{Serialize, Deserialize};
pub use error::BreakError;
pub type BreakResult<T> = Result<T, BreakError>;
#[derive(Debug)]
pub struct Database<T: Serialize + Deserialize + Eq + Hash> {
file: Mutex<File>,
data: RwLock<HashMap<T, Vec<u8>>>,
}
impl<T: Serialize + Deserialize + Eq + Hash> Database<T> {
pub fn open<P: AsRef<Path>>(path: P) -> BreakResult<Database<T>> {
use std::fs::OpenOptions;
use fs2::FileExt;
use std::io::Read;
use bincode::serde::deserialize;
let mut file = try!(OpenOptions::new().read(true).write(true).create(true).open(path));
try!(file.try_lock_exclusive());
let mut buf = Vec::new();
try!(file.read_to_end(&mut buf));
let map : HashMap<T, Vec<u8>>;
if buf.len() > 0 {
map = try!(deserialize(&buf));
} else {
map = HashMap::new();
}
Ok(Database {
file: Mutex::new(file),
data: RwLock::new(map),
})
}
pub fn insert<S: Serialize + 'static, K: ?Sized>(&self, key: &K, obj: S) -> BreakResult<()>
where T: Borrow<K>, K: Hash + PartialEq + ToOwned<Owned=T>
{
use bincode::serde::serialize;
use bincode::SizeLimit;
let mut map = match self.data.write() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
map.insert(key.to_owned(), try!(serialize(&obj, SizeLimit::Infinite)));
Ok(())
}
pub fn retrieve<S: Deserialize, K: ?Sized>(&self, key: &K) -> BreakResult<S>
where T: Borrow<K>, K: Hash + Eq
{
use bincode::serde::deserialize;
let map = match self.data.read() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
match map.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
}
pub fn contains_key<S: Deserialize, K: ?Sized>(&self, key: &K) -> bool
where T: Borrow<K>, K: Hash + Eq
{
let map = match self.data.read() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
map.get(key.borrow()).is_some()
}
pub fn flush(&self) -> BreakResult<()> {
use bincode::serde::serialize;
use bincode::SizeLimit;
use std::io::Write;
let map = match self.data.read() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
let mut file = match self.file.lock() {
Ok(guard) => guard,
Err(_) => unimplemented!(),
};
let buf = try!(serialize(&*map, SizeLimit::Infinite));
try!(file.write(&buf));
try!(file.flush());
Ok(())
}
pub fn transaction<'a>(&'a self) -> Transaction<'a, T> {
Transaction {
lock: &self.data,
data: RwLock::new(HashMap::new()),
}
}
pub fn lock<'a>(&'a self) -> Lock<'a, T> {
let map = match self.data.write() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
Lock {
cow: map,
}
}
}
pub struct Lock<'a, T: Serialize + Deserialize + Eq + Hash + 'a> {
cow: RwLockWriteGuard<'a, HashMap<T, Vec<u8>>>,
}
impl<'a, T: Serialize + Deserialize + Eq + Hash + 'a> Lock<'a, T> {
pub fn insert<S: Serialize + 'static, K: ?Sized>(&mut self, key: &K, obj: S) -> BreakResult<()>
where T: Borrow<K>, K: Hash + PartialEq + ToOwned<Owned=T>
{
use bincode::serde::serialize;
use bincode::SizeLimit;
self.cow.insert(key.to_owned(), try!(serialize(&obj, SizeLimit::Infinite)));
Ok(())
}
pub fn retrieve<S: Deserialize, K: ?Sized>(&mut self, key: &K) -> BreakResult<S>
where T: Borrow<K>, K: Hash + Eq
{
use bincode::serde::deserialize;
match self.cow.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
}
}
pub struct Transaction<'a, T: Serialize + Deserialize + Eq + Hash + 'a> {
lock: &'a RwLock<HashMap<T, Vec<u8>>>,
data: RwLock<HashMap<T, Vec<u8>>>,
}
impl<'a, T: Serialize + Deserialize + Eq + Hash + 'a> Transaction<'a, T> {
pub fn insert<S: Serialize + 'static, K: ?Sized>(&mut self, key: &K, obj: S) -> BreakResult<()>
where T: Borrow<K>, K: Hash + PartialEq + ToOwned<Owned=T>
{
use bincode::serde::serialize;
use bincode::SizeLimit;
let mut map = match self.data.write() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
map.insert(key.to_owned(), try!(serialize(&obj, SizeLimit::Infinite)));
Ok(())
}
pub fn retrieve<S: Deserialize, K: ?Sized>(&self, key: &K) -> BreakResult<S>
where T: Borrow<K>, K: Hash + Eq
{
use bincode::serde::deserialize;
let other_map = match self.lock.read() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
if other_map.contains_key(key) {
match other_map.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
} else {
let map = match self.data.read() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
match map.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
}
}
pub fn run(self) -> BreakResult<()> {
let mut other_map = match self.lock.write() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
let mut map = match self.data.write() {
Ok(guard) => guard,
Err(_) => unimplemented!(), };
for (k, v) in map.drain() {
other_map.insert(k, v);
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::Database;
use tempfile::NamedTempFile;
#[test]
fn simple_insert_and_retrieve() {
let tmpf = NamedTempFile::new().unwrap();
let db = Database::open(tmpf.path()).unwrap();
db.insert("test", "Hello World!").unwrap();
let hello : String = db.retrieve("test").unwrap();
assert_eq!(hello, "Hello World!");
}
#[test]
fn test_persistence() {
let tmpf = NamedTempFile::new().unwrap();
let db = Database::open(tmpf.path()).unwrap();
db.insert("test", "Hello World!").unwrap();
db.flush().unwrap();
drop(db);
let db : Database<String> = Database::open(tmpf.path()).unwrap();
let hello : String = db.retrieve("test").unwrap();
assert_eq!(hello, "Hello World!");
}
#[test]
fn simple_transaction() {
let tmpf = NamedTempFile::new().unwrap();
let db = Database::open(tmpf.path()).unwrap();
assert!(db.retrieve::<String, str>("test").is_err());
{
let mut trans = db.transaction();
trans.insert("test", "Hello World!").unwrap();
trans.run().unwrap();
}
{
let mut trans = db.transaction();
trans.insert("test", "Hello World too!!").unwrap();
drop(trans);
}
let hello : String = db.retrieve("test").unwrap();
assert_eq!(hello, "Hello World!");
}
#[test]
fn multithreaded_locking() {
use std::sync::Arc;
let tmpf = NamedTempFile::new().unwrap();
let db = Arc::new(Database::open(tmpf.path()).unwrap());
db.insert("value", 0i64).unwrap();
let mut threads = vec![];
for _ in 0..10 {
use std::thread;
let a = db.clone();
threads.push(thread::spawn(move || {
let mut lock = a.lock();
let x = lock.retrieve::<i64, str>("value").unwrap();
lock.insert("value", x + 1).unwrap();
}));
}
for thr in threads {
thr.join().unwrap();
}
let x = db.retrieve::<i64, str>("value").unwrap();
assert_eq!(x, 10);
}
}