1use serde::de::{SeqAccess, Visitor};
16use serde::ser::SerializeSeq;
17use serde::{Deserialize, Deserializer, Serialize, Serializer};
18use std::borrow::Borrow;
19use std::collections::{BTreeMap, BTreeSet};
20
21#[derive(Clone, Debug, PartialEq)]
22pub struct MMap<K: Ord, V: Ord> {
23 map: BTreeMap<K, BTreeSet<V>>,
24 empty_set: BTreeSet<V>,
26}
27
28impl<K: Ord, V: Ord> Default for MMap<K, V> {
29 fn default() -> MMap<K, V> {
30 MMap::new()
31 }
32}
33
34impl<K: Ord, V: Ord> MMap<K, V> {
35 pub fn new() -> MMap<K, V> {
36 MMap {
37 map: BTreeMap::new(),
38 empty_set: BTreeSet::new(),
39 }
40 }
41
42 pub fn get<Q>(&'_ self, key: &Q) -> Box<dyn Iterator<Item = &'_ V> + '_>
47 where
48 K: Borrow<Q>,
49 Q: Ord + ?Sized,
50 {
51 Box::new(self.map.get(key).unwrap_or(&self.empty_set).iter())
52 }
53
54 pub fn get_from<Q, R>(&'_ self, key: &Q, val: &R) -> Box<dyn Iterator<Item = &'_ V> + '_>
57 where
58 K: Borrow<Q>,
59 Q: Ord + ?Sized,
60 V: Borrow<R>,
61 R: Ord, {
63 Box::new(self.map.get(key).unwrap_or(&self.empty_set).range(val..))
64 }
65
66 pub fn insert(&mut self, key: K, val: V) {
67 self.map
68 .entry(key)
69 .or_insert_with(BTreeSet::new)
70 .insert(val);
71 }
72
73 pub fn remove<Q, R>(&mut self, key: &Q, val: &R) -> bool
74 where
75 K: Borrow<Q>,
76 Q: Ord + ?Sized,
77 V: Borrow<R>,
78 R: Ord + ?Sized,
79 {
80 if let Some(set) = self.map.get_mut(&key) {
81 let ret = set.remove(val);
82 if set.is_empty() {
86 self.map.remove(key);
87 }
88 ret
89 } else {
90 false
91 }
92 }
93
94 pub fn remove_all<Q>(&mut self, key: &Q)
95 where
96 K: Borrow<Q>,
97 Q: Ord + ?Sized,
98 {
99 self.map.remove(key);
100 }
101
102 pub fn contains<Q, R>(&self, key: &Q, val: &R) -> bool
103 where
104 K: Borrow<Q>,
105 Q: Ord + ?Sized,
106 V: Borrow<R>,
107 R: Ord + ?Sized,
108 {
109 self.map.get(key).map(|bindings| bindings.contains(val)) == Some(true)
110 }
111
112 pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
113 self.map
114 .iter()
115 .flat_map(|(k, vs)| vs.iter().map(move |v| (k, v)))
116 }
117}
118
119impl<K: Ord + Serialize, V: Ord + Serialize> Serialize for MMap<K, V> {
120 fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
121 let mut seq = serializer.serialize_seq(None)?;
122 for (k, v) in self.iter() {
123 seq.serialize_element(&(k, v))?;
124 }
125 seq.end()
126 }
127}
128
129impl<'de, K: Ord + Deserialize<'de>, V: Ord + Deserialize<'de>> Deserialize<'de> for MMap<K, V> {
130 fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
131 deserializer.deserialize_seq(MMapVisitor {
132 x: std::marker::PhantomData,
133 })
134 }
135}
136
137struct MMapVisitor<K, V> {
138 x: std::marker::PhantomData<(K, V)>,
139}
140
141impl<'de, K: Ord + Deserialize<'de>, V: Ord + Deserialize<'de>> Visitor<'de> for MMapVisitor<K, V> {
142 type Value = MMap<K, V>;
143
144 fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 write!(formatter, "a sequence of tuples")
146 }
147
148 fn visit_seq<S: SeqAccess<'de>>(self, mut access: S) -> Result<Self::Value, S::Error> {
149 let mut ret = MMap::new();
150 while let Some((key, val)) = access.next_element()? {
151 ret.insert(key, val);
152 }
153 Ok(ret)
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::MMap;
160
161 #[test]
162 fn get_empty() {
163 let mut map = MMap::new();
164 assert!(map.get(&1).next().is_none());
165 map.insert(1, 2);
166 assert!(map.get(&1).next().is_some());
167 assert!(map.get(&2).next().is_none());
168 }
169
170 #[test]
171 fn get_many() {
172 let mut map = MMap::new();
173 map.insert(1, 2);
174 map.insert(1, 3);
175 map.insert(1, 2);
176 map.insert(1, 1);
177 assert_eq!(map.get(&1).cloned().collect::<Vec<_>>(), vec![1, 2, 3]);
178 }
179
180 #[test]
181 fn contains() {
182 let mut map = MMap::new();
183 map.insert(1, 2);
184 map.insert(1, 3);
185 assert!(map.contains(&1, &2));
186 assert!(!map.contains(&2, &1));
187 assert!(!map.contains(&1, &4));
188 }
189
190 #[test]
191 fn serde() {
192 let mut map = MMap::new();
193 map.insert(1, 2);
194 map.insert(1, 3);
195
196 let mut buf = Vec::new();
197 serde_yaml::to_writer(&mut buf, &map).unwrap();
198 let map2: MMap<_, _> = serde_yaml::from_reader(&buf[..]).unwrap();
199 assert_eq!(map, map2);
200 }
201}