use crate::collections::{append, append_slice, Vector};
use crate::{env, IntoStorageKey};
use borsh::{BorshDeserialize, BorshSerialize};
use std::mem::size_of;
const ERR_INCONSISTENT_STATE: &str = "The collection is an inconsistent state. Did previous smart contract execution terminate unexpectedly?";
const ERR_ELEMENT_SERIALIZATION: &str = "Cannot serialize element with Borsh";
#[derive(BorshSerialize, BorshDeserialize)]
pub struct UnorderedSet<T> {
element_index_prefix: Vec<u8>,
elements: Vector<T>,
}
impl<T> UnorderedSet<T> {
pub fn len(&self) -> u64 {
self.elements.len()
}
pub fn is_empty(&self) -> bool {
self.elements.is_empty()
}
pub fn new<S>(prefix: S) -> Self
where
S: IntoStorageKey,
{
let prefix = prefix.into_storage_key();
let element_index_prefix = append(&prefix, b'i');
let elements_prefix = append(&prefix, b'e');
Self { element_index_prefix, elements: Vector::new(elements_prefix) }
}
fn serialize_index(index: u64) -> [u8; size_of::<u64>()] {
index.to_le_bytes()
}
fn deserialize_index(raw_index: &[u8]) -> u64 {
let mut result = [0u8; size_of::<u64>()];
result.copy_from_slice(raw_index);
u64::from_le_bytes(result)
}
fn raw_element_to_index_lookup(&self, element_raw: &[u8]) -> Vec<u8> {
append_slice(&self.element_index_prefix, element_raw)
}
fn contains_raw(&self, element_raw: &[u8]) -> bool {
let index_lookup = self.raw_element_to_index_lookup(element_raw);
env::storage_has_key(&index_lookup)
}
pub fn insert_raw(&mut self, element_raw: &[u8]) -> bool {
let index_lookup = self.raw_element_to_index_lookup(element_raw);
match env::storage_read(&index_lookup) {
Some(_index_raw) => false,
None => {
let next_index = self.len();
let next_index_raw = Self::serialize_index(next_index);
env::storage_write(&index_lookup, &next_index_raw);
self.elements.push_raw(element_raw);
true
}
}
}
pub fn remove_raw(&mut self, element_raw: &[u8]) -> bool {
let index_lookup = self.raw_element_to_index_lookup(element_raw);
match env::storage_read(&index_lookup) {
Some(index_raw) => {
#[allow(clippy::branches_sharing_code)]
if self.len() == 1 {
env::storage_remove(&index_lookup);
} else {
let last_element_raw = match self.elements.get_raw(self.len() - 1) {
Some(x) => x,
None => env::panic_str(ERR_INCONSISTENT_STATE),
};
env::storage_remove(&index_lookup);
if last_element_raw != element_raw {
let last_lookup_element =
self.raw_element_to_index_lookup(&last_element_raw);
env::storage_write(&last_lookup_element, &index_raw);
}
}
let index = Self::deserialize_index(&index_raw);
self.elements.swap_remove_raw(index);
true
}
None => false,
}
}
}
impl<T> UnorderedSet<T>
where
T: BorshSerialize + BorshDeserialize,
{
fn serialize_element(element: &T) -> Vec<u8> {
match element.try_to_vec() {
Ok(x) => x,
Err(_) => env::panic_str(ERR_ELEMENT_SERIALIZATION),
}
}
pub fn contains(&self, element: &T) -> bool {
self.contains_raw(&Self::serialize_element(element))
}
pub fn remove(&mut self, element: &T) -> bool {
self.remove_raw(&Self::serialize_element(element))
}
pub fn insert(&mut self, element: &T) -> bool {
self.insert_raw(&Self::serialize_element(element))
}
pub fn clear(&mut self) {
for raw_element in self.elements.iter_raw() {
let index_lookup = self.raw_element_to_index_lookup(&raw_element);
env::storage_remove(&index_lookup);
}
self.elements.clear();
}
pub fn to_vec(&self) -> std::vec::Vec<T> {
self.iter().collect()
}
pub fn iter(&self) -> impl Iterator<Item = T> + '_ {
self.elements.iter()
}
pub fn extend<IT: IntoIterator<Item = T>>(&mut self, iter: IT) {
for el in iter {
self.insert(&el);
}
}
pub fn as_vector(&self) -> &Vector<T> {
&self.elements
}
}
impl<T> std::fmt::Debug for UnorderedSet<T>
where
T: std::fmt::Debug + BorshSerialize + BorshDeserialize,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("UnorderedSet")
.field("element_index_prefix", &self.element_index_prefix)
.field("elements", &self.elements)
.finish()
}
}
#[cfg(not(target_arch = "wasm32"))]
#[cfg(test)]
mod tests {
use crate::collections::UnorderedSet;
use rand::seq::SliceRandom;
use rand::{Rng, SeedableRng};
use std::collections::HashSet;
use std::iter::FromIterator;
#[test]
pub fn test_insert_one() {
let mut map = UnorderedSet::new(b"m");
assert!(map.insert(&1));
assert!(!map.insert(&1));
}
#[test]
pub fn test_insert() {
let mut set = UnorderedSet::new(b"s");
let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(0);
for _ in 0..500 {
let key = rng.gen::<u64>();
set.insert(&key);
}
}
#[test]
pub fn test_insert_remove() {
let mut set = UnorderedSet::new(b"s");
let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(1);
let mut keys = vec![];
for _ in 0..100 {
let key = rng.gen::<u64>();
keys.push(key);
set.insert(&key);
}
keys.shuffle(&mut rng);
for key in keys {
assert!(set.remove(&key));
}
}
#[test]
pub fn test_remove_last_reinsert() {
let mut set = UnorderedSet::new(b"s");
let key1 = 1u64;
set.insert(&key1);
let key2 = 2u64;
set.insert(&key2);
let actual = set.remove(&key2);
assert!(actual);
let actual_reinsert = set.insert(&key2);
assert!(actual_reinsert);
}
#[test]
pub fn test_insert_override_remove() {
let mut set = UnorderedSet::new(b"s");
let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(2);
let mut keys = vec![];
for _ in 0..100 {
let key = rng.gen::<u64>();
keys.push(key);
set.insert(&key);
}
keys.shuffle(&mut rng);
for key in &keys {
assert!(!set.insert(key));
}
keys.shuffle(&mut rng);
for key in keys {
assert!(set.remove(&key));
}
}
#[test]
pub fn test_contains_non_existent() {
let mut set = UnorderedSet::new(b"s");
let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(3);
let mut set_tmp = HashSet::new();
for _ in 0..500 {
let key = rng.gen::<u64>() % 20_000;
set_tmp.insert(key);
set.insert(&key);
}
for _ in 0..500 {
let key = rng.gen::<u64>() % 20_000;
assert_eq!(set.contains(&key), set_tmp.contains(&key));
}
}
#[test]
pub fn test_to_vec() {
let mut set = UnorderedSet::new(b"s");
let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(4);
let mut keys = HashSet::new();
for _ in 0..500 {
let key = rng.gen::<u64>();
keys.insert(key);
set.insert(&key);
}
let actual = HashSet::from_iter(set.to_vec());
assert_eq!(actual, keys);
}
#[test]
pub fn test_clear() {
let mut set = UnorderedSet::new(b"s");
let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(5);
for _ in 0..10 {
for _ in 0..=(rng.gen::<u64>() % 20 + 1) {
let key = rng.gen::<u64>();
set.insert(&key);
}
assert!(!set.to_vec().is_empty());
set.clear();
assert!(set.to_vec().is_empty());
}
}
#[test]
pub fn test_iter() {
let mut set = UnorderedSet::new(b"s");
let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(4);
let mut keys = HashSet::new();
for _ in 0..500 {
let key = rng.gen::<u64>();
keys.insert(key);
set.insert(&key);
}
let actual: HashSet<u64> = set.iter().collect();
assert_eq!(actual, keys);
}
#[test]
pub fn test_extend() {
let mut set = UnorderedSet::new(b"s");
let mut rng = rand_xorshift::XorShiftRng::seed_from_u64(4);
let mut keys = HashSet::new();
for _ in 0..100 {
let key = rng.gen::<u64>();
keys.insert(key);
set.insert(&key);
}
for _ in 0..10 {
let mut tmp = vec![];
for _ in 0..=(rng.gen::<u64>() % 20 + 1) {
let key = rng.gen::<u64>();
tmp.push(key);
}
keys.extend(tmp.iter().cloned());
set.extend(tmp.iter().cloned());
}
let actual: HashSet<u64> = set.iter().collect();
assert_eq!(actual, keys);
}
#[test]
fn test_debug() {
let mut set = UnorderedSet::new(b"m");
set.insert(&1u64);
set.insert(&3u64);
set.insert(&2u64);
if cfg!(feature = "expensive-debug") {
assert_eq!(
format!("{:?}", set),
"UnorderedSet { element_index_prefix: [109, 105], elements: [1, 3, 2] }"
);
} else {
assert_eq!(
format!("{:?}", set),
"UnorderedSet { element_index_prefix: [109, 105], elements: Vector { len: 3, prefix: [109, 101] } }"
);
}
}
}