#![warn(clippy::pedantic, rust_2018_idioms, missing_docs)]
use std::{
collections::{hash_map::RandomState, HashSet},
fmt::Debug,
hash::{BuildHasher, Hash},
marker::PhantomData,
mem::ManuallyDrop,
};
use mut_guard::MutGuard;
use value_wrapper::ValueWrapper;
use with_size_hint::IteratorExt as _;
#[doc(hidden)]
pub mod iter;
mod mut_guard;
mod value_wrapper;
mod with_size_hint;
#[cfg(feature = "iter_mut")]
pub use gat_lending_iterator::LendingIterator;
pub trait ExtractKey<K: Hash + Eq> {
fn extract_key(&self) -> &K;
}
#[cfg_attr(feature = "typesize", derive(typesize::TypeSize))]
pub struct ExtractMap<K, V, S = RandomState> {
inner: HashSet<ValueWrapper<K, V>, S>,
}
impl<K, V> Default for ExtractMap<K, V, RandomState> {
fn default() -> Self {
Self::new()
}
}
impl<K, V> ExtractMap<K, V, RandomState> {
#[must_use]
pub fn new() -> Self {
Self::with_hasher(RandomState::new())
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self::with_capacity_and_hasher(capacity, RandomState::new())
}
}
impl<K, V, S> ExtractMap<K, V, S> {
#[must_use]
pub fn with_hasher(hasher: S) -> Self {
Self {
inner: HashSet::with_hasher(hasher),
}
}
#[must_use]
pub fn with_capacity_and_hasher(capacity: usize, hasher: S) -> Self {
Self {
inner: HashSet::with_capacity_and_hasher(capacity, hasher),
}
}
}
impl<K, V, S> ExtractMap<K, V, S>
where
K: Hash + Eq,
V: ExtractKey<K>,
S: BuildHasher,
{
pub fn insert(&mut self, value: V) -> Option<V> {
self.inner
.replace(ValueWrapper(value, PhantomData))
.map(|v| v.0)
}
pub fn remove(&mut self, key: &K) -> Option<V> {
self.inner.take(key).map(|v| v.0)
}
#[must_use]
pub fn contains_key(&self, key: &K) -> bool {
self.inner.contains(key)
}
#[must_use]
pub fn get(&self, key: &K) -> Option<&V> {
self.inner.get(key).map(|v| &v.0)
}
#[must_use]
pub fn get_mut<'a>(&'a mut self, key: &K) -> Option<MutGuard<'a, K, V, S>> {
let value = self.inner.take(key)?;
Some(MutGuard {
value: ManuallyDrop::new(value.0),
map: self,
})
}
}
impl<K, V, S> ExtractMap<K, V, S> {
#[must_use]
pub fn capacity(&self) -> usize {
self.inner.capacity()
}
#[must_use]
pub fn len(&self) -> usize {
self.inner.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.inner.is_empty()
}
pub fn iter(&self) -> iter::Iter<'_, K, V> {
self.into_iter()
}
}
#[cfg(feature = "iter_mut")]
impl<K, V, S> ExtractMap<K, V, S>
where
K: Hash + Eq + Clone,
V: ExtractKey<K>,
S: BuildHasher,
{
#[allow(clippy::iter_not_returning_iterator)]
pub fn iter_mut(&mut self) -> iter::IterMut<'_, K, V, S> {
iter::IterMut::new(self)
}
}
impl<K, V: Clone, S: Clone> Clone for ExtractMap<K, V, S> {
fn clone(&self) -> Self {
let inner = self.inner.clone();
Self { inner }
}
}
impl<K, V, S> Debug for ExtractMap<K, V, S>
where
K: Debug + Hash + Eq,
V: Debug + ExtractKey<K>,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_map()
.entries(self.iter().map(|v| (v.extract_key(), v)))
.finish()
}
}
impl<K, V, S> PartialEq for ExtractMap<K, V, S>
where
K: Hash + Eq,
V: ExtractKey<K> + PartialEq,
S: BuildHasher,
{
fn eq(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
self.iter().all(|v| {
let k = v.extract_key();
other.get(k).is_some_and(|other_v| {
let other_k = other_v.extract_key();
k == other_k && v == other_v
})
})
}
}
impl<K, V, S> FromIterator<V> for ExtractMap<K, V, S>
where
K: Hash + Eq,
V: ExtractKey<K>,
S: BuildHasher + Default,
{
fn from_iter<T: IntoIterator<Item = V>>(iter: T) -> Self {
let inner = iter
.into_iter()
.map(|item| ValueWrapper(item, PhantomData))
.collect();
Self { inner }
}
}
impl<K, V, S> Extend<V> for ExtractMap<K, V, S>
where
K: Hash + Eq,
V: ExtractKey<K>,
S: BuildHasher,
{
fn extend<T: IntoIterator<Item = V>>(&mut self, iter: T) {
for item in iter {
self.insert(item);
}
}
}
#[cfg(feature = "serde")]
impl<'de, K, V, S> serde::Deserialize<'de> for ExtractMap<K, V, S>
where
K: Hash + Eq,
V: ExtractKey<K> + serde::Deserialize<'de>,
S: BuildHasher + Default,
{
fn deserialize<D: serde::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
use serde::de::{IgnoredAny, MapAccess, SeqAccess};
struct Visitor<K, V, S>(PhantomData<(K, V, S)>);
impl<'de, K, V, S> serde::de::Visitor<'de> for Visitor<K, V, S>
where
K: Hash + Eq,
V: ExtractKey<K> + serde::Deserialize<'de>,
S: BuildHasher + Default,
{
type Value = ExtractMap<K, V, S>;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("a sequence")
}
fn visit_map<A: MapAccess<'de>>(self, mut map: A) -> Result<Self::Value, A::Error> {
let size_hint = map.size_hint();
std::iter::from_fn(|| map.next_entry::<IgnoredAny, V>().transpose())
.map(|res| res.map(|(_, v)| v))
.with_size_hint(size_hint)
.collect()
}
fn visit_seq<A: SeqAccess<'de>>(self, mut seq: A) -> Result<Self::Value, A::Error> {
let size_hint = seq.size_hint();
std::iter::from_fn(|| seq.next_element().transpose())
.with_size_hint(size_hint)
.collect()
}
}
deserializer.deserialize_any(Visitor(PhantomData))
}
}
#[cfg(feature = "serde")]
impl<K, V: serde::Serialize, H> serde::Serialize for ExtractMap<K, V, H> {
fn serialize<S: serde::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.collect_seq(self)
}
}
#[cfg(feature = "serde")]
pub fn serialize_as_map<K, V, H, S>(map: &ExtractMap<K, V, H>, ser: S) -> Result<S::Ok, S::Error>
where
K: serde::Serialize + Hash + Eq,
V: serde::Serialize + ExtractKey<K>,
S: serde::Serializer,
{
ser.collect_map(map.iter().map(|v| (v.extract_key(), v)))
}