1use std::{borrow::Borrow, collections::HashSet, hash::Hash, mem::replace};
2
3pub struct ChainSet<T> {
4 pub(crate) sets: Vec<HashSet<T>>,
5}
6
7impl<T: Hash + Eq> ChainSet<T> {
8 pub fn new(set: HashSet<T>) -> Self {
9 Self { sets: vec![set] }
10 }
11
12 pub fn insert(&mut self, value: T) -> bool {
13 if let Some(set) = self.sets.last_mut() {
14 set.insert(value)
15 } else {
16 let mut set = HashSet::new();
17 set.insert(value);
18 self.sets.push(set);
19 false
20 }
21 }
22
23 pub fn get<Q: ?Sized>(&self, value: &Q) -> Option<&T>
24 where
25 T: Borrow<Q>,
26 Q: Hash + Eq,
27 {
28 for set in self.sets.iter().rev() {
29 if let Some(v) = set.get(value) {
30 return Some(v);
31 }
32 }
33 None
34 }
35
36 pub fn new_child(&mut self) {
37 self.sets.push(HashSet::new());
38 }
39
40 pub fn new_child_with(&mut self, map: HashSet<T>) {
41 self.sets.push(map);
42 }
43
44 pub fn remove_child(&mut self) -> Option<HashSet<T>> {
45 if self.sets.len() == 1 {
46 let ret = replace(&mut self.sets[0], HashSet::new());
47 Some(ret)
48 } else {
49 self.sets.pop()
50 }
51 }
52}
53
54impl<T: Hash + Eq> Default for ChainSet<T> {
55 fn default() -> Self {
56 Self {
57 sets: vec![HashSet::new()],
58 }
59 }
60}
61
62#[cfg(test)]
63mod test {
64 use super::*;
65 use std::default::Default;
66
67 #[test]
68 fn initialization() {
69 let mut test_set = HashSet::new();
70 test_set.insert("test");
71 let chain_set = ChainSet::new(test_set);
72
73 assert!(chain_set.sets.len() > 0);
74 assert_eq!(chain_set.sets[0].get("test"), Some(&"test"));
75 }
76
77 #[test]
78 fn initialization_default() {
79 let chain_set: ChainSet<()> = ChainSet::default();
80
81 assert!(chain_set.sets.len() > 0);
82 assert!(chain_set.sets[0].is_empty());
83 }
84
85 #[test]
86 fn insert() {
87 let mut chain_set = ChainSet::default();
88 assert!(chain_set.insert("test"));
89
90 assert_eq!(chain_set.sets[0].get("test"), Some(&"test"));
91 }
92
93 #[test]
94 fn get() {
95 let mut chain_set = ChainSet::default();
96 chain_set.insert("test");
97
98 assert_eq!(chain_set.get(&"test"), Some(&"test"));
99 }
100
101 #[test]
102 fn get_none() {
103 let chain_set: ChainSet<&str> = ChainSet::default();
104 assert_eq!(chain_set.get(&"test"), None);
105 }
106
107 #[test]
108 fn new_child() {
109 let mut chain_set = ChainSet::default();
110 chain_set.insert("test");
111 chain_set.new_child();
112 assert!(chain_set.sets.len() > 1);
113 }
114
115 #[test]
116 #[ignore]
117 fn scopes() {
118 let mut chain_set = ChainSet::default();
119 chain_set.insert("x");
120 chain_set.insert("y");
121 chain_set.new_child();
122 assert!(chain_set.insert("x"));
123 }
124
125 #[test]
126 fn remove_child() {
127 let mut chain_set = ChainSet::default();
128 chain_set.insert("x");
129 chain_set.insert("y");
130 chain_set.new_child();
131 chain_set.insert("x");
132 let ret = chain_set.remove_child().unwrap();
133 assert_eq!(ret.get("x"), Some(&"x"));
134 assert_eq!(chain_set.get("x"), Some(&"x"));
135 }
136
137 #[test]
138 fn remove_child_length_1() {
139 let mut chain_set = ChainSet::default();
140 chain_set.insert("x");
141 let _ = chain_set.remove_child();
142 assert_eq!(chain_set.get("x"), None);
143 assert!(chain_set.sets.len() == 1);
144 }
145}