use super::{DB, Update};
#[cfg(feature = "proptest")]
use crate::db::DummyDBStrategy;
use crate::{
DefaultHasher, WellBehavedHasher,
arena::{ArenaHash, ArenaKey},
backend::OnDiskObject,
db::DummyArbitrary,
};
#[allow(deprecated)]
use crypto::digest::generic_array::GenericArray;
#[cfg(feature = "proptest")]
use proptest::prelude::*;
use r2d2::Pool;
use r2d2_sqlite::SqliteConnectionManager;
use rusqlite::{
Connection, OptionalExtension, Result, ToSql, Transaction,
TransactionBehavior::{self, Deferred, Immediate},
config::DbConfig::SQLITE_DBCONFIG_ENABLE_FKEY,
params,
types::FromSql,
};
use serialize::{Deserializable, Serializable};
#[cfg(not(feature = "layout-v2"))]
use std::collections::HashSet;
use std::{
collections::HashMap,
fs::{File, OpenOptions},
marker::PhantomData,
path::Path,
};
#[derive(Debug)]
pub struct SqlDB<H: WellBehavedHasher = DefaultHasher> {
pool: Pool<SqliteConnectionManager>,
_phantom: std::marker::PhantomData<H>,
lock_file: Option<File>,
}
impl<H: WellBehavedHasher> Default for SqlDB<H> {
fn default() -> Self {
let path = tempfile::NamedTempFile::new().unwrap().into_temp_path();
Self::exclusive_file(path)
}
}
impl<H: WellBehavedHasher> SqlDB<H> {
pub fn memory() -> Self {
Self::new(SqliteConnectionManager::memory(), None)
}
pub fn exclusive_file<P: AsRef<Path>>(path: P) -> Self {
Self::file(path, true)
}
#[cfg(test)]
pub(crate) fn non_exclusive_file<P: AsRef<Path>>(path: P) -> Self {
Self::file(path, false)
}
fn file<P: AsRef<Path>>(path: P, exclusive: bool) -> Self {
let normalized_path = path
.as_ref()
.canonicalize()
.unwrap_or_else(|e| panic!("can't canonicalize path {:?}: {e}", path.as_ref()));
let mut mutex_file_path = normalized_path.clone();
mutex_file_path.set_extension("mutex");
let lock_file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.open(&mutex_file_path)
.unwrap_or_else(|e| panic!("can't open .mutex file {:?}: {e}", &mutex_file_path));
if exclusive {
fs2::FileExt::try_lock_exclusive(&lock_file)
.expect("can't get exclusive lock with existing locks active");
} else {
fs2::FileExt::try_lock_shared(&lock_file)
.expect("can't get shared lock with exclusive lock active");
}
Self::new(SqliteConnectionManager::file(path), Some(lock_file))
}
fn new(cm: SqliteConnectionManager, lock_file: Option<File>) -> Self {
let init = |conn: &mut Connection| {
assert!(
conn.set_db_config(SQLITE_DBCONFIG_ENABLE_FKEY, true)?,
"foreign keys aren't supported"
);
let synchronous: u32 =
std::env::var("MIDNIGHT_STORAGE_DB_SQL_SYNCHRONOUS").map_or(0, |v| {
v.parse().expect(
"MIDNIGHT_STORAGE_DB_SQL_SYNCHRONOUS invalid as u32:
{v}",
)
});
conn.pragma_update(None, "synchronous", synchronous)?;
let journal_mode =
std::env::var("MIDNIGHT_STORAGE_DB_SQL_JOURNAL_MODE").unwrap_or("WAL".to_string());
conn.pragma_update(None, "journal_mode", journal_mode)?;
conn.busy_timeout(std::time::Duration::from_millis(10_000))
};
let db = SqlDB {
pool: Pool::new(cm.with_init(init)).unwrap(),
_phantom: PhantomData,
lock_file,
};
db.create_tables();
db
}
fn create_tables(&self) {
self.with_tx(Immediate, |tx| {
#[cfg(not(feature = "layout-v2"))]
let sql = "CREATE TABLE IF NOT EXISTS node (
key BLOB NOT NULL PRIMARY KEY,
data BLOB NOT NULL,
ref_count INT NOT NULL,
children BLOB NOT NULL
)";
#[cfg(feature = "layout-v2")]
let sql = "CREATE TABLE IF NOT EXISTS node (
key BLOB NOT NULL PRIMARY KEY,
data BLOB NOT NULL,
children BLOB NOT NULL
)";
tx.execute(sql, ()).unwrap();
#[cfg(not(feature = "layout-v2"))]
{
let sql = "CREATE INDEX IF NOT EXISTS ix_node_ref_count ON node (ref_count)";
tx.execute(sql, ()).unwrap();
}
let sql = "CREATE TABLE IF NOT EXISTS root (
key BLOB NOT NULL PRIMARY KEY,
count INT NOT NULL
)";
tx.execute(sql, ()).unwrap();
let sql = "CREATE INDEX IF NOT EXISTS ix_root_count ON root (count)";
tx.execute(sql, ()).unwrap();
})
}
fn with_tx<F, R>(&self, behavior: TransactionBehavior, closure: F) -> R
where
F: FnOnce(&Transaction) -> R,
R: Send,
{
let mut conn = self
.pool
.get()
.expect("UNIMPLEMENTED: should retry when connection is not available");
let tx = conn.transaction_with_behavior(behavior).unwrap();
let result = closure(&tx);
tx.commit().unwrap();
result
}
#[cfg(not(feature = "layout-v2"))]
fn _gc(&mut self, additional_roots: HashSet<ArenaHash<H>>) {
self.with_tx(Immediate, |tx| {
let sql =
"SELECT key FROM node WHERE key NOT IN (SELECT key FROM root) AND ref_count = 0";
let mut get_unreachable_keys = tx.prepare(sql).unwrap();
let sql = "SELECT children FROM node WHERE key = (?1)";
let mut get_children = tx.prepare(sql).unwrap();
let sql = "UPDATE node SET ref_count = ref_count - 1 WHERE key = (?1)";
let mut dec_ref_count = tx.prepare(sql).unwrap();
let sql = "DELETE FROM node WHERE key = (?1)";
let mut delete_node = tx.prepare(sql).unwrap();
loop {
let unreachable_keys: Vec<_> = get_unreachable_keys
.query_map([], |row| {
let key: ArenaHash<H> = row.get(0)?;
Ok(key)
})
.unwrap()
.map(|r| r.unwrap())
.filter(|k: &ArenaHash<H>| !additional_roots.contains(k))
.collect();
if unreachable_keys.is_empty() {
break;
}
for key in unreachable_keys {
let children: Vec<ArenaKey<H>> = get_children
.query_row(params![key.clone()], |row| {
let children: Children<H> = row.get(0)?;
Ok(children.0)
})
.unwrap();
for child in children.iter().flat_map(|k| k.refs()) {
dec_ref_count.execute(params![child]).unwrap();
}
delete_node.execute(params![key]).unwrap();
}
}
get_unreachable_keys.finalize().unwrap();
get_children.finalize().unwrap();
dec_ref_count.finalize().unwrap();
delete_node.finalize().unwrap();
})
}
#[cfg(test)]
pub(crate) fn clone_memory_db(&self) -> Self {
match self.lock_file {
Some(_) => panic!("Can't clone file db: found lock file!"),
None => SqlDB {
pool: self.pool.clone(),
_phantom: self._phantom,
lock_file: None,
},
}
}
}
impl<H: WellBehavedHasher> Drop for SqlDB<H> {
fn drop(&mut self) {
if let Some(lock_file) = &self.lock_file
&& let Err(e) = fs2::FileExt::unlock(lock_file)
{
eprintln!("Failed to unlock mutex file: {:?}", e);
}
}
}
impl<H: WellBehavedHasher> ToSql for ArenaHash<H> {
fn to_sql(&self) -> Result<rusqlite::types::ToSqlOutput<'_>> {
Ok(self.0.to_vec().into())
}
}
impl<H: WellBehavedHasher> ToSql for ArenaKey<H> {
fn to_sql(&self) -> Result<rusqlite::types::ToSqlOutput<'_>> {
let mut data = Vec::new();
self.serialize(&mut data)
.expect("serialization to memory should succeed");
Ok(data.into())
}
}
impl<H: WellBehavedHasher> FromSql for ArenaHash<H> {
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
#[allow(deprecated)]
Ok(ArenaHash(
GenericArray::from_slice(value.as_bytes()?).clone(),
))
}
}
struct Children<H: WellBehavedHasher>(Vec<ArenaKey<H>>);
impl<H: WellBehavedHasher> ToSql for Children<H> {
fn to_sql(&self) -> Result<rusqlite::types::ToSqlOutput<'_>> {
let mut buf = vec![];
self.0.serialize(&mut buf).unwrap();
Ok(buf.into())
}
}
impl<H: WellBehavedHasher> FromSql for Children<H> {
fn column_result(value: rusqlite::types::ValueRef<'_>) -> rusqlite::types::FromSqlResult<Self> {
Ok(Children(
Deserializable::deserialize(&mut value.as_bytes()?, 0).unwrap(),
))
}
}
impl<H: WellBehavedHasher> DB for SqlDB<H> {
type Hasher = H;
#[cfg(feature = "gc-v1")]
type ScanResumeHandle = ArenaHash<H>;
fn get_node(&self, key: &ArenaHash<H>) -> Option<OnDiskObject<H>> {
let key = key.clone();
self.with_tx(Deferred, |tx| {
#[cfg(not(feature = "layout-v2"))]
let sql = "SELECT data, ref_count, children FROM node WHERE key = (?1)";
#[cfg(feature = "layout-v2")]
let sql = "SELECT data, children FROM node WHERE key = (?1)";
let mut stmt = tx.prepare(sql).unwrap();
let result = stmt
.query_row(params![key], |row| {
let data = row.get(0)?;
#[cfg(not(feature = "layout-v2"))]
let ref_count = row.get::<_, i64>(1)? as u64;
#[cfg(not(feature = "layout-v2"))]
let children: Children<H> = row.get(2)?;
#[cfg(feature = "layout-v2")]
let children: Children<H> = row.get(1)?;
let children: Vec<ArenaKey<H>> = children.0.into_iter().collect();
Ok(OnDiskObject {
data,
#[cfg(not(feature = "layout-v2"))]
ref_count,
children,
})
})
.optional()
.unwrap();
stmt.finalize().unwrap();
result
})
}
#[cfg(not(feature = "layout-v2"))]
fn get_unreachable_keys(&self) -> Vec<ArenaHash<H>> {
self.with_tx(Deferred, |tx| {
let sql =
"SELECT key FROM node WHERE key NOT IN (SELECT key FROM root) AND ref_count = 0";
let mut get_unreachable_keys = tx.prepare(sql).unwrap();
let unreachable_keys: Vec<ArenaHash<H>> = get_unreachable_keys
.query_map([], |row| {
let key: ArenaHash<H> = row.get(0)?;
Ok(key)
})
.unwrap()
.map(|r| r.unwrap())
.collect();
get_unreachable_keys.finalize().unwrap();
unreachable_keys
})
}
fn batch_get_nodes<I>(&self, keys: I) -> Vec<(ArenaHash<H>, Option<OnDiskObject<H>>)>
where
I: Iterator<Item = ArenaHash<H>>,
{
let keys = keys.collect::<Vec<_>>();
self.with_tx(Deferred, |tx| {
#[cfg(not(feature = "layout-v2"))]
let sql = "SELECT data, ref_count, children FROM node WHERE key = (?1)";
#[cfg(feature = "layout-v2")]
let sql = "SELECT data, children FROM node WHERE key = (?1)";
let mut stmt = tx.prepare(sql).unwrap();
let result = keys
.into_iter()
.filter_map(|key| {
stmt.query_row(params![key.clone()], |row| {
let data = row.get(0)?;
#[cfg(not(feature = "layout-v2"))]
let ref_count = row.get::<_, i64>(1)? as u64;
#[cfg(not(feature = "layout-v2"))]
let children: Children<H> = row.get(2)?;
#[cfg(feature = "layout-v2")]
let children: Children<H> = row.get(1)?;
let children: Vec<ArenaKey<H>> = children.0.into_iter().collect();
let obj = OnDiskObject {
data,
#[cfg(not(feature = "layout-v2"))]
ref_count,
children,
};
Ok((key, Some(obj)))
})
.optional()
.unwrap()
})
.collect();
stmt.finalize().unwrap();
result
})
}
fn insert_node(&mut self, key: ArenaHash<H>, object: OnDiskObject<H>) {
self.with_tx(Immediate, |tx| {
#[cfg(not(feature = "layout-v2"))]
let sql = "INSERT OR REPLACE INTO node (key, data, ref_count, children) \
VALUES (?1, ?2, ?3, ?4)";
#[cfg(feature = "layout-v2")]
let sql = "INSERT OR REPLACE INTO node (key, data, children) \
VALUES (?1, ?2, ?3)";
let mut stmt = tx.prepare(sql).unwrap();
#[cfg(not(feature = "layout-v2"))]
stmt.execute(params![
key,
object.data,
object.ref_count as i64,
Children(object.children)
])
.unwrap();
#[cfg(feature = "layout-v2")]
stmt.execute(params![key, object.data, Children(object.children)])
.unwrap();
stmt.finalize().unwrap();
})
}
fn delete_node(&mut self, key: &ArenaHash<H>) {
let key = key.clone();
self.with_tx(Immediate, |tx| {
let sql = "DELETE FROM node WHERE key = (?1)";
let mut stmt = tx.prepare(sql).unwrap();
stmt.execute(params![key]).unwrap();
stmt.finalize().unwrap();
})
}
fn batch_update<I>(&mut self, iter: I)
where
I: Iterator<Item = (ArenaHash<H>, Update<H>)>,
{
use Update::*;
self.with_tx(Immediate, |tx| {
#[cfg(not(feature = "layout-v2"))]
let sql = "INSERT OR REPLACE INTO node (key, data, ref_count, children) \
VALUES (?1, ?2, ?3, ?4)";
#[cfg(feature = "layout-v2")]
let sql = "INSERT OR REPLACE INTO node (key, data, children) \
VALUES (?1, ?2, ?3)";
let mut insert_node = tx.prepare(sql).unwrap();
let sql = "DELETE FROM node WHERE key = (?1)";
let mut delete_node = tx.prepare(sql).unwrap();
let sql = "INSERT OR REPLACE INTO root (key, count) \
VALUES (?1, ?2)";
let mut set_root_count = tx.prepare(sql).unwrap();
let sql = "DELETE FROM root WHERE key = (?1)";
let mut delete_root_count = tx.prepare(sql).unwrap();
for (key, update) in iter {
match update {
DeleteNode => delete_node.execute(params![key]).unwrap(),
#[cfg(not(feature = "layout-v2"))]
InsertNode(object) => insert_node
.execute(params![
key,
object.data,
object.ref_count as i64,
Children(object.children)
])
.unwrap(),
#[cfg(feature = "layout-v2")]
InsertNode(object) => insert_node
.execute(params![key, object.data, Children(object.children)])
.unwrap(),
SetRootCount(count) => {
if count > 0 {
set_root_count.execute(params![key, count]).unwrap()
} else {
delete_root_count.execute(params![key]).unwrap()
}
}
};
}
insert_node.finalize().unwrap();
delete_node.finalize().unwrap();
set_root_count.finalize().unwrap();
delete_root_count.finalize().unwrap();
})
}
fn size(&self) -> usize {
self.with_tx(Deferred, |tx| {
let sql = "SELECT COUNT(*) FROM node";
let mut stmt = tx.prepare(sql).unwrap();
let result = stmt.query_row([], |row| row.get::<_, i64>(0)).unwrap() as usize;
stmt.finalize().unwrap();
result
})
}
fn get_root_count(&self, key: &ArenaHash<Self::Hasher>) -> u32 {
let key = key.clone();
self.with_tx(Deferred, |tx| {
let sql = "SELECT count FROM root WHERE key = (?1)";
let mut stmt = tx.prepare(sql).unwrap();
let result = stmt
.query_row(params![key], |row| row.get(0))
.optional()
.unwrap()
.unwrap_or(0);
stmt.finalize().unwrap();
result
})
}
fn set_root_count(&mut self, key: ArenaHash<Self::Hasher>, count: u32) {
self.with_tx(Immediate, |tx| {
if count > 0 {
let sql = "INSERT OR REPLACE INTO root (key, count) \
VALUES (?1, ?2)";
let mut stmt = tx.prepare(sql).unwrap();
stmt.execute(params![key, count]).unwrap();
stmt.finalize().unwrap();
} else {
let sql = "DELETE FROM root WHERE key = (?1)";
let mut stmt = tx.prepare(sql).unwrap();
stmt.execute(params![key]).unwrap();
stmt.finalize().unwrap();
}
})
}
fn get_roots(&self) -> HashMap<ArenaHash<Self::Hasher>, u32> {
self.with_tx(Deferred, |tx| {
let sql = "SELECT key, count FROM root";
let mut stmt = tx.prepare(sql).unwrap();
let result = stmt
.query_map([], |row| {
let key: ArenaHash<H> = row.get(0)?;
let count: u32 = row.get(1)?;
Ok((key, count))
})
.unwrap()
.map(|r| r.unwrap())
.collect();
stmt.finalize().unwrap();
result
})
}
#[cfg(feature = "gc-v1")]
fn scan(
&self,
resume_from: Option<Self::ScanResumeHandle>,
batch_size: usize,
) -> (
Vec<(ArenaHash<Self::Hasher>, OnDiskObject<Self::Hasher>)>,
Option<Self::ScanResumeHandle>,
) {
self.with_tx(Deferred, |tx| {
let parse_row = |row: &rusqlite::Row| {
let key: ArenaHash<H> = row.get(0)?;
let data = row.get(1)?;
let children: Children<H> = row.get(2)?;
let children: Vec<ArenaKey<H>> = children.0.into_iter().collect();
Ok((key, OnDiskObject { data, children }))
};
let rows: Vec<_> = if let Some(ref handle) = resume_from {
let sql = "SELECT key, data, children FROM node \
WHERE key > (?1) ORDER BY key LIMIT ?2";
let mut stmt = tx.prepare(sql).expect("Failed to prepare scan statement");
let result = stmt
.query_map(params![handle, batch_size as u32], parse_row)
.expect("Failed to execute scan query")
.map(|r| r.expect("Failed to read scan row"))
.collect();
stmt.finalize().expect("Failed to finalize scan statement");
result
} else {
let sql = "SELECT key, data, children FROM node ORDER BY key LIMIT ?1";
let mut stmt = tx.prepare(sql).expect("Failed to prepare scan statement");
let result = stmt
.query_map(params![batch_size as u32], parse_row)
.expect("Failed to execute scan query")
.map(|r| r.expect("Failed to read scan row"))
.collect();
stmt.finalize().expect("Failed to finalize scan statement");
result
};
let handle = if rows.len() == batch_size {
rows.last().map(|(k, _)| k.clone())
} else {
None
};
(rows, handle)
})
}
}
impl<H: WellBehavedHasher> DummyArbitrary for SqlDB<H> {}
#[cfg(feature = "proptest")]
impl<H: WellBehavedHasher> Arbitrary for SqlDB<H> {
type Parameters = ();
type Strategy = DummyDBStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
DummyDBStrategy::<Self>(PhantomData)
}
}
#[cfg(test)]
mod tests {
use super::{SqlDB, Update::*};
use crate::{
DefaultHasher, WellBehavedHasher, arena::ArenaHash, backend::OnDiskObject, db::DB,
};
use rand::Rng;
use rusqlite::TransactionBehavior::Deferred;
use rusqlite::types::FromSql;
#[cfg(not(feature = "layout-v2"))]
use std::collections::HashSet;
#[test]
#[ignore = "always fails, indep of busy timeout"]
fn concurrent_access_memory() {
let db = SqlDB::memory();
let mk_db = || db.clone_memory_db();
test_concurrent_access(mk_db);
}
#[test]
fn concurrent_access_file() {
let path: tempfile::TempPath = tempfile::NamedTempFile::new().unwrap().into_temp_path();
let mk_db = || SqlDB::non_exclusive_file(&path);
test_concurrent_access(mk_db);
}
const NUM_WRITE_JOBS: usize = 5;
const NUM_BULK_JOBS: usize = 10;
const NUM_READ_JOBS: usize = 100;
const ITERS_PER_JOB: usize = 10;
fn test_concurrent_access(mk_db: impl Fn() -> SqlDB) {
let mut rng = rand::thread_rng();
let k: ArenaHash<_> = rng.r#gen();
let v: OnDiskObject<_> = rng.r#gen();
let mut jobs = vec![];
for _ in 0..NUM_WRITE_JOBS {
let (k, v, db) = (k.clone(), v.clone(), mk_db());
jobs.push(std::thread::spawn(move || {
insert_read_delete_loop(k, v, db)
}));
}
for _ in 0..NUM_BULK_JOBS {
let (k, v, db) = (k.clone(), v.clone(), mk_db());
jobs.push(std::thread::spawn(move || bulk_insert_loop(k, v, db)));
}
for _ in 0..NUM_READ_JOBS {
let (k, db) = (k.clone(), mk_db());
jobs.push(std::thread::spawn(move || read_loop(k, db)));
}
for job in jobs {
job.join().unwrap();
}
}
fn insert_read_delete_loop<H: WellBehavedHasher>(
k: ArenaHash<H>,
v: OnDiskObject<H>,
mut db: SqlDB<H>,
) {
for _ in 0..ITERS_PER_JOB {
db.insert_node(k.clone(), v.clone());
db.get_node(&k);
db.delete_node(&k);
}
}
fn bulk_insert_loop<H: WellBehavedHasher>(
k: ArenaHash<H>,
v: OnDiskObject<H>,
mut db: SqlDB<H>,
) {
let u = InsertNode(v);
let iter = std::iter::repeat_n((k.clone(), u.clone()), ITERS_PER_JOB);
db.batch_update(iter);
db.delete_node(&k);
}
fn read_loop<H: WellBehavedHasher>(k: ArenaHash<H>, db: SqlDB<H>) {
for _ in 0..ITERS_PER_JOB {
db.get_node(&k);
}
}
#[test]
#[cfg(not(feature = "layout-v2"))]
fn db_level_gc() {
use crate::backend::raw_node::RawNode;
let n5 = RawNode::new(&[5], 2, vec![]);
let n4 = RawNode::new(&[4], 1, vec![&n5]);
let n3 = RawNode::new(&[3], 1, vec![&n5]);
let n2 = RawNode::new(&[2], 1, vec![&n4, &n3]);
let n1 = RawNode::new(&[1], 0, vec![&n2]);
let nodes: [&RawNode; 5] = [&n5, &n4, &n3, &n2, &n1];
let init_db = || {
let mut db = SqlDB::default();
for n in nodes.iter() {
n.insert_into_db(&mut db);
}
for n in nodes.iter() {
assert!(db.get_node(&n.key).is_some());
}
db
};
let mut db = init_db();
db.set_root_count(n1.key.clone(), 1);
db._gc(HashSet::new());
for n in nodes.iter() {
assert!(db.get_node(&n.key).is_some());
}
db.set_root_count(n1.key.clone(), 0);
db._gc(HashSet::new());
assert_eq!(db.size(), 0);
let mut db = init_db();
db.set_root_count(n2.key.clone(), 1);
db._gc(HashSet::new());
assert!(db.get_node(&n1.key).is_none());
assert!(db.get_node(&n2.key).is_some());
assert!(db.get_node(&n3.key).is_some());
assert!(db.get_node(&n4.key).is_some());
assert!(db.get_node(&n5.key).is_some());
db.set_root_count(n2.key.clone(), 0);
db.set_root_count(n3.key.clone(), 1);
db._gc(HashSet::new());
assert!(db.get_node(&n1.key).is_none());
assert!(db.get_node(&n2.key).is_none());
assert!(db.get_node(&n3.key).is_some());
assert!(db.get_node(&n4.key).is_none());
assert!(db.get_node(&n5.key).is_some());
db.set_root_count(n3.key.clone(), 0);
db._gc(HashSet::new());
assert_eq!(db.size(), 0);
let mut db = init_db();
let additional_roots = [n3.key.clone(), n4.key.clone()].into_iter().collect();
db._gc(additional_roots);
assert!(db.get_node(&n1.key).is_none());
assert!(db.get_node(&n2.key).is_none());
assert!(db.get_node(&n3.key).is_some());
assert!(db.get_node(&n4.key).is_some());
assert!(db.get_node(&n5.key).is_some());
}
trait Runner {
fn run(&self, action: impl FnOnce() + Send);
}
struct LocalRunner;
impl Runner for LocalRunner {
fn run(&self, action: impl FnOnce()) {
action();
}
}
struct ThreadRunner;
impl Runner for ThreadRunner {
fn run(&self, action: impl FnOnce() + Send) {
std::thread::scope(|s| s.spawn(action).join().unwrap());
}
}
#[test]
fn exclusivity_local() {
test_exclusivity(LocalRunner);
}
#[test]
fn exclusivity_threaded() {
test_exclusivity(ThreadRunner);
}
fn test_exclusivity(runner: impl Runner) {
let path: tempfile::TempPath = tempfile::NamedTempFile::new().unwrap().into_temp_path();
let db = SqlDB::<DefaultHasher>::non_exclusive_file(&path);
runner.run(|| {
SqlDB::<DefaultHasher>::non_exclusive_file(&path);
});
runner.run(|| {
let result = std::panic::catch_unwind(|| {
SqlDB::<DefaultHasher>::exclusive_file(&path);
});
assert!(result.is_err());
});
drop(db);
let db = SqlDB::<DefaultHasher>::exclusive_file(&path);
runner.run(|| {
let result = std::panic::catch_unwind(|| {
SqlDB::<DefaultHasher>::non_exclusive_file(&path);
});
assert!(result.is_err());
});
runner.run(|| {
let result = std::panic::catch_unwind(|| {
SqlDB::<DefaultHasher>::exclusive_file(&path);
});
assert!(result.is_err());
});
drop(db);
}
fn query_pragma<T: FromSql + Sync + Send>(db: &SqlDB, pragma: &str) -> T {
db.with_tx(Deferred, |tx| {
tx.pragma_query_value(None, pragma, |row| row.get::<_, T>(0))
.unwrap()
})
}
#[test]
fn default_sqlite_params() {
let path: tempfile::TempPath = tempfile::NamedTempFile::new().unwrap().into_temp_path();
let db = SqlDB::exclusive_file(&path);
let journal_mode: String = query_pragma(&db, "journal_mode");
assert_eq!(journal_mode.to_uppercase(), "WAL");
let synchronous: i32 = query_pragma(&db, "synchronous");
assert_eq!(synchronous, 0);
}
#[test]
#[ignore = "unsafe because it overrides the shared env"]
fn env_override_sqlite_params() {
use std::env;
unsafe { env::set_var("MIDNIGHT_STORAGE_DB_SQL_JOURNAL_MODE", "DELETE") };
unsafe { env::set_var("MIDNIGHT_STORAGE_DB_SQL_SYNCHRONOUS", "2") };
let path: tempfile::TempPath = tempfile::NamedTempFile::new().unwrap().into_temp_path();
let db = SqlDB::exclusive_file(&path);
let journal_mode: String = query_pragma(&db, "journal_mode");
assert_eq!(journal_mode.to_uppercase(), "DELETE");
let synchronous: i32 = query_pragma(&db, "synchronous");
assert_eq!(synchronous, 2);
}
}