use crate::{
mem::{Ref, Wrapper},
Overwritten,
};
use alloc::{
collections::{btree_map, BTreeMap},
rc::Rc,
};
use core::{
borrow::Borrow,
cmp::Ordering,
fmt,
hash::{Hash, Hasher},
iter::{Extend, FromIterator, FusedIterator},
ops::RangeBounds,
};
pub struct BiBTreeMap<L, R> {
left2right: BTreeMap<Ref<L>, Ref<R>>,
right2left: BTreeMap<Ref<R>, Ref<L>>,
}
impl<L, R> BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
pub fn new() -> Self {
Self {
left2right: BTreeMap::new(),
right2left: BTreeMap::new(),
}
}
pub fn len(&self) -> usize {
self.left2right.len()
}
pub fn is_empty(&self) -> bool {
self.left2right.is_empty()
}
pub fn clear(&mut self) {
self.left2right.clear();
self.right2left.clear();
}
pub fn iter(&self) -> Iter<'_, L, R> {
Iter {
inner: self.left2right.iter(),
}
}
pub fn left_values(&self) -> LeftValues<'_, L, R> {
LeftValues {
inner: self.left2right.iter(),
}
}
pub fn right_values(&self) -> RightValues<'_, L, R> {
RightValues {
inner: self.right2left.iter(),
}
}
pub fn get_by_left<Q>(&self, left: &Q) -> Option<&R>
where
L: Borrow<Q>,
Q: Ord + ?Sized,
{
self.left2right.get(Wrapper::wrap(left)).map(|l| &*l.0)
}
pub fn get_by_right<Q>(&self, right: &Q) -> Option<&L>
where
R: Borrow<Q>,
Q: Ord + ?Sized,
{
self.right2left.get(Wrapper::wrap(right)).map(|r| &*r.0)
}
pub fn contains_left<Q>(&self, left: &Q) -> bool
where
L: Borrow<Q>,
Q: Ord + ?Sized,
{
self.left2right.contains_key(Wrapper::wrap(left))
}
pub fn contains_right<Q>(&self, right: &Q) -> bool
where
R: Borrow<Q>,
Q: Ord + ?Sized,
{
self.right2left.contains_key(Wrapper::wrap(right))
}
pub fn remove_by_left<Q>(&mut self, left: &Q) -> Option<(L, R)>
where
L: Borrow<Q>,
Q: Ord + ?Sized,
{
self.left2right.remove(Wrapper::wrap(left)).map(|right_rc| {
let left_rc = self.right2left.remove(&right_rc).unwrap();
(
Rc::try_unwrap(left_rc.0).ok().unwrap(),
Rc::try_unwrap(right_rc.0).ok().unwrap(),
)
})
}
pub fn remove_by_right<Q>(&mut self, right: &Q) -> Option<(L, R)>
where
R: Borrow<Q>,
Q: Ord + ?Sized,
{
self.right2left.remove(Wrapper::wrap(right)).map(|left_rc| {
let right_rc = self.left2right.remove(&left_rc).unwrap();
(
Rc::try_unwrap(left_rc.0).ok().unwrap(),
Rc::try_unwrap(right_rc.0).ok().unwrap(),
)
})
}
pub fn insert(&mut self, left: L, right: R) -> Overwritten<L, R> {
let retval = match (self.remove_by_left(&left), self.remove_by_right(&right)) {
(None, None) => Overwritten::Neither,
(None, Some(r_pair)) => Overwritten::Right(r_pair.0, r_pair.1),
(Some(l_pair), None) => {
if l_pair.1 == right {
Overwritten::Pair(l_pair.0, l_pair.1)
} else {
Overwritten::Left(l_pair.0, l_pair.1)
}
}
(Some(l_pair), Some(r_pair)) => Overwritten::Both(l_pair, r_pair),
};
self.insert_unchecked(left, right);
retval
}
pub fn insert_no_overwrite(&mut self, left: L, right: R) -> Result<(), (L, R)> {
if self.contains_left(&left) || self.contains_right(&right) {
Err((left, right))
} else {
self.insert_unchecked(left, right);
Ok(())
}
}
fn insert_unchecked(&mut self, left: L, right: R) {
let left = Ref(Rc::new(left));
let right_rc = Ref(Rc::new(right));
self.left2right.insert(left.clone(), right_rc.clone());
self.right2left.insert(right_rc, left);
}
pub fn left_range<T, A>(&self, range: A) -> LeftRange<'_, L, R>
where
L: Borrow<T>,
A: RangeBounds<T>,
T: Ord + ?Sized,
{
let start = Wrapper::wrap_bound(range.start_bound());
let end = Wrapper::wrap_bound(range.end_bound());
LeftRange {
inner: self.left2right.range::<Wrapper<_>, _>((start, end)),
}
}
pub fn right_range<T, A>(&self, range: A) -> RightRange<'_, L, R>
where
R: Borrow<T>,
A: RangeBounds<T>,
T: Ord + ?Sized,
{
let start = Wrapper::wrap_bound(range.start_bound());
let end = Wrapper::wrap_bound(range.end_bound());
RightRange {
inner: self.right2left.range::<Wrapper<_>, _>((start, end)),
}
}
}
impl<L, R> Clone for BiBTreeMap<L, R>
where
L: Clone + Ord,
R: Clone + Ord,
{
fn clone(&self) -> BiBTreeMap<L, R> {
self.iter().map(|(l, r)| (l.clone(), r.clone())).collect()
}
}
impl<L, R> fmt::Debug for BiBTreeMap<L, R>
where
L: fmt::Debug + Ord,
R: fmt::Debug + Ord,
{
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(f, "{{")?;
for (i, (left, right)) in self.left2right.iter().enumerate() {
let comma = if i == 0 { "" } else { ", " };
write!(f, "{}{:?} <> {:?}", comma, left, right)?;
}
write!(f, "}}")?;
Ok(())
}
}
impl<L, R> Default for BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
fn default() -> BiBTreeMap<L, R> {
BiBTreeMap {
left2right: BTreeMap::default(),
right2left: BTreeMap::default(),
}
}
}
impl<L, R> Eq for BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
}
impl<L, R> FromIterator<(L, R)> for BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
fn from_iter<I>(iter: I) -> BiBTreeMap<L, R>
where
I: IntoIterator<Item = (L, R)>,
{
let mut bimap = BiBTreeMap::new();
for (left, right) in iter {
bimap.insert(left, right);
}
bimap
}
}
impl<'a, L, R> IntoIterator for &'a BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
type Item = (&'a L, &'a R);
type IntoIter = Iter<'a, L, R>;
fn into_iter(self) -> Iter<'a, L, R> {
self.iter()
}
}
impl<L, R> IntoIterator for BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
type Item = (L, R);
type IntoIter = IntoIter<L, R>;
fn into_iter(self) -> IntoIter<L, R> {
IntoIter {
inner: self.left2right.into_iter(),
}
}
}
impl<L, R> Extend<(L, R)> for BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
fn extend<T: IntoIterator<Item = (L, R)>>(&mut self, iter: T) {
iter.into_iter().for_each(move |(l, r)| {
self.insert(l, r);
});
}
}
impl<L, R> Ord for BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
fn cmp(&self, other: &Self) -> Ordering {
self.left2right.cmp(&other.left2right)
}
}
impl<L, R> PartialEq for BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
fn eq(&self, other: &Self) -> bool {
self.left2right == other.left2right
}
}
impl<L, R> PartialOrd for BiBTreeMap<L, R>
where
L: Ord,
R: Ord,
{
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.left2right.partial_cmp(&other.left2right)
}
}
impl<L, R> Hash for BiBTreeMap<L, R>
where
L: Hash,
R: Hash,
{
fn hash<H: Hasher>(&self, state: &mut H) {
self.left2right.hash(state);
}
}
pub struct IntoIter<L, R> {
inner: btree_map::IntoIter<Ref<L>, Ref<R>>,
}
impl<L, R> DoubleEndedIterator for IntoIter<L, R> {
fn next_back(&mut self) -> Option<Self::Item> {
self.inner.next_back().map(|(l, r)| {
(
Rc::try_unwrap(l.0).ok().unwrap(),
Rc::try_unwrap(r.0).ok().unwrap(),
)
})
}
}
impl<L, R> ExactSizeIterator for IntoIter<L, R> {}
impl<L, R> FusedIterator for IntoIter<L, R> {}
impl<L, R> Iterator for IntoIter<L, R> {
type Item = (L, R);
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(l, r)| {
(
Rc::try_unwrap(l.0).ok().unwrap(),
Rc::try_unwrap(r.0).ok().unwrap(),
)
})
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
pub struct Iter<'a, L, R> {
inner: btree_map::Iter<'a, Ref<L>, Ref<R>>,
}
impl<'a, L, R> DoubleEndedIterator for Iter<'a, L, R> {
fn next_back(&mut self) -> Option<Self::Item> {
self.inner.next_back().map(|(l, r)| (&*l.0, &*r.0))
}
}
impl<'a, L, R> ExactSizeIterator for Iter<'a, L, R> {}
impl<'a, L, R> FusedIterator for Iter<'a, L, R> {}
impl<'a, L, R> Iterator for Iter<'a, L, R> {
type Item = (&'a L, &'a R);
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(l, r)| (&*l.0, &*r.0))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
pub struct LeftValues<'a, L, R> {
inner: btree_map::Iter<'a, Ref<L>, Ref<R>>,
}
impl<'a, L, R> DoubleEndedIterator for LeftValues<'a, L, R> {
fn next_back(&mut self) -> Option<Self::Item> {
self.inner.next_back().map(|(l, _)| &*l.0)
}
}
impl<'a, L, R> ExactSizeIterator for LeftValues<'a, L, R> {}
impl<'a, L, R> FusedIterator for LeftValues<'a, L, R> {}
impl<'a, L, R> Iterator for LeftValues<'a, L, R> {
type Item = &'a L;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(l, _)| &*l.0)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
pub struct RightValues<'a, L, R> {
inner: btree_map::Iter<'a, Ref<R>, Ref<L>>,
}
impl<'a, L, R> DoubleEndedIterator for RightValues<'a, L, R> {
fn next_back(&mut self) -> Option<Self::Item> {
self.inner.next_back().map(|(r, _)| &*r.0)
}
}
impl<'a, L, R> ExactSizeIterator for RightValues<'a, L, R> {}
impl<'a, L, R> FusedIterator for RightValues<'a, L, R> {}
impl<'a, L, R> Iterator for RightValues<'a, L, R> {
type Item = &'a R;
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(r, _)| &*r.0)
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[derive(Debug)]
pub struct LeftRange<'a, L, R> {
inner: btree_map::Range<'a, Ref<L>, Ref<R>>,
}
impl<'a, L, R> DoubleEndedIterator for LeftRange<'a, L, R> {
fn next_back(&mut self) -> Option<Self::Item> {
self.inner.next_back().map(|(l, r)| (&*l.0, &*r.0))
}
}
impl<'a, L, R> ExactSizeIterator for LeftRange<'a, L, R> {}
impl<'a, L, R> FusedIterator for LeftRange<'a, L, R> {}
impl<'a, L, R> Iterator for LeftRange<'a, L, R> {
type Item = (&'a L, &'a R);
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(l, r)| (&*l.0, &*r.0))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
#[derive(Debug)]
pub struct RightRange<'a, L, R> {
inner: btree_map::Range<'a, Ref<R>, Ref<L>>,
}
impl<'a, L, R> DoubleEndedIterator for RightRange<'a, L, R> {
fn next_back(&mut self) -> Option<Self::Item> {
self.inner.next_back().map(|(r, l)| (&*l.0, &*r.0))
}
}
impl<'a, L, R> ExactSizeIterator for RightRange<'a, L, R> {}
impl<'a, L, R> FusedIterator for RightRange<'a, L, R> {}
impl<'a, L, R> Iterator for RightRange<'a, L, R> {
type Item = (&'a L, &'a R);
fn next(&mut self) -> Option<Self::Item> {
self.inner.next().map(|(r, l)| (&*l.0, &*r.0))
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
unsafe impl<L, R> Send for BiBTreeMap<L, R>
where
L: Send,
R: Send,
{
}
unsafe impl<L, R> Sync for BiBTreeMap<L, R>
where
L: Sync,
R: Sync,
{
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(not(feature = "std"))]
use alloc::vec::Vec;
#[test]
fn clone() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 1);
bimap.insert('b', 2);
let bimap2 = bimap.clone();
assert_eq!(bimap, bimap2);
}
#[test]
fn deep_clone() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 1);
bimap.insert('b', 2);
let mut bimap2 = bimap.clone();
bimap.insert('b', 5);
bimap2.insert('a', 12);
bimap2.remove_by_left(&'a');
bimap.remove_by_right(&2);
}
#[test]
fn debug() {
let mut bimap = BiBTreeMap::new();
assert_eq!("{}", format!("{:?}", bimap));
bimap.insert('a', 1);
assert_eq!("{'a' <> 1}", format!("{:?}", bimap));
bimap.insert('b', 2);
assert_eq!("{'a' <> 1, 'b' <> 2}", format!("{:?}", bimap));
}
#[test]
fn default() {
let _ = BiBTreeMap::<char, i32>::default();
}
#[test]
fn eq() {
let mut bimap = BiBTreeMap::new();
assert_eq!(bimap, bimap);
bimap.insert('a', 1);
assert_eq!(bimap, bimap);
bimap.insert('b', 2);
assert_eq!(bimap, bimap);
let mut bimap2 = BiBTreeMap::new();
assert_ne!(bimap, bimap2);
bimap2.insert('a', 1);
assert_ne!(bimap, bimap2);
bimap2.insert('b', 2);
assert_eq!(bimap, bimap2);
bimap2.insert('c', 3);
assert_ne!(bimap, bimap2);
}
#[test]
fn from_iter() {
let bimap = BiBTreeMap::from_iter(vec![
('a', 1),
('b', 2),
('c', 3),
('b', 2),
('a', 4),
('b', 3),
]);
let mut bimap2 = BiBTreeMap::new();
bimap2.insert('a', 4);
bimap2.insert('b', 3);
assert_eq!(bimap, bimap2);
}
#[test]
fn into_iter() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let pairs = bimap.into_iter().collect::<Vec<_>>();
assert_eq!(pairs, vec![('a', 3), ('b', 2), ('c', 1)]);
}
#[test]
fn into_iter_ref() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let pairs = (&bimap).into_iter().collect::<Vec<_>>();
assert_eq!(pairs, vec![(&'a', &3), (&'b', &2), (&'c', &1)]);
}
#[test]
fn extend() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.extend(vec![('c', 3), ('b', 1), ('a', 4)]);
let mut bimap2 = BiBTreeMap::new();
bimap2.insert('a', 4);
bimap2.insert('b', 1);
bimap2.insert('c', 3);
assert_eq!(bimap, bimap2);
}
#[test]
fn cmp() {
let bimap = BiBTreeMap::from_iter(vec![('a', 2)]);
let bimap2 = BiBTreeMap::from_iter(vec![('b', 1)]);
assert_eq!(bimap.partial_cmp(&bimap2), Some(Ordering::Less));
assert_eq!(bimap.cmp(&bimap2), Ordering::Less);
assert_eq!(bimap2.partial_cmp(&bimap), Some(Ordering::Greater));
assert_eq!(bimap2.cmp(&bimap), Ordering::Greater);
assert_eq!(bimap.cmp(&bimap), Ordering::Equal);
assert_eq!(bimap2.cmp(&bimap2), Ordering::Equal);
}
#[test]
fn iter() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 1);
bimap.insert('b', 2);
bimap.insert('c', 3);
let pairs = bimap.iter().map(|(c, i)| (*c, *i)).collect::<Vec<_>>();
assert_eq!(pairs, vec![('a', 1), ('b', 2), ('c', 3)]);
}
#[test]
fn iter_rev() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 1);
bimap.insert('b', 2);
bimap.insert('c', 3);
let mut iter = bimap.iter();
assert_eq!(iter.next_back(), Some((&'c', &3)));
assert_eq!(iter.next_back(), Some((&'b', &2)));
assert_eq!(iter.next_back(), Some((&'a', &1)));
}
#[test]
fn into_iter_rev() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 1);
bimap.insert('b', 2);
bimap.insert('c', 3);
let mut iter = bimap.into_iter();
assert_eq!(iter.next_back(), Some(('c', 3)));
assert_eq!(iter.next_back(), Some(('b', 2)));
assert_eq!(iter.next_back(), Some(('a', 1)));
}
#[test]
fn left_values() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let left_values = bimap.left_values().cloned().collect::<Vec<_>>();
assert_eq!(left_values, vec!['a', 'b', 'c'])
}
#[test]
fn left_values_rev() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let mut iter = bimap.left_values();
assert_eq!(iter.next_back(), Some(&'c'));
assert_eq!(iter.next_back(), Some(&'b'));
assert_eq!(iter.next_back(), Some(&'a'));
}
#[test]
fn right_values() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let right_values = bimap.right_values().cloned().collect::<Vec<_>>();
assert_eq!(right_values, vec![1, 2, 3])
}
#[test]
fn right_values_rev() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 3);
bimap.insert('b', 2);
bimap.insert('c', 1);
let mut iter = bimap.right_values();
assert_eq!(iter.next_back(), Some(&3));
assert_eq!(iter.next_back(), Some(&2));
assert_eq!(iter.next_back(), Some(&1));
}
#[test]
fn left_range() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 4);
bimap.insert('b', 3);
bimap.insert('c', 2);
bimap.insert('d', 1);
let left_range = bimap
.left_range('b'..'d')
.map(|(l, r)| (*l, *r))
.collect::<Vec<_>>();
assert_eq!(left_range, vec![('b', 3), ('c', 2)])
}
#[test]
fn left_range_rev() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 4);
bimap.insert('b', 3);
bimap.insert('c', 2);
bimap.insert('d', 1);
let mut left_range = bimap.left_range('b'..'d');
assert_eq!(left_range.next_back(), Some((&'c', &2)));
assert_eq!(left_range.next_back(), Some((&'b', &3)));
assert_eq!(left_range.next_back(), None);
}
#[test]
fn right_range() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 4);
bimap.insert('b', 3);
bimap.insert('c', 2);
bimap.insert('d', 1);
let right_range = bimap
.right_range(2..4)
.map(|(l, r)| (*l, *r))
.collect::<Vec<_>>();
assert_eq!(right_range, vec![('c', 2), ('b', 3)])
}
#[test]
fn right_range_rev() {
let mut bimap = BiBTreeMap::new();
bimap.insert('a', 4);
bimap.insert('b', 3);
bimap.insert('c', 2);
bimap.insert('d', 1);
let mut right_range = bimap.right_range(2..4);
assert_eq!(right_range.next_back(), Some((&'b', &3)));
assert_eq!(right_range.next_back(), Some((&'c', &2)));
assert_eq!(right_range.next_back(), None);
}
#[test]
fn clear() {
let mut bimap = BiBTreeMap::from_iter(vec![('a', 1)]);
assert_eq!(bimap.len(), 1);
assert!(!bimap.is_empty());
bimap.clear();
assert_eq!(bimap.len(), 0);
assert!(bimap.is_empty());
}
#[test]
fn get_contains() {
let bimap = BiBTreeMap::from_iter(vec![('a', 1)]);
assert_eq!(bimap.get_by_left(&'a'), Some(&1));
assert!(bimap.contains_left(&'a'));
assert_eq!(bimap.get_by_left(&'b'), None);
assert!(!bimap.contains_left(&'b'));
assert_eq!(bimap.get_by_right(&1), Some(&'a'));
assert!(bimap.contains_right(&1));
assert_eq!(bimap.get_by_right(&2), None);
assert!(!bimap.contains_right(&2));
}
#[test]
fn insert() {
let mut bimap = BiBTreeMap::new();
assert_eq!(bimap.insert('a', 1), Overwritten::Neither);
assert_eq!(bimap.insert('a', 2), Overwritten::Left('a', 1));
assert_eq!(bimap.insert('b', 2), Overwritten::Right('a', 2));
assert_eq!(bimap.insert('b', 2), Overwritten::Pair('b', 2));
assert_eq!(bimap.insert('c', 3), Overwritten::Neither);
assert_eq!(bimap.insert('b', 3), Overwritten::Both(('b', 2), ('c', 3)));
}
#[test]
fn insert_no_overwrite() {
let mut bimap = BiBTreeMap::new();
assert!(bimap.insert_no_overwrite('a', 1).is_ok());
assert!(bimap.insert_no_overwrite('a', 2).is_err());
assert!(bimap.insert_no_overwrite('b', 1).is_err());
}
#[test]
#[cfg(feature = "std")]
fn hash() {
use core::iter::{self, FromIterator};
use std::collections::HashSet;
let mut hashset = HashSet::new();
hashset.insert(BiBTreeMap::new());
hashset.insert(BiBTreeMap::from_iter(iter::once((0, '0'))));
hashset.insert(BiBTreeMap::from_iter(vec![(0, '0'), (0, '1'), (1, '0')]));
hashset.insert(BiBTreeMap::from_iter(vec![(1, '0'), (0, '1'), (0, '0')]));
hashset.insert(BiBTreeMap::from_iter(vec![(0, '0'), (0, '1'), (1, '0')]));
assert_eq!(
hashset,
HashSet::from_iter(vec![
BiBTreeMap::new(),
BiBTreeMap::from_iter(iter::once((0, '0'))),
BiBTreeMap::from_iter(vec![(0, '0'), (1, '0'), (0, '1')]),
])
);
}
}