use crate::SmallMap;
use std::borrow::Borrow;
use std::collections::{BTreeSet, HashSet};
use std::fmt::{self, Debug};
use std::hash::{BuildHasher, Hash};
use std::iter::FromIterator;
pub trait AnySet<T> {
fn contains(&self, value: &T) -> bool;
fn len(&self) -> usize;
fn is_empty(&self) -> bool {
self.len() == 0
}
}
impl<T, const N: usize> AnySet<T> for SmallSet<T, N>
where
T: Eq + Hash,
{
fn contains(&self, value: &T) -> bool {
SmallSet::contains(self, value)
}
fn len(&self) -> usize {
SmallSet::len(&self)
}
}
impl<T, S> AnySet<T> for HashSet<T, S>
where
T: Eq + Hash,
S: BuildHasher,
{
fn contains(&self, value: &T) -> bool {
self.contains(value)
}
fn len(&self) -> usize {
self.len()
}
}
impl<T> AnySet<T> for BTreeSet<T>
where
T: Ord,
{
fn contains(&self, value: &T) -> bool {
self.contains(value)
}
fn len(&self) -> usize {
self.len()
}
}
pub struct SmallSet<T, const N: usize> {
map: SmallMap<T, (), N>,
}
impl<T, const N: usize> SmallSet<T, N>
where
T: Eq + Hash,
{
pub fn new() -> Self {
Self {
map: SmallMap::new(),
}
}
#[inline]
pub fn is_on_stack(&self) -> bool {
self.map.is_on_stack()
}
#[inline]
pub fn len(&self) -> usize {
self.map.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
#[inline]
pub fn clear(&mut self) {
self.map.clear();
}
pub fn insert(&mut self, value: T) -> bool {
if self.map.contains_key(&value) {
false
} else {
self.map.insert(value, ());
true
}
}
pub fn contains<Q>(&self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.map.contains_key(value)
}
pub fn remove<Q>(&mut self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: Hash + Eq + ?Sized,
{
self.map.remove(value).is_some()
}
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> bool,
{
let old_map = std::mem::replace(&mut self.map, SmallMap::new());
for (k, _) in old_map {
if f(&k) {
self.map.insert(k, ());
}
}
}
pub fn iter(&self) -> SetRefIter<'_, T> {
SetRefIter {
iter: self.map.iter(),
}
}
pub fn difference<'a, S>(&'a self, other: &'a S) -> impl Iterator<Item = &'a T>
where
S: AnySet<T>,
{
self.iter().filter(move |v| !other.contains(v))
}
pub fn intersection<'a, S>(&'a self, other: &'a S) -> impl Iterator<Item = &'a T>
where
S: AnySet<T>,
{
self.iter().filter(move |v| other.contains(v))
}
pub fn union<'a, I>(&'a self, other: I) -> impl Iterator<Item = &'a T>
where
I: IntoIterator<Item = &'a T>,
I::IntoIter: 'a,
{
self.iter()
.chain(other.into_iter().filter(move |v| !self.contains(v)))
}
pub fn is_disjoint<S>(&self, other: &S) -> bool
where
S: AnySet<T>,
{
self.iter().all(|v| !other.contains(v))
}
pub fn is_subset<S>(&self, other: &S) -> bool
where
S: AnySet<T>,
{
self.iter().all(|v| other.contains(v))
}
pub fn is_superset<'a, I>(&self, other: I) -> bool
where
T: 'a,
I: IntoIterator<Item = &'a T>,
{
other.into_iter().all(|v| self.contains(v))
}
pub fn symmetric_difference<'a>(
&'a self,
other: &'a SmallSet<T, N>,
) -> impl Iterator<Item = &'a T> {
self.difference(other).chain(other.difference(self))
}
}
impl<T, const N: usize> Clone for SmallSet<T, N>
where
T: Eq + Hash + Clone,
{
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
}
}
}
impl<T: Eq + Hash, const N: usize> Default for SmallSet<T, N> {
fn default() -> Self {
Self::new()
}
}
impl<T: Debug + Eq + Hash, const N: usize> Debug for SmallSet<T, N> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}
impl<T: Eq + Hash, const N: usize> FromIterator<T> for SmallSet<T, N> {
fn from_iter<I: IntoIterator<Item = T>>(iter: I) -> Self {
let mut set = SmallSet::new();
for val in iter {
set.insert(val);
}
set
}
}
impl<T: Eq + Hash, const N: usize> IntoIterator for SmallSet<T, N> {
type Item = T;
type IntoIter = SmallSetIntoIter<T, N>;
fn into_iter(self) -> Self::IntoIter {
SmallSetIntoIter {
iter: self.map.into_iter(),
}
}
}
pub struct SmallSetIntoIter<T, const N: usize> {
iter: crate::SmallMapIntoIter<T, (), N>,
}
impl<T, const N: usize> Iterator for SmallSetIntoIter<T, N> {
type Item = T;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(k, _)| k)
}
}
impl<T, const N: usize> Extend<T> for SmallSet<T, N>
where
T: Eq + Hash + Clone,
{
fn extend<I: IntoIterator<Item = T>>(&mut self, iter: I) {
for item in iter {
self.insert(item);
}
}
}
impl<'a, T, const N: usize> Extend<&'a T> for SmallSet<T, N>
where
T: 'a + Eq + Hash + Clone + Copy,
{
fn extend<I: IntoIterator<Item = &'a T>>(&mut self, iter: I) {
for item in iter {
self.insert(*item);
}
}
}
impl<T, const N: usize, S> PartialEq<S> for SmallSet<T, N>
where
T: Eq + Hash + Clone,
S: AnySet<T>, {
fn eq(&self, other: &S) -> bool {
if self.len() != other.len() {
return false;
}
self.is_subset(other)
}
}
impl<T, const N: usize> Eq for SmallSet<T, N> where T: Eq + Hash + Clone {}
pub struct SetRefIter<'a, T> {
iter: crate::SmallMapIter<'a, T, ()>,
}
impl<'a, T: 'a> Iterator for SetRefIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(k, _)| k)
}
}
impl<'a, T, const N: usize> IntoIterator for &'a SmallSet<T, N>
where
T: Eq + Hash,
{
type Item = &'a T;
type IntoIter = SetRefIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::{BTreeSet, HashSet};
#[test]
fn test_set_stack_ops_basic() {
let mut set: SmallSet<i32, 4> = SmallSet::new();
assert!(set.is_empty());
assert_eq!(set.len(), 0);
assert!(set.is_on_stack());
assert!(set.insert(10));
assert!(set.insert(20));
assert_eq!(set.len(), 2);
assert!(set.contains(&10));
assert!(!set.contains(&99));
assert!(set.remove(&10));
assert!(!set.contains(&10));
assert_eq!(set.len(), 1);
set.clear();
assert!(set.is_empty());
assert!(set.is_on_stack()); }
#[test]
fn test_set_stack_duplicate_insertion() {
let mut set: SmallSet<String, 4> = SmallSet::new();
assert!(set.insert("A".to_string()));
assert_eq!(set.len(), 1);
assert!(!set.insert("A".to_string()));
assert_eq!(set.len(), 1); }
#[test]
fn test_set_spill_trigger_on_insert() {
let mut set: SmallSet<i32, 2> = SmallSet::new();
set.insert(1);
set.insert(2);
assert!(set.is_on_stack());
set.insert(3);
assert!(!set.is_on_stack());
assert_eq!(set.len(), 3);
assert!(set.contains(&1));
assert!(set.contains(&2));
assert!(set.contains(&3));
}
#[test]
fn test_set_any_storage_growth_on_heap() {
let mut set: SmallSet<i32, 2> = SmallSet::new();
for i in 0..100 {
set.insert(i);
}
assert!(!set.is_on_stack());
assert_eq!(set.len(), 100);
assert!(set.contains(&50));
}
#[test]
fn test_set_traits_iter() {
let set: SmallSet<i32, 4> = vec![1, 2, 3].into_iter().collect();
let collected: Vec<_> = set.iter().cloned().collect();
assert_eq!(collected.len(), 3);
assert!(collected.contains(&1));
assert!(collected.contains(&2));
assert!(collected.contains(&3));
}
#[test]
fn test_set_stack_into_iter() {
let mut set: SmallSet<i32, 4> = SmallSet::new();
set.insert(1);
set.insert(2);
let vec: Vec<i32> = set.into_iter().collect();
assert_eq!(vec.len(), 2);
assert!(vec.contains(&1));
assert!(vec.contains(&2));
}
#[test]
fn test_set_any_storage_into_iter_heap() {
let mut set: SmallSet<i32, 2> = SmallSet::new();
set.insert(1);
set.insert(2);
set.insert(3);
let vec: Vec<i32> = set.into_iter().collect();
assert_eq!(vec.len(), 3);
assert!(vec.contains(&1));
}
#[test]
fn test_set_any_set_difference() {
let a: SmallSet<i32, 4> = vec![1, 2, 3].into_iter().collect();
let b: SmallSet<i32, 4> = vec![3, 4, 5].into_iter().collect();
let diff: Vec<_> = a.difference(&b).cloned().collect();
assert_eq!(diff.len(), 2);
assert!(diff.contains(&1));
assert!(diff.contains(&2));
assert!(!diff.contains(&3));
}
#[test]
fn test_set_any_set_intersection() {
let a: SmallSet<i32, 4> = vec![1, 2, 3].into_iter().collect();
let b: SmallSet<i32, 4> = vec![2, 3, 4].into_iter().collect();
let int: Vec<_> = a.intersection(&b).cloned().collect();
assert_eq!(int.len(), 2);
assert!(int.contains(&2));
assert!(int.contains(&3));
assert!(!int.contains(&1));
}
#[test]
fn test_set_any_set_union() {
let a: SmallSet<i32, 4> = vec![1, 2].into_iter().collect();
let b: SmallSet<i32, 4> = vec![2, 3].into_iter().collect();
let u: Vec<_> = a.union(&b).cloned().collect();
assert_eq!(u.len(), 3);
assert!(u.contains(&1));
assert!(u.contains(&2));
assert!(u.contains(&3));
}
#[test]
fn test_set_any_set_disjoint() {
let a: SmallSet<i32, 4> = vec![1, 2].into_iter().collect();
let b: SmallSet<i32, 4> = vec![3, 4].into_iter().collect();
let c: SmallSet<i32, 4> = vec![2, 3].into_iter().collect();
assert!(a.is_disjoint(&b)); assert!(!a.is_disjoint(&c)); }
#[test]
fn test_set_any_set_subset() {
let sub: SmallSet<i32, 4> = vec![1, 2].into_iter().collect();
let sup: SmallSet<i32, 4> = vec![1, 2, 3].into_iter().collect();
assert!(sub.is_subset(&sup));
assert!(!sup.is_subset(&sub));
let empty: SmallSet<i32, 4> = SmallSet::new();
assert!(empty.is_subset(&sub));
}
#[test]
fn test_set_any_set_superset() {
let sup: SmallSet<i32, 4> = vec![1, 2, 3].into_iter().collect();
let sub_vec = vec![1, 2];
assert!(sup.is_superset(&sub_vec)); assert!(!sup.is_superset(&vec![1, 99])); }
#[test]
fn test_set_traits_interop_hashset() {
let small: SmallSet<i32, 4> = vec![1, 2].into_iter().collect();
let std_set: HashSet<i32> = vec![1, 2, 3].into_iter().collect();
assert!(small.is_subset(&std_set));
let diff: Vec<_> = small.difference(&std_set).collect();
assert!(diff.is_empty());
}
#[test]
fn test_set_traits_interop_btreeset() {
let small: SmallSet<i32, 4> = vec![1, 2].into_iter().collect();
let btree: BTreeSet<i32> = vec![2, 3].into_iter().collect();
let int: Vec<_> = small.intersection(&btree).cloned().collect();
assert_eq!(int, vec![2]);
}
#[test]
fn test_set_any_storage_retain() {
let mut set: SmallSet<i32, 4> = vec![1, 2, 3, 4, 5].into_iter().collect();
assert!(!set.is_on_stack());
set.retain(|x| x % 2 == 0);
assert_eq!(set.len(), 2);
assert!(set.contains(&2));
assert!(set.contains(&4));
assert!(!set.contains(&1));
}
#[test]
fn test_set_any_set_symmetric_difference() {
let a: SmallSet<i32, 4> = vec![1, 2, 3].into_iter().collect();
let b: SmallSet<i32, 4> = vec![3, 4, 5].into_iter().collect();
let sym: Vec<_> = a.symmetric_difference(&b).cloned().collect();
assert_eq!(sym.len(), 4);
assert!(sym.contains(&1));
assert!(sym.contains(&4));
assert!(!sym.contains(&3)); }
#[test]
fn test_set_traits_equality() {
let a: SmallSet<i32, 4> = vec![1, 2, 3].into_iter().collect();
let b: SmallSet<i32, 4> = vec![3, 2, 1].into_iter().collect(); let c: SmallSet<i32, 2> = vec![1, 2].into_iter().collect();
assert_eq!(a, b); assert_ne!(a, c);
}
#[test]
fn test_set_traits_extend() {
let mut set: SmallSet<i32, 4> = SmallSet::new();
set.insert(1);
let more = vec![2, 3, 4, 5]; set.extend(more);
assert_eq!(set.len(), 5);
assert!(!set.is_on_stack());
assert!(set.contains(&5));
}
#[test]
fn test_set_traits_clone() {
let mut a: SmallSet<i32, 4> = SmallSet::new();
a.insert(1);
let mut b = a.clone();
b.insert(2);
assert!(a.contains(&1));
assert!(!a.contains(&2)); assert!(b.contains(&1));
assert!(b.contains(&2));
}
#[test]
fn test_set_any_storage_clone_heap() {
let mut original: SmallSet<String, 4> = SmallSet::new();
original.insert("A".to_string());
original.insert("B".to_string());
let mut copy = original.clone();
copy.insert("C".to_string());
copy.remove("A");
assert!(original.contains("A"));
assert!(!original.contains("C"));
assert_eq!(original.len(), 2);
assert!(!copy.contains("A"));
assert!(copy.contains("C"));
assert_eq!(copy.len(), 2);
}
#[test]
fn test_set_traits_equality_different_capacities() {
let mut s1: SmallSet<i32, 4> = SmallSet::new();
let mut s2: SmallSet<i32, 8> = SmallSet::new();
s1.insert(1);
s1.insert(2);
s2.insert(2);
s2.insert(1);
assert_eq!(s1, s2);
s2.insert(3);
assert_ne!(s1, s2);
}
#[test]
fn test_set_traits_equality_interop() {
let mut small: SmallSet<i32, 4> = SmallSet::new();
small.insert(1);
small.insert(2);
let mut hash_set = HashSet::new();
hash_set.insert(1);
hash_set.insert(2);
assert_eq!(small, hash_set);
hash_set.insert(3);
assert_ne!(small, hash_set);
let mut btree_set = BTreeSet::new();
btree_set.insert(1);
btree_set.insert(2);
assert_eq!(small, btree_set); }
#[test]
fn test_set_any_storage_heap_remove() {
let mut set: SmallSet<i32, 2> = vec![1, 2, 3].into_iter().collect();
assert!(!set.is_on_stack());
assert!(set.remove(&2));
assert_eq!(set.len(), 2);
}
#[test]
fn test_set_any_storage_clone_heap_v2() {
let set: SmallSet<i32, 2> = vec![1, 2, 3].into_iter().collect();
let cloned = set.clone();
assert_eq!(cloned.len(), 3);
assert!(!cloned.is_on_stack());
}
#[test]
fn test_set_traits_debug_display() {
let set: SmallSet<i32, 2> = vec![1].into_iter().collect();
let debug = format!("{:?}", set);
assert!(debug.contains("1"));
}
#[test]
fn test_set_traits_any_set_impl() {
let set: SmallSet<i32, 2> = vec![1, 2].into_iter().collect();
let any: &dyn AnySet<i32> = &set;
assert_eq!(any.len(), 2);
assert!(any.contains(&1));
assert!(!any.is_empty());
}
#[test]
fn test_set_coverage_gaps() {
let set: SmallSet<i32, 4> = Default::default();
assert!(set.is_empty());
let mut set: SmallSet<i32, 4> = SmallSet::new();
let refs = vec![1, 2, 3];
set.extend(&refs);
assert_eq!(set.len(), 3);
assert!(set.contains(&2));
}
}