#[cfg(feature = "sqlite")]
mod sql;
#[cfg(feature = "sqlite")]
pub use sql::SqlDB;
#[cfg(feature = "parity-db")]
pub mod paritydb;
#[cfg(feature = "parity-db")]
pub use paritydb::ParityDb;
use crate::DefaultHasher;
use crate::backend::OnDiskObject;
use crate::{
WellBehavedHasher,
arena::{ArenaHash, ArenaKey},
};
#[cfg(feature = "proptest")]
use proptest::{
prelude::*,
strategy::{NewTree, ValueTree},
test_runner::TestRunner,
};
use std::collections::{BTreeMap, HashMap, HashSet};
use std::fmt::Debug;
#[cfg(feature = "proptest")]
use std::marker::PhantomData;
use std::sync::{Arc, Mutex};
#[derive(Clone, Debug)]
pub enum Update<H: WellBehavedHasher> {
InsertNode(OnDiskObject<H>),
DeleteNode,
SetRootCount(u32),
}
#[cfg(feature = "proptest")]
pub trait DummyArbitrary: Arbitrary {}
#[cfg(not(feature = "proptest"))]
pub trait DummyArbitrary {}
pub trait DB: Default + Sync + Send + Debug + DummyArbitrary + 'static {
type Hasher: WellBehavedHasher;
#[cfg(feature = "gc-v1")]
type ScanResumeHandle: Debug + Send + Sync + Clone + 'static;
fn get_node(&self, key: &ArenaHash<Self::Hasher>) -> Option<OnDiskObject<Self::Hasher>>;
#[cfg(not(feature = "layout-v2"))]
fn get_unreachable_keys(&self) -> std::vec::Vec<ArenaHash<Self::Hasher>>;
fn insert_node(&mut self, key: ArenaHash<Self::Hasher>, object: OnDiskObject<Self::Hasher>);
fn delete_node(&mut self, key: &ArenaHash<Self::Hasher>);
fn batch_update<I>(&mut self, iter: I)
where
I: Iterator<Item = (ArenaHash<Self::Hasher>, Update<Self::Hasher>)>;
#[allow(clippy::type_complexity)]
fn batch_get_nodes<I>(
&self,
keys: I,
) -> std::vec::Vec<(ArenaHash<Self::Hasher>, Option<OnDiskObject<Self::Hasher>>)>
where
I: Iterator<Item = ArenaHash<Self::Hasher>>;
#[allow(clippy::type_complexity)]
fn bfs_get_nodes<C>(
&self,
key: &ArenaHash<Self::Hasher>,
cache_get: C,
truncate: bool,
max_depth: Option<usize>,
max_count: Option<usize>,
) -> std::vec::Vec<(ArenaHash<Self::Hasher>, OnDiskObject<Self::Hasher>)>
where
C: Fn(&ArenaHash<Self::Hasher>) -> Option<OnDiskObject<Self::Hasher>>,
{
let mut kvs = vec![];
let mut visited = HashSet::new();
let mut current_depth = 0;
let mut current_keys = vec![key.clone()];
while !current_keys.is_empty()
&& max_depth.is_none_or(|max_depth| current_depth <= max_depth)
{
let mut next_keys = vec![];
let mut unknown_keys = vec![];
for k in current_keys {
if !visited.contains(&k) {
visited.insert(k.clone());
match cache_get(&k) {
Some(node) => {
if !truncate {
next_keys
.extend(node.children.iter().flat_map(ArenaKey::refs).cloned());
}
}
_ => {
unknown_keys.push(k);
}
}
}
}
if let Some(max_count) = max_count {
unknown_keys.truncate(max_count - kvs.len());
}
for (k, v) in self.batch_get_nodes(unknown_keys.into_iter()) {
match v {
Some(node) => {
next_keys.extend(node.children.iter().flat_map(ArenaKey::refs).cloned());
kvs.push((k, node));
}
None => {
if current_depth > 0 {
panic!("child key {k:?} must be in memory or db");
}
}
}
}
current_depth += 1;
current_keys = next_keys;
}
kvs
}
fn get_root_count(&self, key: &ArenaHash<Self::Hasher>) -> u32;
fn set_root_count(&mut self, key: ArenaHash<Self::Hasher>, count: u32);
fn get_roots(&self) -> HashMap<ArenaHash<Self::Hasher>, u32>;
fn size(&self) -> usize;
#[cfg(feature = "gc-v1")]
fn scan(
&self,
resume_from: Option<Self::ScanResumeHandle>,
batch_size: usize,
) -> (
Vec<(ArenaHash<Self::Hasher>, OnDiskObject<Self::Hasher>)>,
Option<Self::ScanResumeHandle>,
);
}
pub fn dubious_batch_update<D: DB, I>(db: &mut D, iter: I)
where
I: Iterator<Item = (ArenaHash<D::Hasher>, Update<D::Hasher>)>,
{
use Update::*;
for (k, v) in iter {
match v {
InsertNode(value) => db.insert_node(k, value),
DeleteNode => db.delete_node(&k),
SetRootCount(count) => db.set_root_count(k, count),
}
}
}
#[allow(clippy::type_complexity)]
pub fn dubious_batch_get_nodes<D: DB, I>(
db: &D,
keys: I,
) -> Vec<(ArenaHash<D::Hasher>, Option<OnDiskObject<D::Hasher>>)>
where
I: Iterator<Item = ArenaHash<D::Hasher>>,
{
keys.map(|k| (k.clone(), db.get_node(&k))).collect()
}
#[derive(Clone, Debug)]
pub struct InMemoryDB<H: WellBehavedHasher = DefaultHasher> {
nodes: Arc<Mutex<BTreeMap<ArenaHash<H>, OnDiskObject<H>>>>,
roots: Arc<Mutex<HashMap<ArenaHash<H>, u32>>>,
}
impl<H: WellBehavedHasher> DummyArbitrary for InMemoryDB<H> {}
#[cfg(feature = "proptest")]
pub struct DummyDBTree<D: DB>(PhantomData<D>);
#[cfg(feature = "proptest")]
impl<D: DB> ValueTree for DummyDBTree<D> {
type Value = D;
fn current(&self) -> Self::Value {
D::default()
}
fn simplify(&mut self) -> bool {
false
}
fn complicate(&mut self) -> bool {
false
}
}
#[cfg(feature = "proptest")]
#[derive(Debug)]
pub struct DummyDBStrategy<D: DB>(PhantomData<D>);
#[cfg(feature = "proptest")]
impl<D: DB> Strategy for DummyDBStrategy<D> {
type Tree = DummyDBTree<D>;
type Value = D;
fn new_tree(&self, _runner: &mut TestRunner) -> NewTree<Self> {
Ok(DummyDBTree(PhantomData))
}
}
#[cfg(feature = "proptest")]
impl<H: WellBehavedHasher> Arbitrary for InMemoryDB<H> {
type Parameters = ();
type Strategy = DummyDBStrategy<Self>;
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
DummyDBStrategy::<Self>(PhantomData)
}
}
impl<H: WellBehavedHasher> InMemoryDB<H> {
fn lock_nodes(&self) -> std::sync::MutexGuard<'_, BTreeMap<ArenaHash<H>, OnDiskObject<H>>> {
self.nodes.lock().expect("db lock poisoned")
}
fn lock_roots(&self) -> std::sync::MutexGuard<'_, HashMap<ArenaHash<H>, u32>> {
self.roots.lock().expect("db lock poisoned")
}
}
impl<H: WellBehavedHasher> DB for InMemoryDB<H> {
type Hasher = H;
#[cfg(feature = "gc-v1")]
type ScanResumeHandle = ArenaHash<H>;
fn get_node(&self, key: &ArenaHash<H>) -> Option<OnDiskObject<H>> {
self.lock_nodes().get(key).cloned()
}
#[cfg(not(feature = "layout-v2"))]
fn get_unreachable_keys(&self) -> std::vec::Vec<ArenaHash<Self::Hasher>> {
let nodes_guard = self.lock_nodes();
let roots_guard = self.lock_roots();
let mut unreachable = vec![];
for (key, node) in nodes_guard.iter() {
if node.ref_count == 0 && !roots_guard.contains_key(key) {
unreachable.push(key.clone());
}
}
unreachable
}
fn insert_node(&mut self, key: ArenaHash<H>, object: OnDiskObject<H>) {
self.lock_nodes().insert(key, object);
}
fn delete_node(&mut self, key: &ArenaHash<H>) {
self.lock_nodes().remove(key);
}
fn get_root_count(&self, key: &ArenaHash<Self::Hasher>) -> u32 {
self.lock_roots().get(key).cloned().unwrap_or(0)
}
fn set_root_count(&mut self, key: ArenaHash<Self::Hasher>, count: u32) {
if count > 0 {
self.lock_roots().insert(key, count);
} else {
self.lock_roots().remove(&key);
}
}
fn get_roots(&self) -> HashMap<ArenaHash<Self::Hasher>, u32> {
self.lock_roots().clone()
}
fn size(&self) -> usize {
self.lock_nodes().len()
}
fn batch_update<I>(&mut self, iter: I)
where
I: Iterator<Item = (ArenaHash<Self::Hasher>, Update<Self::Hasher>)>,
{
dubious_batch_update(self, iter);
}
fn batch_get_nodes<I>(
&self,
keys: I,
) -> Vec<(ArenaHash<Self::Hasher>, Option<OnDiskObject<Self::Hasher>>)>
where
I: Iterator<Item = ArenaHash<Self::Hasher>>,
{
dubious_batch_get_nodes(self, keys)
}
#[cfg(feature = "gc-v1")]
fn scan(
&self,
resume_from: Option<Self::ScanResumeHandle>,
batch_size: usize,
) -> (
Vec<(ArenaHash<Self::Hasher>, OnDiskObject<Self::Hasher>)>,
Option<Self::ScanResumeHandle>,
) {
let start = match resume_from {
Some(handle) => std::ops::Bound::Excluded(handle),
None => std::ops::Bound::Unbounded,
};
let batch = self
.lock_nodes()
.range((start, std::ops::Bound::Unbounded))
.take(batch_size)
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<Vec<_>>();
let handle = batch.last().map(|(k, _)| k.clone());
(batch, handle)
}
}
impl<H: WellBehavedHasher> Default for InMemoryDB<H> {
fn default() -> Self {
Self {
nodes: Arc::default(),
roots: Arc::default(),
}
}
}
#[cfg(test)]
mod tests {
use super::Update::*;
use crate::backend::raw_node::RawNode;
use crate::{
DefaultHasher,
arena::ArenaHash,
backend::OnDiskObject,
db::{DB, InMemoryDB},
};
use rand::Rng;
use std::collections::{HashMap, HashSet};
const BULK_READ_NUM_KVS: usize = 1000;
#[test]
fn bulk_read_inmemorydb() {
for chunk_size in [10, 100, 1000] {
let db = InMemoryDB::<DefaultHasher>::default();
let mk_db = || db.clone();
let num_kvs = BULK_READ_NUM_KVS;
test_bulk_read(num_kvs, chunk_size, mk_db);
}
}
#[cfg(feature = "sqlite")]
#[test]
fn bulk_read_sqldb_memory() {
for chunk_size in [10, 100, 1000] {
let db = crate::db::SqlDB::<DefaultHasher>::memory();
let mk_db = || db.clone_memory_db();
let num_kvs = BULK_READ_NUM_KVS;
test_bulk_read(num_kvs, chunk_size, mk_db);
}
}
#[cfg(feature = "sqlite")]
#[test]
fn bulk_read_sqldb_file() {
for chunk_size in [10, 100, 1000] {
let path = tempfile::NamedTempFile::new().unwrap().into_temp_path();
let mk_db = || crate::db::SqlDB::<DefaultHasher>::exclusive_file(&path);
let num_kvs = BULK_READ_NUM_KVS;
test_bulk_read(num_kvs, chunk_size, mk_db);
}
}
#[cfg(feature = "parity-db")]
#[test]
fn bulk_read_paritydb() {
for chunk_size in [10, 100, 1000] {
let path = tempfile::TempDir::new().unwrap().keep();
let mk_db = || crate::db::ParityDb::<DefaultHasher>::open(&path);
let num_kvs = BULK_READ_NUM_KVS;
test_bulk_read(num_kvs, chunk_size, mk_db);
}
}
fn test_bulk_read<D: DB, F: Fn() -> D>(num_kvs: usize, chunk_size: usize, open_db: F) {
let mut db = open_db();
let mut rng = rand::thread_rng();
let kvs = (0..num_kvs)
.map(|_| rng.r#gen())
.collect::<Vec<(ArenaHash<_>, OnDiskObject<_>)>>();
let mut t = crate::test::Timer::new("test_bulk_read");
let iter = kvs.iter().map(|(k, v)| (k.clone(), InsertNode(v.clone())));
db.batch_update(iter);
t.delta("batch insert kvs");
drop(db);
let db = open_db();
t.delta("reopen db");
for (k, _) in &kvs {
db.get_node(k).unwrap();
}
let delta_1by1 = t.delta("read kvs one-by-one");
drop(db);
let db = open_db();
t.delta("reopen db");
let iter = kvs.clone().into_iter().map(|(k, _)| k);
use itertools::Itertools;
let chunks = iter.chunks(chunk_size);
for chunk in chunks.into_iter() {
for (_, v) in db.batch_get_nodes(chunk) {
v.unwrap();
}
}
let delta_batch = t.delta("batch read kvs");
println!(
"Speedup for num_kvs={}, chunk_size={}: {:.1}",
num_kvs,
chunk_size,
delta_1by1 / delta_batch
);
}
const ALL_OPS_NUM_KVS: usize = 100;
#[test]
fn all_ops_inmemorydb() {
let mut db = InMemoryDB::<DefaultHasher>::default();
test_all_ops(ALL_OPS_NUM_KVS, &mut db);
}
#[cfg(feature = "sqlite")]
#[test]
fn all_ops_sqldb_memory() {
let mut db = crate::db::SqlDB::<DefaultHasher>::memory();
test_all_ops(ALL_OPS_NUM_KVS, &mut db);
}
#[cfg(feature = "sqlite")]
#[test]
fn all_ops_sqldb_file() {
let file = tempfile::NamedTempFile::new().unwrap();
let mut db = crate::db::SqlDB::<DefaultHasher>::exclusive_file(file.into_temp_path());
test_all_ops(ALL_OPS_NUM_KVS, &mut db);
}
#[cfg(feature = "parity-db")]
#[test]
fn all_ops_paritydb() {
let mut db = crate::db::ParityDb::<DefaultHasher>::default();
test_all_ops(ALL_OPS_NUM_KVS, &mut db);
}
fn test_all_ops<D: DB>(num_kvs: usize, db: &mut D) {
let mut t = crate::test::Timer::new("test_all_ops");
let mut rng = rand::thread_rng();
let kvs = (0..num_kvs)
.map(|_| rng.r#gen())
.collect::<Vec<(ArenaHash<_>, OnDiskObject<_>)>>();
t.delta("gen kvs");
for (k, v) in kvs.clone() {
db.insert_node(k, v);
}
t.delta("insert kvs");
for (k, v) in kvs.clone() {
assert_eq!(db.get_node(&k), Some(v));
}
assert_eq!(db.size(), num_kvs);
t.delta("get kvs");
for (i, (k, _)) in kvs.clone().into_iter().enumerate() {
db.set_root_count(k, i as u32);
}
assert_eq!(db.size(), num_kvs);
t.delta("set root counts");
for (i, (k, _)) in kvs.iter().enumerate() {
assert_eq!(db.get_root_count(k), i as u32);
}
t.delta("get root counts");
for (k, v) in kvs.clone() {
assert_eq!(db.get_node(&k), Some(v));
db.delete_node(&k);
assert_eq!(db.get_node(&k), None);
}
assert_eq!(db.size(), 0);
t.delta("get, delete, and get kvs");
let iter = kvs.iter().enumerate().flat_map(|(i, (k, v))| {
[
(k.clone(), InsertNode(v.clone())),
(k.clone(), SetRootCount(i as u32)),
]
});
db.batch_update(iter);
for (i, (k, v)) in kvs.clone().into_iter().enumerate() {
assert_eq!(db.get_node(&k), Some(v));
assert_eq!(db.get_root_count(&k), i as u32);
}
assert_eq!(db.size(), num_kvs);
t.delta("batch insert and get kvs and root counts");
let root_counts_golden: HashMap<_, _> = kvs
.iter()
.enumerate()
.map(|(i, (k, _))| (k.clone(), i as u32))
.skip(1)
.collect();
let root_counts_db = db.get_roots();
assert_eq!(root_counts_golden.len(), root_counts_db.len());
assert_eq!(root_counts_golden, root_counts_db);
t.delta("batch get all roots");
let iter = kvs
.iter()
.flat_map(|(k, _)| [(k.clone(), DeleteNode), (k.clone(), SetRootCount(0))]);
db.batch_update(iter);
for (k, _) in kvs.clone() {
assert_eq!(db.get_node(&k), None);
assert_eq!(db.get_root_count(&k), 0);
}
assert_eq!(db.size(), 0);
t.delta("batch delete and get kvs and root counts");
}
#[test]
fn bfs_get_nodes_inmemorydb() {
test_bfs_get_nodes::<InMemoryDB>();
}
#[cfg(feature = "sqlite")]
#[test]
fn bfs_get_nodes_sqldb() {
test_bfs_get_nodes::<crate::db::SqlDB>();
}
#[cfg(feature = "parity-db")]
#[test]
fn bfs_get_nodes_paritydb() {
test_bfs_get_nodes::<crate::db::ParityDb>();
}
fn test_bfs_get_nodes<D: DB<Hasher = DefaultHasher>>() {
use crate::backend::raw_node::RawNode;
let n41 = RawNode::new(&[1, 4, 1], 1, vec![]);
let n42 = RawNode::new(&[1, 4, 2], 3, vec![]);
let n43 = RawNode::new(&[1, 4, 3], 2, vec![]);
let n44 = RawNode::new(&[1, 4, 4], 2, vec![]);
let n31 = RawNode::new(&[1, 3, 1], 2, vec![&n41, &n42]);
let n32 = RawNode::new(&[1, 3, 2], 2, vec![&n42, &n43]);
let n33 = RawNode::new(&[1, 3, 3], 1, vec![&n43, &n44]);
let n21 = RawNode::new(&[1, 2, 1], 2, vec![&n31, &n42, &n32]);
let n22 = RawNode::new(&[1, 2, 2], 1, vec![&n32, &n33]);
let n11 = RawNode::new(&[1, 1, 1], 0, vec![&n31, &n21, &n22]);
let o31 = RawNode::new(&[2, 3, 1], 1, vec![]);
let o32 = RawNode::new(&[2, 3, 2], 1, vec![]);
let o21 = RawNode::new(&[2, 2, 1], 1, vec![&o31, &o32]);
let o11 = RawNode::new(&[2, 1, 1], 0, vec![&n21, &n44, &o21]);
let n_nodes = [&n41, &n42, &n43, &n44, &n31, &n32, &n33, &n21, &n22, &n11];
let o_nodes: [&RawNode; 4] = [&o31, &o32, &o21, &o11];
let mut db = D::default();
for n in n_nodes.iter().chain(o_nodes.iter()) {
n.insert_into_db(&mut db);
}
let kvs = db.bfs_get_nodes(&n11.key, |_| None, false, None, None);
let keys: std::vec::Vec<_> = kvs.clone().into_iter().map(|(k, _)| k).collect();
let expected_keys: std::vec::Vec<_> =
[&n11, &n31, &n21, &n22, &n41, &n42, &n32, &n33, &n43, &n44]
.map(|n| n.key.clone())
.into_iter()
.collect();
assert_eq!(keys, expected_keys);
let cache: HashMap<_, _> = kvs.into_iter().collect();
let kvs = db.bfs_get_nodes(&o11.key, |key| cache.get(key).cloned(), false, None, None);
let keys: HashSet<_> = kvs.into_iter().map(|(k, _)| k).collect();
let expected_keys: HashSet<_> = o_nodes.iter().map(|n| n.key.clone()).collect();
assert_eq!(keys, expected_keys);
let cache: HashMap<_, _> = [&n21, &n22, &n41]
.map(|n| (n.key.clone(), n.clone().into_obj()))
.into_iter()
.collect();
let kvs = db.bfs_get_nodes(&n11.key, |key| cache.get(key).cloned(), false, None, None);
let keys: HashSet<_> = kvs.into_iter().map(|(k, _)| k).collect();
let expected_keys: HashSet<_> = n_nodes
.iter()
.filter(|n| !cache.contains_key(&n.key))
.map(|n| n.key.clone())
.collect();
assert_eq!(keys, expected_keys);
let kvs = db.bfs_get_nodes(&n11.key, |key| cache.get(key).cloned(), true, None, None);
let keys: HashSet<_> = kvs.into_iter().map(|(k, _)| k).collect();
let expected_keys: HashSet<_> = [&n11, &n31, &n42]
.map(|n| n.key.clone())
.into_iter()
.collect();
assert_eq!(keys, expected_keys);
let kvs = db.bfs_get_nodes(&n11.key, |_| None, false, Some(2), None);
let mut keys: std::vec::Vec<_> = kvs.into_iter().map(|(k, _)| k).collect();
let mut expected_keys: std::vec::Vec<_> = [
&n11.key, &n21.key, &n22.key, &n31.key, &n32.key, &n33.key, &n41.key, &n42.key,
]
.into_iter()
.cloned()
.collect();
keys.sort();
expected_keys.sort();
assert_eq!(keys, expected_keys);
let kvs = db.bfs_get_nodes(&n11.key, |_| None, false, None, Some(5));
let keys: std::vec::Vec<_> = kvs.into_iter().map(|(k, _)| k).collect();
assert_eq!(keys.len(), 5);
}
#[cfg(not(feature = "layout-v2"))]
#[test]
fn get_unreachable_keys_inmemorydb() {
test_get_unreachable_keys::<InMemoryDB>();
}
#[cfg(all(feature = "sqlite", not(feature = "layout-v2")))]
#[test]
fn get_unreachable_keys_sqldb() {
test_get_unreachable_keys::<crate::db::SqlDB>();
}
#[cfg(all(feature = "parity-db", not(feature = "layout-v2")))]
#[test]
fn get_unreachable_keys_paritydb() {
test_get_unreachable_keys::<crate::db::ParityDb>();
}
#[cfg(not(feature = "layout-v2"))]
fn test_get_unreachable_keys<D: DB<Hasher = DefaultHasher>>() {
let mut db = D::default();
let n41 = RawNode::new(&[4, 1], 0, vec![]);
let n31 = RawNode::new(&[3, 1], 1, vec![]);
let n32 = RawNode::new(&[3, 2], 0, vec![]);
let n33 = RawNode::new(&[3, 3], 1, vec![]);
let n21 = RawNode::new(&[2, 1], 0, vec![&n31, &n33]);
let n22 = RawNode::new(&[2, 2], 1, vec![]);
let n11 = RawNode::new(&[1, 1], 0, vec![&n22]);
let nodes = [&n41, &n31, &n32, &n33, &n21, &n22, &n11];
for n in nodes {
n.insert_into_db(&mut db);
}
let keys: HashSet<_> = [&n11, &n21, &n32, &n41]
.map(|n| n.key.clone())
.into_iter()
.collect();
assert_eq!(keys, db.get_unreachable_keys().into_iter().collect());
db.set_root_count(n11.key.clone(), 1);
db.set_root_count(n22.key.clone(), 1);
let keys: HashSet<_> = [&n21, &n32, &n41]
.map(|n| n.key.clone())
.into_iter()
.collect();
assert_eq!(keys, db.get_unreachable_keys().into_iter().collect());
db.set_root_count(n11.key.clone(), 0);
db.set_root_count(n22.key.clone(), 0);
}
#[test]
fn update_ref_count_inmemorydb() {
test_update_ref_count::<InMemoryDB>();
}
#[cfg(feature = "sqlite")]
#[test]
fn update_ref_count_sqldb() {
test_update_ref_count::<crate::db::SqlDB>();
}
#[cfg(feature = "parity-db")]
#[test]
fn update_ref_count_paritydb() {
test_update_ref_count::<crate::db::ParityDb>();
}
fn test_update_ref_count<D: DB>() {
let mut db = D::default();
let n1 = RawNode::new(&[1], 0, vec![]);
let k1 = n1.key.clone();
let n2 = RawNode::new(&[1], 1, vec![]);
let k2 = n2.key.clone();
assert_eq!(k1, k2);
n1.insert_into_db(&mut db);
assert_eq!(db.get_node(&k1).unwrap(), n1.into_obj());
n2.insert_into_db(&mut db);
assert_eq!(db.get_node(&k1).unwrap(), n2.into_obj());
}
}