use alloc::{collections::BTreeMap, vec::Vec};
use core::{borrow::Borrow, cmp, fmt, ops};
use hashbrown::HashMap;
#[derive(Clone)]
pub struct TrieDiff<T = ()> {
btree: BTreeMap<Vec<u8>, bool>,
hashmap: HashMap<Vec<u8>, (Option<Vec<u8>>, T), fnv::FnvBuildHasher>,
}
impl<T> TrieDiff<T> {
pub fn empty() -> Self {
Self {
btree: BTreeMap::default(),
hashmap: HashMap::with_capacity_and_hasher(0, Default::default()),
}
}
pub fn clear(&mut self) {
self.hashmap.clear();
self.btree.clear();
}
pub fn diff_insert(
&mut self,
key: impl Into<Vec<u8>>,
value: impl Into<Vec<u8>>,
user_data: T,
) -> Option<(Option<Vec<u8>>, T)> {
let key = key.into();
let previous = self
.hashmap
.insert(key.clone(), (Some(value.into()), user_data));
match &previous {
Some((Some(_), _)) => {
debug_assert_eq!(self.btree.get(&key), Some(&true));
}
None | Some((None, _)) => {
self.btree.insert(key, true);
}
}
previous
}
pub fn diff_insert_erase(
&mut self,
key: impl Into<Vec<u8>>,
user_data: T,
) -> Option<(Option<Vec<u8>>, T)> {
let key = key.into();
let previous = self.hashmap.insert(key.clone(), (None, user_data));
match &previous {
Some((None, _)) => {
debug_assert_eq!(self.btree.get(&key), Some(&false));
}
None | Some((Some(_), _)) => {
self.btree.insert(key, false);
}
}
previous
}
pub fn diff_remove(&mut self, key: impl AsRef<[u8]>) -> Option<(Option<Vec<u8>>, T)> {
let previous = self.hashmap.remove(key.as_ref());
if let Some(_previous) = &previous {
let _in_btree = self.btree.remove(key.as_ref());
debug_assert_eq!(_in_btree, Some(_previous.0.is_some()));
}
previous
}
pub fn diff_get(&self, key: &[u8]) -> Option<(Option<&[u8]>, &T)> {
self.hashmap
.get(key)
.map(|(v, ud)| (v.as_ref().map(|v| &v[..]), ud))
}
pub fn diff_iter_unordered(
&self,
) -> impl ExactSizeIterator<Item = (&[u8], Option<&[u8]>, &T)> + Clone {
self.hashmap
.iter()
.map(|(k, (v, ud))| (&k[..], v.as_ref().map(|v| &v[..]), ud))
}
pub fn diff_into_iter_unordered(
self,
) -> impl ExactSizeIterator<Item = (Vec<u8>, Option<Vec<u8>>, T)> {
self.hashmap.into_iter().map(|(k, (v, ud))| (k, v, ud))
}
pub fn diff_range_ordered<U: ?Sized>(
&self,
range: impl ops::RangeBounds<U>,
) -> impl Iterator<Item = (&[u8], bool)> + Clone
where
U: Ord,
Vec<u8>: Borrow<U>,
{
self.btree
.range(range)
.map(|(k, inserts)| (&k[..], *inserts))
}
pub fn storage_next_key<'a>(
&'a self,
key: &[u8],
in_parent_next_key: Option<&'a [u8]>,
or_equal: bool,
) -> StorageNextKey<'a> {
if let Some(in_parent_next_key) = in_parent_next_key {
assert!(in_parent_next_key > key);
}
let in_diff = self
.btree
.range::<[u8], _>((
if or_equal {
ops::Bound::Included(key)
} else {
ops::Bound::Excluded(key)
},
ops::Bound::Unbounded,
))
.next();
match (in_parent_next_key, in_diff) {
(Some(a), Some((b, true))) if a <= &b[..] => StorageNextKey::Found(Some(a)),
(Some(a), Some((b, false))) if a < &b[..] => StorageNextKey::Found(Some(a)),
(Some(a), Some((b, false))) => {
debug_assert!(a >= &b[..]); debug_assert!(&b[..] > key || or_equal);
StorageNextKey::NextOf(b)
}
(Some(a), Some((b, true))) => {
debug_assert!(a >= &b[..]);
StorageNextKey::Found(Some(&b[..]))
}
(Some(a), None) => StorageNextKey::Found(Some(a)),
(None, Some((b, true))) => StorageNextKey::Found(Some(&b[..])),
(None, Some((b, false))) => {
debug_assert!(&b[..] > key || or_equal);
let found = self
.btree
.range::<[u8], _>((ops::Bound::Excluded(&b[..]), ops::Bound::Unbounded))
.find(|(_, value)| **value)
.map(|(key, _)| &key[..]);
StorageNextKey::Found(found)
}
(None, None) => StorageNextKey::Found(None),
}
}
pub fn merge(&mut self, other: &TrieDiff<T>)
where
T: Clone,
{
self.merge_map(other, |v| v.clone())
}
pub fn merge_map<U>(&mut self, other: &TrieDiff<U>, mut map: impl FnMut(&U) -> T) {
for (key, (value, user_data)) in &other.hashmap {
self.hashmap
.insert(key.clone(), (value.clone(), map(user_data)));
self.btree.insert(key.clone(), value.is_some());
}
}
}
impl<T> fmt::Debug for TrieDiff<T>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
fmt::Debug::fmt(&self.hashmap, f)
}
}
impl<T, U> cmp::PartialEq<TrieDiff<U>> for TrieDiff<T>
where
T: cmp::PartialEq<U>,
{
fn eq(&self, other: &TrieDiff<U>) -> bool {
if self.hashmap.len() != other.hashmap.len() {
return false;
}
self.hashmap.iter().all(|(key, (v1, u1))| {
other
.hashmap
.get(key)
.map_or(false, |(v2, u2)| *v1 == *v2 && *u1 == *u2)
})
}
}
impl<T> cmp::Eq for TrieDiff<T> where T: cmp::Eq {}
impl<T> Default for TrieDiff<T> {
fn default() -> Self {
TrieDiff::empty()
}
}
impl<T> FromIterator<(Vec<u8>, Option<Vec<u8>>, T)> for TrieDiff<T> {
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = (Vec<u8>, Option<Vec<u8>>, T)>,
{
let hashmap = iter
.into_iter()
.map(|(k, v, ud)| (k, (v, ud)))
.collect::<HashMap<Vec<u8>, (Option<Vec<u8>>, T), fnv::FnvBuildHasher>>();
let btree = hashmap
.iter()
.map(|(k, (v, _))| (k.clone(), v.is_some()))
.collect();
Self { btree, hashmap }
}
}
pub enum StorageNextKey<'a> {
Found(Option<&'a [u8]>),
NextOf(&'a [u8]),
}