mpmc_map/
lib.rs

1use std::{borrow::Borrow, sync::Arc};
2
3use arc_swap::{ArcSwap, Guard};
4use im::HashMap;
5use std::hash::Hash;
6use tokio::sync::{mpsc, oneshot};
7
8pub type AtomicOp<K, V> = Box<dyn FnOnce(&HashMap<K, V>) -> Option<HashMap<K, V>> + Send>;
9
10enum MpmcMapMutationOp<
11    K: Send + Sync + Hash + Clone + Eq + 'static,
12    V: Send + Clone + Sync + 'static,
13> {
14    Reset(Arc<HashMap<K, V>>),
15    Atomic(AtomicOp<K, V>),
16    Insert(K, V),
17    Remove(K),
18}
19
20enum MpmcMapMutationResponse<V: Send + Clone + Sync + 'static> {
21    None,
22    Bool(bool),
23    Value(V),
24}
25
26struct MpmcMapMutation<
27    K: Send + Sync + Hash + Clone + Eq + 'static,
28    V: Send + Clone + Sync + 'static,
29> {
30    op: MpmcMapMutationOp<K, V>,
31    response: oneshot::Sender<MpmcMapMutationResponse<V>>,
32}
33
34#[derive(Debug)]
35pub struct MpmcMap<K: Send + Sync + Hash + Clone + Eq + 'static, V: Send + Clone + Sync + 'static> {
36    inner: Arc<ArcSwap<HashMap<K, V>>>,
37    sender: mpsc::Sender<MpmcMapMutation<K, V>>,
38}
39
40impl<K: Send + Sync + Hash + Clone + Eq + 'static, V: Send + Clone + Sync + 'static> Default
41    for MpmcMap<K, V>
42{
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl<K: Send + Sync + Hash + Clone + Eq + 'static, V: Send + Clone + Sync + 'static> Clone
49    for MpmcMap<K, V>
50{
51    fn clone(&self) -> Self {
52        Self {
53            inner: self.inner.clone(),
54            sender: self.sender.clone(),
55        }
56    }
57}
58
59impl<K: Send + Sync + Hash + Clone + Eq + 'static, V: Send + Clone + Sync + 'static> MpmcMap<K, V> {
60    pub fn new() -> Self {
61        let (sender, receiver) = mpsc::channel(512);
62
63        let new_self = MpmcMap {
64            inner: Arc::new(ArcSwap::new(Arc::new(HashMap::new()))),
65            sender,
66        };
67        tokio::spawn(Self::updater(new_self.inner.clone(), receiver));
68        new_self
69    }
70
71    async fn updater(
72        map: Arc<ArcSwap<HashMap<K, V>>>,
73        mut receiver: mpsc::Receiver<MpmcMapMutation<K, V>>,
74    ) {
75        while let Some(mutation) = receiver.recv().await {
76            match mutation.op {
77                MpmcMapMutationOp::Insert(key, value) => {
78                    let map_load = map.load();
79                    if let Some((old_value, prior)) = map_load.extract(&key) {
80                        let new_map = prior.update(key, value);
81                        map.store(Arc::new(new_map));
82                        mutation
83                            .response
84                            .send(MpmcMapMutationResponse::Value(old_value))
85                            .ok();
86                    } else {
87                        let new_map = map_load.update(key, value);
88                        map.store(Arc::new(new_map));
89                        mutation.response.send(MpmcMapMutationResponse::None).ok();
90                    }
91                }
92                MpmcMapMutationOp::Remove(key) => {
93                    if let Some((old_value, new_map)) = map.load().extract(&key) {
94                        map.store(Arc::new(new_map));
95                        mutation
96                            .response
97                            .send(MpmcMapMutationResponse::Value(old_value))
98                            .ok();
99                    } else {
100                        mutation.response.send(MpmcMapMutationResponse::None).ok();
101                    }
102                }
103                MpmcMapMutationOp::Reset(value) => {
104                    map.store(value);
105                    mutation.response.send(MpmcMapMutationResponse::None).ok();
106                }
107                MpmcMapMutationOp::Atomic(op) => {
108                    let map_load = map.load();
109                    let new_map = op(&**map_load);
110                    let mutated = new_map.is_some();
111                    if let Some(new_map) = new_map {
112                        map.store(Arc::new(new_map));
113                    }
114                    mutation.response.send(MpmcMapMutationResponse::Bool(mutated)).ok();
115                }
116            }
117        }
118    }
119
120    pub async fn insert(&self, key: K, value: V) -> Option<V> {
121        let (response, receiver) = oneshot::channel::<MpmcMapMutationResponse<V>>();
122        self.sender
123            .send(MpmcMapMutation {
124                op: MpmcMapMutationOp::Insert(key, value),
125                response,
126            })
127            .await
128            .ok()
129            .expect("failed to send insert mutation");
130        match receiver
131            .await
132            .expect("failed to receive mpmc map mutation response")
133        {
134            MpmcMapMutationResponse::None => None,
135            MpmcMapMutationResponse::Bool(_) => None,
136            MpmcMapMutationResponse::Value(v) => Some(v),
137        }
138    }
139
140    pub async fn remove(&self, key: K) -> Option<V> {
141        let (response, receiver) = oneshot::channel::<MpmcMapMutationResponse<V>>();
142        self.sender
143            .send(MpmcMapMutation {
144                op: MpmcMapMutationOp::Remove(key),
145                response,
146            })
147            .await
148            .ok()
149            .expect("failed to send insert mutation");
150        match receiver
151            .await
152            .expect("failed to receive mpmc map mutation response")
153        {
154            MpmcMapMutationResponse::None => None,
155            MpmcMapMutationResponse::Bool(_) => None,
156            MpmcMapMutationResponse::Value(v) => Some(v),
157        }
158    }
159
160    pub fn get<BK: ?Sized>(&self, key: &BK) -> Option<V>
161    where
162        BK: Hash + Eq,
163        K: Borrow<BK>,
164    {
165        self.inner.load().get(key).cloned()
166    }
167
168    pub fn contains_key<BK: ?Sized>(&self, key: &BK) -> bool
169    where
170        BK: Hash + Eq,
171        K: Borrow<BK>,
172    {
173        self.inner.load().contains_key(key)
174    }
175
176    pub fn inner_full(&self) -> Arc<HashMap<K, V>> {
177        self.inner.load_full()
178    }
179
180    pub fn inner(&self) -> Guard<Arc<HashMap<K, V>>> {
181        self.inner.load()
182    }
183
184    // pending updates will be applied to the new value
185    // this function should generally be used sparingly, as it can be hard to ensure correct semantics
186    // look at `MpmcMap::reset` instead
187    #[doc(hidden)]
188    pub fn reset_now(&self, value: Arc<HashMap<K, V>>) {
189        self.inner.store(value);
190    }
191
192    // this function will apply pending updates before reseting the internal map. Those updates will be lost, but will not mutate the new version of the hashmap, `value`.
193    pub async fn reset(&self, value: Arc<HashMap<K, V>>) {
194        let (response, receiver) = oneshot::channel::<MpmcMapMutationResponse<V>>();
195        self.sender
196            .send(MpmcMapMutation {
197                op: MpmcMapMutationOp::Reset(value),
198                response,
199            })
200            .await
201            .ok()
202            .expect("failed to send insert mutation");
203        receiver
204            .await
205            .expect("failed to receive mpmc map mutation response");
206    }
207
208    // performs an atomic mutation, returns true if a mutation took place (`op` returned `Some`)
209    pub async fn atomic(&self, op: AtomicOp<K, V>) -> bool {
210        let (response, receiver) = oneshot::channel::<MpmcMapMutationResponse<V>>();
211        self.sender
212            .send(MpmcMapMutation {
213                op: MpmcMapMutationOp::Atomic(op),
214                response,
215            })
216            .await
217            .ok()
218            .expect("failed to send insert mutation");
219        match receiver
220            .await
221            .expect("failed to receive mpmc map mutation response")
222        {
223            MpmcMapMutationResponse::None => false,
224            MpmcMapMutationResponse::Bool(b) => b,
225            MpmcMapMutationResponse::Value(_) => false,
226        }
227    }
228
229    pub fn len(&self) -> usize {
230        self.inner().len()
231    }
232
233    pub fn is_empty(&self) -> bool {
234        self.inner().is_empty()
235    }
236}