1use anyhow::Context;
30use async_observable::Observable;
31use async_std::sync::Mutex;
32use async_std::task::block_on;
33use std::collections::BTreeMap;
34use std::fmt::Debug;
35use std::hash::Hash;
36use std::ops::{Deref, DerefMut};
37use std::sync::Arc;
38
39#[derive(Clone, Debug)]
62pub struct SubscriptionMap<K, V>(Arc<Mutex<BTreeMap<K, SubscriptionEntry<V>>>>)
63where
64 K: Clone + Debug + Eq + Hash + Ord,
65 V: Clone + Debug;
66
67#[derive(Clone, Debug)]
69struct SubscriptionEntry<V>
70where
71 V: Clone + Debug,
72{
73 observable: Observable<V>,
74 rc: usize,
75}
76
77impl<V> SubscriptionEntry<V>
78where
79 V: Clone + Debug,
80{
81 pub fn new(value: V) -> Self {
82 Self {
83 observable: Observable::new(value),
84 rc: 0,
85 }
86 }
87}
88
89impl<K, V> SubscriptionMap<K, V>
90where
91 K: Clone + Debug + Eq + Hash + Ord,
92 V: Clone + Debug,
93{
94 pub fn new() -> Self {
96 Self(Arc::new(Mutex::new(BTreeMap::new())))
97 }
98
99 pub async fn get_or_insert(&self, key: K, value: V) -> SubscriptionRef<K, V> {
101 let mut map = self.0.lock().await;
102 let entry = {
103 let entry = SubscriptionEntry::new(value);
104 map.entry(key.clone()).or_insert(entry)
105 };
106
107 SubscriptionRef::new(key, self.clone(), entry)
108 }
109
110 #[cfg(test)]
111 async fn snapshot(&self) -> BTreeMap<K, SubscriptionEntry<V>> {
112 self.0.lock().await.deref().clone()
113 }
114
115 async fn remove(&self, key: &K) -> anyhow::Result<()> {
116 let mut map = self.0.lock().await;
117
118 let entry = map
119 .get(key)
120 .with_context(|| format!("unable remove not present key {:?} in {:#?}", key, self))?;
121
122 assert!(
123 entry.rc == 0,
124 "invalid removal of referenced subscription at {:?}",
125 key
126 );
127
128 map.remove(key);
129
130 Ok(())
131 }
132}
133
134impl<K, V> SubscriptionMap<K, V>
135where
136 K: Clone + Debug + Eq + Hash + Ord,
137 V: Clone + Debug + Eq,
138{
139 pub async fn publish_if_changed(&self, key: &K, value: V) -> anyhow::Result<bool> {
158 let mut map = self.0.lock().await;
159 let entry = map
160 .get_mut(key)
161 .with_context(|| format!("unable publish new version of not present key {:?}", key))?;
162
163 Ok(entry.observable.publish_if_changed(value))
164 }
165
166 pub async fn modify_and_publish<F, R>(&self, key: &K, modify: F) -> anyhow::Result<()>
184 where
185 F: FnOnce(&mut V) -> R,
186 {
187 let mut map = self.0.lock().await;
188 let entry = map
189 .get_mut(key)
190 .with_context(|| format!("unable modify not present key {:?}", key))?;
191
192 entry.observable.modify(|v| {
193 modify(v);
194 });
195
196 Ok(())
197 }
198}
199
200impl<K, V> Default for SubscriptionMap<K, V>
201where
202 K: Clone + Debug + Eq + Hash + Ord,
203 V: Clone + Debug,
204{
205 fn default() -> Self {
206 Self::new()
207 }
208}
209
210#[derive(Debug)]
214#[must_use = "entries are removed as soon as no one subscribes to them"]
215pub struct SubscriptionRef<K, V>
216where
217 K: Clone + Debug + Eq + Hash + Ord,
218 V: Clone + Debug,
219{
220 key: K,
221 owner: SubscriptionMap<K, V>,
222 observable: Observable<V>,
223}
224
225impl<K, V> SubscriptionRef<K, V>
226where
227 K: Clone + Debug + Eq + Hash + Ord,
228 V: Clone + Debug,
229{
230 fn new(key: K, owner: SubscriptionMap<K, V>, entry: &mut SubscriptionEntry<V>) -> Self {
231 entry.rc += 1;
232
233 Self {
234 key,
235 owner,
236 observable: entry.observable.clone(),
237 }
238 }
239}
240
241impl<K, V> Deref for SubscriptionRef<K, V>
242where
243 K: Clone + Debug + Eq + Hash + Ord,
244 V: Clone + Debug,
245{
246 type Target = Observable<V>;
247
248 fn deref(&self) -> &Self::Target {
249 &self.observable
250 }
251}
252
253impl<K, V> DerefMut for SubscriptionRef<K, V>
254where
255 K: Clone + Debug + Eq + Hash + Ord,
256 V: Clone + Debug,
257{
258 fn deref_mut(&mut self) -> &mut Self::Target {
259 &mut self.observable
260 }
261}
262
263impl<K, V> Drop for SubscriptionRef<K, V>
264where
265 K: Clone + Debug + Eq + Hash + Ord,
266 V: Clone + Debug,
267{
268 fn drop(&mut self) {
269 log::trace!("drop for subscription ref for key {:?}", self.key);
270
271 let mut map = block_on(self.owner.0.lock());
272 let mut entry = match map.get_mut(&self.key) {
273 Some(entry) => entry,
274 None => {
275 log::error!("could not obtain rc in subscription map {:#?}", map.deref());
276 return;
277 }
278 };
279
280 entry.rc -= 1;
281
282 if entry.rc == 0 {
283 drop(map);
284 let res = block_on(self.owner.remove(&self.key));
285
286 if let Err(e) = res {
287 log::error!("error occurred while cleanup subscription ref {}", e);
288 }
289 }
290 }
291}
292
293#[cfg(test)]
294mod test {
295 use super::SubscriptionMap;
296
297 macro_rules! assert_map_len {
298 ($map:ident, $len:expr) => {
299 assert_eq!($map.snapshot().await.len(), $len);
300 };
301 }
302
303 macro_rules! assert_ref_count {
304 ($map:ident, $key:expr, $rc:expr) => {
305 assert_eq!($map.snapshot().await.get($key).unwrap().rc, $rc);
306 };
307 }
308
309 #[async_std::test]
310 async fn should_immediately_remove_unused() {
311 let map: SubscriptionMap<usize, usize> = SubscriptionMap::new();
312 assert_map_len!(map, 0);
313
314 let _ = map.get_or_insert(1, 1).await;
315 assert_map_len!(map, 0);
316
317 let _ = map.get_or_insert(2, 2).await;
318 assert_map_len!(map, 0);
319 }
320
321 #[async_std::test]
322 async fn should_remove_entries_on_ref_drop() {
323 let map: SubscriptionMap<usize, usize> = SubscriptionMap::new();
324 assert_map_len!(map, 0);
325
326 let ref_one = map.get_or_insert(1, 1).await;
327 assert_map_len!(map, 1);
328
329 let ref_two = map.get_or_insert(2, 2).await;
330 assert_map_len!(map, 2);
331
332 drop(ref_one);
333 assert_map_len!(map, 1);
334 assert!(map.snapshot().await.get(&1).is_none());
335 assert!(map.snapshot().await.get(&2).is_some());
336
337 drop(ref_two);
338 assert_map_len!(map, 0);
339 assert!(map.snapshot().await.get(&1).is_none());
340 assert!(map.snapshot().await.get(&2).is_none());
341 }
342
343 #[async_std::test]
344 async fn should_keep_track_of_ref_count() {
345 let map: SubscriptionMap<usize, usize> = SubscriptionMap::new();
346 assert_map_len!(map, 0);
347
348 let ref_one = map.get_or_insert(1, 1).await;
349 assert_ref_count!(map, &1, 1);
350
351 let ref_two = map.get_or_insert(1, 1).await;
352 assert_ref_count!(map, &1, 2);
353
354 drop(ref_one);
355 assert_ref_count!(map, &1, 1);
356
357 drop(ref_two);
358 assert_map_len!(map, 0);
359 }
360
361 #[async_std::test]
362 #[should_panic]
363 async fn shouldnt_remove_if_rc_is_not_zero() {
364 let map: SubscriptionMap<usize, usize> = SubscriptionMap::new();
365 assert_map_len!(map, 0);
366
367 let _ref = map.get_or_insert(1, 1).await;
368 assert_ref_count!(map, &1, 1);
369
370 map.remove(&1).await.unwrap();
371 }
372}