use alloc::vec::Vec;
use core::{
borrow::Borrow,
fmt::{self, Debug},
iter::FusedIterator,
slice::Iter,
};
#[derive(Clone, PartialEq, Eq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[cfg_attr(feature = "serde", serde(transparent))]
pub struct Set<T> {
backing: Vec<T>,
}
impl<T> Default for Set<T> {
fn default() -> Self {
Self {
backing: Vec::default(),
}
}
}
impl<T: Eq> Set<T> {
pub fn new() -> Self {
Self {
backing: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
backing: Vec::with_capacity(capacity),
}
}
pub fn capacity(&self) -> usize {
self.backing.capacity()
}
pub fn clear(&mut self) {
self.backing.clear()
}
pub fn contains<Q>(&self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: Eq + ?Sized,
{
self.backing.iter().any(|v| value.eq(v.borrow()))
}
pub fn difference<'a>(&'a self, other: &'a Self) -> Difference<'a, T> {
Difference {
iter: self.iter(),
other,
}
}
pub fn drain(&mut self) -> alloc::vec::Drain<T> {
self.backing.drain(..)
}
pub fn get<Q>(&self, value: &Q) -> Option<&T>
where
T: Borrow<Q>,
Q: Eq + ?Sized,
{
self.backing.iter().find(|v| value.eq((*v).borrow()))
}
pub fn get_or_insert(&mut self, value: T) -> &T {
let self_ptr = self as *mut Self;
if let Some(value) = self.get(&value) {
return value;
}
unsafe { (*self_ptr).backing.push(value) };
self.backing.last().unwrap()
}
pub fn get_or_insert_with<Q>(&mut self, value: &Q, f: impl FnOnce(&Q) -> T) -> &T
where
T: Borrow<Q>,
Q: Eq + ?Sized,
{
let self_ptr = self as *mut Self;
if let Some(value) = self.get(value) {
return value;
}
unsafe { (*self_ptr).backing.push(f(value)) };
self.backing.last().unwrap()
}
pub fn insert(&mut self, value: T) -> bool {
!self.backing.iter().any(|v| *v == value) && {
self.backing.push(value);
true
}
}
pub fn intersection<'a>(&'a self, other: &'a Self) -> Intersection<'a, T> {
Intersection {
iter: self.iter(),
other,
}
}
pub fn is_disjoint<'a>(&'a self, other: &'a Self) -> bool {
self.intersection(other).count() == 0
}
pub fn is_empty(&self) -> bool {
self.backing.is_empty()
}
pub fn is_subset(&self, other: &Self) -> bool {
self.len() <= other.len() && self.difference(other).count() == 0
}
pub fn is_superset(&self, other: &Self) -> bool {
other.is_subset(self)
}
pub fn iter(&self) -> Iter<T> {
self.backing.iter()
}
pub fn len(&self) -> usize {
self.backing.len()
}
pub fn remove<Q>(&mut self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: Eq + ?Sized,
{
self.take(value).is_some()
}
pub fn replace(&mut self, value: T) -> Option<T> {
match self.backing.iter_mut().find(|v| **v == value) {
Some(v) => Some(core::mem::replace(v, value)),
None => {
self.backing.push(value);
None
}
}
}
pub fn reserve(&mut self, additional: usize) {
self.backing.reserve(additional)
}
pub fn retain(&mut self, f: impl FnMut(&T) -> bool) {
self.backing.retain(f);
}
pub fn shrink_to_fit(&mut self) {
self.backing.shrink_to_fit()
}
pub fn symmetric_difference<'a>(&'a self, other: &'a Self) -> SymmetricDifference<'a, T> {
SymmetricDifference {
iter: self.difference(other).chain(other.difference(self)),
}
}
pub fn take<Q>(&mut self, value: &Q) -> Option<T>
where
T: Borrow<Q>,
Q: Eq + ?Sized,
{
self.backing
.iter()
.position(|v| value.eq(v.borrow()))
.map(|pos| self.backing.swap_remove(pos))
}
pub fn union<'a>(&'a self, other: &'a Self) -> Union<'a, T> {
Union {
iter: self.iter().chain(other.difference(self)),
}
}
pub fn try_reserve(
&mut self,
additional: usize,
) -> Result<(), alloc::collections::TryReserveError> {
self.backing.try_reserve(additional)
}
pub fn shrink_to(&mut self, min_capacity: usize) {
self.backing.shrink_to(min_capacity)
}
}
impl<T: Debug> fmt::Debug for Set<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_set().entries(self.backing.iter()).finish()
}
}
impl<'a, T> IntoIterator for &'a Set<T> {
type Item = &'a T;
type IntoIter = core::slice::Iter<'a, T>;
fn into_iter(self) -> <Self as IntoIterator>::IntoIter {
self.backing.iter()
}
}
impl<'a, T> IntoIterator for &'a mut Set<T> {
type Item = &'a mut T;
type IntoIter = core::slice::IterMut<'a, T>;
fn into_iter(self) -> <Self as IntoIterator>::IntoIter {
self.backing.iter_mut()
}
}
impl<T> IntoIterator for Set<T> {
type Item = T;
type IntoIter = alloc::vec::IntoIter<T>;
fn into_iter(self) -> <Self as IntoIterator>::IntoIter {
self.backing.into_iter()
}
}
impl<T: Eq> FromIterator<T> for Set<T> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let iter = iter.into_iter();
let mut this = match iter.size_hint() {
(min, Some(max)) if min > 0 && min == max => {
Self::with_capacity(min)
}
(min, Some(_)) | (min, None) if min > 0 => {
Self::with_capacity(min)
}
(_, _) => {
Self::new()
}
};
this.extend(iter);
this.shrink_to_fit();
this
}
}
impl<T: Eq> Extend<T> for Set<T> {
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for item in iter {
self.insert(item);
}
}
}
impl<'a, T: 'a + Copy + Eq> Extend<&'a T> for Set<T> {
fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
for item in iter {
self.insert(*item);
}
}
}
impl<V: Eq, T: Into<Vec<V>>> From<T> for Set<V> {
fn from(values: T) -> Self {
let values = values.into();
let mut map = Self::with_capacity(values.len());
map.extend(values);
map.shrink_to_fit();
map
}
}
impl<T: Clone + Eq> core::ops::BitOr<&Set<T>> for &Set<T> {
type Output = Set<T>;
fn bitor(self, rhs: &Set<T>) -> Set<T> {
self.union(rhs).cloned().collect()
}
}
impl<T: Clone + Eq> core::ops::BitAnd<&Set<T>> for &Set<T> {
type Output = Set<T>;
fn bitand(self, rhs: &Set<T>) -> Set<T> {
self.intersection(rhs).cloned().collect()
}
}
impl<T: Clone + Eq> core::ops::BitXor<&Set<T>> for &Set<T> {
type Output = Set<T>;
fn bitxor(self, rhs: &Set<T>) -> Set<T> {
self.symmetric_difference(rhs).cloned().collect()
}
}
impl<T: Clone + Eq> core::ops::Sub<&Set<T>> for &Set<T> {
type Output = Set<T>;
fn sub(self, rhs: &Set<T>) -> Set<T> {
self.difference(rhs).cloned().collect()
}
}
#[derive(Debug, Clone)]
pub struct Difference<'a, T> {
iter: core::slice::Iter<'a, T>,
other: &'a Set<T>,
}
impl<'a, T> Iterator for Difference<'a, T>
where
T: Eq,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
loop {
let elt = self.iter.next()?;
if !self.other.contains(elt) {
return Some(elt);
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (_, upper) = self.iter.size_hint();
(0, upper)
}
}
impl<T> DoubleEndedIterator for Difference<'_, T>
where
T: Eq,
{
fn next_back(&mut self) -> Option<Self::Item> {
loop {
let elt = self.iter.next_back()?;
if !self.other.contains(elt) {
return Some(elt);
}
}
}
}
impl<T> FusedIterator for Difference<'_, T> where T: Eq {}
#[derive(Debug, Clone)]
pub struct Intersection<'a, T> {
iter: core::slice::Iter<'a, T>,
other: &'a Set<T>,
}
impl<'a, T> Iterator for Intersection<'a, T>
where
T: Eq,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
loop {
let elt = self.iter.next()?;
if self.other.contains(elt) {
return Some(elt);
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (_, upper) = self.iter.size_hint();
(0, upper)
}
}
impl<T> DoubleEndedIterator for Intersection<'_, T>
where
T: Eq,
{
fn next_back(&mut self) -> Option<Self::Item> {
loop {
let elt = self.iter.next_back()?;
if self.other.contains(elt) {
return Some(elt);
}
}
}
}
impl<T> FusedIterator for Intersection<'_, T> where T: Eq {}
#[derive(Debug, Clone)]
pub struct SymmetricDifference<'a, T> {
iter: core::iter::Chain<Difference<'a, T>, Difference<'a, T>>,
}
impl<'a, T> Iterator for SymmetricDifference<'a, T>
where
T: Eq,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (_, upper) = self.iter.size_hint();
(0, upper)
}
}
impl<T> DoubleEndedIterator for SymmetricDifference<'_, T>
where
T: Eq,
{
fn next_back(&mut self) -> Option<Self::Item> {
self.iter.next_back()
}
}
impl<T> FusedIterator for SymmetricDifference<'_, T> where T: Eq {}
#[derive(Debug, Clone)]
pub struct Union<'a, T> {
iter: core::iter::Chain<Iter<'a, T>, Difference<'a, T>>,
}
impl<'a, T> Iterator for Union<'a, T>
where
T: Eq,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next()
}
fn size_hint(&self) -> (usize, Option<usize>) {
let (_, upper) = self.iter.size_hint();
(0, upper)
}
}
impl<T> DoubleEndedIterator for Union<'_, T>
where
T: Eq,
{
fn next_back(&mut self) -> Option<Self::Item> {
self.iter.next_back()
}
}
impl<T> FusedIterator for Union<'_, T> where T: Eq {}
#[cfg(test)]
mod test_set {
use super::Set;
#[test]
fn test_zero_capacities() {
type S = Set<i32>;
let s = S::new();
assert_eq!(s.capacity(), 0);
let s = S::default();
assert_eq!(s.capacity(), 0);
let s = S::with_capacity(0);
assert_eq!(s.capacity(), 0);
let mut s = S::new();
s.insert(1);
s.insert(2);
s.remove(&1);
s.remove(&2);
s.shrink_to_fit();
assert_eq!(s.capacity(), 0);
let mut s = S::new();
s.reserve(0);
assert_eq!(s.capacity(), 0);
}
#[test]
fn test_disjoint() {
let mut xs = Set::new();
let mut ys = Set::new();
assert!(xs.is_disjoint(&ys));
assert!(ys.is_disjoint(&xs));
assert!(xs.insert(5));
assert!(ys.insert(11));
assert!(xs.is_disjoint(&ys));
assert!(ys.is_disjoint(&xs));
assert!(xs.insert(7));
assert!(xs.insert(19));
assert!(xs.insert(4));
assert!(ys.insert(2));
assert!(ys.insert(-11));
assert!(xs.is_disjoint(&ys));
assert!(ys.is_disjoint(&xs));
assert!(ys.insert(7));
assert!(!xs.is_disjoint(&ys));
assert!(!ys.is_disjoint(&xs));
}
#[test]
fn test_subset_and_superset() {
let mut a = Set::new();
assert!(a.insert(0));
assert!(a.insert(5));
assert!(a.insert(11));
assert!(a.insert(7));
let mut b = Set::new();
assert!(b.insert(0));
assert!(b.insert(7));
assert!(b.insert(19));
assert!(b.insert(250));
assert!(b.insert(11));
assert!(b.insert(200));
assert!(!a.is_subset(&b));
assert!(!a.is_superset(&b));
assert!(!b.is_subset(&a));
assert!(!b.is_superset(&a));
assert!(b.insert(5));
assert!(a.is_subset(&b));
assert!(!a.is_superset(&b));
assert!(!b.is_subset(&a));
assert!(b.is_superset(&a));
}
#[test]
fn test_iterate() {
let mut a = Set::new();
for i in 0..32 {
assert!(a.insert(i));
}
assert_eq!(a.len(), 32);
let mut observed: u32 = 0;
for k in a {
observed |= 1 << k;
}
assert_eq!(observed, 0xFFFF_FFFF);
}
#[test]
fn test_iterate_ref() {
let a = Set::from_iter(0..32);
assert_eq!(a.len(), 32);
let mut observed: u32 = 0;
for &k in &a {
observed |= 1 << k;
}
assert_eq!(observed, 0xFFFF_FFFF);
}
#[test]
fn test_iterate_mut() {
let mut a: Set<_> = (0..32).collect();
assert_eq!(a.len(), 32);
let mut observed: u32 = 0;
for &mut k in &mut a {
observed |= 1 << k;
}
assert_eq!(observed, 0xFFFF_FFFF);
}
#[test]
fn test_intersection() {
let mut a = Set::new();
let mut b = Set::new();
assert!(a.intersection(&b).next().is_none());
assert!(a.insert(11));
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(77));
assert!(a.insert(103));
assert!(a.insert(5));
assert!(a.insert(-5));
assert!(b.insert(2));
assert!(b.insert(11));
assert!(b.insert(77));
assert!(b.insert(-9));
assert!(b.insert(-42));
assert!(b.insert(5));
assert!(b.insert(3));
let mut i = 0;
let expected = [3, 5, 11, 77];
for x in a.intersection(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
assert!(a.insert(9)); i = 0;
for x in a.intersection(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
i = 0;
for x in b.intersection(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
fn test_difference() {
let mut a = Set::new();
let mut b = Set::new();
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(5));
assert!(a.insert(9));
assert!(a.insert(11));
assert!(b.insert(3));
assert!(b.insert(9));
let mut i = 0;
let expected = [1, 5, 11];
for x in a.difference(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
fn test_symmetric_difference() {
let mut a = Set::new();
let mut b = Set::new();
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(5));
assert!(a.insert(9));
assert!(a.insert(11));
assert!(b.insert(-2));
assert!(b.insert(3));
assert!(b.insert(9));
assert!(b.insert(14));
assert!(b.insert(22));
let mut i = 0;
let expected = [-2, 1, 5, 11, 14, 22];
for x in a.symmetric_difference(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
fn test_union() {
let mut a = Set::new();
let mut b = Set::new();
assert!(a.union(&b).next().is_none());
assert!(b.union(&a).next().is_none());
assert!(a.insert(1));
assert!(a.insert(3));
assert!(a.insert(11));
assert!(a.insert(16));
assert!(a.insert(19));
assert!(a.insert(24));
assert!(b.insert(-2));
assert!(b.insert(1));
assert!(b.insert(5));
assert!(b.insert(9));
assert!(b.insert(13));
assert!(b.insert(19));
let mut i = 0;
let expected = [-2, 1, 3, 5, 9, 11, 13, 16, 19, 24];
for x in a.union(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
assert!(a.insert(9)); assert!(a.insert(5));
i = 0;
for x in a.union(&b) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
i = 0;
for x in b.union(&a) {
assert!(expected.contains(x));
i += 1
}
assert_eq!(i, expected.len());
}
#[test]
fn test_from_iter() {
let xs = [1, 2, 2, 3, 4, 5, 6, 7, 8, 9];
let set: Set<_> = xs.iter().cloned().collect();
for x in &xs {
assert!(set.contains(x));
}
assert_eq!(set.iter().len(), xs.len() - 1);
}
#[test]
fn test_move_iter() {
let hs = {
let mut hs = Set::new();
hs.insert('a');
hs.insert('b');
hs
};
let v = hs.into_iter().collect::<Vec<char>>();
assert!(v == ['a', 'b'] || v == ['b', 'a']);
}
#[test]
fn test_eq() {
let mut s1 = Set::new();
s1.insert(1);
s1.insert(2);
s1.insert(3);
let mut s2 = Set::new();
s2.insert(1);
s2.insert(2);
assert!(s1 != s2);
s2.insert(3);
assert_eq!(s1, s2);
}
#[test]
fn test_show() {
let mut set = Set::new();
let empty = Set::<i32>::new();
set.insert(1);
set.insert(2);
let set_str = format!("{:?}", set);
assert!(set_str == "{1, 2}" || set_str == "{2, 1}");
assert_eq!(format!("{:?}", empty), "{}");
}
#[test]
fn test_trivial_drain() {
let mut s = Set::<i32>::new();
for _ in s.drain() {}
assert!(s.is_empty());
drop(s);
let mut s = Set::<i32>::new();
drop(s.drain());
assert!(s.is_empty());
}
#[test]
fn test_drain() {
let mut s: Set<_> = (1..100).collect();
for _ in 0..20 {
assert_eq!(s.len(), 99);
{
let mut last_i = 0;
let mut d = s.drain();
for (i, x) in d.by_ref().take(50).enumerate() {
last_i = i;
assert!(x != 0);
}
assert_eq!(last_i, 49);
}
assert_eq!(s.iter().next(), None, "s should be empty!");
s.extend(1..100);
}
}
#[test]
fn test_replace() {
#[derive(Debug)]
struct Foo(&'static str, i32);
impl PartialEq for Foo {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl Eq for Foo {}
let mut s = Set::new();
assert_eq!(s.replace(Foo("a", 1)), None);
assert_eq!(s.len(), 1);
assert_eq!(s.replace(Foo("a", 2)), Some(Foo("a", 1)));
assert_eq!(s.len(), 1);
let mut it = s.iter();
assert_eq!(it.next(), Some(&Foo("a", 2)));
assert_eq!(it.next(), None);
}
#[test]
fn test_extend_ref() {
let mut a = Set::new();
a.insert(1);
a.extend(&[2, 3, 4]);
assert_eq!(a.len(), 4);
assert!(a.contains(&1));
assert!(a.contains(&2));
assert!(a.contains(&3));
assert!(a.contains(&4));
let mut b = Set::new();
b.insert(5);
b.insert(6);
a.extend(&b);
assert_eq!(a.len(), 6);
assert!(a.contains(&1));
assert!(a.contains(&2));
assert!(a.contains(&3));
assert!(a.contains(&4));
assert!(a.contains(&5));
assert!(a.contains(&6));
}
#[test]
fn test_retain() {
let xs = [1, 2, 3, 4, 5, 6];
let mut set: Set<i32> = xs.iter().cloned().collect();
set.retain(|&k| k % 2 == 0);
assert_eq!(set.len(), 3);
assert!(set.contains(&2));
assert!(set.contains(&4));
assert!(set.contains(&6));
}
#[test]
fn test_default() {
struct NoDefault;
let _: Vec<NoDefault> = Default::default();
let _: Set<NoDefault> = Default::default();
}
#[test]
fn test_from_into_vec() {
#[allow(clippy::useless_conversion)]
let _: Vec<()> = vec![()].into();
let _: Set<()> = vec![()].into();
let _: Vec<()> = [()].into();
let _: Set<()> = [()].into();
let expected: Set<char> = ['a', 'b'].iter().copied().collect();
let actual: Set<char> = ['a', 'b', 'a'].into();
assert_eq!(expected, actual, "Values should be de-duped");
}
}