use core::{
borrow::Borrow,
fmt,
hash::{BuildHasher, Hash},
};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
use hash32::{BuildHasherDefault, FnvHasher};
use crate::index_map::{self, IndexMap};
pub type FnvIndexSet<T, const N: usize> = IndexSet<T, BuildHasherDefault<FnvHasher>, N>;
#[cfg_attr(feature = "zeroize", derive(Zeroize), zeroize(bound = "T: Zeroize"))]
pub struct IndexSet<T, S, const N: usize> {
map: IndexMap<T, (), S, N>,
}
impl<T, S, const N: usize> IndexSet<T, BuildHasherDefault<S>, N> {
pub const fn new() -> Self {
Self {
map: IndexMap::new(),
}
}
}
impl<T, S, const N: usize> IndexSet<T, S, N> {
pub fn capacity(&self) -> usize {
self.map.capacity()
}
pub fn iter(&self) -> Iter<'_, T> {
Iter {
iter: self.map.iter(),
}
}
pub fn first(&self) -> Option<&T> {
self.map.first().map(|(k, _v)| k)
}
pub fn last(&self) -> Option<&T> {
self.map.last().map(|(k, _v)| k)
}
pub fn len(&self) -> usize {
self.map.len()
}
pub fn is_empty(&self) -> bool {
self.map.is_empty()
}
pub fn is_full(&self) -> bool {
self.map.is_full()
}
pub fn clear(&mut self) {
self.map.clear();
}
}
impl<T, S, const N: usize> IndexSet<T, S, N>
where
T: Eq + Hash,
S: BuildHasher,
{
pub fn difference<'a, S2, const N2: usize>(
&'a self,
other: &'a IndexSet<T, S2, N2>,
) -> Difference<'a, T, S2, N2>
where
S2: BuildHasher,
{
Difference {
iter: self.iter(),
other,
}
}
pub fn symmetric_difference<'a, S2, const N2: usize>(
&'a self,
other: &'a IndexSet<T, S2, N2>,
) -> impl Iterator<Item = &'a T>
where
S2: BuildHasher,
{
self.difference(other).chain(other.difference(self))
}
pub fn intersection<'a, S2, const N2: usize>(
&'a self,
other: &'a IndexSet<T, S2, N2>,
) -> Intersection<'a, T, S2, N2>
where
S2: BuildHasher,
{
Intersection {
iter: self.iter(),
other,
}
}
pub fn union<'a, S2, const N2: usize>(
&'a self,
other: &'a IndexSet<T, S2, N2>,
) -> impl Iterator<Item = &'a T>
where
S2: BuildHasher,
{
self.iter().chain(other.difference(self))
}
pub fn contains<Q>(&self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: ?Sized + Eq + Hash,
{
self.map.contains_key(value)
}
pub fn is_disjoint<S2, const N2: usize>(&self, other: &IndexSet<T, S2, N2>) -> bool
where
S2: BuildHasher,
{
self.iter().all(|v| !other.contains(v))
}
pub fn is_subset<S2, const N2: usize>(&self, other: &IndexSet<T, S2, N2>) -> bool
where
S2: BuildHasher,
{
self.iter().all(|v| other.contains(v))
}
pub fn is_superset<S2, const N2: usize>(&self, other: &IndexSet<T, S2, N2>) -> bool
where
S2: BuildHasher,
{
other.is_subset(self)
}
pub fn insert(&mut self, value: T) -> Result<bool, T> {
self.map
.insert(value, ())
.map(|old| old.is_none())
.map_err(|(k, _)| k)
}
pub fn remove<Q>(&mut self, value: &Q) -> bool
where
T: Borrow<Q>,
Q: ?Sized + Eq + Hash,
{
self.map.remove(value).is_some()
}
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&T) -> bool,
{
self.map.retain(move |k, _| f(k));
}
}
impl<T, S, const N: usize> Clone for IndexSet<T, S, N>
where
T: Clone,
S: Clone,
{
fn clone(&self) -> Self {
Self {
map: self.map.clone(),
}
}
}
impl<T, S, const N: usize> fmt::Debug for IndexSet<T, S, N>
where
T: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_set().entries(self.iter()).finish()
}
}
impl<T, S, const N: usize> Default for IndexSet<T, S, N>
where
S: Default,
{
fn default() -> Self {
Self {
map: <_>::default(),
}
}
}
impl<T, S1, S2, const N1: usize, const N2: usize> PartialEq<IndexSet<T, S2, N2>>
for IndexSet<T, S1, N1>
where
T: Eq + Hash,
S1: BuildHasher,
S2: BuildHasher,
{
fn eq(&self, other: &IndexSet<T, S2, N2>) -> bool {
self.len() == other.len() && self.is_subset(other)
}
}
impl<T, S, const N: usize> Extend<T> for IndexSet<T, S, N>
where
T: Eq + Hash,
S: BuildHasher,
{
fn extend<I>(&mut self, iterable: I)
where
I: IntoIterator<Item = T>,
{
self.map.extend(iterable.into_iter().map(|k| (k, ())));
}
}
impl<'a, T, S, const N: usize> Extend<&'a T> for IndexSet<T, S, N>
where
T: 'a + Eq + Hash + Copy,
S: BuildHasher,
{
fn extend<I>(&mut self, iterable: I)
where
I: IntoIterator<Item = &'a T>,
{
self.extend(iterable.into_iter().cloned());
}
}
impl<T, S, const N: usize> FromIterator<T> for IndexSet<T, S, N>
where
T: Eq + Hash,
S: BuildHasher + Default,
{
fn from_iter<I>(iter: I) -> Self
where
I: IntoIterator<Item = T>,
{
let mut set = Self::default();
set.extend(iter);
set
}
}
impl<'a, T, S, const N: usize> IntoIterator for &'a IndexSet<T, S, N>
where
T: Eq + Hash,
S: BuildHasher,
{
type Item = &'a T;
type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
pub struct Iter<'a, T> {
iter: index_map::Iter<'a, T, ()>,
}
impl<'a, T> Iterator for Iter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
self.iter.next().map(|(k, _)| k)
}
}
impl<T> Clone for Iter<'_, T> {
fn clone(&self) -> Self {
Self {
iter: self.iter.clone(),
}
}
}
pub struct Difference<'a, T, S, const N: usize>
where
S: BuildHasher,
T: Eq + Hash,
{
iter: Iter<'a, T>,
other: &'a IndexSet<T, S, N>,
}
impl<'a, T, S, const N: usize> Iterator for Difference<'a, T, S, N>
where
S: BuildHasher,
T: Eq + Hash,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
loop {
let elt = self.iter.next()?;
if !self.other.contains(elt) {
return Some(elt);
}
}
}
}
pub struct Intersection<'a, T, S, const N: usize>
where
S: BuildHasher,
T: Eq + Hash,
{
iter: Iter<'a, T>,
other: &'a IndexSet<T, S, N>,
}
impl<'a, T, S, const N: usize> Iterator for Intersection<'a, T, S, N>
where
S: BuildHasher,
T: Eq + Hash,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
loop {
let elt = self.iter.next()?;
if self.other.contains(elt) {
return Some(elt);
}
}
}
}
#[cfg(test)]
mod tests {
use static_assertions::assert_not_impl_any;
use super::{BuildHasherDefault, IndexSet};
assert_not_impl_any!(IndexSet<*const (), BuildHasherDefault<()>, 4>: Send);
#[test]
#[cfg(feature = "zeroize")]
fn test_index_set_zeroize() {
use zeroize::Zeroize;
let mut set: IndexSet<u8, BuildHasherDefault<hash32::FnvHasher>, 8> = IndexSet::new();
for i in 1..=8 {
set.insert(i).unwrap();
}
assert_eq!(set.len(), 8);
assert!(set.contains(&8));
set.zeroize();
assert_eq!(set.len(), 0);
assert!(set.is_empty());
set.insert(1).unwrap();
assert_eq!(set.len(), 1);
assert!(set.contains(&1));
}
}