use crate::Storable;
use crate::arena::{ArenaHash, ArenaKey, Opaque, Sp};
use crate::db::DB;
use crate::storable::Loader;
use crate::storage::{Map, default_storage};
use crate::{self as storage, DefaultDB};
use derive_where::derive_where;
use rand::distributions::{Distribution, Standard};
use serialize::{Deserializable, Serializable, Tagged};
#[cfg(test)]
use std::collections::BTreeMap;
use std::collections::BTreeSet;
#[cfg(feature = "proptest")]
use {proptest::prelude::Arbitrary, serialize::NoStrategy, std::marker::PhantomData};
#[derive_where(Debug, PartialEq, Eq)]
pub struct ChildRef<D: DB> {
pub child: ArenaKey<D::Hasher>,
}
impl<D: DB> ChildRef<D> {
pub fn new(child: ArenaKey<D::Hasher>) -> Self {
default_storage::<D>().with_backend(|b| child.refs().iter().for_each(|r| b.persist(r)));
Self { child }
}
}
impl<D: DB> Clone for ChildRef<D> {
fn clone(&self) -> Self {
ChildRef::new(self.child.clone())
}
}
impl<D: DB> Drop for ChildRef<D> {
fn drop(&mut self) {
default_storage::<D>()
.with_backend(|b| self.child.refs().iter().for_each(|r| b.unpersist(r)));
}
}
impl<D: DB> Storable<D> for ChildRef<D> {
fn children(&self) -> std::vec::Vec<ArenaKey<D::Hasher>> {
vec![self.child.clone()]
}
fn to_binary_repr<W: std::io::Write>(&self, _writer: &mut W) -> Result<(), std::io::Error>
where
Self: Sized,
{
Ok(())
}
fn from_binary_repr<R: std::io::Read>(
reader: &mut R,
children: &mut impl Iterator<Item = ArenaKey<D::Hasher>>,
loader: &impl Loader<D>,
) -> Result<Self, std::io::Error>
where
Self: Sized,
{
let mut children = children.collect::<Vec<_>>();
let mut data = Vec::new();
reader.read_to_end(&mut data)?;
if children.len() == 1 && data.is_empty() {
let child = children.pop().expect("must be present");
let mut sp: Sp<Opaque<D>, D> = loader.get(&child)?;
sp.persist();
let child_ref = Self::new(child);
sp.unpersist();
Ok(child_ref)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Ref should have exactly one child and no data",
))
}
}
}
impl<D: DB> Serializable for ChildRef<D> {
fn serialize(&self, writer: &mut impl std::io::Write) -> std::io::Result<()> {
self.child.serialize(writer)
}
fn serialized_size(&self) -> usize {
self.child.serialized_size()
}
}
impl<D: DB> Deserializable for ChildRef<D> {
fn deserialize(reader: &mut impl std::io::Read, recursive_depth: u32) -> std::io::Result<Self> {
ArenaKey::<D::Hasher>::deserialize(reader, recursive_depth).map(ChildRef::new)
}
}
impl<D: DB> Distribution<ChildRef<D>> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> ChildRef<D> {
ChildRef::new(ArenaKey::Ref(rng.r#gen()))
}
}
impl<D: DB> Tagged for ChildRef<D> {
fn tag() -> std::borrow::Cow<'static, str> {
"childref[v1]".into()
}
fn tag_unique_factor() -> String {
"children[v1]".into()
}
}
#[derive_where(Debug, Clone, PartialEq, Eq)]
#[derive(Storable)]
#[storable(db = D)]
#[tag = "rcmap[v1]"]
pub struct RcMap<D: DB = DefaultDB> {
#[cfg(feature = "public-internal-structure")]
pub rc_ge_1: Map<ArenaHash<D::Hasher>, u64, D>,
#[cfg(not(feature = "public-internal-structure"))]
rc_ge_1: Map<ArenaHash<D::Hasher>, u64, D>,
#[cfg(feature = "public-internal-structure")]
pub rc_0: Map<ArenaKey<D::Hasher>, ChildRef<D>, D>,
#[cfg(not(feature = "public-internal-structure"))]
rc_0: Map<ArenaKey<D::Hasher>, ChildRef<D>, D>,
}
impl<D: DB> RcMap<D> {
pub(crate) fn contains(&self, key: &ArenaKey<D::Hasher>) -> bool {
self.get_rc(key).is_some()
}
pub(crate) fn get_rc(&self, key: &ArenaKey<D::Hasher>) -> Option<u64> {
if let ArenaKey::Ref(key) = key
&& let Some(count) = self.rc_ge_1.get(key)
{
Some(*count)
} else if self.rc_0.contains_key(key) {
Some(0)
} else {
None }
}
#[must_use]
pub(crate) fn ins_root(&self, key: ArenaKey<D::Hasher>) -> Self {
RcMap {
rc_ge_1: self.rc_ge_1.clone(),
rc_0: self.rc_0.insert(key.clone(), ChildRef::new(key.clone())),
}
}
#[must_use]
pub(crate) fn rm_root(&self, key: &ArenaKey<D::Hasher>) -> Self {
RcMap {
rc_ge_1: self.rc_ge_1.clone(),
rc_0: self.rc_0.remove(key),
}
}
#[must_use]
pub(crate) fn modify_rc(&self, key: &ArenaHash<D::Hasher>, updated: u64) -> Self {
let curr = self.rc_ge_1.get(key).copied().unwrap_or(0);
match (curr, updated) {
(0, 0) =>
{
RcMap {
rc_ge_1: self.rc_ge_1.clone(),
rc_0: self.rc_0.insert(
ArenaKey::Ref(key.clone()),
ChildRef::new(ArenaKey::Ref(key.clone())),
),
}
}
(0, 1..) =>
{
RcMap {
rc_ge_1: self.rc_ge_1.insert(key.clone(), updated),
rc_0: self.rc_0.remove(&ArenaKey::Ref(key.clone())),
}
}
(1.., 1..) =>
{
RcMap {
rc_ge_1: self.rc_ge_1.insert(key.clone(), updated),
rc_0: self.rc_0.clone(),
}
}
(1.., 0) =>
{
RcMap {
rc_ge_1: self.rc_ge_1.remove(key),
rc_0: self.rc_0.insert(
ArenaKey::Ref(key.clone()),
ChildRef::new(ArenaKey::Ref(key.clone())),
),
}
}
}
}
pub(crate) fn get_unreachable_keys_not_in(
&self,
roots: &BTreeSet<ArenaKey<D::Hasher>>,
) -> impl Iterator<Item = ArenaKey<D::Hasher>> {
self.rc_0.keys().filter(|key| !roots.contains(key))
}
#[must_use]
pub(crate) fn remove_unreachable_key(&self, key: &ArenaKey<D::Hasher>) -> Option<Self> {
if self.rc_0.contains_key(key) {
Some(RcMap {
rc_ge_1: self.rc_ge_1.clone(),
rc_0: self.rc_0.remove(key),
})
} else {
None
}
}
#[cfg(test)]
pub(crate) fn get_rcs(&self) -> BTreeMap<ArenaKey<D::Hasher>, u64> {
let mut result = BTreeMap::new();
for key in self.rc_0.keys() {
result.insert(key.clone(), 0);
}
for (key, count) in self.rc_ge_1.iter() {
result.insert(ArenaKey::Ref(key.clone()), *count);
}
result
}
}
impl<D: DB> Default for RcMap<D> {
fn default() -> Self {
RcMap {
rc_ge_1: Map::new(),
rc_0: Map::new(),
}
}
}
#[cfg(feature = "proptest")]
impl<D: DB> Arbitrary for RcMap<D>
where
D::Hasher: Arbitrary,
{
type Strategy = NoStrategy<RcMap<D>>;
type Parameters = ();
fn arbitrary_with(_args: Self::Parameters) -> Self::Strategy {
NoStrategy(PhantomData)
}
}
impl<D: DB> Distribution<RcMap<D>> for Standard {
fn sample<R: rand::Rng + ?Sized>(&self, rng: &mut R) -> RcMap<D> {
RcMap {
rc_ge_1: rng.r#gen(),
rc_0: rng.r#gen(),
}
}
}
#[cfg(test)]
pub(crate) mod tests {
use super::*;
use crate::arena::Sp;
use crate::db::InMemoryDB;
use crate::storable::SMALL_OBJECT_LIMIT;
#[test]
fn keyref_round_trip_storable() {
let val = Sp::<_, InMemoryDB>::new([0u8; 1024]);
let key = val.as_child();
let keyref = ChildRef::<InMemoryDB>::new(key);
let _ = Sp::new(keyref.clone());
let keyrefs = vec![
Sp::new(keyref.clone()),
Sp::new(keyref.clone()),
Sp::new(keyref.clone()),
];
let mut bytes = Vec::new();
keyrefs.to_binary_repr(&mut bytes).unwrap();
let mut reader = &bytes[..];
let mut child_iter = keyrefs.children().into_iter();
let arena = &crate::storage::default_storage().arena;
let loader = storage_core::arena::BackendLoader::new(arena, None);
let deserialized: Vec<Sp<ChildRef<InMemoryDB>, InMemoryDB>> =
Storable::from_binary_repr(&mut reader, &mut child_iter, &loader).unwrap();
assert_eq!(keyrefs, deserialized);
}
#[cfg(test)]
pub(crate) fn get_rcmap_descendants<D: DB>(
rcmap: &RcMap<D>,
) -> std::collections::BTreeSet<ArenaKey<D::Hasher>> {
let mut visited = std::collections::BTreeSet::new();
let mut to_visit = rcmap.children();
let arena = &crate::storage::default_storage::<D>().arena;
while let Some(current) = to_visit.pop() {
if !visited.insert(current.clone()) {
continue;
}
match current {
ArenaKey::Direct(d) => to_visit.extend(d.children.iter().cloned()),
ArenaKey::Ref(ref r) => {
arena.with_backend(|backend| {
let disk_obj = backend.get(r).expect("Key should exist in backend");
to_visit.extend(disk_obj.children.clone());
});
}
}
}
visited
}
#[test]
fn rc_0_keys_are_descendants() {
let val = Sp::<_, InMemoryDB>::new([42u8; SMALL_OBJECT_LIMIT]);
let key = val.root.clone();
let rcmap = RcMap::<InMemoryDB>::default().modify_rc(&key, 0);
assert!(rcmap.rc_0.contains_key(&ArenaKey::Ref(key.clone())));
let descendants = get_rcmap_descendants(&rcmap);
assert!(
descendants.contains(&val.as_child()),
"Key in rc_0 must be a descendant of RcMap"
);
}
#[test]
fn rcmap_operations() {
let val1 = Sp::<_, InMemoryDB>::new([1u8; 1024]);
let key1 = val1.as_child();
let ArenaKey::Ref(hash1) = key1.clone() else {
panic!("testing refs");
};
let val2 = Sp::<_, InMemoryDB>::new([2u8; 1024]);
let key2 = val2.as_child();
let ArenaKey::Ref(hash2) = key2.clone() else {
panic!("testing refs");
};
let val3 = Sp::<_, InMemoryDB>::new([3u8; 1024]);
let key3 = val3.as_child();
let ArenaKey::Ref(hash3) = key3.clone() else {
panic!("testing refs");
};
let rcmap = RcMap::<InMemoryDB>::default().ins_root(key1.clone());
assert_eq!(rcmap.get_rc(&key1), Some(0), "get_rc should return 0");
assert!(rcmap.rc_0.contains_key(&key1), "key1 should be in rc_0 map");
assert!(
!rcmap.rc_ge_1.contains_key(&hash1),
"key1 should not be in rc_ge_1 map"
);
let rcmap = rcmap.modify_rc(&hash1, 1);
assert_eq!(rcmap.get_rc(&key1), Some(1), "get_rc should return 1");
assert!(
!rcmap.rc_0.contains_key(&key1),
"key1 should not be in rc_0 map"
);
assert!(
rcmap.rc_ge_1.contains_key(&hash1),
"key1 should be in rc_ge_1 map"
);
let rcmap = rcmap.modify_rc(&hash1, 2);
let rcmap = rcmap.modify_rc(&hash1, 3);
assert_eq!(rcmap.get_rc(&key1), Some(3), "get_rc should return 3");
assert!(
rcmap.rc_ge_1.contains_key(&hash1),
"key1 should remain in rc_ge_1 map"
);
let rcmap = rcmap.modify_rc(&hash1, 2);
let rcmap = rcmap.modify_rc(&hash1, 1);
assert!(
rcmap.rc_ge_1.contains_key(&hash1),
"key1 should still be in rc_ge_1 map"
);
let rcmap = rcmap.modify_rc(&hash1, 0);
assert_eq!(rcmap.get_rc(&key1), Some(0), "get_rc should return 0");
assert!(
rcmap.rc_0.contains_key(&key1),
"key1 should be back in rc_0 map"
);
assert!(
!rcmap.rc_ge_1.contains_key(&hash1),
"key1 should not be in rc_ge_1 map"
);
assert_eq!(
rcmap.get_rc(&key2),
None,
"get_rc on nonexistent key should return None"
);
let rcmap = rcmap.modify_rc(&hash2, 1);
let rcmap = rcmap.modify_rc(&hash3, 2);
assert_eq!(rcmap.get_rc(&key1), Some(0));
assert_eq!(rcmap.get_rc(&key2), Some(1));
assert_eq!(rcmap.get_rc(&key3), Some(2));
assert!(rcmap.rc_0.contains_key(&key1));
assert!(rcmap.rc_ge_1.contains_key(&hash2));
assert!(rcmap.rc_ge_1.contains_key(&hash3));
let rcmap_new = rcmap.remove_unreachable_key(&key1);
assert!(
rcmap_new.is_some(),
"remove_unreachable_key should succeed for rc=0 key"
);
let rcmap = rcmap_new.unwrap();
assert!(!rcmap.contains(&key1), "key1 should no longer be in rcmap");
assert_eq!(
rcmap.get_rc(&key1),
None,
"get_rc should return None for removed key"
);
let rcmap_new = rcmap.remove_unreachable_key(&key2);
assert!(
rcmap_new.is_none(),
"remove_unreachable_key should fail for rc>0 key"
);
let rcmap_new = rcmap.remove_unreachable_key(&key1);
assert!(
rcmap_new.is_none(),
"remove_unreachable_key should fail for nonexistent key"
);
}
}