use crate::{
mem::{Ref, Wrapper},
Overwritten,
};
use std::{
borrow::Borrow,
collections::{hash_map, HashMap},
fmt,
hash::{BuildHasher, Hash},
iter::{Extend, FromIterator, FusedIterator},
rc::Rc,
};
pub struct BiHashMap<L, R, LS = hash_map::RandomState, RS = hash_map::RandomState> {
left2right: HashMap<Ref<L>, Ref<R>, LS>,
right2left: HashMap<Ref<R>, Ref<L>, RS>,
}
impl<L, R> BiHashMap<L, R, hash_map::RandomState, hash_map::RandomState>
where
L: Eq + Hash,
R: Eq + Hash,
{
pub fn new() -> Self {
Self {
left2right: HashMap::new(),
right2left: HashMap::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
left2right: HashMap::with_capacity(capacity),
right2left: HashMap::with_capacity(capacity),
}
}
}
impl<L, R, LS, RS> BiHashMap<L, R, LS, RS>
where
L: Eq + Hash,
R: Eq + Hash,
{
pub fn len(&self) -> usize {
self.left2right.len()
}
pub fn is_empty(&self) -> bool {
self.left2right.is_empty()
}
pub fn capacity(&self) -> usize {
self.left2right.capacity().min(self.right2left.capacity())
}
pub fn clear(&mut self) {
self.left2right.clear();
self.right2left.clear();
}
pub fn iter(&self) -> Iter<'_, L, R> {
Iter {
inner: self.left2right.iter(),
}
}
pub fn left_values(&self) -> LeftValues<'_, L, R> {
LeftValues {
inner: self.left2right.iter(),
}
}
pub fn right_values(&self) -> RightValues<'_, L, R> {
RightValues {
inner: self.right2left.iter(),
}
}
}
impl<L, R, LS, RS> BiHashMap<L, R, LS, RS>
where
L: Eq + Hash,
R: Eq + Hash,
LS: BuildHasher,
RS: BuildHasher,
{
pub fn with_hashers(hash_builder_left: LS, hash_builder_right: RS) -> Self {
Self {
left2right: HashMap::with_hasher(hash_builder_left),
right2left: HashMap::with_hasher(hash_builder_right),
}
}
pub fn with_capacity_and_hashers(
capacity: usize,
hash_builder_left: LS,
hash_builder_right: RS,
) -> Self {
Self {
left2right: HashMap::with_capacity_and_hasher(capacity, hash_builder_left),
right2left: HashMap::with_capacity_and_hasher(capacity, hash_builder_right),
}
}
pub fn reserve(&mut self, additional: usize) {
self.left2right.reserve(additional);
self.right2left.reserve(additional);
}
pub fn shrink_to_fit(&mut self) {
self.left2right.shrink_to_fit();
self.right2left.shrink_to_fit();
}
pub fn shrink_to(&mut self, min_capacity: usize) {
self.left2right.shrink_to(min_capacity);
self.right2left.shrink_to(min_capacity);
}
pub fn get_by_left<Q>(&self, left: &Q) -> Option<&R>
where
L: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
self.left2right.get(Wrapper::wrap(left)).map(|r| &*r.0)
}
pub fn get_by_right<Q>(&self, right: &Q) -> Option<&L>
where
R: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
self.right2left.get(Wrapper::wrap(right)).map(|l| &*l.0)
}
pub fn contains_left<Q>(&self, left: &Q) -> bool
where
L: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
self.left2right.contains_key(Wrapper::wrap(left))
}
pub fn contains_right<Q>(&self, right: &Q) -> bool
where
R: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
self.right2left.contains_key(Wrapper::wrap(right))
}
pub fn remove_by_left<Q>(&mut self, left: &Q) -> Option<(L, R)>
where
L: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
self.left2right.remove(Wrapper::wrap(left)).map(|right_rc| {
let left_rc = self.right2left.remove(&right_rc).unwrap();
(
Rc::try_unwrap(left_rc.0).ok().unwrap(),
Rc::try_unwrap(right_rc.0).ok().unwrap(),
)
})
}
pub fn remove_by_right<Q>(&mut self, right: &Q) -> Option<(L, R)>
where
R: Borrow<Q>,
Q: Eq + Hash + ?Sized,
{
self.right2left.remove(Wrapper::wrap(right)).map(|left_rc| {
let right_rc = self.left2right.remove(&left_rc).unwrap();
(
Rc::try_unwrap(left_rc.0).ok().unwrap(),
Rc::try_unwrap(right_rc.0).ok().unwrap(),
)
})
}
pub fn insert(&mut self, left: L, right: R) -> Overwritten<L, R> {
let retval = match (self.remove_by_left(&left), self.remove_by_right(&right)) {
(None, None) => Overwritten::Neither,
(None, Some(r_pair)) => Overwritten::Right(r_pair.0, r_pair.1),
(Some(l_pair), None) => {
if l_pair.1 == right {
Overwritten::Pair(l_pair.0, l_pair.1)
} else {
Overwritten::Left(l_pair.0, l_pair.1)
}
}
(Some(l_pair), Some(r_pair)) => Overwritten::Both(l_pair, r_pair),
};
self.insert_unchecked(left, right);
retval
}
pub fn insert_no_overwrite(&mut self, left: L, right: R) -> Result<(), (L, R)> {
if self.contains_left(&left) || self.contains_right(&right) {
Err((left, right))
} else {
self.insert_unchecked(left, right);
Ok(())
}
}
pub fn retain<F>(&mut self, f: F)
where
F: FnMut(&L, &R) -> bool,
{
let mut f = f;
let right2left = &mut self.right2left;
self.left2right.retain(|l, r| {
let to_retain = f(&l.0, &r.0);
if !to_retain {
right2left.remove(r);
}
to_retain
});
}
fn insert_unchecked(&mut self, left: L, right: R) {
let left = Ref(Rc::new(left));
let right = Ref(Rc::new(right));
self.left2right.insert(left.clone(), right.clone());
self.right2left.insert(right, left);
}
}
impl<L, R, LS, RS> Clone for BiHashMap<L, R, LS, RS>
where
L: Clone + Eq + Hash,
R: Clone + Eq + Hash,
LS: BuildHasher + Clone,
RS: BuildHasher + Clone,
{
fn clone(&self) -> BiHashMap<L, R, LS, RS> {
let mut new_bimap = BiHashMap::with_capacity_and_hashers(
self.capacity(),
self.left2right.hasher().clone(),
self.right2left.hasher().clone(),
);
for (l, r) in self.iter() {
new_bimap.insert(l.clone(), r.clone());
}
new_bimap
}
}
impl<L, R, LS, RS> fmt::Debug for BiHashMap<L, R, LS, RS>
where
L: fmt::Debug,
R: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
struct EntryDebugger<'a, L, R> {
left: &'a L,
right: &'a R,
}
impl<'a, L, R> fmt::Debug for EntryDebugger<'a, L, R>
where
L: fmt::Debug,
R: fmt::Debug,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
self.left.fmt(f)?;
write!(f, " <> ")?;
self.right.fmt(f)
}
}
f.debug_set()
.entries(
self.left2right
.iter()
.map(|(left, right)| EntryDebugger { left, right }),
)
.finish()
}
}
impl<L, R, LS, RS> Default for BiHashMap<L, R, LS, RS>
where
L: Eq + Hash,
R: Eq + Hash,
LS: BuildHasher + Default,
RS: BuildHasher + Default,
{
fn default() -> BiHashMap<L, R, LS, RS> {
BiHashMap {
left2right: HashMap::default(),
right2left: HashMap::default(),
}
}
}
impl<L, R, LS, RS> Eq for BiHashMap<L, R, LS, RS>
where
L: Eq + Hash,
R: Eq + Hash,
LS: BuildHasher,
RS: BuildHasher,
{
}
impl<L, R, LS, RS> FromIterator<(L, R)> for BiHashMap<L, R, LS, RS>
where
L: Eq + Hash,
R: Eq + Hash,
LS: BuildHasher + Default,
RS: BuildHasher + Default,
{
fn from_iter<I>(iter: I) -> BiHashMap<L, R, LS, RS>
where
I: IntoIterator<Item = (L, R)>,
{
let iter = iter.into_iter();
let mut bimap = match iter.size_hint() {
(lower, None) => {
BiHashMap::with_capacity_and_hashers(lower, LS::default(), RS::default())
}
(_, Some(upper)) => {
BiHashMap::with_capacity_and_hashers(upper, LS::default(), RS::default())
}
};
for (left, right) in iter {
bimap.insert(left, right);
}
bimap
}
}
impl<'a, L, R, LS, RS> IntoIterator for &'a BiHashMap<L, R, LS, RS>
where
L: Eq + Hash,
R: Eq + Hash,
{
type Item = (&'a L, &'a R);
type IntoIter = Iter<'a, L, R>;
fn into_iter(self) -> Iter<'a, L, R> {
self.iter()
}
}
impl<L, R, LS, RS> IntoIterator for BiHashMap<L, R, LS, RS>
where
L: Eq + Hash,
R: Eq + Hash,
{
type Item = (L, R);
type IntoIter = IntoIter<L, R>;
fn into_iter(self) -> IntoIter<L, R> {
IntoIter {
inner: self.left2right.into_iter(),
}
}
}
impl<L, R, LS, RS> Extend<(L, R)> for BiHashMap<L, R, LS, RS>
where
L: Eq + Hash,
R: Eq + Hash,
LS: BuildHasher,
RS: BuildHasher,
{
fn extend<T: IntoIterator<Item = (L, R)>>(&mut self, iter: T) {
iter.into_iter().for_each(move |(l, r)| {
self.insert(l, r);
});
}
}
impl<L, R, LS, RS> PartialEq for BiHashMap<L, R, LS, RS>
where
L: Eq + Hash,
R: Eq + Hash,
LS: BuildHasher,
RS: BuildHasher,
{
fn eq(&self, other: &Self) -> bool {
self.left2right == other.left2right
}
}
pub struct IntoIter<L, R> {
inner: hash_map::IntoIter<Ref<L>, Ref<R>>,
}
impl<L, R> ExactSizeIterator for IntoIter<L, R> {}
impl<L, R> FusedIterator for IntoIter<L, R> {}
impl<L, R> Iterator for IntoIter<L, R> {
type Item = (L, R);
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(l, r)| {
(
Rc::try_unwrap(l.0).ok().unwrap(),
Rc::try_unwrap(r.0).ok().unwrap(),
)
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[derive(Debug, Clone)]
pub struct Iter<'a, L, R> {
inner: hash_map::Iter<'a, Ref<L>, Ref<R>>,
}
impl<'a, L, R> ExactSizeIterator for Iter<'a, L, R> {}
impl<'a, L, R> FusedIterator for Iter<'a, L, R> {}
impl<'a, L, R> Iterator for Iter<'a, L, R> {
type Item = (&'a L, &'a R);
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(l, r)| (&*l.0, &*r.0))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[derive(Debug, Clone)]
pub struct LeftValues<'a, L, R> {
inner: hash_map::Iter<'a, Ref<L>, Ref<R>>,
}
impl<'a, L, R> ExactSizeIterator for LeftValues<'a, L, R> {}
impl<'a, L, R> FusedIterator for LeftValues<'a, L, R> {}
impl<'a, L, R> Iterator for LeftValues<'a, L, R> {
type Item = &'a L;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(l, _)| &*l.0)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[derive(Debug, Clone)]
pub struct RightValues<'a, L, R> {
inner: hash_map::Iter<'a, Ref<R>, Ref<L>>,
}
impl<'a, L, R> ExactSizeIterator for RightValues<'a, L, R> {}
impl<'a, L, R> FusedIterator for RightValues<'a, L, R> {}
impl<'a, L, R> Iterator for RightValues<'a, L, R> {
type Item = &'a R;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(r, _)| &*r.0)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
unsafe impl<L, R, LS, RS> Send for BiHashMap<L, R, LS, RS>
where
L: Send,
R: Send,
LS: Send,
RS: Send,
{
}
unsafe impl<L, R, LS, RS> Sync for BiHashMap<L, R, LS, RS>
where
L: Sync,
R: Sync,
LS: Sync,
RS: Sync,
{
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn clone() {
let mut bimap = BiHashMap::new();
bimap.insert('a', 1);
bimap.insert('b', 2);
let bimap2 = bimap.clone();
assert_eq!(bimap, bimap2);
}
#[test]
fn deep_clone() {
let mut bimap = BiHashMap::new();
bimap.insert('a', 1);
bimap.insert('b', 2);
let mut bimap2 = bimap.clone();
bimap.insert('b', 5);
bimap2.insert('a', 12);
bimap2.remove_by_left(&'a');
bimap.remove_by_right(&2);
}
#[test]
fn debug() {
let mut bimap = BiHashMap::new();
assert_eq!("{}", format!("{:?}", bimap));
bimap.insert('a', 1);
assert_eq!("{'a' <> 1}", format!("{:?}", bimap));
bimap.insert('b', 2);
let expected1 = "{'a' <> 1, 'b' <> 2}";
let expected2 = "{'b' <> 2, 'a' <> 1}";
let formatted = format!("{:?}", bimap);
assert!(formatted == expected1 || formatted == expected2);
}
#[test]
fn default() {
let _ = BiHashMap::<char, i32>::default();
}
#[test]
fn eq() {
let mut bimap = BiHashMap::new();
assert_eq!(bimap, bimap);
bimap.insert('a', 1);
assert_eq!(bimap, bimap);
bimap.insert('b', 2);
assert_eq!(bimap, bimap);
let mut bimap2 = BiHashMap::new();
assert_ne!(bimap, bimap2);
bimap2.insert('a', 1);
assert_ne!(bimap, bimap2);
bimap2.insert('b', 2);
assert_eq!(bimap, bimap2);
bimap2.insert('c', 3);
assert_ne!(bimap, bimap2);
}
#[test]
fn from_iter() {
let bimap = BiHashMap::from_iter(vec![
('a', 1),
('b', 2),
('c', 3),
('b', 2),
('a', 4),
('b', 3),
]);
let mut bimap2 = BiHashMap::new();
bimap2.insert('a', 4);
bimap2.insert('b', 3);
assert_eq!(bimap, bimap2);
}
#[test]
fn into_iter() {
let mut bimap = BiHashMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let mut pairs = bimap.into_iter().collect::<Vec<_>>();
pairs.sort();
assert_eq!(pairs, vec![('a', 3), ('b', 2), ('c', 1)]);
}
#[test]
fn into_iter_ref() {
let mut bimap = BiHashMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let mut pairs = (&bimap).into_iter().collect::<Vec<_>>();
pairs.sort();
assert_eq!(pairs, vec![(&'a', &3), (&'b', &2), (&'c', &1)]);
}
#[test]
fn extend() {
let mut bimap = BiHashMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.extend(vec![('c', 3), ('b', 1), ('a', 4)]);
let mut bimap2 = BiHashMap::new();
bimap2.insert('a', 4);
bimap2.insert('b', 1);
bimap2.insert('c', 3);
assert_eq!(bimap, bimap2);
}
#[test]
fn iter() {
let mut bimap = BiHashMap::new();
bimap.insert('a', 1);
bimap.insert('b', 2);
bimap.insert('c', 3);
let mut pairs = bimap.iter().map(|(c, i)| (*c, *i)).collect::<Vec<_>>();
pairs.sort();
assert_eq!(pairs, vec![('a', 1), ('b', 2), ('c', 3)]);
}
#[test]
fn left_values() {
let mut bimap = BiHashMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let mut left_values = bimap.left_values().cloned().collect::<Vec<_>>();
left_values.sort();
assert_eq!(left_values, vec!['a', 'b', 'c'])
}
#[test]
fn right_values() {
let mut bimap = BiHashMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let mut right_values = bimap.right_values().cloned().collect::<Vec<_>>();
right_values.sort();
assert_eq!(right_values, vec![1, 2, 3])
}
#[test]
fn capacity() {
let bimap = BiHashMap::<char, i32>::with_capacity(10);
assert!(bimap.capacity() >= 10);
}
#[test]
fn with_hashers() {
let s_left = hash_map::RandomState::new();
let s_right = hash_map::RandomState::new();
let mut bimap = BiHashMap::<char, i32>::with_hashers(s_left, s_right);
bimap.insert('a', 42);
assert_eq!(Some(&'a'), bimap.get_by_right(&42));
assert_eq!(Some(&42), bimap.get_by_left(&'a'));
}
#[test]
fn reserve() {
let mut bimap = BiHashMap::<char, i32>::new();
assert!(bimap.is_empty());
assert_eq!(bimap.len(), 0);
assert_eq!(bimap.capacity(), 0);
bimap.reserve(10);
assert!(bimap.is_empty());
assert_eq!(bimap.len(), 0);
assert!(bimap.capacity() >= 10);
}
#[test]
fn shrink_to_fit() {
let mut bimap = BiHashMap::<char, i32>::with_capacity(100);
assert!(bimap.is_empty());
assert_eq!(bimap.len(), 0);
assert!(bimap.capacity() >= 100);
bimap.insert('a', 1);
bimap.insert('b', 2);
assert!(!bimap.is_empty());
assert_eq!(bimap.len(), 2);
assert!(bimap.capacity() >= 100);
bimap.shrink_to_fit();
assert!(!bimap.is_empty());
assert_eq!(bimap.len(), 2);
assert!(bimap.capacity() >= 2);
}
#[test]
fn shrink_to() {
let mut bimap = BiHashMap::<char, i32>::with_capacity(100);
assert!(bimap.is_empty());
assert_eq!(bimap.len(), 0);
assert!(bimap.capacity() >= 100);
bimap.insert('a', 1);
bimap.insert('b', 2);
assert!(!bimap.is_empty());
assert_eq!(bimap.len(), 2);
assert!(bimap.capacity() >= 100);
bimap.shrink_to(10);
assert!(!bimap.is_empty());
assert_eq!(bimap.len(), 2);
assert!(bimap.capacity() >= 10);
bimap.shrink_to(0);
assert!(!bimap.is_empty());
assert_eq!(bimap.len(), 2);
assert!(bimap.capacity() >= 2);
}
#[test]
fn clear() {
let mut bimap = vec![('a', 1)].into_iter().collect::<BiHashMap<_, _>>();
assert_eq!(bimap.len(), 1);
assert!(!bimap.is_empty());
bimap.clear();
assert_eq!(bimap.len(), 0);
assert!(bimap.is_empty());
}
#[test]
fn get_contains() {
let bimap = vec![('a', 1)].into_iter().collect::<BiHashMap<_, _>>();
assert_eq!(bimap.get_by_left(&'a'), Some(&1));
assert!(bimap.contains_left(&'a'));
assert_eq!(bimap.get_by_left(&'b'), None);
assert!(!bimap.contains_left(&'b'));
assert_eq!(bimap.get_by_right(&1), Some(&'a'));
assert!(bimap.contains_right(&1));
assert_eq!(bimap.get_by_right(&2), None);
assert!(!bimap.contains_right(&2));
}
#[test]
fn insert() {
let mut bimap = BiHashMap::new();
assert_eq!(bimap.insert('a', 1), Overwritten::Neither);
assert_eq!(bimap.insert('a', 2), Overwritten::Left('a', 1));
assert_eq!(bimap.insert('b', 2), Overwritten::Right('a', 2));
assert_eq!(bimap.insert('b', 2), Overwritten::Pair('b', 2));
assert_eq!(bimap.insert('c', 3), Overwritten::Neither);
assert_eq!(bimap.insert('b', 3), Overwritten::Both(('b', 2), ('c', 3)));
}
#[test]
fn insert_no_overwrite() {
let mut bimap = BiHashMap::new();
assert!(bimap.insert_no_overwrite('a', 1).is_ok());
assert!(bimap.insert_no_overwrite('a', 2).is_err());
assert!(bimap.insert_no_overwrite('b', 1).is_err());
}
#[test]
fn retain_calls_f_once() {
let mut bimap = BiHashMap::new();
bimap.insert('a', 1);
bimap.insert('b', 2);
bimap.insert('c', 3);
let mut i = 0;
bimap.retain(|_l, _r| {
i += 1;
i <= 1
});
assert_eq!(bimap.len(), 1);
assert_eq!(i, 3);
}
}