jp_multimap/
lib.rs

1// Copyright 2018-2019 Joe Neeman.
2//
3// Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
4// http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
5// <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
6// option. This file may not be copied, modified, or distributed
7// except according to those terms.
8//
9// See the LICENSE-APACHE or LICENSE-MIT files at the top-level directory
10// of this distribution.
11
12// This is just a hacked-up multimap. Eventually, we'll need to move to a fully persistent (in the
13// functional-data-structure sense), on-disk multimap.
14
15use serde::de::{SeqAccess, Visitor};
16use serde::ser::SerializeSeq;
17use serde::{Deserialize, Deserializer, Serialize, Serializer};
18use std::borrow::Borrow;
19use std::collections::{BTreeMap, BTreeSet};
20
21#[derive(Clone, Debug, PartialEq)]
22pub struct MMap<K: Ord, V: Ord> {
23    map: BTreeMap<K, BTreeSet<V>>,
24    // hackity
25    empty_set: BTreeSet<V>,
26}
27
28impl<K: Ord, V: Ord> Default for MMap<K, V> {
29    fn default() -> MMap<K, V> {
30        MMap::new()
31    }
32}
33
34impl<K: Ord, V: Ord> MMap<K, V> {
35    pub fn new() -> MMap<K, V> {
36        MMap {
37            map: BTreeMap::new(),
38            empty_set: BTreeSet::new(),
39        }
40    }
41
42    /// Returns an iterator over all the values associated with this key.
43    // FIXME: I don't understand why the one with the Box works, but the one without gives lifetime
44    // errors downstream.
45    //pub fn get<Q>(&'_ self, key: &Q) -> impl Iterator<Item = &'_ V> + '_
46    pub fn get<Q>(&'_ self, key: &Q) -> Box<dyn Iterator<Item = &'_ V> + '_>
47    where
48        K: Borrow<Q>,
49        Q: Ord + ?Sized,
50    {
51        Box::new(self.map.get(key).unwrap_or(&self.empty_set).iter())
52    }
53
54    /// Returns an iterator over all the values associated with this key and that are greater than
55    /// or equal to `val`.
56    pub fn get_from<Q, R>(&'_ self, key: &Q, val: &R) -> Box<dyn Iterator<Item = &'_ V> + '_>
57    where
58        K: Borrow<Q>,
59        Q: Ord + ?Sized,
60        V: Borrow<R>,
61        R: Ord, // I'm not sure why R has to be Sized here...
62    {
63        Box::new(self.map.get(key).unwrap_or(&self.empty_set).range(val..))
64    }
65
66    pub fn insert(&mut self, key: K, val: V) {
67        self.map
68            .entry(key)
69            .or_insert_with(BTreeSet::new)
70            .insert(val);
71    }
72
73    pub fn remove<Q, R>(&mut self, key: &Q, val: &R) -> bool
74    where
75        K: Borrow<Q>,
76        Q: Ord + ?Sized,
77        V: Borrow<R>,
78        R: Ord + ?Sized,
79    {
80        if let Some(set) = self.map.get_mut(&key) {
81            let ret = set.remove(val);
82            // Remove empty sets entirely. Partly because it seems reasonable to get rid of unused
83            // entries, but mostly because it makes the auto-derived PartialEq implementation
84            // correct.
85            if set.is_empty() {
86                self.map.remove(key);
87            }
88            ret
89        } else {
90            false
91        }
92    }
93
94    pub fn remove_all<Q>(&mut self, key: &Q)
95    where
96        K: Borrow<Q>,
97        Q: Ord + ?Sized,
98    {
99        self.map.remove(key);
100    }
101
102    pub fn contains<Q, R>(&self, key: &Q, val: &R) -> bool
103    where
104        K: Borrow<Q>,
105        Q: Ord + ?Sized,
106        V: Borrow<R>,
107        R: Ord + ?Sized,
108    {
109        self.map.get(key).map(|bindings| bindings.contains(val)) == Some(true)
110    }
111
112    pub fn iter(&self) -> impl Iterator<Item = (&K, &V)> {
113        self.map
114            .iter()
115            .flat_map(|(k, vs)| vs.iter().map(move |v| (k, v)))
116    }
117}
118
119impl<K: Ord + Serialize, V: Ord + Serialize> Serialize for MMap<K, V> {
120    fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
121        let mut seq = serializer.serialize_seq(None)?;
122        for (k, v) in self.iter() {
123            seq.serialize_element(&(k, v))?;
124        }
125        seq.end()
126    }
127}
128
129impl<'de, K: Ord + Deserialize<'de>, V: Ord + Deserialize<'de>> Deserialize<'de> for MMap<K, V> {
130    fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
131        deserializer.deserialize_seq(MMapVisitor {
132            x: std::marker::PhantomData,
133        })
134    }
135}
136
137struct MMapVisitor<K, V> {
138    x: std::marker::PhantomData<(K, V)>,
139}
140
141impl<'de, K: Ord + Deserialize<'de>, V: Ord + Deserialize<'de>> Visitor<'de> for MMapVisitor<K, V> {
142    type Value = MMap<K, V>;
143
144    fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        write!(formatter, "a sequence of tuples")
146    }
147
148    fn visit_seq<S: SeqAccess<'de>>(self, mut access: S) -> Result<Self::Value, S::Error> {
149        let mut ret = MMap::new();
150        while let Some((key, val)) = access.next_element()? {
151            ret.insert(key, val);
152        }
153        Ok(ret)
154    }
155}
156
157#[cfg(test)]
158mod tests {
159    use super::MMap;
160
161    #[test]
162    fn get_empty() {
163        let mut map = MMap::new();
164        assert!(map.get(&1).next().is_none());
165        map.insert(1, 2);
166        assert!(map.get(&1).next().is_some());
167        assert!(map.get(&2).next().is_none());
168    }
169
170    #[test]
171    fn get_many() {
172        let mut map = MMap::new();
173        map.insert(1, 2);
174        map.insert(1, 3);
175        map.insert(1, 2);
176        map.insert(1, 1);
177        assert_eq!(map.get(&1).cloned().collect::<Vec<_>>(), vec![1, 2, 3]);
178    }
179
180    #[test]
181    fn contains() {
182        let mut map = MMap::new();
183        map.insert(1, 2);
184        map.insert(1, 3);
185        assert!(map.contains(&1, &2));
186        assert!(!map.contains(&2, &1));
187        assert!(!map.contains(&1, &4));
188    }
189
190    #[test]
191    fn serde() {
192        let mut map = MMap::new();
193        map.insert(1, 2);
194        map.insert(1, 3);
195
196        let mut buf = Vec::new();
197        serde_yaml::to_writer(&mut buf, &map).unwrap();
198        let map2: MMap<_, _> = serde_yaml::from_reader(&buf[..]).unwrap();
199        assert_eq!(map, map2);
200    }
201}