use crate::{
self as concordium_std, cmp::Ordering, marker::PhantomData, mem, prims, vec::Vec, Deletable,
Deserial, DeserialWithState, Get, HasStateApi, ParseResult, Read, Serial, Serialize, StateApi,
StateItemPrefix, StateMap, StateRef, StateRefMut, UnwrapAbort, Write, STATE_ITEM_PREFIX_SIZE,
};
#[derive(Serial)]
pub struct StateBTreeMap<K, V, const M: usize = 8> {
pub(crate) key_value: StateMap<K, V, StateApi>,
pub(crate) key_order: StateBTreeSet<K, M>,
}
impl<const M: usize, K, V> StateBTreeMap<K, V, M> {
#[must_use]
pub fn insert(&mut self, key: K, value: V) -> Option<V>
where
K: Serialize + Ord,
V: Serial + DeserialWithState<StateApi>, {
let old_value_option = self.key_value.insert_borrowed(&key, value);
if old_value_option.is_none() && !self.key_order.insert(key) {
crate::trap();
}
old_value_option
}
#[must_use]
pub fn remove_and_get(&mut self, key: &K) -> Option<V>
where
K: Serialize + Ord,
V: Serial + DeserialWithState<StateApi> + Deletable, {
let v = self.key_value.remove_and_get(key);
if v.is_some() && !self.key_order.remove(key) {
crate::trap();
}
v
}
pub fn remove(&mut self, key: &K)
where
K: Serialize + Ord,
V: Serial + DeserialWithState<StateApi> + Deletable, {
if self.key_order.remove(key) {
self.key_value.remove(key);
}
}
pub fn get(&self, key: &K) -> Option<StateRef<V>>
where
K: Serialize,
V: Serial + DeserialWithState<StateApi>, {
if self.key_order.is_empty() {
None
} else {
self.key_value.get(key)
}
}
pub fn get_mut(&mut self, key: &K) -> Option<StateRefMut<V, StateApi>>
where
K: Serialize,
V: Serial + DeserialWithState<StateApi>, {
if self.key_order.is_empty() {
None
} else {
self.key_value.get_mut(key)
}
}
#[inline(always)]
pub fn contains_key(&self, key: &K) -> bool
where
K: Serialize + Ord, {
self.key_order.contains(key)
}
#[inline(always)]
pub fn higher(&self, key: &K) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
self.key_order.higher(key)
}
#[inline(always)]
pub fn eq_or_higher(&self, key: &K) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
self.key_order.eq_or_higher(key)
}
#[inline(always)]
pub fn lower(&self, key: &K) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
self.key_order.lower(key)
}
#[inline(always)]
pub fn eq_or_lower(&self, key: &K) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
self.key_order.eq_or_lower(key)
}
#[inline(always)]
pub fn first_key(&self) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
self.key_order.first()
}
#[inline(always)]
pub fn last_key(&self) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
self.key_order.last()
}
#[inline(always)]
pub fn len(&self) -> u32 { self.key_order.len() }
#[inline(always)]
pub fn is_empty(&self) -> bool { self.key_order.is_empty() }
#[inline(always)]
pub fn iter(&self) -> StateBTreeMapIter<K, V, M> {
StateBTreeMapIter {
key_iter: self.key_order.iter(),
map: &self.key_value,
}
}
pub fn clear(&mut self)
where
K: Serialize,
V: Serial + DeserialWithState<StateApi> + Deletable, {
self.key_value.clear();
self.key_order.clear();
}
pub fn clear_flat(&mut self)
where
K: Serialize,
V: Serialize, {
self.key_value.clear_flat();
self.key_order.clear();
}
}
pub struct StateBTreeSet<K, const M: usize = 8> {
_marker_key: PhantomData<K>,
prefix: StateItemPrefix,
state_api: StateApi,
root: Option<NodeId>,
len: u32,
next_node_id: NodeId,
}
impl<const M: usize, K> StateBTreeSet<K, M> {
pub(crate) fn new(state_api: StateApi, prefix: StateItemPrefix) -> Self {
Self {
_marker_key: Default::default(),
prefix,
state_api,
root: None,
len: 0,
next_node_id: NodeId {
id: 0,
},
}
}
pub fn insert(&mut self, key: K) -> bool
where
K: Serialize + Ord, {
let Some(root_id) = self.root else {
let (node_id, _) = self.create_node(crate::vec![key], Vec::new());
self.root = Some(node_id);
self.len = 1;
return true;
};
let root_node = self.get_node_mut(root_id);
if !root_node.is_full() {
let new = self.insert_non_full(root_node, key);
if new {
self.len += 1;
}
return new;
} else if root_node.keys.binary_search(&key).is_ok() {
return false;
}
let (new_root_id, mut new_root) = self.create_node(Vec::new(), crate::vec![root_id]);
self.root = Some(new_root_id);
let mut child = root_node;
let new_larger_child = self.split_child(&mut new_root, 0, &mut child);
let key_in_root = unsafe { new_root.keys.get_unchecked(0) };
let child = if key_in_root < &key {
new_larger_child
} else {
child
};
let new = self.insert_non_full(child, key);
if new {
self.len += 1;
}
new
}
pub fn contains(&self, key: &K) -> bool
where
K: Serialize + Ord, {
let Some(root_node_id) = self.root else {
return false;
};
let mut node = self.get_node(root_node_id);
loop {
let Err(child_index) = node.keys.binary_search(key) else {
return true;
};
if node.is_leaf() {
return false;
}
let child_node_id = unsafe { *node.children.get_unchecked(child_index) };
node = self.get_node(child_node_id);
}
}
pub fn len(&self) -> u32 { self.len }
pub fn is_empty(&self) -> bool { self.root.is_none() }
pub fn iter(&self) -> StateBTreeSetIter<K, M> {
StateBTreeSetIter {
length: self.len.try_into().unwrap_abort(),
next_node: self.root,
depth_first_stack: Vec::new(),
tree: self,
_marker_lifetime: Default::default(),
}
}
pub fn clear(&mut self) {
self.root = None;
self.next_node_id = NodeId {
id: 0,
};
self.len = 0;
self.state_api.delete_prefix(&self.prefix).unwrap_abort();
}
pub fn higher(&self, key: &K) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
let Some(root_node_id) = self.root else {
return None;
};
let mut node = self.get_node(root_node_id);
let mut higher_so_far = None;
loop {
let higher_key_index = match node.keys.binary_search(key) {
Ok(index) => index + 1,
Err(index) => index,
};
if node.is_leaf() {
return if higher_key_index < node.keys.len() {
Some(StateRef::new(node.keys.swap_remove(higher_key_index)))
} else {
higher_so_far
};
} else {
if higher_key_index < node.keys.len() {
higher_so_far = Some(StateRef::new(node.keys.swap_remove(higher_key_index)))
}
let child_node_id = unsafe { *node.children.get_unchecked(higher_key_index) };
node = self.get_node(child_node_id);
}
}
}
pub fn eq_or_higher(&self, key: &K) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
let Some(root_node_id) = self.root else {
return None;
};
let mut node = self.get_node(root_node_id);
let mut higher_so_far = None;
loop {
let higher_key_index = match node.keys.binary_search(key) {
Ok(index) => {
return Some(StateRef::new(node.keys.swap_remove(index)));
}
Err(index) => index,
};
if node.is_leaf() {
return if higher_key_index < node.keys.len() {
Some(StateRef::new(node.keys.swap_remove(higher_key_index)))
} else {
higher_so_far
};
} else {
if higher_key_index < node.keys.len() {
higher_so_far = Some(StateRef::new(node.keys.swap_remove(higher_key_index)))
}
let child_node_id = unsafe { *node.children.get_unchecked(higher_key_index) };
node = self.get_node(child_node_id);
}
}
}
pub fn lower(&self, key: &K) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
let Some(root_node_id) = self.root else {
return None;
};
let mut node = self.get_node(root_node_id);
let mut lower_so_far = None;
loop {
let lower_key_index = match node.keys.binary_search(key) {
Ok(index) => index,
Err(index) => index,
};
if node.is_leaf() {
return if lower_key_index == 0 {
lower_so_far
} else {
Some(StateRef::new(node.keys.swap_remove(lower_key_index - 1)))
};
} else {
if lower_key_index > 0 {
lower_so_far = Some(StateRef::new(node.keys.swap_remove(lower_key_index - 1)));
}
let child_node_id = unsafe { node.children.get_unchecked(lower_key_index) };
node = self.get_node(*child_node_id)
}
}
}
pub fn eq_or_lower(&self, key: &K) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
let Some(root_node_id) = self.root else {
return None;
};
let mut node = self.get_node(root_node_id);
let mut lower_so_far = None;
loop {
let lower_key_index = match node.keys.binary_search(key) {
Ok(index) => {
return Some(StateRef::new(node.keys.swap_remove(index)));
}
Err(index) => index,
};
if node.is_leaf() {
return if lower_key_index == 0 {
lower_so_far
} else {
Some(StateRef::new(node.keys.swap_remove(lower_key_index - 1)))
};
} else {
if lower_key_index > 0 {
lower_so_far = Some(StateRef::new(node.keys.swap_remove(lower_key_index - 1)));
}
let child_node_id = unsafe { node.children.get_unchecked(lower_key_index) };
node = self.get_node(*child_node_id)
}
}
}
pub fn first(&self) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
let Some(root_node_id) = self.root else {
return None;
};
let mut root = self.get_node(root_node_id);
if root.is_leaf() {
Some(StateRef::new(root.keys.swap_remove(0)))
} else {
Some(StateRef::new(self.get_lowest_key(&root, 0)))
}
}
pub fn last(&self) -> Option<StateRef<K>>
where
K: Serialize + Ord, {
let Some(root_node_id) = self.root else {
return None;
};
let mut root = self.get_node(root_node_id);
if root.is_leaf() {
Some(StateRef::new(root.keys.pop().unwrap_abort()))
} else {
Some(StateRef::new(self.get_highest_key(&root, root.children.len() - 1)))
}
}
pub fn remove(&mut self, key: &K) -> bool
where
K: Ord + Serialize, {
let Some(root_node_id) = self.root else {
return false;
};
let deleted_something = {
let mut node = self.get_node_mut(root_node_id);
loop {
match node.keys.binary_search(key) {
Ok(index) => {
if node.is_leaf() {
node.keys.remove(index);
break true;
}
let mut left_child = self.get_node_mut(node.children[index]);
if !left_child.is_at_min() {
node.keys[index] = self.remove_largest_key(left_child);
break true;
}
let right_child = self.get_node_mut(node.children[index + 1]);
if !right_child.is_at_min() {
node.keys[index] = self.remove_smallest_key(right_child);
break true;
}
self.merge(&mut node, index, &mut left_child, right_child);
node = left_child;
continue;
}
Err(index) => {
if node.is_leaf() {
break false;
}
node = self.prepare_child_for_key_removal(node, index);
}
};
}
};
let root = self.get_node_mut(root_node_id);
if deleted_something {
self.len -= 1;
if self.len == 0 {
self.root = None;
self.delete_node(root_node_id, root);
return true;
}
}
if root.keys.is_empty() {
self.root = Some(root.children[0]);
self.delete_node(root_node_id, root);
}
deleted_something
}
fn remove_largest_key(&mut self, mut node: StateRefMut<'_, Node<M, K>, StateApi>) -> K
where
K: Ord + Serialize, {
while !node.is_leaf() {
let child_index = node.children.len() - 1;
node = self.prepare_child_for_key_removal(node, child_index);
}
node.keys.pop().unwrap_abort()
}
fn remove_smallest_key(&mut self, mut node: StateRefMut<'_, Node<M, K>, StateApi>) -> K
where
K: Ord + Serialize, {
while !node.is_leaf() {
let child_index = 0;
node = self.prepare_child_for_key_removal(node, child_index);
}
node.keys.remove(0)
}
fn prepare_child_for_key_removal<'c>(
&mut self,
mut node: StateRefMut<Node<M, K>, StateApi>,
index: usize,
) -> StateRefMut<'c, Node<M, K>, StateApi>
where
K: Ord + Serialize, {
let mut child = self.get_node_mut(node.children[index]);
if !child.is_at_min() {
return child;
}
let has_smaller_sibling = 0 < index;
let has_larger_sibling = index < node.children.len() - 1;
let smaller_sibling = if has_smaller_sibling {
let mut smaller_sibling = self.get_node_mut(node.children[index - 1]);
if !smaller_sibling.is_at_min() {
let largest_key_sibling = smaller_sibling.keys.pop().unwrap_abort();
let swapped_node_key = mem::replace(&mut node.keys[index - 1], largest_key_sibling);
child.keys.insert(0, swapped_node_key);
if !child.is_leaf() {
child.children.insert(0, smaller_sibling.children.pop().unwrap_abort());
}
return child;
}
Some(smaller_sibling)
} else {
None
};
let larger_sibling = if has_larger_sibling {
let mut larger_sibling = self.get_node_mut(node.children[index + 1]);
if !larger_sibling.is_at_min() {
let first_key_sibling = larger_sibling.keys.remove(0);
let swapped_node_key = mem::replace(&mut node.keys[index], first_key_sibling);
child.keys.push(swapped_node_key);
if !child.is_leaf() {
child.children.push(larger_sibling.children.remove(0));
}
return child;
}
Some(larger_sibling)
} else {
None
};
if let Some(sibling) = larger_sibling {
self.merge(&mut node, index, &mut child, sibling);
child
} else if let Some(mut sibling) = smaller_sibling {
self.merge(&mut node, index - 1, &mut sibling, child);
sibling
} else {
crate::trap();
}
}
fn get_highest_key(&self, node: &Node<M, K>, child_index: usize) -> K
where
K: Ord + Serialize, {
let mut node = self.get_node(node.children[child_index]);
while !node.is_leaf() {
let child_node_id = node.children.last().unwrap_abort();
node = self.get_node(*child_node_id);
}
node.keys.pop().unwrap_abort()
}
fn get_lowest_key(&self, node: &Node<M, K>, child_index: usize) -> K
where
K: Ord + Serialize, {
let mut node = self.get_node(node.children[child_index]);
while !node.is_leaf() {
let child_node_id = node.children.first().unwrap_abort();
node = self.get_node(*child_node_id);
}
node.keys.swap_remove(0)
}
fn merge(
&mut self,
parent_node: &mut Node<M, K>,
index: usize,
child: &mut Node<M, K>,
mut larger_child: StateRefMut<Node<M, K>, StateApi>,
) where
K: Ord + Serialize, {
let parent_key = parent_node.keys.remove(index);
let larger_child_id = parent_node.children.remove(index + 1);
child.keys.push(parent_key);
child.keys.append(&mut larger_child.keys);
child.children.append(&mut larger_child.children);
self.delete_node(larger_child_id, larger_child);
}
fn create_node<'b>(
&mut self,
keys: Vec<K>,
children: Vec<NodeId>,
) -> (NodeId, StateRefMut<'b, Node<M, K>, StateApi>)
where
K: Serialize, {
let node_id = self.next_node_id.fetch_and_add();
let node = Node {
keys,
children,
};
let entry = self.state_api.create_entry(&node_id.as_key(&self.prefix)).unwrap_abort();
let mut ref_mut: StateRefMut<'_, Node<M, K>, StateApi> =
StateRefMut::new(entry, self.state_api.clone());
ref_mut.set(node);
(node_id, ref_mut)
}
fn delete_node(&mut self, node_id: NodeId, node: StateRefMut<Node<M, K>, StateApi>)
where
K: Serial, {
let key = node_id.as_key(&self.prefix);
node.drop_without_storing();
unsafe { prims::state_delete_entry(key.as_ptr(), key.len() as u32) };
}
fn insert_non_full(&mut self, initial_node: StateRefMut<Node<M, K>, StateApi>, key: K) -> bool
where
K: Serialize + Ord, {
let mut node = initial_node;
loop {
let Err(insert_index) = node.keys.binary_search(&key) else {
return false;
};
if node.is_leaf() {
node.keys.insert(insert_index, key);
return true;
}
let child_id = unsafe { node.children.get_unchecked(insert_index) };
let mut child = self.get_node_mut(*child_id);
node = if !child.is_full() {
child
} else {
let larger_child = self.split_child(&mut node, insert_index, &mut child);
let moved_up_key = &node.keys[insert_index];
match moved_up_key.cmp(&key) {
Ordering::Equal => return false,
Ordering::Less => larger_child,
Ordering::Greater => child,
}
};
}
}
fn split_child<'b>(
&mut self,
node: &mut Node<M, K>,
child_index: usize,
child: &mut Node<M, K>,
) -> StateRefMut<'b, Node<M, K>, StateApi>
where
K: Serialize + Ord, {
let split_index = Node::<M, K>::MINIMUM_KEY_LEN + 1;
let (new_larger_sibling_id, new_larger_sibling) = self.create_node(
child.keys.split_off(split_index),
if child.is_leaf() {
Vec::new()
} else {
child.children.split_off(split_index)
},
);
let key = child.keys.pop().unwrap_abort();
node.children.insert(child_index + 1, new_larger_sibling_id);
node.keys.insert(child_index, key);
new_larger_sibling
}
fn get_node<Key>(&self, node_id: NodeId) -> Node<M, Key>
where
Key: Deserial, {
let key = node_id.as_key(&self.prefix);
let mut entry = self.state_api.lookup_entry(&key).unwrap_abort();
entry.get().unwrap_abort()
}
fn get_node_mut<'b>(&mut self, node_id: NodeId) -> StateRefMut<'b, Node<M, K>, StateApi>
where
K: Serial, {
let key = node_id.as_key(&self.prefix);
let entry = self.state_api.lookup_entry(&key).unwrap_abort();
StateRefMut::new(entry, self.state_api.clone())
}
}
pub struct StateBTreeSetIter<'a, 'b, K, const M: usize> {
length: usize,
next_node: Option<NodeId>,
depth_first_stack: Vec<(Node<M, KeyWrapper<K>>, usize)>,
tree: &'a StateBTreeSet<K, M>,
_marker_lifetime: PhantomData<&'b K>,
}
impl<'a, 'b, const M: usize, K> Iterator for StateBTreeSetIter<'a, 'b, K, M>
where
'a: 'b,
K: Deserial,
{
type Item = StateRef<'b, K>;
fn next(&mut self) -> Option<Self::Item> {
while let Some(id) = self.next_node.take() {
let node = self.tree.get_node(id);
if !node.is_leaf() {
self.next_node = Some(node.children[0]);
}
self.depth_first_stack.push((node, 0));
}
let (node, index) = self.depth_first_stack.last_mut()?;
let key = node.keys[*index].key.take().unwrap_abort();
*index += 1;
let no_more_keys = index == &node.keys.len();
if !node.is_leaf() {
let child_id = node.children[*index];
self.next_node = Some(child_id);
}
if no_more_keys {
let _ = self.depth_first_stack.pop();
}
self.length -= 1;
Some(StateRef::new(key))
}
fn size_hint(&self) -> (usize, Option<usize>) { (self.length, Some(self.length)) }
}
pub struct StateBTreeMapIter<'a, 'b, K, V, const M: usize> {
key_iter: StateBTreeSetIter<'a, 'b, K, M>,
map: &'a StateMap<K, V, StateApi>,
}
impl<'a, 'b, const M: usize, K, V> Iterator for StateBTreeMapIter<'a, 'b, K, V, M>
where
'a: 'b,
K: Serialize,
V: Serial + DeserialWithState<StateApi> + 'b,
{
type Item = (StateRef<'b, K>, StateRef<'b, V>);
fn next(&mut self) -> Option<Self::Item> {
let next_key = self.key_iter.next()?;
let value = self.map.get(&next_key).unwrap_abort();
Some((next_key, value))
}
fn size_hint(&self) -> (usize, Option<usize>) { self.key_iter.size_hint() }
}
#[derive(Debug, Copy, Clone, Serialize)]
#[repr(transparent)]
struct NodeId {
id: u64,
}
const BTREE_NODE_KEY_SIZE: usize = STATE_ITEM_PREFIX_SIZE + NodeId::SERIALIZED_BYTE_SIZE;
impl NodeId {
const SERIALIZED_BYTE_SIZE: usize = 8;
fn fetch_and_add(&mut self) -> Self {
let current = *self;
self.id += 1;
current
}
fn as_key(&self, prefix: &StateItemPrefix) -> [u8; BTREE_NODE_KEY_SIZE] {
let mut prefixed: [mem::MaybeUninit<u8>; BTREE_NODE_KEY_SIZE] =
unsafe { mem::MaybeUninit::uninit().assume_init() };
for (place, value) in prefixed.iter_mut().zip(prefix) {
place.write(*value);
}
let id_bytes = self.id.to_le_bytes();
for (place, value) in prefixed[STATE_ITEM_PREFIX_SIZE..].iter_mut().zip(id_bytes) {
place.write(value);
}
unsafe { mem::transmute(prefixed) }
}
}
#[derive(Debug, Serialize)]
struct Node<const M: usize, K> {
keys: Vec<K>,
children: Vec<NodeId>,
}
impl<const M: usize, K> Node<M, K> {
const MAXIMUM_CHILD_LEN: usize = 2 * M;
const MAXIMUM_KEY_LEN: usize = Self::MAXIMUM_CHILD_LEN - 1;
const MINIMUM_CHILD_LEN: usize = M;
const MINIMUM_KEY_LEN: usize = Self::MINIMUM_CHILD_LEN - 1;
#[inline(always)]
fn is_full(&self) -> bool { self.keys.len() == Self::MAXIMUM_KEY_LEN }
#[inline(always)]
fn is_leaf(&self) -> bool { self.children.is_empty() }
#[inline(always)]
fn is_at_min(&self) -> bool { self.keys.len() == Self::MINIMUM_KEY_LEN }
}
#[repr(transparent)]
struct KeyWrapper<K> {
key: Option<K>,
}
impl<K: Deserial> Deserial for KeyWrapper<K> {
fn deserial<R: Read>(source: &mut R) -> ParseResult<Self> {
let key = K::deserial(source)?;
Ok(Self {
key: Some(key),
})
}
}
impl<const M: usize, K> Serial for StateBTreeSet<K, M> {
fn serial<W: Write>(&self, out: &mut W) -> Result<(), W::Err> {
self.prefix.serial(out)?;
self.root.serial(out)?;
self.len.serial(out)?;
self.next_node_id.serial(out)
}
}
impl<const M: usize, K> DeserialWithState<StateApi> for StateBTreeSet<K, M> {
fn deserial_with_state<R: Read>(state: &StateApi, source: &mut R) -> ParseResult<Self> {
let prefix = source.get()?;
let root = source.get()?;
let len = source.get()?;
let next_node_id = source.get()?;
Ok(Self {
_marker_key: Default::default(),
prefix,
state_api: state.clone(),
root,
len,
next_node_id,
})
}
}
impl<const M: usize, K, V> DeserialWithState<StateApi> for StateBTreeMap<K, V, M> {
fn deserial_with_state<R: Read>(state: &StateApi, source: &mut R) -> ParseResult<Self> {
let key_value = StateMap::deserial_with_state(state, source)?;
let key_order = StateBTreeSet::deserial_with_state(state, source)?;
Ok(Self {
key_value,
key_order,
})
}
}
impl<const M: usize, K, V> Deletable for StateBTreeMap<K, V, M>
where
K: Serialize,
V: Serial + DeserialWithState<StateApi> + Deletable,
{
fn delete(mut self) { self.clear(); }
}
#[cfg(feature = "internal-wasm-test")]
mod wasm_test_btree {
use super::*;
use crate::{claim, claim_eq, concordium_test, StateApi, StateBuilder};
#[derive(Debug)]
pub(crate) enum InvariantViolation {
NonZeroLenWithNoRoot,
ZeroKeysInRoot,
IterationOutOfOrder,
LeafAtDifferentDepth,
NodeKeysOutOfOrder,
MismatchingChildrenLenKeyLen,
KeysLenBelowMin,
KeysLenAboveMax,
LeafWithChildren,
ChildrenLenBelowMin,
ChildrenLenAboveMax,
}
impl<K, const M: usize> StateBTreeSet<K, M> {
fn check_invariants(&self) -> Result<(), InvariantViolation>
where
K: Serialize + Ord, {
use crate::ops::Deref;
let Some(root_node_id) = self.root else {
return if self.len == 0 {
Ok(())
} else {
Err(InvariantViolation::NonZeroLenWithNoRoot)
};
};
let root: Node<M, K> = self.get_node(root_node_id);
if root.keys.is_empty() {
return Err(InvariantViolation::ZeroKeysInRoot);
}
for i in 1..root.keys.len() {
if &root.keys[i - 1] >= &root.keys[i] {
return Err(InvariantViolation::NodeKeysOutOfOrder);
}
}
if root.keys.len() > Node::<M, K>::MAXIMUM_KEY_LEN {
return Err(InvariantViolation::KeysLenAboveMax);
}
if root.is_leaf() {
if !root.children.is_empty() {
return Err(InvariantViolation::LeafWithChildren);
}
} else {
if root.children.len() != root.keys.len() + 1 {
return Err(InvariantViolation::MismatchingChildrenLenKeyLen);
}
if root.children.len() > Node::<M, K>::MAXIMUM_CHILD_LEN {
return Err(InvariantViolation::ChildrenLenAboveMax);
}
}
let mut stack = vec![(0usize, root.children)];
let mut leaf_depth = None;
while let Some((node_level, mut nodes)) = stack.pop() {
while let Some(node_id) = nodes.pop() {
let node: Node<M, K> = self.get_node(node_id);
node.check_invariants()?;
if node.is_leaf() {
let depth = leaf_depth.get_or_insert(node_level);
if *depth != node_level {
return Err(InvariantViolation::LeafAtDifferentDepth);
}
} else {
stack.push((node_level + 1, node.children));
}
}
}
let mut prev = None;
for key in self.iter() {
if let Some(p) = prev.as_deref() {
if p > key.deref() {
return Err(InvariantViolation::IterationOutOfOrder);
}
}
prev = Some(key);
}
Ok(())
}
pub(crate) fn debug(&self) -> String
where
K: Serialize + std::fmt::Debug + Ord, {
let Some(root_node_id) = self.root else {
return format!("no root");
};
let mut string = String::new();
let root: Node<M, K> = self.get_node(root_node_id);
string.push_str(format!("root: {:#?}", root).as_str());
let mut stack = root.children;
while let Some(node_id) = stack.pop() {
let node: Node<M, K> = self.get_node(node_id);
string.push_str(
format!("node {} {:?}: {:#?},\n", node_id.id, node.check_invariants(), node)
.as_str(),
);
stack.extend(node.children);
}
string
}
}
impl<const M: usize, K> Node<M, K> {
pub(crate) fn check_invariants(&self) -> Result<(), InvariantViolation>
where
K: Ord, {
for i in 1..self.keys.len() {
if &self.keys[i - 1] >= &self.keys[i] {
return Err(InvariantViolation::NodeKeysOutOfOrder);
}
}
if self.keys.len() < Self::MINIMUM_KEY_LEN {
return Err(InvariantViolation::KeysLenBelowMin);
}
if self.keys.len() > Self::MAXIMUM_KEY_LEN {
return Err(InvariantViolation::KeysLenAboveMax);
}
if self.is_leaf() {
if !self.children.is_empty() {
return Err(InvariantViolation::LeafWithChildren);
}
} else {
if self.children.len() != self.keys.len() + 1 {
return Err(InvariantViolation::MismatchingChildrenLenKeyLen);
}
if self.children.len() < Self::MINIMUM_CHILD_LEN {
return Err(InvariantViolation::ChildrenLenBelowMin);
}
if self.children.len() > Self::MAXIMUM_CHILD_LEN {
return Err(InvariantViolation::ChildrenLenAboveMax);
}
}
Ok(())
}
}
#[concordium_test]
fn test_btree_insert_asc_above_max_branching_degree() {
let mut state_builder = StateBuilder::open(StateApi::open());
const M: usize = 5;
let mut tree = state_builder.new_btree_set_degree::<M, _>();
let items = (2 * M) as u32;
for n in 0..items {
claim!(tree.insert(n));
}
for n in 0..items {
claim!(tree.contains(&n));
}
claim_eq!(tree.len(), items)
}
#[concordium_test]
fn test_btree_insert_asc_height_3() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in 0..16 {
claim!(tree.insert(n));
}
for n in 0..16 {
claim!(tree.contains(&n));
}
claim_eq!(tree.len(), 16);
claim!(!tree.contains(&17));
}
#[concordium_test]
fn test_btree_insert_random_order() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
tree.insert(0);
tree.insert(3);
tree.insert(2);
tree.insert(1);
tree.insert(5);
tree.insert(7);
tree.insert(6);
tree.insert(4);
for n in 0..=7 {
claim!(tree.contains(&n));
}
claim_eq!(tree.len(), 8)
}
#[concordium_test]
fn test_btree_higher() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
tree.insert(1);
tree.insert(2);
tree.insert(3);
tree.insert(4);
tree.insert(5);
tree.insert(7);
claim_eq!(tree.higher(&0).as_deref(), Some(&1));
claim_eq!(tree.higher(&1).as_deref(), Some(&2));
claim_eq!(tree.higher(&2).as_deref(), Some(&3));
claim_eq!(tree.higher(&3).as_deref(), Some(&4));
claim_eq!(tree.higher(&4).as_deref(), Some(&5));
claim_eq!(tree.higher(&5).as_deref(), Some(&7));
claim_eq!(tree.higher(&6).as_deref(), Some(&7));
claim_eq!(tree.higher(&7).as_deref(), None);
claim_eq!(tree.higher(&8).as_deref(), None);
}
#[concordium_test]
fn test_btree_lower() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
tree.insert(1);
tree.insert(2);
tree.insert(3);
tree.insert(4);
tree.insert(5);
tree.insert(7);
claim_eq!(tree.lower(&0).as_deref(), None);
claim_eq!(tree.lower(&1).as_deref(), None);
claim_eq!(tree.lower(&2).as_deref(), Some(&1));
claim_eq!(tree.lower(&3).as_deref(), Some(&2));
claim_eq!(tree.lower(&4).as_deref(), Some(&3));
claim_eq!(tree.lower(&5).as_deref(), Some(&4));
claim_eq!(tree.lower(&6).as_deref(), Some(&5));
claim_eq!(tree.lower(&7).as_deref(), Some(&5));
claim_eq!(tree.lower(&8).as_deref(), Some(&7));
}
#[concordium_test]
fn test_btree_eq_or_higher() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
tree.insert(1);
tree.insert(2);
tree.insert(3);
tree.insert(4);
tree.insert(5);
tree.insert(7);
claim_eq!(tree.eq_or_higher(&0).as_deref(), Some(&1));
claim_eq!(tree.eq_or_higher(&1).as_deref(), Some(&1));
claim_eq!(tree.eq_or_higher(&2).as_deref(), Some(&2));
claim_eq!(tree.eq_or_higher(&3).as_deref(), Some(&3));
claim_eq!(tree.eq_or_higher(&4).as_deref(), Some(&4));
claim_eq!(tree.eq_or_higher(&5).as_deref(), Some(&5));
claim_eq!(tree.eq_or_higher(&6).as_deref(), Some(&7));
claim_eq!(tree.eq_or_higher(&7).as_deref(), Some(&7));
claim_eq!(tree.eq_or_higher(&8).as_deref(), None);
}
#[concordium_test]
fn test_btree_eq_or_lower() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
tree.insert(1);
tree.insert(2);
tree.insert(3);
tree.insert(4);
tree.insert(5);
tree.insert(7);
claim_eq!(tree.eq_or_lower(&0).as_deref(), None);
claim_eq!(tree.eq_or_lower(&1).as_deref(), Some(&1));
claim_eq!(tree.eq_or_lower(&2).as_deref(), Some(&2));
claim_eq!(tree.eq_or_lower(&3).as_deref(), Some(&3));
claim_eq!(tree.eq_or_lower(&4).as_deref(), Some(&4));
claim_eq!(tree.eq_or_lower(&5).as_deref(), Some(&5));
claim_eq!(tree.eq_or_lower(&6).as_deref(), Some(&5));
claim_eq!(tree.eq_or_lower(&7).as_deref(), Some(&7));
claim_eq!(tree.eq_or_lower(&8).as_deref(), Some(&7));
}
#[concordium_test]
fn test_btree_insert_a_lot_then_reinsert() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in 0..500 {
claim!(tree.insert(n));
}
for n in (500..1000).into_iter().rev() {
claim!(tree.insert(n));
}
for n in 0..1000 {
claim!(tree.contains(&n))
}
claim_eq!(tree.len(), 1000);
for n in 0..1000 {
claim!(!tree.insert(n))
}
claim_eq!(tree.len(), 1000)
}
#[concordium_test]
fn test_btree_remove_from_one_node_tree() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in 0..3 {
tree.insert(n);
}
claim!(tree.contains(&1));
claim!(tree.remove(&1));
claim!(tree.contains(&0));
claim!(!tree.contains(&1));
claim!(tree.contains(&2));
}
#[concordium_test]
fn test_btree_remove_only_key_lower_leaf_in_three_node() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in 0..4 {
tree.insert(n);
}
tree.remove(&3);
claim!(tree.remove(&0));
claim!(!tree.contains(&0));
claim!(tree.contains(&1));
claim!(tree.contains(&2));
}
#[concordium_test]
fn test_btree_remove_only_key_higher_leaf_in_three_node() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in (0..4).into_iter().rev() {
tree.insert(n);
}
tree.remove(&3);
claim!(tree.contains(&2));
claim!(tree.remove(&2));
claim!(tree.contains(&0));
claim!(tree.contains(&1));
claim!(!tree.contains(&2));
}
#[concordium_test]
fn test_btree_remove_from_higher_leaf_in_three_node_taking_from_sibling() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in (0..4).into_iter().rev() {
tree.insert(n);
}
claim!(tree.contains(&3));
claim!(tree.remove(&3));
claim!(!tree.contains(&3));
}
#[concordium_test]
fn test_btree_remove_from_lower_leaf_in_three_node_taking_from_sibling() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in 0..4 {
tree.insert(n);
}
claim!(tree.remove(&0));
claim!(!tree.contains(&0));
claim!(tree.contains(&1));
claim!(tree.contains(&2));
claim!(tree.contains(&3));
}
#[concordium_test]
fn test_btree_remove_from_root_in_three_node_causing_merge() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in 0..4 {
tree.insert(n);
}
tree.remove(&3);
claim!(tree.remove(&1));
claim!(tree.contains(&0));
claim!(!tree.contains(&1));
claim!(tree.contains(&2));
}
#[concordium_test]
fn test_btree_remove_from_root_in_three_node_taking_key_from_higher_child() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in 0..4 {
tree.insert(n);
}
claim!(tree.contains(&1));
claim!(tree.remove(&1));
claim!(!tree.contains(&1));
}
#[concordium_test]
fn test_btree_remove_from_root_in_three_node_taking_key_from_lower_child() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in (0..4).into_iter().rev() {
tree.insert(n);
}
claim!(tree.contains(&2));
claim!(tree.remove(&2));
claim!(!tree.contains(&2));
}
#[concordium_test]
fn test_btree_iter() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
let keys: Vec<u32> = (0..15).into_iter().collect();
for &k in &keys {
tree.insert(k);
}
let iter_keys: Vec<u32> = tree.iter().map(|k| k.clone()).collect();
claim_eq!(keys, iter_keys);
}
#[concordium_test]
fn test_btree_insert_present_key() {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for n in [0, 3, 4, 1, 2].into_iter() {
tree.insert(n);
}
claim!(!tree.insert(1));
}
#[allow(deprecated)]
mod quickcheck {
use super::super::*;
use crate::{
self as concordium_std, concordium_quickcheck, concordium_test, fail, StateApi,
StateBuilder, StateError,
};
use ::quickcheck::{Arbitrary, Gen, TestResult};
#[concordium_quickcheck]
fn quickcheck_btree_inserts(items: Vec<u32>) -> TestResult {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for k in items.clone() {
tree.insert(k);
}
if let Err(violation) = tree.check_invariants() {
return TestResult::error(format!("Invariant violated: {:?}", violation));
}
for k in items.iter() {
if !tree.contains(k) {
return TestResult::error(format!("Missing key: {}", k));
}
}
TestResult::passed()
}
#[concordium_quickcheck]
fn quickcheck_btree_clear(items: Vec<u32>) -> TestResult {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for k in items.clone() {
tree.insert(k);
}
tree.clear();
for k in items.iter() {
if tree.contains(k) {
return TestResult::error(format!("Found {k} in a cleared btree"));
}
}
let state_api = StateApi::open();
match state_api.iterator(&tree.prefix) {
Ok(node_iter) => {
let nodes_in_state = node_iter.count();
TestResult::error(format!(
"Found {} nodes still stored in the state",
nodes_in_state
))
}
Err(StateError::SubtreeWithPrefixNotFound) => TestResult::passed(),
Err(err) => {
TestResult::error(format!("Failed to get iterator for btree nodes: {err:?}"))
}
}
}
#[concordium_quickcheck(num_tests = 100)]
fn quickcheck_btree_iter(mut items: Vec<u32>) -> TestResult {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for k in items.clone() {
tree.insert(k);
}
if let Err(violation) = tree.check_invariants() {
return TestResult::error(format!("Invariant violated: {:?}", violation));
}
items.sort();
items.dedup();
for (value, expected) in tree.iter().zip(items.into_iter()) {
if *value != expected {
return TestResult::error(format!("Got {} but expected {expected}", *value));
}
}
TestResult::passed()
}
#[concordium_quickcheck(num_tests = 100)]
fn quickcheck_btree_higher_lower(mut items: Vec<u32>) -> TestResult {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
for k in items.clone() {
tree.insert(k);
}
if let Err(violation) = tree.check_invariants() {
return TestResult::error(format!("Invariant violated: {:?}", violation));
}
items.sort();
items.dedup();
for window in items.windows(2) {
let l = &window[0];
let r = &window[1];
let l_higher = tree.higher(l);
if l_higher.as_deref() != Some(r) {
return TestResult::error(format!(
"higher({l}) gave {:?} instead of the expected Some({r})",
l_higher.as_deref()
));
}
let r_lower = tree.lower(r);
if r_lower.as_deref() != Some(l) {
return TestResult::error(format!(
"lower({r}) gave {:?} instead of the expected Some({l})",
r_lower.as_deref()
));
}
let space_between = r - l > 1;
if space_between {
let l_eq_or_higher = tree.eq_or_higher(&(l + 1));
if l_eq_or_higher.as_deref() != Some(r) {
return TestResult::error(format!(
"eq_or_higher({}) gave {:?} instead of the expected Some({r})",
l + 1,
l_higher.as_deref()
));
}
}
if space_between {
let r_eq_or_lower = tree.eq_or_lower(&(r - 1));
if r_eq_or_lower.as_deref() != Some(l) {
return TestResult::error(format!(
"eq_or_lower({}) gave {:?} instead of the expected Some({l})",
r - 1,
l_higher.as_deref()
));
}
}
}
if let Some(first) = items.first() {
let lower = tree.lower(first);
if lower.is_some() {
return TestResult::error(format!(
"lower({first}) gave {:?} instead of the expected None",
lower.as_deref()
));
}
}
if let Some(last) = items.last() {
let higher = tree.higher(last);
if higher.is_some() {
return TestResult::error(format!(
"higher({last}) gave {:?} instead of the expected None",
higher.as_deref()
));
}
}
TestResult::passed()
}
#[concordium_quickcheck(num_tests = 500)]
fn quickcheck_btree_inserts_removes(mutations: Mutations<u32>) -> TestResult {
let mut state_builder = StateBuilder::open(StateApi::open());
let mut tree = state_builder.new_btree_set_degree::<2, _>();
if let Err(err) = run_mutations(&mut tree, &mutations.mutations) {
TestResult::error(format!("Error: {}, tree: {}", err, tree.debug()))
} else {
TestResult::passed()
}
}
#[derive(Debug, Clone)]
struct Mutations<K> {
expected_keys: crate::collections::BTreeSet<K>,
mutations: Vec<(K, Operation)>,
}
#[derive(Debug, Clone, Copy)]
enum Operation {
InsertKeyNotPresent,
InsertKeyPresent,
RemoveKeyPresent,
RemoveKeyNotPresent,
}
fn run_mutations<const M: usize>(
tree: &mut StateBTreeSet<u32, M>,
mutations: &[(u32, Operation)],
) -> Result<(), String> {
for (k, op) in mutations.into_iter() {
if let Err(violation) = tree.check_invariants() {
return Err(format!("Invariant violated: {:?}", violation));
}
match op {
Operation::InsertKeyPresent => {
if tree.insert(*k) {
return Err(format!("InsertKeyPresent was not present: {}", k));
}
}
Operation::InsertKeyNotPresent => {
if !tree.insert(*k) {
return Err(format!("InsertKeyNotPresent was present: {}", k));
}
}
Operation::RemoveKeyNotPresent => {
if tree.remove(k) {
return Err(format!("RemoveKeyNotPresent was present: {}", k));
}
}
Operation::RemoveKeyPresent => {
if !tree.remove(k) {
return Err(format!("RemoveKeyPresent was not present: {}", k));
}
}
}
}
Ok(())
}
impl Arbitrary for Operation {
fn arbitrary(g: &mut Gen) -> Self {
g.choose(&[
Self::InsertKeyNotPresent,
Self::InsertKeyPresent,
Self::RemoveKeyPresent,
Self::RemoveKeyNotPresent,
])
.unwrap()
.clone()
}
}
impl<K> Arbitrary for Mutations<K>
where
K: Arbitrary + Ord,
{
fn arbitrary(g: &mut Gen) -> Self {
let mut inserted_keys: Vec<K> = Vec::new();
let mut mutations = Vec::new();
while mutations.len() < g.size() {
let op: Operation = Operation::arbitrary(g);
match op {
Operation::InsertKeyPresent if inserted_keys.len() > 0 => {
let indexes: Vec<usize> =
(0..inserted_keys.len()).into_iter().collect();
let k_index = g.choose(&indexes).unwrap();
let k = &inserted_keys[*k_index];
mutations.push((k.clone(), op));
}
Operation::InsertKeyNotPresent => {
let k = K::arbitrary(g);
if let Err(index) = inserted_keys.binary_search(&k) {
inserted_keys.insert(index, k.clone());
mutations.push((k, op));
}
}
Operation::RemoveKeyPresent if inserted_keys.len() > 0 => {
let indexes: Vec<usize> =
(0..inserted_keys.len()).into_iter().collect();
let k_index = g.choose(&indexes).unwrap();
let k = inserted_keys.remove(*k_index);
mutations.push((k, op));
}
Operation::RemoveKeyNotPresent => {
let k = K::arbitrary(g);
if inserted_keys.binary_search(&k).is_err() {
mutations.push((k, op));
}
}
_ => {}
}
}
Self {
expected_keys: crate::collections::BTreeSet::from_iter(
inserted_keys.into_iter(),
),
mutations,
}
}
fn shrink(&self) -> Box<dyn Iterator<Item = Self>> {
let pop = {
let mut clone = self.clone();
clone.mutations.pop();
clone
};
let mut v = vec![pop];
for (i, (k, op)) in self.mutations.iter().enumerate() {
match op {
Operation::InsertKeyPresent | Operation::RemoveKeyNotPresent => {
let mut clone = self.clone();
clone.mutations.remove(i);
v.push(clone);
}
Operation::RemoveKeyPresent => {
let mut clone = self.clone();
let mut prev = self.mutations[0..i].iter().enumerate().rev();
clone.mutations.remove(i);
clone.expected_keys.remove(k);
loop {
if let Some((j, (k2, op))) = prev.next() {
match op {
Operation::InsertKeyPresent if k == k2 => {
clone.mutations.remove(j);
}
Operation::InsertKeyNotPresent if k == k2 => {
clone.mutations.remove(j);
break;
}
_ => {}
}
} else {
fail!("No insertion found before")
}
}
v.push(clone);
}
_ => {}
}
}
Box::new(v.into_iter())
}
}
}
}