#![allow(rustdoc::private_intra_doc_links)]
#![allow(clippy::derived_hash_with_manual_eq)]
use crate::storable::{Loader, child_from};
use crate::storage::{DEFAULT_CACHE_SIZE, default_storage};
use crate::{
DefaultDB, DefaultHasher,
backend::{OnDiskObject, StorageBackend},
db::DB,
};
use crate::{Storable, WellBehavedHasher};
use base_crypto::hash::PERSISTENT_HASH_BYTES;
#[allow(deprecated)]
use crypto::digest::{Digest, OutputSizeUser, crypto_common::generic_array::GenericArray};
use derive_where::derive_where;
use hex::ToHex;
use parking_lot::{ReentrantMutex as SyncMutex, ReentrantMutexGuard as MutexGuard};
use rand::Rng;
use rand::distributions::{Distribution, Standard};
use serialize::{self, Deserializable, Serializable, Tagged};
use std::any::TypeId;
use std::cell::RefCell;
use std::fmt::Display;
use std::io;
use std::marker::PhantomData;
use std::rc::Rc;
use std::sync::OnceLock;
use std::{
any::Any,
collections::{HashMap, HashSet},
fmt::Debug,
hash::Hash,
io::Read,
ops::Deref,
sync::Arc,
};
#[cfg(feature = "test-utilities")]
pub static TCONSTRUCT: std::sync::Mutex<
Option<HashMap<&'static str, (usize, std::time::Duration)>>,
> = std::sync::Mutex::new(None);
pub(crate) fn hash<'a, H: WellBehavedHasher>(
root_binary_repr: &[u8],
child_hashes: impl Iterator<Item = &'a ArenaHash<H>>,
) -> ArenaHash<H> {
let mut hasher = H::default();
hasher.update((root_binary_repr.len() as u32).to_le_bytes());
hasher.update(root_binary_repr);
for c in child_hashes {
hasher.update(c.0.clone())
}
ArenaHash(hasher.finalize())
}
#[derive_where(Debug, Clone, PartialEq, Eq, Ord, PartialOrd)]
#[derive(Serializable)]
#[phantom(T, H)]
pub struct TypedArenaKey<T: ?Sized, H: WellBehavedHasher> {
pub key: ArenaKey<H>,
_phantom: PhantomData<T>,
}
impl<T, H: WellBehavedHasher> TypedArenaKey<T, H> {
pub fn refs(&self) -> Vec<&ArenaHash<H>> {
self.key.refs()
}
}
impl<T, H: WellBehavedHasher> From<TypedArenaKey<T, H>> for ArenaKey<H> {
fn from(val: TypedArenaKey<T, H>) -> Self {
val.key
}
}
impl<T, H: WellBehavedHasher> From<ArenaKey<H>> for TypedArenaKey<T, H> {
fn from(val: ArenaKey<H>) -> Self {
TypedArenaKey {
key: val,
_phantom: PhantomData,
}
}
}
impl<T: Tagged, H: WellBehavedHasher> Tagged for TypedArenaKey<T, H> {
fn tag() -> std::borrow::Cow<'static, str> {
std::borrow::Cow::Owned(format!("storage-key({})", T::tag()))
}
fn tag_unique_factor() -> String {
"storage-key".into()
}
}
#[allow(deprecated)]
type HashArray<H> = GenericArray<u8, <H as OutputSizeUser>::OutputSize>;
#[derive_where(Clone, PartialEq, Eq, Ord, PartialOrd, Default)]
pub struct ArenaHash<H: Digest = DefaultHasher>(pub HashArray<H>);
impl<H: Digest> Tagged for ArenaHash<H> {
fn tag() -> std::borrow::Cow<'static, str> {
"storage-hash".into()
}
fn tag_unique_factor() -> String {
"storage-hash".into()
}
}
impl<D: DB> Storable<D> for ArenaHash<D::Hasher> {
fn children(&self) -> std::vec::Vec<ArenaKey<<D as DB>::Hasher>> {
std::vec::Vec::new()
}
fn to_binary_repr<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error>
where
Self: Sized,
{
writer.write_all(&self.0)?;
Ok(())
}
fn from_binary_repr<R: std::io::Read>(
reader: &mut R,
_child_hashes: &mut impl Iterator<Item = ArenaKey<<D as DB>::Hasher>>,
_loader: &impl Loader<D>,
) -> Result<Self, std::io::Error>
where
Self: Sized,
{
#[allow(deprecated)]
let mut array = GenericArray::<u8, <D::Hasher as OutputSizeUser>::OutputSize>::default();
reader.read_exact(&mut array)?;
Ok(Self(array))
}
}
impl<H: Digest> Debug for ArenaHash<H> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0.encode_hex::<String>())
}
}
impl<D: Digest> Hash for ArenaHash<D> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.0.hash::<H>(state)
}
fn hash_slice<H: std::hash::Hasher>(data: &[Self], state: &mut H)
where
Self: Sized,
{
#[allow(deprecated)]
GenericArray::<u8, <D as OutputSizeUser>::OutputSize>::hash_slice(
data.iter()
.map(|k| k.0.clone())
.collect::<std::vec::Vec<GenericArray<u8, <D as OutputSizeUser>::OutputSize>>>()
.as_slice(),
state,
)
}
}
impl<H: Digest> Serializable for ArenaHash<H> {
fn serialize(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
writer.write_all(&self.0[..])
}
fn serialized_size(&self) -> usize {
<H as Digest>::output_size()
}
}
impl<H: Digest> Deserializable for ArenaHash<H> {
fn deserialize(
reader: &mut impl std::io::Read,
_recursive_depth: u32,
) -> std::io::Result<Self> {
let mut res = vec![0u8; <H as Digest>::output_size()];
reader.read_exact(&mut res[..])?;
#[allow(deprecated)]
Ok(ArenaHash(GenericArray::clone_from_slice(&res)))
}
}
impl<H: Digest> Distribution<ArenaHash<H>> for Standard {
fn sample<R: rand::prelude::Rng + ?Sized>(&self, rng: &mut R) -> ArenaHash<H> {
#[allow(deprecated)]
let mut bytes = GenericArray::default();
rng.fill_bytes(&mut bytes);
ArenaHash(bytes)
}
}
impl<H: Digest> serde::Serialize for ArenaHash<H> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.serialize_bytes(&self.0[..])
}
}
impl<'de, H: Digest> serde::Deserialize<'de> for ArenaHash<H> {
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
struct ArenaHashVisitor<H: Digest>(std::marker::PhantomData<H>);
impl<'de, H: Digest> serde::de::Visitor<'de> for ArenaHashVisitor<H> {
type Value = ArenaHash<H>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(
formatter,
"a byte array of length {}",
<H as Digest>::output_size()
)
}
fn visit_bytes<E: serde::de::Error>(self, v: &[u8]) -> Result<Self::Value, E> {
if v.len() != <H as Digest>::output_size() {
return Err(E::invalid_length(v.len(), &self));
}
#[allow(deprecated)]
Ok(ArenaHash(GenericArray::clone_from_slice(v)))
}
fn visit_byte_buf<E: serde::de::Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
self.visit_bytes(&v)
}
}
deserializer.deserialize_bytes(ArenaHashVisitor(std::marker::PhantomData))
}
}
impl<H: Digest> ArenaHash<H> {
pub(crate) fn _from_bytes(bs: &[u8]) -> Self {
#[allow(deprecated)]
let mut bytes = GenericArray::default();
for (i, b) in bs.iter().enumerate() {
bytes[i] = *b;
}
ArenaHash(bytes)
}
}
#[derive(Debug, Clone, Storable, Serializable)]
#[derive_where(Hash, PartialEq, Eq, PartialOrd, Ord)]
#[storable(base)]
#[tag = "storage-key[v2]"]
#[phantom(H)]
pub enum ArenaKey<H: WellBehavedHasher = DefaultHasher> {
Ref(ArenaHash<H>),
Direct(DirectChildNode<H>),
}
impl<H: WellBehavedHasher> From<ArenaHash<H>> for ArenaKey<H> {
fn from(value: ArenaHash<H>) -> Self {
ArenaKey::Ref(value)
}
}
impl<H: WellBehavedHasher> Distribution<ArenaKey<H>> for Standard {
fn sample<R: Rng + ?Sized>(&self, rng: &mut R) -> ArenaKey<H> {
ArenaKey::Ref(rng.r#gen())
}
}
impl<H: WellBehavedHasher> ArenaKey<H> {
pub fn hash(&self) -> &ArenaHash<H> {
match self {
ArenaKey::Ref(h) => h,
ArenaKey::Direct(n) => &n.hash,
}
}
pub fn refs(&self) -> Vec<&ArenaHash<H>> {
let mut res = Vec::with_capacity(32);
let mut frontier = Vec::with_capacity(32);
frontier.push(self);
while let Some(node) = frontier.pop() {
match node {
ArenaKey::Ref(n) => res.push(n),
ArenaKey::Direct(d) => frontier.extend(d.children.iter()),
}
}
res
}
#[cfg(test)]
pub fn into_ref(&self) -> Option<&ArenaHash<H>> {
match self {
ArenaKey::Ref(key) => Some(key),
ArenaKey::Direct(..) => None,
}
}
}
#[derive(Debug, Clone)]
#[derive_where(PartialOrd, Ord, Hash)]
pub struct DirectChildNode<H: WellBehavedHasher> {
pub data: Arc<Vec<u8>>,
pub children: Arc<Vec<ArenaKey<H>>>,
pub(crate) hash: ArenaHash<H>,
pub(crate) serialized_size: usize,
}
impl<H: WellBehavedHasher> PartialEq for DirectChildNode<H> {
fn eq(&self, other: &Self) -> bool {
self.hash == other.hash
}
}
impl<H: WellBehavedHasher> Eq for DirectChildNode<H> {}
impl<H: WellBehavedHasher> DirectChildNode<H> {
pub(crate) fn new(data: Vec<u8>, children: Vec<ArenaKey<H>>) -> Self {
let hash = crate::arena::hash(&data, children.iter().map(|c| c.hash()));
let serialized_size = data.serialized_size() + children.serialized_size();
DirectChildNode {
data: Arc::new(data),
children: Arc::new(children),
hash,
serialized_size,
}
}
}
impl<H: WellBehavedHasher> Serializable for DirectChildNode<H> {
fn serialize(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
self.data.serialize(writer)?;
self.children.serialize(writer)
}
fn serialized_size(&self) -> usize {
self.serialized_size
}
}
impl<H: WellBehavedHasher> Tagged for DirectChildNode<H> {
fn tag() -> std::borrow::Cow<'static, str> {
std::borrow::Cow::Borrowed("storage-direct-child-node[v1]")
}
fn tag_unique_factor() -> String {
"(vec(u8),vec(storage-key[v2]))".to_owned()
}
}
impl<H: WellBehavedHasher> Deserializable for DirectChildNode<H> {
fn deserialize(reader: &mut impl std::io::Read, recursion_depth: u32) -> std::io::Result<Self> {
let data: Vec<u8> = Deserializable::deserialize(reader, recursion_depth + 1)?;
let children: Vec<ArenaKey<H>> = Deserializable::deserialize(reader, recursion_depth + 1)?;
Ok(DirectChildNode::new(data, children))
}
}
#[derive(Debug)]
#[derive_where(Clone)]
pub struct Arena<D: DB = DefaultDB> {
metadata: Arc<SyncMutex<RefCell<MetaData<D>>>>,
sp_cache: Arc<SyncMutex<RefCell<SpCache<D>>>>,
backend: Arc<SyncMutex<RefCell<StorageBackend<D>>>>,
}
impl<D: DB> Default for Arena<D> {
fn default() -> Self {
Self::new_from_backend(StorageBackend::<D>::new(DEFAULT_CACHE_SIZE, D::default()))
}
}
type MetaData<D> = HashMap<ArenaHash<<D as DB>::Hasher>, Node>;
type DynTypedArenaHash<H> = (ArenaHash<H>, TypeId);
type SpCache<D> =
HashMap<DynTypedArenaHash<<D as DB>::Hasher>, std::sync::Weak<dyn Any + Sync + Send>>;
#[allow(clippy::type_complexity)]
impl<D: DB> Arena<D> {
#[allow(clippy::type_complexity)]
fn lock_metadata(&self) -> MutexGuard<'_, RefCell<MetaData<D>>> {
self.metadata.lock()
}
fn lock_backend(&self) -> MutexGuard<'_, RefCell<StorageBackend<D>>> {
self.backend.lock()
}
fn lock_sp_cache(&self) -> MutexGuard<'_, RefCell<SpCache<D>>> {
self.sp_cache.lock()
}
pub(crate) fn new_from_backend(backend: StorageBackend<D>) -> Self {
Arena {
backend: Arc::new(SyncMutex::new(RefCell::new(backend))),
metadata: Arc::new(SyncMutex::new(RefCell::new(HashMap::new()))),
sp_cache: Arc::new(SyncMutex::new(RefCell::new(HashMap::new()))),
}
}
pub fn with_backend<R>(&self, f: impl FnOnce(&mut StorageBackend<D>) -> R) -> R {
f(&mut RefCell::borrow_mut(&self.lock_backend()))
}
pub fn alloc<T: Storable<D>>(&self, value: T) -> Sp<T, D> {
let children = value.children();
assert!(
children.len() <= 16,
"In order to represent the arena as an MPT Storable values must have no more than 16 children (found: {} on type {})",
children.len(),
std::any::type_name::<T>(),
);
let mut data: std::vec::Vec<u8> = std::vec::Vec::new();
value
.to_binary_repr(&mut data)
.expect("Storable data should be able to be represented in binary");
let child_repr = child_from(&data, &children);
let root_hash = child_repr.hash().clone();
if let ArenaKey::Ref(_) = &child_repr {
self.new_sp_locked(
&mut self.lock_metadata(),
value,
root_hash.clone(),
data,
children,
child_repr,
)
} else {
Sp {
arena: self.clone(),
data: OnceLock::from(Arc::new(value)),
child_repr,
root: root_hash.clone(),
}
}
}
fn new_sp_locked<T: Storable<D>>(
&self,
metadata: &mut MutexGuard<'_, RefCell<MetaData<D>>>,
value: T,
key: ArenaHash<D::Hasher>,
data: std::vec::Vec<u8>,
children: std::vec::Vec<ArenaKey<D::Hasher>>,
child_repr: ArenaKey<D::Hasher>,
) -> Sp<T, D> {
self.track_locked(metadata, key.clone(), data, children, &child_repr);
let arc = {
let guard = &self.lock_sp_cache();
match self.read_sp_cache_locked(guard, &key) {
Some(arc) => arc,
None => {
let arc = Arc::new(value);
self.write_sp_cache_locked(guard, key.clone(), arc.clone());
arc
}
}
};
Sp::eager(self.clone(), key, arc, child_repr)
}
fn new_sp<T: Storable<D>>(
&self,
value: T,
key: ArenaHash<D::Hasher>,
data: std::vec::Vec<u8>,
children: std::vec::Vec<ArenaKey<D::Hasher>>,
child_repr: ArenaKey<D::Hasher>,
) -> Sp<T, D> {
self.new_sp_locked(
&mut self.lock_metadata(),
value,
key,
data,
children,
child_repr,
)
}
fn read_sp_cache_locked<T: Sync + Send + Any>(
&self,
sp_cache: &MutexGuard<RefCell<SpCache<D>>>,
key: &ArenaHash<D::Hasher>,
) -> Option<Arc<T>> {
let type_id = TypeId::of::<T>();
let cache_key = (key.clone(), type_id);
let sp_cache = RefCell::borrow(sp_cache);
sp_cache
.get(&cache_key)
.and_then(|weak| weak.upgrade())
.map(|arc| arc.clone().downcast::<T>().unwrap())
}
fn write_sp_cache_locked<T: Storable<D>>(
&self,
sp_cache: &MutexGuard<RefCell<SpCache<D>>>,
key: ArenaHash<D::Hasher>,
value: Arc<T>,
) {
let type_id = TypeId::of::<T>();
let cache_key = (key, type_id);
let arc: Arc<dyn Any + Send + Sync> = value;
RefCell::borrow_mut(sp_cache).insert(cache_key, Arc::downgrade(&arc));
}
pub fn size(&self) -> usize {
self.lock_metadata().borrow().len()
}
fn get_from_cache<T: Storable<D>>(&self, key: &ArenaHash<D::Hasher>) -> Option<Sp<T, D>> {
let _metadata_lock = self.lock_metadata();
let sp_cache_lock = self.lock_sp_cache();
self.read_sp_cache_locked::<T>(&sp_cache_lock, key)
.map(|arc| {
let child_repr = arc.as_child();
Sp::eager(self.clone(), key.clone(), arc, child_repr)
})
}
pub fn get<T: Storable<D>>(
&self,
key: &TypedArenaKey<T, D::Hasher>,
) -> Result<Sp<T, D>, std::io::Error> {
self.get_unversioned(&key.key)
}
pub(crate) fn get_unversioned<T: Storable<D>>(
&self,
key: &ArenaKey<D::Hasher>,
) -> Result<Sp<T, D>, std::io::Error> {
let max_depth = None;
Sp::<T, D>::from_arena(self, key, max_depth)
}
pub fn children(
&self,
key: &ArenaHash<D::Hasher>,
) -> Result<Vec<ArenaKey<D::Hasher>>, io::Error> {
Ok(self
.lock_backend()
.borrow_mut()
.get(key)
.ok_or(io::Error::new(
io::ErrorKind::NotFound,
format!("BackendLoader::get(): key {key:?} not in storage arena. Are you sure you persisted this key or one of its ancestors?"),
))?
.children.clone())
}
pub fn get_lazy<T: Storable<D> + Tagged>(
&self,
key: &TypedArenaKey<T, D::Hasher>,
) -> Result<Sp<T, D>, std::io::Error> {
self.get_lazy_unversioned(&key.key)
}
pub(crate) fn get_lazy_unversioned<T: Storable<D>>(
&self,
key: &ArenaKey<D::Hasher>,
) -> Result<Sp<T, D>, std::io::Error> {
let max_depth = Some(0);
Sp::<T, D>::from_arena(self, key, max_depth)
}
fn track_locked(
&self,
metadata: &MutexGuard<'_, RefCell<MetaData<D>>>,
key: ArenaHash<D::Hasher>,
data: std::vec::Vec<u8>,
children: std::vec::Vec<ArenaKey<D::Hasher>>,
child_repr: &ArenaKey<D::Hasher>,
) {
if !RefCell::borrow(metadata).contains_key(&key) {
RefCell::borrow_mut(metadata).insert(key.clone(), Node::new());
if let ArenaKey::Ref(_) = child_repr {
RefCell::borrow_mut(&self.lock_backend()).cache(key, data, children);
}
}
}
fn track_lazy(
&self,
metadata: &MutexGuard<'_, RefCell<MetaData<D>>>,
key: ArenaHash<D::Hasher>,
child_repr: &ArenaKey<D::Hasher>,
) {
if !RefCell::borrow(metadata).contains_key(&key) {
RefCell::borrow_mut(metadata).insert(key.clone(), Node::new());
if let ArenaKey::Ref(_) = child_repr {
RefCell::borrow_mut(&self.lock_backend()).cache_lazy(key);
}
}
}
fn remove_locked(
&self,
metadata: &mut MutexGuard<'_, RefCell<MetaData<D>>>,
key: &ArenaHash<D::Hasher>,
) {
RefCell::borrow_mut(metadata).remove(key);
RefCell::borrow_mut(&self.lock_backend()).uncache(key);
}
fn decrement_ref_locked(
&self,
metadata: &mut MutexGuard<'_, RefCell<MetaData<D>>>,
key: &ArenaHash<D::Hasher>,
) {
let mut remove = None;
if let Some(v) = RefCell::borrow_mut(metadata).get_mut(key) {
v.ref_count -= 1;
if v.ref_count == 0 {
remove = Some(key);
}
}
if let Some(key) = remove {
self.remove_locked(metadata, key);
}
}
fn decrement_ref(&self, key: &ArenaHash<D::Hasher>) {
self.decrement_ref_locked(&mut self.lock_metadata(), key);
}
fn increment_ref_locked(
&self,
metadata: &mut MutexGuard<'_, RefCell<MetaData<D>>>,
key: &ArenaHash<D::Hasher>,
) {
let mut metadata = RefCell::borrow_mut(metadata);
let rc = metadata
.get_mut(key)
.expect("attempted to increment non-existant ref");
rc.ref_count += 1;
}
fn increment_ref(&self, key: &ArenaHash<D::Hasher>) {
self.increment_ref_locked(&mut self.lock_metadata(), key)
}
#[inline(always)]
pub fn deserialize_sp<T: Storable<D>, R: Read>(
&self,
reader: &mut R,
recursive_depth: u32,
) -> Result<Sp<T, D>, std::io::Error> {
let nodes: TopoSortedNodes = Deserializable::deserialize(reader, recursive_depth)?;
let mut existing_nodes: Vec<IntermediateRepr<D>> = Vec::with_capacity(nodes.nodes.len());
fn idx_existing_nodes<D: DB>(
n: &[IntermediateRepr<D>],
i: u64,
) -> std::io::Result<&IntermediateRepr<D>> {
if i < n.len() as u64 {
Ok(&n[i as usize])
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!(
"error deserializing storage graph: child node index {i} out of range of processed nodes {}",
n.len()
),
))
}
}
let mut result = Err(std::io::Error::new(
std::io::ErrorKind::NotFound,
"no nodes",
));
for node in nodes.nodes.iter() {
let children = node
.child_indices
.iter()
.map(|i| {
idx_existing_nodes(&existing_nodes, *i)
.map(|n| hash::<D::Hasher>(&n.binary_repr, n.children.iter()))
})
.collect::<Result<Vec<_>, _>>()?;
let root = hash::<D::Hasher>(&node.data, children.iter());
let ir: IntermediateRepr<D> = IntermediateRepr {
binary_repr: node.data.clone(),
children,
db_type: PhantomData,
};
existing_nodes.push(ir);
result = Ok(root);
}
let mut key_to_child_repr: HashMap<ArenaHash<<D as DB>::Hasher>, ArenaKey<D::Hasher>> =
std::collections::HashMap::new();
for node in nodes.nodes.iter() {
let children = node
.child_indices
.iter()
.map(|i| {
idx_existing_nodes(&existing_nodes, *i)
.map(|n| hash::<D::Hasher>(&n.binary_repr, n.children.iter()))
})
.collect::<Result<Vec<_>, _>>()?;
let root = hash::<D::Hasher>(&node.data, children.iter());
let children = children
.iter()
.map(|h| {
key_to_child_repr
.get(h)
.ok_or(std::io::Error::other("child not in key_to_child_repr"))
})
.map(|r| r.cloned())
.collect::<Result<Vec<_>, _>>()?;
key_to_child_repr.insert(root, child_from(&node.data, &children));
}
let key = result?;
let res: Sp<T, D> = IrLoader {
arena: self,
all: &existing_nodes
.into_iter()
.map(|node| {
(
hash::<D::Hasher>(&node.binary_repr, node.children.iter()),
node,
)
})
.collect(),
recursion_depth: recursive_depth,
visited: Rc::new(RefCell::new(HashSet::new())),
key_to_child_repr,
}
.get(&ArenaKey::Ref(key))?;
if nodes == res.serialize_to_node_list() {
Ok(res)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"deserialized storage graph not in normal form",
))
}
}
}
pub struct BackendLoader<'a, D: DB> {
arena: &'a Arena<D>,
max_depth: Option<usize>,
recursion_depth: u32,
}
impl<'a, D: DB> BackendLoader<'a, D> {
pub fn new(arena: &'a Arena<D>, max_depth: Option<usize>) -> Self {
BackendLoader {
arena,
max_depth,
recursion_depth: 0,
}
}
}
#[cfg(feature = "test-utilities")]
struct ConstructTracker(&'static str, std::time::Instant);
#[cfg(feature = "test-utilities")]
impl Drop for ConstructTracker {
fn drop(&mut self) {
let dt = self.1.elapsed();
let mut construct_map = TCONSTRUCT.lock().unwrap();
let (nconstruct, tconstruct) = construct_map
.get_or_insert_default()
.entry(self.0)
.or_default();
*nconstruct += 1;
*tconstruct += dt;
}
}
impl<D: DB> Loader<D> for BackendLoader<'_, D> {
const CHECK_INVARIANTS: bool = false;
fn get<T: Storable<D>>(
&self,
child: &ArenaKey<<D as DB>::Hasher>,
) -> Result<Sp<T, D>, std::io::Error> {
if self.max_depth == Some(0) {
if let ArenaKey::Ref(key) = child {
self.arena
.track_lazy(&self.arena.lock_metadata(), key.clone(), child);
}
return Ok(Sp::lazy(
self.arena.clone(),
child.hash().clone(),
child.clone(),
));
}
#[cfg(feature = "test-utilities")]
let _tracker = ConstructTracker(std::any::type_name::<T>(), std::time::Instant::now());
let (data, children) = match child {
ArenaKey::Direct(direct_node) => {
(direct_node.data.clone(), direct_node.children.clone())
}
ArenaKey::Ref(key) => {
let metadata_lock = self.arena.lock_metadata();
let maybe_arc = self
.arena
.read_sp_cache_locked::<T>(&self.arena.lock_sp_cache(), key);
if let Some(arc) = maybe_arc {
return Ok(Sp::eager(
self.arena.clone(),
key.clone(),
arc,
child.clone(),
));
}
drop(metadata_lock);
let obj = self
.arena
.lock_backend()
.borrow_mut()
.get(key)
.ok_or(io::Error::new(
io::ErrorKind::NotFound,
format!("BackendLoader::get(): key {key:?} not in storage arena. Are you sure you persisted this key or one of its ancestors?"),
))?
.clone();
(Arc::new(obj.data), Arc::new(obj.children))
}
};
let loader = BackendLoader {
arena: self.arena,
max_depth: self.max_depth.map(|max_depth| max_depth - 1),
recursion_depth: self.recursion_depth + 1,
};
let value =
T::from_binary_repr::<&[u8]>(&mut &data[..], &mut children.iter().cloned(), &loader)?;
match child {
ArenaKey::Ref(hash) => Ok(self.arena.new_sp(
value,
hash.clone(),
data.deref().clone(),
children.deref().clone(),
child.clone(),
)),
ArenaKey::Direct(_) => Ok(Sp {
arena: self.arena.clone(),
data: OnceLock::from(Arc::new(value)),
child_repr: child.clone(),
root: child.hash().clone(),
}),
}
}
fn alloc<T: Storable<D>>(&self, obj: T) -> Sp<T, D> {
self.arena.alloc(obj)
}
fn get_recursion_depth(&self) -> u32 {
self.recursion_depth
}
}
pub(crate) struct IrLoader<'a, D: DB> {
arena: &'a Arena<D>,
all: &'a HashMap<ArenaHash<D::Hasher>, IntermediateRepr<D>>,
recursion_depth: u32,
visited: Rc<RefCell<HashSet<DynTypedArenaHash<D::Hasher>>>>,
key_to_child_repr: HashMap<ArenaHash<D::Hasher>, ArenaKey<D::Hasher>>,
}
#[cfg(test)]
impl<'a, D: DB> IrLoader<'a, D> {
pub(crate) fn new(
arena: &'a Arena<D>,
all: &'a HashMap<ArenaHash<D::Hasher>, IntermediateRepr<D>>,
key_to_child_repr: HashMap<ArenaHash<D::Hasher>, ArenaKey<D::Hasher>>,
) -> IrLoader<'a, D> {
IrLoader {
arena,
all,
recursion_depth: 0,
visited: Rc::new(RefCell::new(HashSet::new())),
key_to_child_repr,
}
}
}
impl<D: DB> Loader<D> for IrLoader<'_, D> {
const CHECK_INVARIANTS: bool = true;
fn get<T: Storable<D>>(
&self,
child: &ArenaKey<<D as DB>::Hasher>,
) -> Result<Sp<T, D>, std::io::Error> {
let key = match child {
ArenaKey::Direct(child) => {
let value = T::from_binary_repr(
&mut &child.data[..],
&mut child.children.iter().cloned(),
self,
)?;
return Ok(self.arena.alloc(value));
}
ArenaKey::Ref(key) => key,
};
let typed_key = (key.clone(), TypeId::of::<T>());
if self.visited.borrow().contains(&typed_key) {
if let Some(sp) = self.arena.get_from_cache::<T>(key) {
assert!(!sp.is_lazy(), "BUG: IrLoader MUST return strict sps");
return Ok(sp);
}
}
let ir = self.all.get(key).ok_or(io::Error::new(
io::ErrorKind::NotFound,
"IR not found in `all` map",
))?;
if self.recursion_depth > serialize::RECURSION_LIMIT {
return Err(std::io::Error::other("Reached recursion limit".to_string()));
}
let loader = IrLoader {
arena: self.arena,
all: self.all,
recursion_depth: self.recursion_depth + 1,
visited: self.visited.clone(),
key_to_child_repr: self.key_to_child_repr.clone(),
};
let sp = self.arena.alloc(T::from_binary_repr(
&mut ir.binary_repr.clone().as_slice(),
&mut ir.children.clone().into_iter().map(|k| {
self.key_to_child_repr
.get(&k)
.expect("should be able to convert child ArenaHash to ArenaKey")
.clone()
}),
&loader,
)?);
assert!(!sp.is_lazy(), "BUG: IrLoader MUST return strict sps");
self.visited.borrow_mut().insert(typed_key);
Ok(sp)
}
fn alloc<T: Storable<D>>(&self, obj: T) -> Sp<T, D> {
self.arena.alloc(obj)
}
fn get_recursion_depth(&self) -> u32 {
self.recursion_depth
}
}
#[derive(Debug)]
pub struct IntermediateRepr<D: DB> {
binary_repr: std::vec::Vec<u8>,
children: std::vec::Vec<ArenaHash<D::Hasher>>,
db_type: PhantomData<D>,
}
impl<D: DB> IntermediateRepr<D> {
#[cfg(test)]
pub fn from_storable<S: Storable<D>>(s: &S) -> Self {
let mut binary_repr: std::vec::Vec<u8> = vec![];
s.to_binary_repr(&mut binary_repr).unwrap();
IntermediateRepr {
binary_repr,
children: s.children().into_iter().map(|n| n.hash().clone()).collect(),
db_type: PhantomData,
}
}
}
#[derive(Debug, Clone)]
struct Node {
ref_count: u64,
}
impl Node {
fn new() -> Self {
Node { ref_count: 0 }
}
}
pub struct Sp<T: ?Sized + 'static, D: DB = DefaultDB> {
data: OnceLock<Arc<T>>,
pub child_repr: ArenaKey<D::Hasher>,
pub arena: Arena<D>,
pub root: ArenaHash<D::Hasher>,
}
impl<T: Display, D: DB> Display for Sp<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.data.get() {
Some(arc) => arc.fmt(f),
None => write!(f, "<Lazy Sp>"),
}
}
}
impl<T: Debug, D: DB> Debug for Sp<T, D> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self.data.get() {
Some(arc) => arc.fmt(f),
None => write!(f, "<Lazy Sp>"),
}
}
}
impl<T: Tagged, D: DB> Tagged for Sp<T, D> {
fn tag() -> std::borrow::Cow<'static, str> {
T::tag()
}
fn tag_unique_factor() -> String {
T::tag_unique_factor()
}
}
impl<T, D: DB> Hash for Sp<T, D> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.root.hash(state);
}
}
impl<T: Storable<D>, D: DB> Sp<T, D> {
pub fn new(value: T) -> Self {
default_storage().arena.alloc(value)
}
}
impl<T: ?Sized + 'static, D: DB> Sp<T, D> {
fn eager(
arena: Arena<D>,
root: ArenaHash<D::Hasher>,
arc: Arc<T>,
child_repr: ArenaKey<D::Hasher>,
) -> Self {
let sp = Sp::lazy(arena.clone(), root.clone(), child_repr);
let _ = sp.data.set(arc);
sp
}
pub fn into_tracked(&self) -> Self
where
T: Storable<D>,
{
match &self.child_repr {
ArenaKey::Direct(dcn) => {
let mut data: std::vec::Vec<u8> = std::vec::Vec::new();
let value = self.force_as_arc();
value
.to_binary_repr(&mut data)
.expect("Storable data should be able to be represented in binary");
let child_repr = ArenaKey::Ref(self.root.clone());
self.arena.new_sp_locked(
&mut self.arena.lock_metadata(),
value.as_ref().clone(),
self.root.clone(),
data,
dcn.children.deref().clone(),
child_repr,
)
}
ArenaKey::Ref(_) => self.clone(),
}
}
fn lazy(arena: Arena<D>, root: ArenaHash<D::Hasher>, child_repr: ArenaKey<D::Hasher>) -> Self {
let data = OnceLock::new();
if let ArenaKey::Ref(_) = child_repr {
arena.increment_ref(&root);
};
Sp {
data,
arena,
root,
child_repr,
}
}
}
impl<T: Storable<D>, D: DB> Sp<T, D> {
fn from_arena(
arena: &Arena<D>,
key: &ArenaKey<D::Hasher>,
max_depth: Option<usize>,
) -> Result<Sp<T, D>, std::io::Error> {
let loader = BackendLoader {
arena,
max_depth,
recursion_depth: 0,
};
loader.get(key)
}
}
impl<T: Storable<D>, D: DB> Deref for Sp<T, D> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.force_as_arc()
}
}
impl<T: ?Sized, D: DB> Clone for Sp<T, D> {
fn clone(&self) -> Self {
if let ArenaKey::Ref(_) = self.child_repr {
self.arena.increment_ref(&self.root);
}
Sp {
root: self.root.clone(),
child_repr: self.child_repr.clone(),
arena: self.arena.clone(),
data: self.data.clone(),
}
}
}
impl<D: DB> Sp<dyn Any + Send + Sync, D> {
pub fn downcast<T: Any + Send + Sync>(&self) -> Option<Sp<T, D>> {
if let ArenaKey::Ref(_) = self.child_repr {
self.arena.increment_ref(&self.root);
}
let data: OnceLock<Arc<T>> = match self.data.get() {
Some(arc) => {
let concrete_arc: Arc<T> = arc.clone().downcast().ok()?;
concrete_arc.into()
}
None => OnceLock::new(),
};
Some(Sp {
root: self.root.clone(),
child_repr: self.child_repr.clone(),
arena: self.arena.clone(),
data,
})
}
pub fn force_downcast<T: Any + Send + Sync>(&self) -> Sp<T, D> {
if let ArenaKey::Ref(_) = self.child_repr {
self.arena.increment_ref(&self.root);
}
let data: OnceLock<Arc<T>> = match self.data.get().map(|arc| arc.clone().downcast::<T>()) {
Some(Ok(concrete_arc)) => concrete_arc.into(),
None | Some(Err(_)) => OnceLock::new(),
};
Sp {
root: self.root.clone(),
child_repr: self.child_repr.clone(),
arena: self.arena.clone(),
data,
}
}
}
impl<T: Any + Send + Sync, D: DB> Sp<T, D> {
pub fn upcast(&self) -> Sp<dyn Any + Send + Sync, D> {
if let ArenaKey::Ref(_) = self.child_repr {
self.arena.increment_ref(&self.root);
}
let data: OnceLock<Arc<dyn Any + Send + Sync>> = match self.data.get() {
Some(arc) => {
let dyn_arc: Arc<dyn Any + Send + Sync> = arc.clone();
dyn_arc.into()
}
None => OnceLock::new(),
};
Sp {
root: self.root.clone(),
child_repr: self.child_repr.clone(),
arena: self.arena.clone(),
data,
}
}
}
impl<T: ?Sized, D: DB> Sp<T, D> {
pub fn is_lazy(&self) -> bool {
self.data.get().is_none()
}
pub fn hash(&self) -> ArenaHash<D::Hasher> {
self.root.clone()
}
pub fn as_typed_key(&self) -> TypedArenaKey<T, D::Hasher> {
TypedArenaKey {
key: self.as_child(),
_phantom: PhantomData,
}
}
pub fn as_child(&self) -> ArenaKey<D::Hasher> {
self.child_repr.clone()
}
}
impl<T: Storable<D>, D: DB> Sp<T, D> {
pub fn persist(&mut self) {
if let ArenaKey::Direct(..) = self.child_repr {
*self = self.into_tracked();
}
self.arena.with_backend(|backend| {
self.child_repr
.refs()
.into_iter()
.for_each(|ref_| backend.persist(ref_))
});
}
pub fn unpersist(&self) {
self.arena.with_backend(|backend| {
backend.unpersist(&self.root)
});
}
pub fn into_inner(this: Sp<T, D>) -> Option<T> {
let data: Option<Arc<T>> = this.data.get().cloned();
drop(this);
data.and_then(|arc| Arc::into_inner(arc))
}
}
impl<T: ?Sized + 'static, D: DB> Sp<T, D> {
pub fn unload(&mut self) {
let _ = self.data.take();
self.gc_weak_pointer();
}
fn gc_weak_pointer(&mut self) {
let sp_cache_guard = self.arena.lock_sp_cache();
let mut sp_cache = sp_cache_guard.borrow_mut();
let key = (self.root.clone(), TypeId::of::<T>());
if sp_cache
.get(&key)
.is_some_and(|weak| weak.strong_count() == 0)
{
sp_cache.remove(&key);
}
}
}
impl<T: Storable<D>, D: DB> Sp<T, D> {
fn force_as_arc(&self) -> &Arc<T> {
if self.data.get().is_none() {
let _metadata_lock = self.arena.lock_metadata();
let cache_lock = self.arena.lock_sp_cache();
let maybe_arc = self
.arena
.read_sp_cache_locked::<T>(&cache_lock, &self.root);
let arc: Arc<T> = match maybe_arc {
Some(arc) => arc,
None => {
let max_depth = Some(1);
let mut sp: Sp<T, _> =
match Sp::from_arena(&self.arena, &self.as_child(), max_depth) {
Ok(v) => v,
Err(e) => panic!(
"root should be in the arena (T={}): {e:?}",
std::any::type_name::<T>()
),
};
let arc = sp
.data
.take()
.expect("result of Sp::from_arena should be initialized");
if let ArenaKey::Ref(_) = &self.child_repr {
self.arena.write_sp_cache_locked(
&cache_lock,
self.root.clone(),
arc.clone(),
);
}
arc
}
};
let _ = self.data.set(arc);
}
self.data.get().unwrap()
}
pub fn serialize_to_node_list(&self) -> TopoSortedNodes {
self.serialize_to_node_list_bounded(u64::MAX)
.expect("unbounded serialization must succeed")
}
pub fn serialize_to_node_list_bounded(
&self,
mut raw_size_limit: u64,
) -> Option<TopoSortedNodes> {
let arena = self.arena.clone();
let root = self.child_repr.clone();
let mut incoming_vertices: HashMap<ArenaHash<_>, usize> = HashMap::new();
let mut disk_objects = HashMap::new();
let mut frontier = vec![root.clone()];
while let Some(child) = frontier.pop() {
if disk_objects.contains_key(child.hash()) {
continue;
}
let node = match child {
ArenaKey::Ref(ref key) => arena
.lock_backend()
.borrow_mut()
.get(key)
.expect("Arena should contain current serialization target")
.clone(),
ArenaKey::Direct(ref d) => OnDiskObject {
data: d.data.as_ref().clone(),
#[cfg(not(feature = "layout-v2"))]
ref_count: 0,
children: d.children.as_ref().clone(),
},
};
for child in node.children.iter() {
*incoming_vertices.entry(child.hash().clone()).or_default() += 1;
frontier.push(child.clone());
}
raw_size_limit = raw_size_limit
.checked_sub(PERSISTENT_HASH_BYTES as u64 + node.data.len() as u64)?;
disk_objects.insert(child.hash().clone(), node);
}
let mut list_indices: HashMap<ArenaHash<_>, u64> = HashMap::new();
let mut empty_incoming_nodes = vec![root.hash().clone()];
while let Some(node_hash) = empty_incoming_nodes.pop() {
if list_indices.contains_key(&node_hash) {
continue;
}
let disk = disk_objects.get(&node_hash).expect("node must be present");
list_indices.insert(node_hash, list_indices.len() as u64);
for child in disk.children.iter() {
let incoming = incoming_vertices
.get_mut(child.hash())
.expect("node must be present");
*incoming -= 1;
if *incoming == 0 {
empty_incoming_nodes.push(child.hash().clone());
}
}
}
let len = list_indices.len();
let mut list = TopoSortedNodes {
nodes: vec![TopoSortedNode::default(); len],
};
for (child_hash, idx) in list_indices.iter() {
let disk = disk_objects
.remove(child_hash)
.expect("node must be present");
list.nodes[len - 1 - *idx as usize] = TopoSortedNode {
child_indices: disk
.children
.iter()
.map(|child| len as u64 - 1 - list_indices[child.hash()])
.collect(),
data: disk.data,
};
}
Some(list)
}
}
impl<T: ?Sized + 'static, D: DB> Drop for Sp<T, D> {
fn drop(&mut self) {
self.unload();
if let ArenaKey::Ref(hash) = &self.child_repr {
self.arena.decrement_ref(hash);
}
}
}
impl<T, D: DB> PartialEq for Sp<T, D> {
fn eq(&self, other: &Self) -> bool {
self.root == other.root
}
}
impl<T, D: DB> Eq for Sp<T, D> {}
impl<T: PartialOrd + Storable<D>, D: DB> PartialOrd for Sp<T, D> {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
if self.root == other.root {
return Some(std::cmp::Ordering::Equal);
}
self.force_as_arc().partial_cmp(other.force_as_arc())
}
}
impl<T: Ord + Storable<D>, D: DB> Ord for Sp<T, D> {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
if self.root == other.root {
return std::cmp::Ordering::Equal;
}
self.force_as_arc().cmp(other.force_as_arc())
}
}
#[derive(Clone, PartialEq, Eq, Debug, Serializable)]
pub struct TopoSortedNodes {
pub nodes: Vec<TopoSortedNode>,
}
#[derive(Clone, PartialEq, Eq, Debug, Default, Serializable)]
pub struct TopoSortedNode {
pub child_indices: Vec<u64>,
pub data: Vec<u8>,
}
#[derive_where(Clone)]
pub struct Opaque<D: DB> {
data: Vec<u8>,
children: Vec<Sp<dyn Any + Send + Sync, D>>,
}
impl<D: DB> Storable<D> for Opaque<D> {
fn children(&self) -> std::vec::Vec<ArenaKey<<D as DB>::Hasher>> {
self.children.iter().map(|child| child.as_child()).collect()
}
fn to_binary_repr<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error>
where
Self: Sized,
{
writer.write_all(&self.data)
}
fn from_binary_repr<R: std::io::Read>(
reader: &mut R,
child_nodes: &mut impl Iterator<Item = ArenaKey<<D as DB>::Hasher>>,
loader: &impl Loader<D>,
) -> Result<Self, std::io::Error>
where
Self: Sized,
{
let mut data = Vec::new();
reader.read_to_end(&mut data)?;
let children = child_nodes
.map(|hash| loader.get::<Opaque<_>>(&hash).map(|sp| sp.upcast()))
.collect::<Result<_, _>>()?;
Ok(Self { data, children })
}
}
impl<D: DB> Storable<D> for Sp<dyn Any + Send + Sync, D> {
fn children(&self) -> std::vec::Vec<ArenaKey<<D as DB>::Hasher>> {
match &self.child_repr {
ArenaKey::Direct(key) => key.children.deref().clone(),
ArenaKey::Ref(hash) => self.arena.with_backend(|backend| {
backend
.get(hash)
.expect("ref Sp must be in backend")
.children
.clone()
}),
}
}
fn from_binary_repr<R: std::io::Read>(
reader: &mut R,
child_nodes: &mut impl Iterator<Item = ArenaKey<<D as DB>::Hasher>>,
loader: &impl Loader<D>,
) -> Result<Self, std::io::Error>
where
Self: Sized,
{
Opaque::from_binary_repr(reader, child_nodes, loader).map(|opaque| Sp::new(opaque).upcast())
}
fn to_binary_repr<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error>
where
Self: Sized,
{
match &self.child_repr {
ArenaKey::Direct(key) => writer.write_all(&key.data),
ArenaKey::Ref(hash) => self.arena.with_backend(|backend| {
writer.write_all(&backend.get(hash).expect("ref Sp must be in backend").data)
}),
}
}
}
impl<D: DB, T: Storable<D>> Storable<D> for Sp<T, D> {
fn children(&self) -> std::vec::Vec<ArenaKey<D::Hasher>> {
self.deref().children()
}
fn from_binary_repr<R: std::io::Read>(
reader: &mut R,
child_hashes: &mut impl Iterator<Item = ArenaKey<D::Hasher>>,
loader: &impl Loader<D>,
) -> Result<Self, std::io::Error> {
T::from_binary_repr(reader, child_hashes, loader).map(|sp| loader.alloc(sp))
}
fn to_binary_repr<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
self.deref().to_binary_repr(writer)
}
fn check_invariant(&self) -> Result<(), std::io::Error> {
T::check_invariant(self)
}
}
impl<T: Storable<D>, D: DB> Serializable for Sp<T, D> {
#[allow(clippy::type_complexity)]
fn serialize(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
self.serialize_to_node_list().serialize(writer)
}
fn serialized_size(&self) -> usize {
self.serialize_to_node_list().serialized_size()
}
}
impl<T: Storable<D>, D: DB> Deserializable for Sp<T, D> {
fn deserialize(
reader: &mut impl std::io::Read,
recursive_depth: u32,
) -> Result<Self, std::io::Error> {
default_storage()
.arena
.clone()
.deserialize_sp(reader, recursive_depth)
}
}
#[cfg(any(test, feature = "test-utilities"))]
pub mod bin_tree {
use super::*;
use crate::{self as storage, storable::SMALL_OBJECT_LIMIT};
use macros::Storable;
use std::fmt;
#[derive(Storable)]
#[derive_where(Clone, PartialEq, Eq)]
#[tag = "test-bin-tree"]
#[storable(db = D)]
pub struct BinTree<D: DB> {
value: u64,
pub(crate) left: Option<Sp<BinTree<D>, D>>,
pub(crate) right: Option<Sp<BinTree<D>, D>>,
_data: [u8; SMALL_OBJECT_LIMIT], }
impl<D: DB> fmt::Debug for BinTree<D> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("BinTree")
.field("value", &self.value)
.field("left", &self.left)
.field("right", &self.right)
.finish()
}
}
impl<D: DB> BinTree<D> {
pub fn new(
value: u64,
left: Option<Sp<BinTree<D>, D>>,
right: Option<Sp<BinTree<D>, D>>,
) -> BinTree<D> {
BinTree {
value,
left,
right,
_data: [0; SMALL_OBJECT_LIMIT],
}
}
#[cfg(all(
feature = "test-utilities",
any(feature = "parity-db", feature = "sqlite")
))]
pub fn sum(&self) -> u64 {
self.value
+ self.left.as_ref().map(|l| l.sum()).unwrap_or(0)
+ self.right.as_ref().map(|r| r.sum()).unwrap_or(0)
}
}
#[cfg(any(
test,
all(
feature = "test-utilities",
any(feature = "parity-db", feature = "sqlite")
)
))]
pub fn counting_tree<D: DB>(arena: &Arena<D>, height: usize) -> Sp<BinTree<D>, D> {
fn go<D: DB>(arena: &Arena<D>, value: u64, height: usize) -> Sp<BinTree<D>, D> {
assert!(height > 0);
let (left, right) = {
if height == 1 {
(None, None)
} else {
(
Some(go(arena, 2 * value, height - 1)),
Some(go(arena, 2 * value + 1, height - 1)),
)
}
};
arena.alloc(BinTree::new(value, left, right))
}
go(arena, 1, height)
}
}
pub mod test_helpers {
use super::*;
pub fn get_root_count<D: DB>(arena: &Arena<D>, key: &ArenaHash<D::Hasher>) -> u32 {
arena.lock_backend().borrow().get_root_count(key)
}
pub fn read_sp_cache<D: DB, T: Storable<D>>(
arena: &Arena<D>,
key: &ArenaHash<D::Hasher>,
) -> Option<Arc<T>> {
arena.read_sp_cache_locked::<T>(&arena.lock_sp_cache(), key)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate as storage;
use crate::DefaultHasher;
use crate::storable::SMALL_OBJECT_LIMIT;
use macros::Storable;
fn new_arena() -> Arena<DefaultDB> {
Arena::<DefaultDB>::new_from_backend(StorageBackend::<DefaultDB>::new(
16,
DefaultDB::default(),
))
}
#[test]
fn alloc() {
let val: u8 = 2;
let map = new_arena();
let alloced = map.alloc::<u8>(val);
assert_eq!(*alloced, val);
}
#[test]
fn dedup() {
let val = [0; SMALL_OBJECT_LIMIT];
let map = new_arena();
let _malloced_a = map.alloc::<[u8; SMALL_OBJECT_LIMIT]>(val);
let _malloced_b = map.alloc::<[u8; SMALL_OBJECT_LIMIT]>(val);
assert_eq!(map.size(), 1)
}
#[test]
fn drop_node() {
let map = new_arena();
let _malloc_a = map.alloc::<[u8; SMALL_OBJECT_LIMIT]>([0; SMALL_OBJECT_LIMIT]);
{
let _malloc_b = map.alloc::<[u8; SMALL_OBJECT_LIMIT]>([1; SMALL_OBJECT_LIMIT]);
assert_eq!(map.size(), 2);
}
assert_eq!(map.size(), 1);
}
#[test]
fn clone_increment_refcount() {
let map = new_arena();
let payload = [0; SMALL_OBJECT_LIMIT]; let malloc_a = map.alloc::<[u8; SMALL_OBJECT_LIMIT]>(payload);
let malloc_b = malloc_a.clone();
let ref_count = map
.lock_metadata()
.borrow()
.get(&malloc_a.root)
.unwrap()
.ref_count;
assert_eq!(malloc_a, malloc_b);
assert_eq!(ref_count, 2);
}
#[test]
fn into_inner() {
let arena = new_arena();
let sp1 = arena.alloc(42u32);
let sp2 = sp1.clone();
assert!(Sp::into_inner(sp1).is_none());
assert!(Sp::into_inner(sp2).is_some());
}
#[test]
fn test_sp_nesting() {
let arena = new_arena();
#[derive(Clone, PartialOrd, Ord, PartialEq, Eq)]
struct Nesty(Option<Sp<Nesty>>);
impl Storable<DefaultDB> for Nesty {
fn children(&self) -> std::vec::Vec<ArenaKey<<DefaultDB as DB>::Hasher>> {
self.0.children()
}
fn to_binary_repr<W: std::io::Write>(
&self,
writer: &mut W,
) -> Result<(), std::io::Error> {
self.0.to_binary_repr(writer)
}
fn from_binary_repr<R: std::io::Read>(
reader: &mut R,
child_hashes: &mut impl Iterator<Item = ArenaKey<DefaultHasher>>,
loader: &impl Loader<DefaultDB>,
) -> Result<Self, std::io::Error> {
Ok(Nesty(
<Option<Sp<Nesty>> as Storable<DefaultDB>>::from_binary_repr(
reader,
child_hashes,
loader,
)?,
))
}
}
impl Drop for Nesty {
fn drop(&mut self) {
if self.0.is_none() {
return;
}
let mut frontier = std::mem::take(&mut self.0)
.into_iter()
.collect::<std::vec::Vec<_>>();
while let Some(nest) = frontier.pop() {
frontier
.extend(Sp::into_inner(nest).and_then(|mut n| std::mem::take(&mut n.0)));
}
}
}
let mut nest = Nesty(None);
for _ in 0..100_000 {
nest = Nesty(Some(arena.alloc(nest)));
}
drop(nest);
}
#[test]
fn sp_cache_sp_drop() {
let arena = &new_arena();
let sp1 = arena.alloc([42u8; SMALL_OBJECT_LIMIT]);
let root_key = sp1.root.clone();
let type_id = TypeId::of::<[u8; SMALL_OBJECT_LIMIT]>();
let cache_key = (root_key.clone(), type_id);
{
let sp_cache = arena.lock_sp_cache();
let sp_cache = sp_cache.borrow();
assert!(sp_cache.get(&cache_key).is_some());
let weak_ref = sp_cache.get(&cache_key).unwrap();
assert!(weak_ref.upgrade().is_some());
let dyn_arc = weak_ref.upgrade().unwrap();
let arc = dyn_arc.downcast::<[u8; SMALL_OBJECT_LIMIT]>().unwrap();
assert!(Arc::ptr_eq(&arc, sp1.data.get().unwrap()));
}
let sp2 = sp1.clone();
assert!(Arc::ptr_eq(
sp1.data.get().unwrap(),
sp2.data.get().unwrap()
));
assert_eq!(Arc::strong_count(sp1.data.get().unwrap()), 2);
drop(sp2);
assert_eq!(Arc::strong_count(sp1.data.get().unwrap()), 1);
{
let sp_cache = arena.lock_sp_cache();
let sp_cache = sp_cache.borrow();
assert!(sp_cache.get(&cache_key).is_some());
let weak_ref = sp_cache.get(&cache_key).unwrap();
assert!(weak_ref.upgrade().is_some());
let dyn_arc = weak_ref.upgrade().unwrap();
let arc = dyn_arc.downcast::<[u8; SMALL_OBJECT_LIMIT]>().unwrap();
assert!(Arc::ptr_eq(&arc, sp1.data.get().unwrap()));
}
drop(sp1);
{
let sp_cache = arena.lock_sp_cache();
let sp_cache = sp_cache.borrow();
assert!(
sp_cache.get(&cache_key).is_none(),
"the weak reference should be gone"
);
}
}
#[test]
fn sp_cache_sp_unload() {
let arena = &new_arena();
let mut sp1 = arena.alloc([42u8; SMALL_OBJECT_LIMIT]);
let mut sp2 = sp1.clone();
let cache_key = (sp1.root.clone(), TypeId::of::<[u8; SMALL_OBJECT_LIMIT]>());
{
let sp_cache = arena.lock_sp_cache();
let sp_cache = sp_cache.borrow();
let weak_ref = sp_cache.get(&cache_key).unwrap();
assert!(
weak_ref.upgrade().is_some(),
"weak reference should be valid before unload"
);
}
sp1.unload();
{
let sp_cache = arena.lock_sp_cache();
let sp_cache = sp_cache.borrow();
let weak_ref = sp_cache.get(&cache_key).unwrap();
assert!(
weak_ref.upgrade().is_some(),
"weak reference should still be valid after unloading sp1"
);
}
sp2.unload();
{
let sp_cache = arena.lock_sp_cache();
let sp_cache = sp_cache.borrow();
assert!(
sp_cache.get(&cache_key).is_none(),
"the weak reference should be gone after unloading sp2"
);
}
}
#[test]
fn sp_cache_alloc_same_data_twice() {
let arena = &new_arena();
let sp1 = arena.alloc([0u8; SMALL_OBJECT_LIMIT]);
let sp2 = arena.alloc([0u8; SMALL_OBJECT_LIMIT]);
let data1 = sp1.data.get().unwrap();
let data2 = sp2.data.get().unwrap();
assert!(
Arc::ptr_eq(data1, data2),
"underlying Arc should be shared when allocating the same data"
);
}
#[test]
fn lazy_load_large_data_structure() {
use super::bin_tree::*;
let arena = &new_arena();
type BinTree = super::bin_tree::BinTree<DefaultDB>;
{
let mut bt = BinTree::new(0, None, None);
let depth = 5;
for i in 1..depth {
bt = BinTree::new(i, Some(arena.alloc(bt.clone())), Some(arena.alloc(bt)));
}
let mut bt = arena.alloc(bt);
bt.persist();
bt.unload();
let mut p = Some(&bt);
for _ in 0..depth {
p = p.unwrap().left.as_ref();
}
let actual = format!("{:?}", bt);
dbg!(&actual);
assert!(actual.ends_with("right: Some(<Lazy Sp>) }), right: Some(<Lazy Sp>) }), right: Some(<Lazy Sp>) }), right: Some(<Lazy Sp>) }"));
}
{
let mut bt1 = BinTree::new(0, None, None);
let depth = 100;
for i in 1..depth {
bt1 = BinTree::new(i, Some(arena.alloc(bt1.clone())), Some(arena.alloc(bt1)));
}
let mut bt1 = arena.alloc(bt1);
let key = bt1.as_typed_key();
bt1.persist();
bt1.unload();
let bt2 = arena.get_lazy::<BinTree>(&key).unwrap();
let mut p1 = Some(&bt1);
let mut p2 = Some(&bt2);
for _ in 0..depth {
assert!(p1.unwrap().data.get().is_none());
assert!(p2.unwrap().data.get().is_none());
assert!(Arc::ptr_eq(
p1.unwrap().force_as_arc(),
p2.unwrap().force_as_arc(),
));
p1 = p1.unwrap().left.as_ref();
p2 = p2.unwrap().right.as_ref();
}
}
{
let depth = 13;
let mut bt = counting_tree(arena, depth);
assert_eq!(arena.lock_sp_cache().borrow().len(), (1 << depth) - 1);
bt.persist();
bt.unload();
let mut p = Some(&bt);
let random: u64 = 0x616a7011af5e1b64;
for i in 0..depth {
if (random >> i) & 1 == 0 {
assert!(p.unwrap().data.get().is_none());
p = p.unwrap().left.as_ref();
} else {
assert!(p.unwrap().data.get().is_none());
p = p.unwrap().right.as_ref();
}
}
assert_eq!(arena.lock_sp_cache().borrow().len(), depth);
}
}
#[test]
fn concurrent_arena_access() {
use std::thread;
type Ty = Sp<Option<Sp<Option<Sp<Option<Sp<Option<Sp<u32>>>>>>>>>;
let mut threads = std::vec::Vec::new();
let num_threads = 20;
for i in 0..num_threads {
let arena = default_storage().arena.clone();
threads.push(thread::spawn(move || {
let mk_sp = |value: u32| -> Ty {
let sp = arena.alloc(value);
let sp = arena.alloc(Some(sp));
let sp = arena.alloc(Some(sp));
let sp = arena.alloc(Some(sp));
let sp = arena.alloc(Some(sp));
sp.clone()
};
let mut common_sp = mk_sp(0);
let mut sp_unique = mk_sp((i + 1) as u32);
let force_sp = |sp: &Ty| -> u32 {
let sp = sp.deref().as_ref().unwrap();
let sp = sp.deref().as_ref().unwrap();
let sp = sp.deref().as_ref().unwrap();
let sp = sp.deref().as_ref().unwrap();
*sp.deref()
};
for _ in 0..100 {
common_sp.unload();
sp_unique.unload();
let common_val = force_sp(&common_sp);
let unique_val = force_sp(&sp_unique);
assert_eq!(common_val, 0);
assert_eq!(unique_val, (i + 1) as u32);
assert_eq!(arena.get(&common_sp.as_typed_key()).unwrap(), common_sp);
assert_eq!(
arena.get_lazy(&sp_unique.as_typed_key()).unwrap(),
sp_unique
);
}
}));
}
thread::sleep(std::time::Duration::from_secs(10));
for t in threads {
assert!(
t.is_finished(),
"deadlock: the threads should finish in about 2 seconds"
);
}
}
#[test]
fn serialize_sp() {
let arena = &new_arena();
let sp = arena.alloc(42u32);
let sp = arena.alloc(Some(sp));
let sp = arena.alloc(Some(sp));
let sp = arena.alloc(Some(sp));
let mut sp = arena.alloc(Some(sp));
let eager_size = Sp::serialized_size(&sp);
let mut eager_serialization = vec![];
Sp::serialize(&sp, &mut eager_serialization).unwrap();
assert_eq!(eager_serialization.len(), eager_size);
sp.unload();
let lazy_size = Sp::serialized_size(&sp);
sp.unload();
let mut lazy_serialization = vec![];
Sp::serialize(&sp, &mut lazy_serialization).unwrap();
assert_eq!(lazy_serialization.len(), lazy_size);
}
#[test]
fn serialize_highly_duplicated_dag() {
use std::thread;
use std::time::Duration;
#[derive(Storable, Clone, PartialEq, Eq, Debug)]
#[tag = "test-bin-tree"]
struct BinTree {
value: u32,
left: Option<Sp<BinTree>>,
right: Option<Sp<BinTree>>,
}
let arena = &new_arena();
let mut bt = BinTree {
value: 0,
left: None,
right: None,
};
let height = 30;
for i in 1..height {
bt = BinTree {
value: i,
left: Some(arena.alloc(bt.clone())),
right: Some(arena.alloc(bt)),
};
}
let sp = arena.alloc(bt);
let handle = std::thread::spawn(move || {
let mut serialized = vec![];
Sp::serialize(&sp, &mut serialized).unwrap();
serialized
});
for _ in 0..50 {
thread::sleep(Duration::from_millis(100));
if handle.is_finished() {
break;
}
}
if !handle.is_finished() {
panic!("serialize_highly_duplicated_dag: serialization took too long!");
}
let serialized = handle.join().unwrap();
let handle = std::thread::spawn(move || {
let recursive_depth = 0;
Sp::<BinTree>::deserialize(&mut serialized.as_slice(), recursive_depth).unwrap();
});
for _ in 0..50 {
thread::sleep(Duration::from_millis(100));
if handle.is_finished() {
break;
}
}
if !handle.is_finished() {
panic!("serialize_highly_duplicated_dag: deserialization took too long!");
}
handle.join().unwrap();
}
#[test]
fn deserialize_same_key_at_two_different_types() {
#[derive(Clone, Storable)]
#[tag = "test-pair"]
struct Pair {
#[storable(child)]
x: Sp<u32>,
#[storable(child)]
y: Sp<u64>,
}
let arena = &new_arena();
let x = arena.alloc(0u32);
let y = arena.alloc(0u64);
assert_eq!(x.as_typed_key().key, y.as_typed_key().key);
assert_ne!(x.type_id(), y.type_id());
let sp = arena.alloc(Pair { x, y });
assert_eq!(
sp.children().len(),
2,
"children were inlined, need to fix `Pair as Storable` impl"
);
let mut bytes: Vec<u8> = vec![];
Sp::serialize(&sp, &mut bytes).unwrap();
drop(sp);
let _ = Sp::<Pair, _>::deserialize(&mut bytes.as_slice(), 0).unwrap();
}
#[test]
fn get_unknown_key() {
let arena = new_arena();
let sp = arena.alloc([0; SMALL_OBJECT_LIMIT]);
let key = sp.as_typed_key();
assert!(arena.get::<[u8; SMALL_OBJECT_LIMIT]>(&key).is_ok());
let arena = new_arena();
assert!(arena.get::<[u8; SMALL_OBJECT_LIMIT]>(&key).is_err());
}
#[test]
fn metadata_sp_cache_race() {
use std::thread;
let arena = new_arena();
let mut sp = arena.alloc(42u32);
let key = sp.as_typed_key();
sp.persist();
drop(sp);
let arena1 = arena.clone();
let key1 = key.clone();
let t1 = thread::spawn(move || {
for _ in 0..1000 {
let sp = arena1.get::<u32>(&key1).unwrap();
drop(sp);
}
});
for i in 0..1000 {
if i % 2 == 0 {
let sp = arena.get_lazy::<u32>(&key).unwrap();
drop(sp);
} else {
let sp = arena.get::<u32>(&key).unwrap();
drop(sp);
}
}
t1.join().unwrap();
}
#[test]
fn sp_is_lazy() {
let arena = new_arena();
let mut sp = arena.alloc([42u8; SMALL_OBJECT_LIMIT]);
assert!(!sp.is_lazy());
sp.persist();
sp.unload();
assert!(sp.is_lazy());
let _ = sp.deref();
assert!(!sp.is_lazy());
let key = sp.as_typed_key();
sp.persist();
drop(sp);
let sp = arena.get_lazy::<[u8; SMALL_OBJECT_LIMIT]>(&key).unwrap();
assert!(sp.is_lazy());
let sp = arena.get::<[u8; SMALL_OBJECT_LIMIT]>(&key).unwrap();
assert!(!sp.is_lazy());
}
#[test]
fn serialize_small_sp() {
let arena = new_arena();
let sp = arena.alloc(42u32);
let mut bytes: Vec<u8> = vec![];
Sp::serialize(&sp, &mut bytes).unwrap();
let other_sp = Sp::deserialize(&mut bytes.as_slice(), 0).unwrap();
assert_eq!(sp, other_sp);
}
#[test]
fn force_as_arc_lock_ordering_regression() {
use std::sync::mpsc;
use std::time::Duration;
#[derive(Storable, Clone, PartialEq, Eq)]
struct Parent {
#[storable(child)]
a: Sp<u32>,
#[storable(child)]
b: Sp<u32>,
}
let arena = new_arena();
let parent = arena.alloc(Parent {
a: arena.alloc(42u32),
b: arena.alloc(99u32),
});
let mut tracked = parent.into_tracked();
tracked.persist();
let root_key: ArenaKey<_> = ArenaKey::Ref(tracked.root.clone());
arena.with_backend(|b| b.flush_all_changes_to_db());
tracked.unload();
drop(tracked);
let (tx, rx) = mpsc::channel();
let tx_a = tx.clone();
let tx_b = tx;
let arena_a = arena.clone();
let key_a = root_key.clone();
std::thread::spawn(move || {
for _ in 0..10_000 {
let sp = arena_a
.get_lazy_unversioned::<Parent>(&key_a)
.expect("get_lazy_unversioned");
let _ = &*sp; drop(sp);
}
tx_a.send("A").ok();
});
let arena_b = arena.clone();
std::thread::spawn(move || {
for i in 0u32..10_000 {
let sp = arena_b.alloc(i);
let _ = sp.into_tracked();
}
tx_b.send("B").ok();
});
let timeout = Duration::from_secs(30);
rx.recv_timeout(timeout)
.expect("DEADLOCK: neither thread completed within 30s — lock ordering violation between force_as_arc (sp_cache → metadata) and new_sp_locked (metadata → sp_cache)");
rx.recv_timeout(timeout)
.expect("DEADLOCK: only one thread completed — lock ordering violation");
}
#[test]
fn serialize_direct_and_ref_same_hash() {
#[derive(Storable, Clone, PartialEq, Eq)]
struct DualLeaf {
#[storable(child)]
a: Sp<u32>,
#[storable(child)]
b: Sp<u32>,
}
let arena = &new_arena();
let leaf: Sp<u32, DefaultDB> = arena.alloc(42u32);
assert!(matches!(leaf.child_repr, ArenaKey::Direct(_)));
let leaf_tracked = leaf.into_tracked();
assert!(matches!(leaf_tracked.child_repr, ArenaKey::Ref(_)));
assert_eq!(leaf.child_repr.hash(), leaf_tracked.child_repr.hash());
assert_ne!(leaf.child_repr, leaf_tracked.child_repr);
let root = arena.alloc(DualLeaf {
a: leaf,
b: leaf_tracked,
});
let nodes = root.serialize_to_node_list();
assert_eq!(nodes.nodes.len(), 2, "root + one deduplicated leaf");
let root_node = nodes.nodes.last().unwrap();
assert_eq!(
root_node.child_indices[0], root_node.child_indices[1],
"both children deduplicate to the same leaf index"
);
}
}