use std::{
borrow::Borrow,
fmt::{self, Debug},
};
use oasis_borsh::{BorshDeserialize, BorshSerialize};
#[derive(Clone, Default, PartialEq, Eq)]
pub struct Set<T> {
backing: Vec<T>,
}
impl<T: Eq> Set<T> {
pub fn new() -> Self {
Self {
backing: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
backing: Vec::with_capacity(capacity),
}
}
pub fn capacity(&self) -> usize {
self.backing.capacity()
}
pub fn clear(&mut self) {
self.backing.clear()
}
pub fn contains<Q: ?Sized>(&self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: Eq,
{
self.backing.iter().any(|v| value.eq(v.borrow()))
}
pub fn difference<'a>(
&'a self,
other: &'a Self,
) -> impl Iterator<Item = &'a T> + DoubleEndedIterator {
self.backing.iter().filter(move |v| !other.contains(v))
}
pub fn drain(&mut self) -> std::vec::Drain<T> {
self.backing.drain(..)
}
pub fn get<Q: ?Sized>(&self, value: &Q) -> Option<&T>
where
T: Borrow<Q>,
Q: Eq,
{
self.backing.iter().find(|v| value.eq((*v).borrow()))
}
pub fn get_or_insert(&mut self, value: T) -> &T {
let self_ptr = self as *mut Self;
for v in self.backing.iter() {
if *v == value {
return v;
}
}
unsafe { (*self_ptr).backing.push(value) };
self.backing.last().unwrap()
}
pub fn get_or_insert_with<Q: ?Sized>(&mut self, value: &Q, f: impl FnOnce(&Q) -> T) -> &T
where
T: Borrow<Q>,
Q: Eq,
{
let self_ptr = self as *mut Self;
for v in self.backing.iter() {
if (*v).borrow() == value {
return v;
}
}
unsafe { (*self_ptr).backing.push(f(value)) };
self.backing.last().unwrap()
}
pub fn insert(&mut self, value: T) -> bool {
!self.backing.iter().any(|v| *v == value) && {
self.backing.push(value);
true
}
}
pub fn intersection<'a>(
&'a self,
other: &'a Self,
) -> impl Iterator<Item = &'a T> + DoubleEndedIterator<Item = &'a T> {
self.backing.iter().filter(move |v| other.contains(v))
}
pub fn is_disjoint<'a>(&'a self, other: &'a Self) -> bool {
self.intersection(other).count() == 0
}
pub fn is_empty(&self) -> bool {
self.backing.is_empty()
}
pub fn is_subset(&self, other: &Self) -> bool {
self.len() <= other.len() && self.difference(other).count() == 0
}
pub fn is_superset(&self, other: &Self) -> bool {
other.is_subset(self)
}
pub fn iter(&self) -> impl Iterator<Item = &T> + DoubleEndedIterator + ExactSizeIterator {
self.backing.iter()
}
pub fn len(&self) -> usize {
self.backing.len()
}
pub fn remove<Q: ?Sized>(&mut self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: Eq,
{
self.take(value).is_some()
}
pub fn replace(&mut self, value: T) -> Option<T> {
match self.backing.iter_mut().find(|v| **v == value) {
Some(v) => Some(core::mem::replace(v, value)),
None => {
self.backing.push(value);
None
}
}
}
pub fn reserve(&mut self, additional: usize) {
self.backing.reserve(additional)
}
pub fn retain(&mut self, mut f: impl FnMut(&T) -> bool) {
self.backing.drain_filter(|v| !f(v));
}
pub fn shrink_to(&mut self, min_capacity: usize) {
self.backing.shrink_to(min_capacity)
}
pub fn shrink_to_fit(&mut self) {
self.backing.shrink_to_fit()
}
pub fn symmetric_difference<'a>(
&'a self,
other: &'a Self,
) -> impl Iterator<Item = &'a T> + DoubleEndedIterator {
self.difference(other).chain(other.difference(self))
}
pub fn take<Q: ?Sized>(&mut self, value: &Q) -> Option<T>
where
T: Borrow<Q>,
Q: Eq,
{
self.backing
.iter()
.position(|v| value.eq(v.borrow()))
.map(|pos| self.backing.remove(pos))
}
pub fn try_reserve(
&mut self,
additional: usize,
) -> Result<(), std::collections::TryReserveError> {
self.backing.try_reserve(additional)
}
pub fn union<'a>(
&'a self,
other: &'a Self,
) -> impl Iterator<Item = &'a T> + DoubleEndedIterator {
self.iter().chain(other.difference(self))
}
}
impl<T: Debug> fmt::Debug for Set<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_set().entries(self.backing.iter()).finish()
}
}
impl<'a, T> IntoIterator for &'a Set<T> {
type Item = &'a T;
type IntoIter = core::slice::Iter<'a, T>;
fn into_iter(self) -> <Self as IntoIterator>::IntoIter {
self.backing.iter()
}
}
impl<T> IntoIterator for Set<T> {
type Item = T;
type IntoIter = std::vec::IntoIter<T>;
fn into_iter(self) -> <Self as IntoIterator>::IntoIter {
self.backing.into_iter()
}
}
impl<T: Eq> core::iter::FromIterator<T> for Set<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let mut this = Self::new();
this.extend(iter);
this
}
}
impl<T: Eq> Extend<T> for Set<T> {
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for item in iter {
self.insert(item);
}
}
}
impl<'a, T: 'a + Copy + Eq> Extend<&'a T> for Set<T> {
fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
for item in iter {
self.insert(*item);
}
}
}
impl<T: Clone + Eq> core::ops::BitOr<&Set<T>> for &Set<T> {
type Output = Set<T>;
fn bitor(self, rhs: &Set<T>) -> Set<T> {
self.union(rhs).cloned().collect()
}
}
impl<T: Clone + Eq> core::ops::BitAnd<&Set<T>> for &Set<T> {
type Output = Set<T>;
fn bitand(self, rhs: &Set<T>) -> Set<T> {
self.intersection(rhs).cloned().collect()
}
}
impl<T: Clone + Eq> core::ops::BitXor<&Set<T>> for &Set<T> {
type Output = Set<T>;
fn bitxor(self, rhs: &Set<T>) -> Set<T> {
self.symmetric_difference(rhs).cloned().collect()
}
}
impl<T: Clone + Eq> core::ops::Sub<&Set<T>> for &Set<T> {
type Output = Set<T>;
fn sub(self, rhs: &Set<T>) -> Set<T> {
self.difference(rhs).cloned().collect()
}
}
impl<T> BorshSerialize for Set<T>
where
T: BorshSerialize + PartialOrd,
{
fn serialize<W: std::io::Write>(&self, writer: &mut W) -> Result<(), std::io::Error> {
let mut sorted = self.backing.iter().collect::<Vec<_>>();
sorted.sort_by(|a, b| (*a).partial_cmp(b).unwrap());
(sorted.len() as u32).serialize(writer)?;
for item in sorted {
item.serialize(writer)?;
}
Ok(())
}
}
impl<T> BorshDeserialize for Set<T>
where
T: BorshDeserialize + Eq,
{
fn deserialize<R: std::io::Read>(reader: &mut R) -> Result<Self, std::io::Error> {
let len = u32::deserialize(reader)?;
let mut backing: Vec<T> = Vec::with_capacity(len as usize);
for _ in 0..len {
let item = T::deserialize(reader)?;
if backing.last().map(|prev| item == *prev).unwrap_or_default() {
continue;
}
backing.push(item);
}
Ok(Self { backing })
}
}
#[cfg(test)]
#[allow(clippy::all)]
mod test_set {
use super::*;
#[test]
fn test_zero_capacities() {
type S = Set<i32>;
let s = S::new();
assert_eq!(s.capacity(), 0);
let s = S::default();
assert_eq!(s.capacity(), 0);
let s = S::with_capacity(0);
assert_eq!(s.capacity(), 0);
let mut s = S::new();
s.insert(1);
s.insert(2);
s.remove(&1);
s.remove(&2);
s.shrink_to_fit();
assert_eq!(s.capacity(), 0);
let mut s = S::new();
s.reserve(0);
assert_eq!(s.capacity(), 0);
}
#[test]
fn test_disjoint() {
let mut xs = Set::new();
let mut ys = Set::new();
assert!(xs.is_disjoint(&ys));
assert!(ys.is_disjoint(&xs));
assert!(xs.insert(5));
assert!(ys.insert(11));
assert!(xs.is_disjoint(&ys));
assert!(ys.is_disjoint(&xs));
assert!(xs.insert(7));
assert!(xs.insert(19));
assert!(xs.insert(4));
assert!(ys.insert(2));
assert!(ys.insert(-11));
assert!(xs.is_disjoint(&ys));
assert!(ys.is_disjoint(&xs));
assert!(ys.insert(7));
assert!(!xs.is_disjoint(&ys));
assert!(!ys.is_disjoint(&xs));
}
#[test]
fn test_subset_and_superset() {
let mut a = Set::new();
assert!(a.insert(0));
assert!(a.insert(5));
assert!(a.insert(11));
assert!(a.insert(7));
let mut b = Set::new();
assert!(b.insert(0));
assert!(b.insert(7));
assert!(b.insert(19));
assert!(b.insert(250));
assert!(b.insert(11));
assert!(b.insert(200));
assert!(!a.is_subset(&b));
assert!(!a.is_superset(&b));
assert!(!b.is_subset(&a));
assert!(!b.is_superset(&a));
assert!(b.insert(5));
assert!(a.is_subset(&b));
assert!(!a.is_superset(&b));
assert!(!b.is_subset(&a));
assert!(b.is_superset(&a));
}
#[test]
fn test_iterate() {
let mut a = Set::new();
for i in 0..32 {
assert!(a.insert(i));
}
let mut observed: u32 = 0;
for k in &a {
observed |= 1 << *k;
}
assert_eq!(observed, 0xFFFF_FFFF);
}
#[test]
fn test_intersection() {
let mut a = Set::new();
let mut b = Set::new();
assert!(a.intersection(&b).next().is_none());
assert!(a.insert(11));
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(77));
assert!(a.insert(103));
assert!(a.insert(5));
assert!(a.insert(-5));
assert!(b.insert(2));
assert!(b.insert(11));
assert!(b.insert(77));
assert!(b.insert(-9));
assert!(b.insert(-42));
assert!(b.insert(5));
assert!(b.insert(3));
let mut i = 0;
let expected = [3, 5, 11, 77];
for x in a.intersection(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
assert!(a.insert(9));
i = 0;
for x in a.intersection(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
i = 0;
for x in b.intersection(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
fn test_difference() {
let mut a = Set::new();
let mut b = Set::new();
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(5));
assert!(a.insert(9));
assert!(a.insert(11));
assert!(b.insert(3));
assert!(b.insert(9));
let mut i = 0;
let expected = [1, 5, 11];
for x in a.difference(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
fn test_symmetric_difference() {
let mut a = Set::new();
let mut b = Set::new();
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(5));
assert!(a.insert(9));
assert!(a.insert(11));
assert!(b.insert(-2));
assert!(b.insert(3));
assert!(b.insert(9));
assert!(b.insert(14));
assert!(b.insert(22));
let mut i = 0;
let expected = [-2, 1, 5, 11, 14, 22];
for x in a.symmetric_difference(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
fn test_union() {
let mut a = Set::new();
let mut b = Set::new();
assert!(a.union(&b).next().is_none());
assert!(b.union(&a).next().is_none());
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(11));
assert!(a.insert(16));
assert!(a.insert(19));
assert!(a.insert(24));
assert!(b.insert(-2));
assert!(b.insert(1));
assert!(b.insert(5));
assert!(b.insert(9));
assert!(b.insert(13));
assert!(b.insert(19));
let mut i = 0;
let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
for x in a.union(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
assert!(a.insert(9)); assert!(a.insert(5));
i = 0;
for x in a.union(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
i = 0;
for x in b.union(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
fn test_from_iter() {
let xs = [1, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let set: Set<_> = xs.iter().cloned().collect();
for x in &xs {
assert!(set.contains(x));
}
assert_eq!(set.iter().len(), xs.len() - 1);
}
#[test]
fn test_move_iter() {
let hs = {
let mut hs = Set::new();
hs.insert('a');
hs.insert('b');
hs
};
let v = hs.into_iter().collect::<Vec<char>>();
assert!(v == ['a', 'b'] || v == ['b', 'a']);
}
#[test]
fn test_eq() {
let mut s1 = Set::new();
s1.insert(1);
s1.insert(2);
s1.insert(3);
let mut s2 = Set::new();
s2.insert(1);
s2.insert(2);
assert!(s1 != s2);
s2.insert(3);
assert_eq!(s1, s2);
}
#[test]
fn test_show() {
let mut set = Set::new();
let empty = Set::<i32>::new();
set.insert(1);
set.insert(2);
let set_str = format!("{:?}", set);
assert!(set_str == "{1, 2}" || set_str == "{2, 1}");
assert_eq!(format!("{:?}", empty), "{}");
}
#[test]
fn test_trivial_drain() {
let mut s = Set::<i32>::new();
for _ in s.drain() {}
assert!(s.is_empty());
drop(s);
let mut s = Set::<i32>::new();
drop(s.drain());
assert!(s.is_empty());
}
#[test]
fn test_drain() {
let mut s: Set<_> = (1..100).collect();
for _ in 0..20 {
assert_eq!(s.len(), 99);
{
let mut last_i = 0;
let mut d = s.drain();
for (i, x) in d.by_ref().take(50).enumerate() {
last_i = i;
assert!(x != 0);
}
assert_eq!(last_i, 49);
}
for _ in &s {
panic!("s should be empty!");
}
s.extend(1..100);
}
}
#[test]
fn test_replace() {
#[derive(Debug)]
struct Foo(&'static str, i32);
impl PartialEq for Foo {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for Foo {}
let mut s = Set::new();
assert_eq!(s.replace(Foo("a", 1)), None);
assert_eq!(s.len(), 1);
assert_eq!(s.replace(Foo("a", 2)), Some(Foo("a", 1)));
assert_eq!(s.len(), 1);
let mut it = s.iter();
assert_eq!(it.next(), Some(&Foo("a", 2)));
assert_eq!(it.next(), None);
}
#[test]
fn test_extend_ref() {
let mut a = Set::new();
a.insert(1);
a.extend(&[2, 3, 4]);
assert_eq!(a.len(), 4);
assert!(a.contains(&1));
assert!(a.contains(&2));
assert!(a.contains(&3));
assert!(a.contains(&4));
let mut b = Set::new();
b.insert(5);
b.insert(6);
a.extend(&b);
assert_eq!(a.len(), 6);
assert!(a.contains(&1));
assert!(a.contains(&2));
assert!(a.contains(&3));
assert!(a.contains(&4));
assert!(a.contains(&5));
assert!(a.contains(&6));
}
#[test]
fn test_retain() {
let xs = [1, 2, 3, 4, 5, 6];
let mut set: Set<i32> = xs.iter().cloned().collect();
set.retain(|&k| k % 2 == 0);
assert_eq!(set.len(), 3);
assert!(set.contains(&2));
assert!(set.contains(&4));
assert!(set.contains(&6));
}
#[test]
fn test_borsh_roundtrip() {
let mut s = Set::new();
s.insert("that".to_string());
s.insert("the other thing".to_string());
s.insert("this".to_string());
let s2: Set<String> = BorshDeserialize::try_from_slice(&s.try_to_vec().unwrap()).unwrap();
assert_eq!(s2, s);
}
#[test]
fn test_borsh_nonunique() {
let mut not_set = Vec::new();
not_set.push("a".to_string());
not_set.push("a".to_string());
not_set.push("b".to_string());
let s2: Set<String> =
BorshDeserialize::try_from_slice(¬_set.try_to_vec().unwrap()).unwrap();
assert_eq!(s2, not_set.iter().cloned().collect());
}
}