async_map/
versioned_map.rs

1use std::cell::{Ref, RefCell};
2
3use std::future::{ready, Future};
4
5use std::boxed::Box;
6use std::pin::Pin;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, RwLock};
9use std::task::{Context, Poll, Waker};
10
11use crate::{AsyncKey, AsyncMap, AsyncStorable, FactoryBorrow};
12
13use futures::FutureExt;
14
15use im::HashMap;
16
17use tokio::sync::mpsc::{self, UnboundedSender};
18use tokio::sync::oneshot;
19
20enum MapAction<K: AsyncKey, V: AsyncStorable> {
21    GetOrCreate(
22        K,
23        Box<dyn FactoryBorrow<K, V>>,
24        oneshot::Sender<(V, MapHolder<K, V>)>,
25        Waker,
26    ),
27}
28
29struct MapReturnFuture<K: AsyncKey, V: AsyncStorable, B>
30where
31    B: FactoryBorrow<K, V> + Unpin,
32{
33    update_sender: UnboundedSender<MapAction<K, V>>,
34    key: K,
35    factory: Option<B>,
36    result_sender: Option<oneshot::Sender<(V, MapHolder<K, V>)>>,
37}
38
39impl<'a, K: AsyncKey, V: AsyncStorable, B> Future for MapReturnFuture<K, V, B>
40where
41    B: FactoryBorrow<K, V> + Unpin,
42{
43    type Output = ();
44    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45        let mut mutable = self;
46
47        if mutable.result_sender.is_none() {
48            Poll::Ready(())
49        } else {
50            let result_sender = mutable.result_sender.take().unwrap();
51            match mutable.factory.take() {
52                None => {
53                    todo!()
54                }
55                Some(factory) => {
56                    match mutable.update_sender.send(MapAction::GetOrCreate(
57                        mutable.key.clone(),
58                        Box::new(factory),
59                        result_sender,
60                        cx.waker().clone(),
61                    )) {
62                        Ok(_) => Poll::Pending,
63                        Err(_) => Poll::Pending,
64                    }
65                }
66            }
67        }
68    }
69}
70
71#[derive(Clone)]
72struct MapHolder<K: AsyncKey, V: AsyncStorable> {
73    version: u64,
74    map: HashMap<K, V>,
75}
76
77pub struct VersionedMap<K: AsyncKey, V: AsyncStorable> {
78    latest_version: Arc<AtomicU64>,
79    map_holder: RefCell<MapHolder<K, V>>,
80    update_sender: UnboundedSender<MapAction<K, V>>,
81    update_receiver: UpdateReceiver<K, V>,
82    latest_map_holder: Arc<RwLock<MapHolder<K, V>>>,
83}
84
85struct UpdateReceiver<K: AsyncKey, V: AsyncStorable> {
86    receiver: RefCell<Option<oneshot::Receiver<MapHolder<K, V>>>>,
87}
88
89impl<K: AsyncKey, V: AsyncStorable> Default for UpdateReceiver<K, V> {
90    fn default() -> Self {
91        UpdateReceiver {
92            receiver: RefCell::new(None),
93        }
94    }
95}
96
97impl<K: AsyncKey, V: AsyncStorable> UpdateReceiver<K, V> {
98    pub fn updater(&self) -> MapUpdater<K, V> {
99        let (sender, receiver) = oneshot::channel();
100        // Note that any prior receiver will be lost. Since updates are
101        // linear, that is not an issue
102        self.receiver.replace(Some(receiver));
103        MapUpdater { sender }
104    }
105
106    pub fn get_update(&self) -> Option<MapHolder<K, V>> {
107        self.receiver.take().and_then(|mut receiver| {
108            match receiver.try_recv() {
109                Err(oneshot::error::TryRecvError::Empty) => {
110                    // Not ready yet - put it back
111                    self.receiver.replace(Some(receiver));
112                    None
113                }
114                Err(oneshot::error::TryRecvError::Closed) => {
115                    println!("get_if_present: closed");
116                    std::process::exit(-1);
117                }
118                Ok(holder) => Some(holder),
119            }
120        })
121    }
122}
123
124struct MapUpdater<K: AsyncKey, V: AsyncStorable> {
125    sender: oneshot::Sender<MapHolder<K, V>>,
126}
127
128impl<K: AsyncKey, V: AsyncStorable> MapUpdater<K, V> {
129    pub fn apply(self, new_map: MapHolder<K, V>) {
130        if let Err(_) = self.sender.send(new_map) {
131            // probably the map was alread dropped; ignore
132        }
133    }
134}
135
136impl<K: AsyncKey, V: AsyncStorable> AsyncMap for VersionedMap<K, V> {
137    type Key = K;
138    type Value = V;
139
140    /// Synchronously returns the value associated with the provided key, if present; otherwise None
141    fn get_if_present(&self, key: &Self::Key) -> Option<Self::Value> {
142        self.latest_map().map.get(key).map(V::clone)
143    }
144
145    fn get<'a, 'b, B: FactoryBorrow<K, V>>(
146        &'a self,
147        key: &'a Self::Key,
148        factory: B,
149    ) -> Pin<Box<dyn Future<Output = Self::Value> + Send + 'b>> {
150        match self.get_if_present(key) {
151            Some(x) => Box::pin(ready(x)),
152            None => self.send_update(key.clone(), factory),
153        }
154    }
155}
156
157impl<K: AsyncKey, V: AsyncStorable> Clone for VersionedMap<K, V> {
158    fn clone(&self) -> Self {
159        VersionedMap {
160            latest_version: self.latest_version.clone(),
161            map_holder: self.map_holder.clone(),
162            update_sender: self.update_sender.clone(),
163            update_receiver: UpdateReceiver::default(), // The clone will start the process of listening for updates independently
164            latest_map_holder: self.latest_map_holder.clone(),
165        }
166    }
167}
168
169impl<K: AsyncKey, V: AsyncStorable> VersionedMap<K, V> {
170    pub fn new() -> Self {
171        let (update_sender, mut update_receiver) = mpsc::unbounded_channel();
172
173        let initial_version = 0;
174        let latest_version = Arc::new(AtomicU64::new(initial_version));
175        let map = HashMap::default();
176
177        let map_holder = MapHolder {
178            version: initial_version,
179            map: map.clone(),
180        };
181
182        let current_map_holder = Arc::new(RwLock::new(MapHolder {
183            version: initial_version,
184            map: map,
185        }));
186
187        let non_locking_map: VersionedMap<K, V> = VersionedMap {
188            latest_version: latest_version.clone(),
189            map_holder: RefCell::new(map_holder),
190            update_sender,
191            update_receiver: UpdateReceiver::default(),
192            latest_map_holder: current_map_holder.clone(),
193        };
194
195        Some(tokio::task::spawn(async move {
196            let lockable_map_holder = current_map_holder;
197            while let Some(action) = update_receiver.recv().await {
198                match action {
199                    MapAction::GetOrCreate(key, factory, result_sender, waker) => {
200                        let read_lock = lockable_map_holder.read();
201
202                        let updated = match read_lock {
203                            Err(_) => todo!(),
204                            Ok(map_holder) => VersionedMap::create_if_necessary(
205                                &latest_version,
206                                &map_holder.map,
207                                key,
208                                factory,
209                                result_sender,
210                            ),
211                        }; // Read lock dropped here.
212
213                        if let Some((new_map, new_version)) = updated {
214                            let write_lock = lockable_map_holder.write();
215
216                            match write_lock {
217                                Err(_) => todo!(),
218                                Ok(mut map_holder) => {
219                                    map_holder.version = new_version;
220                                    map_holder.map = new_map;
221                                }
222                            }
223                        }
224
225                        waker.wake();
226                    }
227                }
228            }
229        }));
230
231        non_locking_map
232    }
233
234    fn send_update<'a, 'b, B: FactoryBorrow<K, V>>(
235        &self,
236        key: K,
237        factory: B,
238    ) -> Pin<Box<dyn Future<Output = V> + Send + 'b>> {
239        let (tx, mut rx) = oneshot::channel();
240        let map_updater = self.get_updater();
241
242        self.create_return_future(key, factory, tx)
243            .then(move |_| match rx.try_recv() {
244                Err(_) => {
245                    std::process::exit(-1);
246                }
247                Ok((value, map_holder)) => {
248                    map_updater.apply(map_holder);
249                    ready(value)
250                }
251            })
252            .boxed()
253    }
254
255    fn create_return_future<B: FactoryBorrow<K, V>>(
256        &self,
257        key: K,
258        factory: B,
259        sender: oneshot::Sender<(V, MapHolder<K, V>)>,
260    ) -> MapReturnFuture<K, V, B> {
261        MapReturnFuture {
262            key: key,
263            factory: Some(factory),
264            update_sender: self.update_sender.clone(),
265            result_sender: Some(sender),
266        }
267    }
268
269    fn get_updater(&self) -> MapUpdater<K, V> {
270        self.update_receiver.updater()
271    }
272
273    fn latest_map(&self) -> Ref<MapHolder<K, V>> {
274        let latest_version = self.latest_version.load(Ordering::Acquire);
275
276        // Get any update received from a write op, filtering on version
277        let received_update = self
278            .get_received_update()
279            .filter(|holder| holder.version == latest_version);
280        if let Some(new_map_holder) = received_update {
281            self.map_holder.replace(new_map_holder);
282        } else {
283            let mut current = self.map_holder.borrow_mut();
284
285            if current.version != latest_version {
286                let latest = self.get_latest();
287
288                current.map = latest.map;
289                current.version = latest.version;
290            }
291        }
292
293        self.map_holder.borrow()
294    }
295
296    fn get_received_update(&self) -> Option<MapHolder<K, V>> {
297        self.update_receiver.get_update()
298    }
299
300    fn get_latest(&self) -> MapHolder<K, V> {
301        let lock_result = self.latest_map_holder.read();
302
303        match lock_result {
304            Err(_) => todo!(),
305            Ok(guard) => {
306                let latest_holder = guard.clone();
307                latest_holder
308            }
309        }
310    }
311
312    fn create_if_necessary(
313        latest_version: &Arc<AtomicU64>,
314        map: &HashMap<K, V>,
315        key: K,
316        factory: Box<dyn FactoryBorrow<K, V>>,
317        result_sender: oneshot::Sender<(V, MapHolder<K, V>)>,
318    ) -> Option<(HashMap<K, V>, u64)> {
319        match map.get(&key) {
320            Some(v) => {
321                // nothing to do; probably multiple creates were queued up for the same key
322                if let Err(_) = result_sender.send((
323                    v.clone(),
324                    MapHolder {
325                        version: latest_version.load(Ordering::Acquire),
326                        map: map.clone(),
327                    },
328                )) {
329                    todo!()
330                }
331                None
332            }
333            None => {
334                let value = (*factory).borrow()(&key);
335
336                // println!("Length: {}", map.len());
337                let updated = map.update(key, value.clone());
338
339                // fetch_add returns the prior value!
340                let new_version = latest_version.fetch_add(1, Ordering::AcqRel) + 1;
341
342                if let Err(_) = result_sender.send((
343                    value,
344                    MapHolder {
345                        version: new_version,
346                        map: updated.clone(),
347                    },
348                )) {
349                    todo!()
350                }
351                Some((updated, new_version))
352            }
353        }
354    }
355}
356
357#[cfg(test)]
358mod test {
359
360    use super::VersionedMap;
361    use crate::{AsyncFactory, AsyncMap};
362    #[tokio::test]
363    async fn get_sync() {
364        let map = VersionedMap::<String, String>::new();
365
366        assert_eq!(None, map.get_if_present(&"foo".to_owned()));
367    }
368
369    fn hello_factory(key: &String) -> String {
370        format!("Hello, {}!", key)
371    }
372
373    #[tokio::test]
374    async fn get_sync2() {
375        let map = VersionedMap::<String, String>::new();
376
377        let key = "foo".to_owned();
378
379        let future = map.get(
380            &key,
381            Box::new(hello_factory) as Box<dyn AsyncFactory<String, String>>,
382        );
383
384        assert_eq!(None, map.get_if_present(&key));
385        let value = future.await;
386
387        assert_eq!("Hello, foo!", value);
388        assert_eq!("Hello, foo!", map.get_if_present(&key).unwrap());
389    }
390}