use crate::{optionals::*, utils::*};
use OptionalPair::*;
use hashbrown::{
hash_map::DefaultHashBuilder,
raw::{RawDrain, RawIter, RawTable},
TryReserveError,
};
use core::mem;
use std::{
borrow::Borrow,
default::Default,
fmt,
hash::{BuildHasher, Hash},
iter::FusedIterator,
marker::PhantomData,
};
#[cfg(doc)]
use hashbrown::HashMap;
pub(crate) struct MappingPair<T> {
pub(crate) value: T,
pub(crate) hash: u64,
pub(crate) id: u64,
}
pub(crate) fn equivalent_key<Q: PartialEq<K> + ?Sized, K>(
k: &Q,
) -> impl Fn(&MappingPair<K>) -> bool + '_ {
move |x| k.eq(&x.value)
}
pub(crate) fn hash_and_id<'a, Q: PartialEq + ?Sized>(
hash: u64,
id: u64,
) -> impl Fn(&MappingPair<Q>) -> bool + 'a {
move |x| id == x.id && hash == x.hash
}
pub(crate) fn just_id<'a, Q: PartialEq + ?Sized>(id: u64) -> impl Fn(&MappingPair<Q>) -> bool + 'a {
move |x| id == x.id
}
pub struct CycleMap<L, R, St = DefaultHashBuilder> {
pub(crate) hash_builder: St,
pub(crate) counter: u64,
left_set: RawTable<MappingPair<L>>,
right_set: RawTable<MappingPair<R>>,
}
impl<L, R> CycleMap<L, R, DefaultHashBuilder> {
#[inline]
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn with_capacity(capacity: usize) -> Self {
Self::with_capacity_and_hasher(capacity, DefaultHashBuilder::default())
}
}
impl<L, R, S> CycleMap<L, R, S>
where
L: Eq + Hash,
R: Eq + Hash,
S: BuildHasher,
{
pub fn insert(&mut self, left: L, right: R) -> InsertOptional<L, R> {
let opt_from_left = self.remove_via_left(&left);
let opt_from_right = self.remove_via_right(&right);
let digest = InsertOptional::from((opt_from_left, opt_from_right));
let l_hash = make_hash::<L, S>(&self.hash_builder, &left);
let r_hash = make_hash::<R, S>(&self.hash_builder, &right);
let left_pairing = MappingPair {
value: left,
hash: r_hash,
id: self.counter,
};
let right_pairing = MappingPair {
value: right,
hash: l_hash,
id: self.counter,
};
self.counter += 1;
self.left_set.insert(
l_hash,
left_pairing,
make_hasher::<MappingPair<L>, S>(&self.hash_builder),
);
self.right_set.insert(
r_hash,
right_pairing,
make_hasher::<MappingPair<R>, S>(&self.hash_builder),
);
digest
}
pub fn are_paired(&self, left: &L, right: &R) -> bool {
let l_hash = make_hash::<L, S>(&self.hash_builder, left);
let r_hash = make_hash::<R, S>(&self.hash_builder, right);
let opt_left = self.left_set.get(l_hash, equivalent_key(left));
let opt_right = self.right_set.get(r_hash, equivalent_key(right));
match (opt_left, opt_right) {
(Some(left), Some(right)) => {
left.id == right.id && l_hash == right.hash && r_hash == left.hash
}
_ => false,
}
}
pub fn contains_left(&self, left: &L) -> bool {
let hash = make_hash::<L, S>(&self.hash_builder, left);
self.left_set.get(hash, equivalent_key(left)).is_some()
}
pub fn contains_right(&self, right: &R) -> bool {
let hash = make_hash::<R, S>(&self.hash_builder, right);
self.right_set.get(hash, equivalent_key(right)).is_some()
}
pub fn remove(&mut self, left: &L, right: &R) -> Option<(L, R)> {
if self.are_paired(left, right) {
self.remove_via_left(left)
} else {
None
}
}
pub fn remove_via_left(&mut self, item: &L) -> Option<(L, R)> {
let l_hash = make_hash::<L, S>(&self.hash_builder, item);
let left_pairing: MappingPair<L> =
self.left_set.remove_entry(l_hash, equivalent_key(item))?;
let right_pairing = self
.right_set
.remove_entry(left_pairing.hash, hash_and_id(l_hash, left_pairing.id))
.unwrap();
Some((left_pairing.extract(), right_pairing.extract()))
}
pub fn remove_via_right(&mut self, item: &R) -> Option<(L, R)> {
let r_hash = make_hash::<R, S>(&self.hash_builder, item);
let right_pairing: MappingPair<R> =
self.right_set.remove_entry(r_hash, equivalent_key(item))?;
let left_pairing = self
.left_set
.remove_entry(right_pairing.hash, hash_and_id(r_hash, right_pairing.id))
.unwrap();
Some((left_pairing.extract(), right_pairing.extract()))
}
fn remove_via_hashes_and_id(&mut self, l_hash: u64, r_hash: u64, id: u64) -> Option<(L, R)> {
let left_pairing = self
.left_set
.remove_entry(l_hash, hash_and_id(r_hash, id))?;
let right_pairing = self
.right_set
.remove_entry(r_hash, hash_and_id(l_hash, id))
.unwrap();
Some((left_pairing.extract(), right_pairing.extract()))
}
pub fn swap_left(&mut self, old: &L, new: L) -> OptionalPair<L, (L, R)> {
let new_l_hash = make_hash::<L, S>(&self.hash_builder, &new);
let eq_opt = self.swap_left_eq_check(old, &new, new_l_hash);
let old_l_hash = make_hash::<L, S>(&self.hash_builder, old);
let l_pairing: &MappingPair<L> = match self.left_set.get(old_l_hash, equivalent_key(old)) {
Some(p) => p,
None => {
return Neither;
}
};
let r_pairing: &mut MappingPair<R> = self
.right_set
.get_mut(l_pairing.hash, hash_and_id(old_l_hash, l_pairing.id))
.unwrap();
r_pairing.hash = new_l_hash;
let new_left_pairing: MappingPair<L> = MappingPair {
value: new,
hash: l_pairing.hash,
id: l_pairing.id,
};
let old_left_item: L = self
.left_set
.remove_entry(old_l_hash, equivalent_key(old))
.unwrap()
.extract();
self.left_set.insert(
new_l_hash,
new_left_pairing,
make_hasher::<MappingPair<L>, S>(&self.hash_builder),
);
OptionalPair::from((Some(old_left_item), eq_opt))
}
pub fn swap_left_checked(&mut self, old: &L, expected: &R, new: L) -> OptionalPair<L, (L, R)> {
if !self.are_paired(old, expected) {
return Neither;
}
self.swap_left(old, new)
}
pub fn swap_left_or_insert(
&mut self,
old: &L,
new: L,
to_insert: R,
) -> OptionalPair<L, (L, R)> {
let old_l_hash = make_hash::<L, S>(&self.hash_builder, old);
if self.left_set.get(old_l_hash, equivalent_key(old)).is_some() {
self.swap_left(old, new)
} else {
self.insert(new, to_insert).map_left(|(l, _)| l)
}
}
fn swap_left_eq_check(&mut self, old: &L, new: &L, new_hash: u64) -> Option<(L, R)> {
self.left_set.get(new_hash, equivalent_key(new))?;
if new != old {
self.remove_via_left(new)
} else {
None
}
}
pub fn swap_right(&mut self, old: &R, new: R) -> OptionalPair<R, (L, R)> {
let new_r_hash = make_hash::<R, S>(&self.hash_builder, &new);
let eq_opt = self.swap_right_eq_check(old, &new, new_r_hash);
let old_r_hash = make_hash::<R, S>(&self.hash_builder, old);
let r_pairing: &MappingPair<R> = match self.right_set.get(old_r_hash, equivalent_key(old)) {
Some(p) => p,
None => {
return Neither;
}
};
let l_pairing: &mut MappingPair<L> = self
.left_set
.get_mut(r_pairing.hash, hash_and_id(old_r_hash, r_pairing.id))
.unwrap();
let new_r_hash = make_hash::<R, S>(&self.hash_builder, &new);
l_pairing.hash = new_r_hash;
let new_right_pairing = MappingPair {
value: new,
hash: r_pairing.hash,
id: r_pairing.id,
};
let old_right_item: R = self
.right_set
.remove_entry(old_r_hash, equivalent_key(old))
.unwrap()
.extract();
self.right_set.insert(
new_r_hash,
new_right_pairing,
make_hasher::<MappingPair<R>, S>(&self.hash_builder),
);
OptionalPair::from((Some(old_right_item), eq_opt))
}
pub fn swap_right_checked(&mut self, old: &R, expected: &L, new: R) -> OptionalPair<R, (L, R)> {
if !self.are_paired(expected, old) {
return Neither;
} self.swap_right(old, new)
}
pub fn swap_right_or_insert(
&mut self,
old: &R,
new: R,
to_insert: L,
) -> OptionalPair<R, (L, R)> {
let old_r_hash = make_hash::<R, S>(&self.hash_builder, old);
if self
.right_set
.get(old_r_hash, equivalent_key(old))
.is_some()
{
self.swap_right(old, new)
} else {
match self.insert(to_insert, new) {
InsertOptional::Neither => Neither,
InsertOptional::SomeRight(pair) => SomeRight(pair),
_ => {
unreachable!("There isn't a left item")
}
}
}
}
fn swap_right_eq_check(&mut self, old: &R, new: &R, new_hash: u64) -> Option<(L, R)> {
self.right_set.get(new_hash, equivalent_key(new))?;
if new != old {
self.remove_via_right(new)
} else {
None
}
}
pub fn get_left<Q>(&self, item: &Q) -> Option<&L>
where
R: Borrow<Q>,
Q: Hash + Eq + PartialEq<R>,
{
let r_hash = make_hash::<_, S>(&self.hash_builder, item);
let right_pairing: &MappingPair<R> = self.get_right_inner_with_hash(item, r_hash)?;
match self
.left_set
.get(right_pairing.hash, hash_and_id(r_hash, right_pairing.id))
{
None => None,
Some(pairing) => Some(&pairing.value),
}
}
pub fn get_right<Q>(&self, item: &Q) -> Option<&R>
where
L: Borrow<Q>,
Q: Hash + Eq + PartialEq<L>,
{
let l_hash = make_hash::<_, S>(&self.hash_builder, item);
let left_pairing: &MappingPair<L> = self.get_left_inner_with_hash(item, l_hash)?;
match self
.right_set
.get(left_pairing.hash, hash_and_id(l_hash, left_pairing.id))
{
None => None,
Some(pairing) => Some(&pairing.value),
}
}
#[inline]
fn get_left_inner(&self, item: &L) -> Option<&MappingPair<L>> {
let hash = make_hash::<L, S>(&self.hash_builder, item);
self.left_set.get(hash, equivalent_key(item))
}
#[inline]
fn get_left_inner_with_hash<Q>(&self, item: &Q, hash: u64) -> Option<&MappingPair<L>>
where
L: Borrow<Q>,
Q: Hash + Eq + PartialEq<L>,
{
self.left_set.get(hash, equivalent_key(item))
}
#[inline]
fn get_right_inner_with_hash<Q>(&self, item: &Q, hash: u64) -> Option<&MappingPair<R>>
where
R: Borrow<Q>,
Q: Hash + Eq + PartialEq<R>,
{
self.right_set.get(hash, equivalent_key(item))
}
pub fn iter(&self) -> Iter<'_, L, R, S> {
Iter {
left_iter: unsafe { self.left_set.iter() },
map_ref: self,
}
}
pub fn iter_left(&self) -> SingleIter<'_, L> {
SingleIter {
iter: unsafe { self.left_set.iter() },
marker: PhantomData,
}
}
pub fn iter_right(&self) -> SingleIter<'_, R> {
SingleIter {
iter: unsafe { self.right_set.iter() },
marker: PhantomData,
}
}
pub fn drain(&mut self) -> DrainIter<'_, L, R> {
self.counter = 0;
DrainIter {
left_iter: self.left_set.drain(),
right_ref: &mut self.right_set,
}
}
pub fn drain_filter<F>(&mut self, f: F) -> DrainFilterIter<'_, L, R, F>
where
F: FnMut(&L, &R) -> bool,
{
DrainFilterIter {
f,
inner: DrainFilterInner {
left_iter: unsafe { self.left_set.iter() },
left_ref: &mut self.left_set,
right_ref: &mut self.right_set,
},
}
}
pub fn retain<F>(&mut self, mut f: F)
where
F: FnMut(&L, &R) -> bool,
{
let mut to_drop: Vec<(u64, u64, u64)> = Vec::with_capacity(self.left_set.len());
for (left, right) in self.iter() {
if !f(left, right) {
let l_hash = make_hash::<L, S>(&self.hash_builder, left);
let r_hash = make_hash::<R, S>(&self.hash_builder, right);
let id = self.get_left_inner(left).unwrap().id;
to_drop.push((l_hash, r_hash, id));
}
}
for (l_hash, r_hash, id) in to_drop {
self.remove_via_hashes_and_id(l_hash, r_hash, id);
}
}
pub fn shrink_to(&mut self, min_capacity: usize) {
self.left_set
.shrink_to(min_capacity, make_hasher(&self.hash_builder));
self.right_set
.shrink_to(min_capacity, make_hasher(&self.hash_builder));
}
pub fn shrink_to_fit(&mut self) {
self.left_set
.shrink_to(self.len(), make_hasher(&self.hash_builder));
self.right_set
.shrink_to(self.len(), make_hasher(&self.hash_builder));
}
pub fn reserve(&mut self, additional: usize) {
self.left_set
.reserve(additional, make_hasher(&self.hash_builder));
self.right_set
.reserve(additional, make_hasher(&self.hash_builder));
}
pub fn try_reserve(&mut self, additional: usize) -> Result<(), TryReserveError> {
self.left_set
.try_reserve(additional, make_hasher(&self.hash_builder))?;
self.right_set
.try_reserve(additional, make_hasher(&self.hash_builder))?;
Ok(())
}
}
impl<L, R, S> CycleMap<L, R, S> {
pub const fn with_hasher(hash_builder: S) -> Self {
Self {
hash_builder,
counter: 0,
left_set: RawTable::new(),
right_set: RawTable::new(),
}
}
pub fn with_capacity_and_hasher(capacity: usize, hash_builder: S) -> Self {
Self {
hash_builder,
counter: 0,
left_set: RawTable::with_capacity(capacity),
right_set: RawTable::with_capacity(capacity),
}
}
pub fn hasher(&self) -> &S {
&self.hash_builder
}
pub fn capacity(&self) -> usize {
self.left_set.capacity()
}
pub fn len(&self) -> usize {
self.left_set.len()
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn clear(&mut self) {
self.left_set.clear();
self.right_set.clear();
}
}
impl<L, R, S> Clone for CycleMap<L, R, S>
where
L: Eq + Hash + Clone,
R: Eq + Hash + Clone,
S: BuildHasher + Clone,
{
fn clone(&self) -> Self {
Self {
left_set: self.left_set.clone(),
right_set: self.right_set.clone(),
counter: self.counter,
hash_builder: self.hash_builder.clone(),
}
}
}
impl<L, R, S> Default for CycleMap<L, R, S>
where
S: Default,
{
fn default() -> Self {
Self::with_hasher(Default::default())
}
}
impl<L, R, S> fmt::Debug for CycleMap<L, R, S>
where
L: Hash + Eq + fmt::Debug,
R: Hash + Eq + fmt::Debug,
S: BuildHasher,
{
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_set().entries(self.iter()).finish()
}
}
impl<L, R, S> PartialEq<CycleMap<L, R, S>> for CycleMap<L, R, S>
where
L: Hash + Eq,
R: Hash + Eq,
S: BuildHasher,
{
fn eq(&self, other: &Self) -> bool {
if self.len() != other.len() {
return false;
}
self.iter().all(|(l, r)| other.are_paired(l, r))
}
}
impl<L, R, S> Eq for CycleMap<L, R, S>
where
L: Hash + Eq,
R: Hash + Eq,
S: BuildHasher,
{
}
impl<L, R, S> Extend<(L, R)> for CycleMap<L, R, S>
where
L: Hash + Eq,
R: Hash + Eq,
S: BuildHasher,
{
#[inline]
fn extend<T: IntoIterator<Item = (L, R)>>(&mut self, iter: T) {
for (l, r) in iter {
self.insert(l, r);
}
}
}
impl<L, R> FromIterator<(L, R)> for CycleMap<L, R>
where
L: Hash + Eq,
R: Hash + Eq,
{
fn from_iter<T: IntoIterator<Item = (L, R)>>(iter: T) -> Self {
let mut digest = CycleMap::default();
digest.extend(iter);
digest
}
}
pub struct Iter<'a, L, R, S> {
left_iter: RawIter<MappingPair<L>>,
map_ref: &'a CycleMap<L, R, S>,
}
impl<L, R, S> Clone for Iter<'_, L, R, S> {
fn clone(&self) -> Self {
Self {
left_iter: self.left_iter.clone(),
map_ref: self.map_ref,
}
}
}
impl<L, R, S> fmt::Debug for Iter<'_, L, R, S>
where
L: Hash + Eq + fmt::Debug,
R: Hash + Eq + fmt::Debug,
S: BuildHasher,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.clone()).finish()
}
}
impl<'a, L, R, S> Iterator for Iter<'a, L, R, S>
where
L: Hash + Eq,
R: Hash + Eq,
S: BuildHasher,
{
type Item = (&'a L, &'a R);
fn next(&mut self) -> Option<Self::Item> {
match self.left_iter.next() {
Some(l) => unsafe {
let left = &l.as_ref().value;
let right = self.map_ref.get_right(left).unwrap();
Some((left, right))
},
None => None,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.left_iter.size_hint()
}
}
impl<L, R, S> ExactSizeIterator for Iter<'_, L, R, S>
where
L: Hash + Eq,
R: Hash + Eq,
S: BuildHasher,
{
fn len(&self) -> usize {
self.left_iter.len()
}
}
impl<L, R, S> FusedIterator for Iter<'_, L, R, S>
where
L: Hash + Eq,
R: Hash + Eq,
S: BuildHasher,
{
}
pub struct SingleIter<'a, T> {
iter: RawIter<MappingPair<T>>,
marker: PhantomData<&'a T>,
}
impl<T> Clone for SingleIter<'_, T> {
fn clone(&self) -> Self {
Self {
iter: self.iter.clone(),
marker: PhantomData,
}
}
}
impl<T> fmt::Debug for SingleIter<'_, T>
where
T: Hash + Eq + fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_list().entries(self.clone()).finish()
}
}
impl<'a, T> Iterator for SingleIter<'a, T>
where
T: 'a + Hash + Eq,
{
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
match self.iter.next() {
Some(item) => {
let val = unsafe { &item.as_ref().value };
Some(val)
}
None => None,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.iter.size_hint()
}
}
impl<T> ExactSizeIterator for SingleIter<'_, T>
where
T: Hash + Eq,
{
fn len(&self) -> usize {
self.iter.len()
}
}
impl<T> FusedIterator for SingleIter<'_, T> where T: Hash + Eq {}
#[allow(missing_debug_implementations)]
pub struct DrainIter<'a, L, R> {
left_iter: RawDrain<'a, MappingPair<L>>,
right_ref: &'a mut RawTable<MappingPair<R>>,
}
impl<'a, L, R> Iterator for DrainIter<'a, L, R>
where
L: Hash + Eq,
R: Hash + Eq,
{
type Item = (L, R);
fn next(&mut self) -> Option<Self::Item> {
let left = self.left_iter.next()?;
let right = self
.right_ref
.remove_entry(left.hash, just_id(left.id))
.unwrap();
Some((left.extract(), right.extract()))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.left_iter.size_hint()
}
}
impl<L, R> ExactSizeIterator for DrainIter<'_, L, R>
where
L: Hash + Eq,
R: Hash + Eq,
{
fn len(&self) -> usize {
self.left_iter.len()
}
}
impl<L, R> FusedIterator for DrainIter<'_, L, R>
where
L: Hash + Eq,
R: Hash + Eq,
{
}
#[allow(missing_debug_implementations)]
pub struct DrainFilterIter<'a, L, R, F>
where
R: Eq,
F: FnMut(&L, &R) -> bool,
{
f: F,
inner: DrainFilterInner<'a, L, R>,
}
impl<'a, L, R, F> Drop for DrainFilterIter<'a, L, R, F>
where
R: Eq,
F: FnMut(&L, &R) -> bool,
{
fn drop(&mut self) {
while let Some(item) = self.next() {
let guard = ConsumeAllOnDrop(self);
drop(item);
mem::forget(guard);
}
}
}
pub(super) struct ConsumeAllOnDrop<'a, T: Iterator>(pub(super) &'a mut T);
impl<T: Iterator> Drop for ConsumeAllOnDrop<'_, T> {
fn drop(&mut self) {
self.0.for_each(drop)
}
}
impl<L, R: Eq, F> Iterator for DrainFilterIter<'_, L, R, F>
where
F: FnMut(&L, &R) -> bool,
{
type Item = (L, R);
fn next(&mut self) -> Option<Self::Item> {
self.inner.next(&mut self.f)
}
#[inline]
fn size_hint(&self) -> (usize, Option<usize>) {
(0, self.inner.left_iter.size_hint().1)
}
}
impl<L, R: Eq, F> FusedIterator for DrainFilterIter<'_, L, R, F> where F: FnMut(&L, &R) -> bool {}
pub(super) struct DrainFilterInner<'a, L, R> {
pub(super) left_iter: RawIter<MappingPair<L>>,
pub(super) left_ref: &'a mut RawTable<MappingPair<L>>,
pub(super) right_ref: &'a mut RawTable<MappingPair<R>>,
}
impl<L, R: Eq> DrainFilterInner<'_, L, R> {
pub(super) fn next<F>(&mut self, f: &mut F) -> Option<(L, R)>
where
F: FnMut(&L, &R) -> bool,
{
unsafe {
for left in self.left_iter.by_ref() {
let l_pairing = left.as_ref();
let right = self
.right_ref
.find(l_pairing.hash, just_id(l_pairing.id))
.unwrap();
if f(&l_pairing.value, &right.as_ref().value) {
let l = self.left_ref.remove(left).extract();
let r = self.right_ref.remove(right).extract();
return Some((l, r));
}
}
}
None
}
}
impl<T> MappingPair<T> {
pub(crate) fn extract(self) -> T {
self.value
}
}
impl<T: Hash> Hash for MappingPair<T> {
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
self.value.hash(state)
}
}
impl<T: PartialEq> PartialEq for MappingPair<T> {
fn eq(&self, other: &Self) -> bool {
self.id.eq(&other.id) && self.value.eq(&other.value)
}
}
impl<T: PartialEq> PartialEq<T> for MappingPair<T> {
fn eq(&self, other: &T) -> bool {
self.value.eq(other)
}
}
impl<T: Eq> Eq for MappingPair<T> {}
impl<T: Clone> Clone for MappingPair<T> {
fn clone(&self) -> Self {
Self {
value: self.value.clone(),
hash: self.hash,
id: self.id,
}
}
}
impl<T: fmt::Debug> fmt::Debug for MappingPair<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"MappingPair {{ value: {:?}, hash: {}, id: {} }}",
self.value, self.hash, self.id
)
}
}