perigee/data_structures/
bimap.rs

1use serde::{Deserialize, Serialize};
2use std::collections::HashMap;
3use std::hash::Hash;
4use std::ops::Deref;
5use std::rc::Rc;
6
7/// A bidirectional HashMap.
8#[derive(Serialize, Deserialize, Debug, Clone, Default)]
9pub struct BiMap<A, B>
10where
11    A: Eq + Hash + ?Sized,
12    B: Eq + Hash + ?Sized,
13{
14    left_to_right: HashMap<Rc<A>, Rc<B>>,
15    right_to_left: HashMap<Rc<B>, Rc<A>>,
16}
17
18impl<A, B> BiMap<A, B>
19where
20    A: Eq + Hash,
21    B: Eq + Hash,
22{
23    pub fn new() -> Self {
24        BiMap {
25            left_to_right: HashMap::new(),
26            right_to_left: HashMap::new(),
27        }
28    }
29
30    pub fn insert(&mut self, a: A, b: B) {
31        let a = Rc::new(a);
32        let b = Rc::new(b);
33        self.left_to_right.insert(a.clone(), b.clone());
34        self.right_to_left.insert(b, a);
35    }
36
37    pub fn get(&self, a: &A) -> Option<&B> {
38        self.left_to_right.get(a).map(Deref::deref)
39    }
40
41    pub fn get_reverse(&self, b: &B) -> Option<&A> {
42        self.right_to_left.get(b).map(Deref::deref)
43    }
44
45    pub fn remove(&mut self, a: &A) -> bool {
46        self.left_to_right
47            .remove(a)
48            .and_then(|right| self.right_to_left.remove(&*right))
49            .is_some()
50    }
51
52    pub fn remove_reverse(&mut self, b: &B) -> bool {
53        self.right_to_left
54            .remove(b)
55            .and_then(|left| self.left_to_right.remove(&*left))
56            .is_some()
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::*;
63
64    #[test]
65    fn insertion_and_retrieval() {
66        let mut map = BiMap::new();
67        map.insert("hi", 2);
68        assert_eq!(map.get(&"hi"), Some(&2));
69        assert_eq!(map.get_reverse(&2), Some(&"hi"));
70    }
71
72    #[test]
73    fn insertion_and_removal() {
74        let mut map = BiMap::new();
75
76        map.insert("hi", 2);
77        assert_eq!(map.get(&"hi"), Some(&2));
78        assert_eq!(map.remove(&"hi"), true);
79        assert_eq!(map.get(&"hi"), None);
80
81        map.insert("bye", 3);
82        assert_eq!(map.get_reverse(&3), Some(&"bye"));
83        assert_eq!(map.remove_reverse(&3), true);
84        assert_eq!(map.get_reverse(&3), None);
85    }
86}