#![deny(
missing_docs,
non_camel_case_types,
non_snake_case,
path_statements,
trivial_casts,
trivial_numeric_casts,
unsafe_code,
unstable_features,
unused_allocation,
unused_import_braces,
unused_imports,
unused_must_use,
unused_mut,
while_true,
)]
extern crate serde;
#[macro_use] extern crate quick_error;
extern crate fs2;
#[cfg(feature = "bin")] extern crate bincode;
#[cfg(feature = "yaml")] extern crate serde_yaml;
#[cfg(test)] extern crate tempfile;
mod error;
#[cfg(feature = "bin")] mod bincode_enc;
#[cfg(feature = "yaml")] mod yaml_enc;
mod enc {
#[cfg(feature = "bin")] pub use bincode_enc::*;
#[cfg(feature = "yaml")] pub use yaml_enc::*;
}
use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use std::sync::{RwLock, RwLockWriteGuard, Mutex};
use std::hash::Hash;
use std::borrow::Borrow;
use serde::Serialize;
use serde::de::DeserializeOwned;
pub use error::BreakError;
pub type Result<T> = ::std::result::Result<T, BreakError>;
#[derive(Debug)]
pub struct Database<T: Serialize + DeserializeOwned + Eq + Hash> {
file: Mutex<File>,
data: RwLock<HashMap<T, enc::Repr>>,
}
impl<T: Serialize + DeserializeOwned + Eq + Hash> Database<T> {
pub fn open<P: AsRef<Path>>(path: P) -> Result<Database<T>> {
use std::fs::OpenOptions;
use fs2::FileExt;
use std::io::Read;
use enc::deserialize;
let mut file = try!(OpenOptions::new().read(true).write(true).create(true).open(path));
try!(file.try_lock_exclusive());
let mut buf = Vec::new();
try!(file.read_to_end(&mut buf));
let map : HashMap<T, enc::Repr> = if !buf.is_empty() {
try!(deserialize(&buf))
} else {
HashMap::new()
};
Ok(Database {
file: Mutex::new(file),
data: RwLock::new(map),
})
}
pub fn insert<S: Serialize + 'static, K: ?Sized>(&self, key: &K, obj: S) -> Result<()>
where T: Borrow<K>, K: Hash + PartialEq + ToOwned<Owned=T>
{
use enc::serialize;
let mut map = try!(self.data.write());
map.insert(key.to_owned(), try!(serialize(&obj)));
Ok(())
}
pub fn delete<K: ?Sized>(&self, key: &K) -> Result<()>
where T: Borrow<K>, K: Hash + Eq
{
let mut map = try!(self.data.write());
map.remove(key.to_owned());
Ok(())
}
pub fn retrieve<S: DeserializeOwned, K: ?Sized>(&self, key: &K) -> Result<S>
where T: Borrow<K>, K: Hash + Eq
{
use enc::deserialize;
let map = try!(self.data.read());
match map.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
}
pub fn contains_key<S: DeserializeOwned, K: ?Sized>(&self, key: &K) -> Result<bool>
where T: Borrow<K>, K: Hash + Eq
{
let map = try!(self.data.read());
Ok(map.get(key.borrow()).is_some())
}
pub fn flush(&self) -> Result<()> {
use enc::serialize;
use std::io::{Write, Seek, SeekFrom};
let map = try!(self.data.read());
let mut file = try!(self.file.lock());
let buf = try!(serialize(&*map));
try!(file.set_len(0));
try!(file.seek(SeekFrom::Start(0)));
try!(file.write(&buf.as_ref()));
try!(file.sync_all());
Ok(())
}
pub fn transaction(&self) -> Transaction<T> {
Transaction {
lock: &self.data,
data: RwLock::new(HashMap::new()),
}
}
pub fn lock(&self) -> Result<Lock<T>> {
let map = try!(self.data.write());
Ok(Lock {
lock: map,
})
}
}
pub struct Lock<'a, T: Serialize + DeserializeOwned + Eq + Hash + 'a> {
lock: RwLockWriteGuard<'a, HashMap<T, enc::Repr>>,
}
impl<'a, T: Serialize + DeserializeOwned + Eq + Hash + 'a> Lock<'a, T> {
pub fn insert<S: Serialize + 'static, K: ?Sized>(&mut self, key: &K, obj: S) -> Result<()>
where T: Borrow<K>, K: Hash + PartialEq + ToOwned<Owned=T>
{
use enc::serialize;
self.lock.insert(key.to_owned(), try!(serialize(&obj)));
Ok(())
}
pub fn retrieve<S: DeserializeOwned, K: ?Sized>(&mut self, key: &K) -> Result<S>
where T: Borrow<K>, K: Hash + Eq
{
use enc::deserialize;
match self.lock.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
}
pub fn transaction<'b>(&'b mut self) -> TransactionLock<'a, 'b, T> {
TransactionLock {
lock: self,
data: RwLock::new(HashMap::new()),
}
}
}
pub struct TransactionLock<'a: 'b, 'b, T: Serialize + DeserializeOwned + Eq + Hash + 'a> {
lock: &'b mut Lock<'a, T>,
data: RwLock<HashMap<T, enc::Repr>>,
}
impl<'a: 'b, 'b, T: Serialize + DeserializeOwned + Eq + Hash + 'a> TransactionLock<'a, 'b, T> {
pub fn insert<S: Serialize + 'static, K: ?Sized>(&mut self, key: &K, obj: S) -> Result<()>
where T: Borrow<K>, K: Hash + PartialEq + ToOwned<Owned=T>
{
use enc::serialize;
let mut map = try!(self.data.write());
map.insert(key.to_owned(), try!(serialize(&obj)));
Ok(())
}
pub fn retrieve<S: DeserializeOwned, K: ?Sized>(&mut self, key: &K) -> Result<S>
where T: Borrow<K>, K: Hash + Eq
{
use enc::deserialize;
let other_map = &mut self.lock.lock;
if other_map.contains_key(key) {
match other_map.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
} else {
let map = try!(self.data.read());
match map.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
}
}
pub fn run(self) -> Result<()> {
let mut other_map = &mut self.lock.lock;
let mut map = try!(self.data.write());
for (k, v) in map.drain() {
other_map.insert(k, v);
}
Ok(())
}
}
pub struct Transaction<'a, T: Serialize + DeserializeOwned + Eq + Hash + 'a> {
lock: &'a RwLock<HashMap<T, enc::Repr>>,
data: RwLock<HashMap<T, enc::Repr>>,
}
impl<'a, T: Serialize + DeserializeOwned + Eq + Hash + 'a> Transaction<'a, T> {
pub fn insert<S: Serialize + 'static, K: ?Sized>(&mut self, key: &K, obj: S) -> Result<()>
where T: Borrow<K>, K: Hash + PartialEq + ToOwned<Owned=T>
{
use enc::serialize;
let mut map = try!(self.data.write());
map.insert(key.to_owned(), try!(serialize(&obj)));
Ok(())
}
pub fn retrieve<S: DeserializeOwned, K: ?Sized>(&self, key: &K) -> Result<S>
where T: Borrow<K>, K: Hash + Eq
{
use enc::deserialize;
let other_map = try!(self.lock.read());
if other_map.contains_key(key) {
match other_map.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
} else {
let map = try!(self.data.read());
match map.get(key.borrow()) {
Some(t) => Ok(try!(deserialize(t))),
None => Err(BreakError::NotFound),
}
}
}
pub fn run(self) -> Result<()> {
let mut other_map = try!(self.lock.write());
let mut map = try!(self.data.write());
for (k, v) in map.drain() {
other_map.insert(k, v);
}
Ok(())
}
}
#[cfg(test)]
mod test {
use super::{Database,BreakError};
use tempfile::NamedTempFile;
#[test]
fn insert_and_delete() {
let tmpf = NamedTempFile::new().unwrap();
let db = Database::open(tmpf.path()).unwrap();
db.insert("test", "Hello World!").unwrap();
db.delete("test").unwrap();
let hello : Result<String,BreakError> = db.retrieve("test");
assert!(hello.is_err())
}
#[test]
fn simple_insert_and_retrieve() {
let tmpf = NamedTempFile::new().unwrap();
let db = Database::open(tmpf.path()).unwrap();
db.insert("test", "Hello World!").unwrap();
let hello : String = db.retrieve("test").unwrap();
assert_eq!(hello, "Hello World!");
}
#[test]
fn test_persistence() {
let tmpf = NamedTempFile::new().unwrap();
let db = Database::open(tmpf.path()).unwrap();
db.insert("test", "Hello World!").unwrap();
db.flush().unwrap();
drop(db);
let db : Database<String> = Database::open(tmpf.path()).unwrap();
let hello : String = db.retrieve("test").unwrap();
assert_eq!(hello, "Hello World!");
}
#[test]
fn simple_transaction() {
let tmpf = NamedTempFile::new().unwrap();
let db = Database::open(tmpf.path()).unwrap();
assert!(db.retrieve::<String, str>("test").is_err());
{
let mut trans = db.transaction();
trans.insert("test", "Hello World!").unwrap();
trans.run().unwrap();
}
{
let mut trans = db.transaction();
trans.insert("test", "Hello World too!!").unwrap();
drop(trans);
}
let hello : String = db.retrieve("test").unwrap();
assert_eq!(hello, "Hello World!");
}
#[test]
fn multithreaded_locking() {
use std::sync::Arc;
let tmpf = NamedTempFile::new().unwrap();
let db = Arc::new(Database::open(tmpf.path()).unwrap());
db.insert("value", 0i64).unwrap();
let mut threads = vec![];
for _ in 0..10 {
use std::thread;
let a = db.clone();
threads.push(thread::spawn(move || {
let mut lock = a.lock().unwrap();
{
let mut trans = lock.transaction();
let x = trans.retrieve::<i64, str>("value").unwrap();
trans.insert("value", x + 1).unwrap();
trans.run().unwrap();
}
{
let mut trans = lock.transaction();
let x = trans.retrieve::<i64, str>("value").unwrap();
trans.insert("value", x - 1).unwrap();
drop(trans);
}
}));
}
for thr in threads {
thr.join().unwrap();
}
let x = db.retrieve::<i64, str>("value").unwrap();
assert_eq!(x, 10);
}
}