1use std::cell::{Ref, RefCell};
2
3use std::future::{ready, Future};
4
5use std::boxed::Box;
6use std::pin::Pin;
7use std::sync::atomic::{AtomicU64, Ordering};
8use std::sync::{Arc, RwLock};
9use std::task::{Context, Poll, Waker};
10
11use crate::{AsyncKey, AsyncMap, AsyncStorable, FactoryBorrow};
12
13use futures::FutureExt;
14
15use im::HashMap;
16
17use tokio::sync::mpsc::{self, UnboundedSender};
18use tokio::sync::oneshot;
19
20enum MapAction<K: AsyncKey, V: AsyncStorable> {
21 GetOrCreate(
22 K,
23 Box<dyn FactoryBorrow<K, V>>,
24 oneshot::Sender<(V, MapHolder<K, V>)>,
25 Waker,
26 ),
27}
28
29struct MapReturnFuture<K: AsyncKey, V: AsyncStorable, B>
30where
31 B: FactoryBorrow<K, V> + Unpin,
32{
33 update_sender: UnboundedSender<MapAction<K, V>>,
34 key: K,
35 factory: Option<B>,
36 result_sender: Option<oneshot::Sender<(V, MapHolder<K, V>)>>,
37}
38
39impl<'a, K: AsyncKey, V: AsyncStorable, B> Future for MapReturnFuture<K, V, B>
40where
41 B: FactoryBorrow<K, V> + Unpin,
42{
43 type Output = ();
44 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
45 let mut mutable = self;
46
47 if mutable.result_sender.is_none() {
48 Poll::Ready(())
49 } else {
50 let result_sender = mutable.result_sender.take().unwrap();
51 match mutable.factory.take() {
52 None => {
53 todo!()
54 }
55 Some(factory) => {
56 match mutable.update_sender.send(MapAction::GetOrCreate(
57 mutable.key.clone(),
58 Box::new(factory),
59 result_sender,
60 cx.waker().clone(),
61 )) {
62 Ok(_) => Poll::Pending,
63 Err(_) => Poll::Pending,
64 }
65 }
66 }
67 }
68 }
69}
70
71#[derive(Clone)]
72struct MapHolder<K: AsyncKey, V: AsyncStorable> {
73 version: u64,
74 map: HashMap<K, V>,
75}
76
77pub struct VersionedMap<K: AsyncKey, V: AsyncStorable> {
78 latest_version: Arc<AtomicU64>,
79 map_holder: RefCell<MapHolder<K, V>>,
80 update_sender: UnboundedSender<MapAction<K, V>>,
81 update_receiver: UpdateReceiver<K, V>,
82 latest_map_holder: Arc<RwLock<MapHolder<K, V>>>,
83}
84
85struct UpdateReceiver<K: AsyncKey, V: AsyncStorable> {
86 receiver: RefCell<Option<oneshot::Receiver<MapHolder<K, V>>>>,
87}
88
89impl<K: AsyncKey, V: AsyncStorable> Default for UpdateReceiver<K, V> {
90 fn default() -> Self {
91 UpdateReceiver {
92 receiver: RefCell::new(None),
93 }
94 }
95}
96
97impl<K: AsyncKey, V: AsyncStorable> UpdateReceiver<K, V> {
98 pub fn updater(&self) -> MapUpdater<K, V> {
99 let (sender, receiver) = oneshot::channel();
100 self.receiver.replace(Some(receiver));
103 MapUpdater { sender }
104 }
105
106 pub fn get_update(&self) -> Option<MapHolder<K, V>> {
107 self.receiver.take().and_then(|mut receiver| {
108 match receiver.try_recv() {
109 Err(oneshot::error::TryRecvError::Empty) => {
110 self.receiver.replace(Some(receiver));
112 None
113 }
114 Err(oneshot::error::TryRecvError::Closed) => {
115 println!("get_if_present: closed");
116 std::process::exit(-1);
117 }
118 Ok(holder) => Some(holder),
119 }
120 })
121 }
122}
123
124struct MapUpdater<K: AsyncKey, V: AsyncStorable> {
125 sender: oneshot::Sender<MapHolder<K, V>>,
126}
127
128impl<K: AsyncKey, V: AsyncStorable> MapUpdater<K, V> {
129 pub fn apply(self, new_map: MapHolder<K, V>) {
130 if let Err(_) = self.sender.send(new_map) {
131 }
133 }
134}
135
136impl<K: AsyncKey, V: AsyncStorable> AsyncMap for VersionedMap<K, V> {
137 type Key = K;
138 type Value = V;
139
140 fn get_if_present(&self, key: &Self::Key) -> Option<Self::Value> {
142 self.latest_map().map.get(key).map(V::clone)
143 }
144
145 fn get<'a, 'b, B: FactoryBorrow<K, V>>(
146 &'a self,
147 key: &'a Self::Key,
148 factory: B,
149 ) -> Pin<Box<dyn Future<Output = Self::Value> + Send + 'b>> {
150 match self.get_if_present(key) {
151 Some(x) => Box::pin(ready(x)),
152 None => self.send_update(key.clone(), factory),
153 }
154 }
155}
156
157impl<K: AsyncKey, V: AsyncStorable> Clone for VersionedMap<K, V> {
158 fn clone(&self) -> Self {
159 VersionedMap {
160 latest_version: self.latest_version.clone(),
161 map_holder: self.map_holder.clone(),
162 update_sender: self.update_sender.clone(),
163 update_receiver: UpdateReceiver::default(), latest_map_holder: self.latest_map_holder.clone(),
165 }
166 }
167}
168
169impl<K: AsyncKey, V: AsyncStorable> VersionedMap<K, V> {
170 pub fn new() -> Self {
171 let (update_sender, mut update_receiver) = mpsc::unbounded_channel();
172
173 let initial_version = 0;
174 let latest_version = Arc::new(AtomicU64::new(initial_version));
175 let map = HashMap::default();
176
177 let map_holder = MapHolder {
178 version: initial_version,
179 map: map.clone(),
180 };
181
182 let current_map_holder = Arc::new(RwLock::new(MapHolder {
183 version: initial_version,
184 map: map,
185 }));
186
187 let non_locking_map: VersionedMap<K, V> = VersionedMap {
188 latest_version: latest_version.clone(),
189 map_holder: RefCell::new(map_holder),
190 update_sender,
191 update_receiver: UpdateReceiver::default(),
192 latest_map_holder: current_map_holder.clone(),
193 };
194
195 Some(tokio::task::spawn(async move {
196 let lockable_map_holder = current_map_holder;
197 while let Some(action) = update_receiver.recv().await {
198 match action {
199 MapAction::GetOrCreate(key, factory, result_sender, waker) => {
200 let read_lock = lockable_map_holder.read();
201
202 let updated = match read_lock {
203 Err(_) => todo!(),
204 Ok(map_holder) => VersionedMap::create_if_necessary(
205 &latest_version,
206 &map_holder.map,
207 key,
208 factory,
209 result_sender,
210 ),
211 }; if let Some((new_map, new_version)) = updated {
214 let write_lock = lockable_map_holder.write();
215
216 match write_lock {
217 Err(_) => todo!(),
218 Ok(mut map_holder) => {
219 map_holder.version = new_version;
220 map_holder.map = new_map;
221 }
222 }
223 }
224
225 waker.wake();
226 }
227 }
228 }
229 }));
230
231 non_locking_map
232 }
233
234 fn send_update<'a, 'b, B: FactoryBorrow<K, V>>(
235 &self,
236 key: K,
237 factory: B,
238 ) -> Pin<Box<dyn Future<Output = V> + Send + 'b>> {
239 let (tx, mut rx) = oneshot::channel();
240 let map_updater = self.get_updater();
241
242 self.create_return_future(key, factory, tx)
243 .then(move |_| match rx.try_recv() {
244 Err(_) => {
245 std::process::exit(-1);
246 }
247 Ok((value, map_holder)) => {
248 map_updater.apply(map_holder);
249 ready(value)
250 }
251 })
252 .boxed()
253 }
254
255 fn create_return_future<B: FactoryBorrow<K, V>>(
256 &self,
257 key: K,
258 factory: B,
259 sender: oneshot::Sender<(V, MapHolder<K, V>)>,
260 ) -> MapReturnFuture<K, V, B> {
261 MapReturnFuture {
262 key: key,
263 factory: Some(factory),
264 update_sender: self.update_sender.clone(),
265 result_sender: Some(sender),
266 }
267 }
268
269 fn get_updater(&self) -> MapUpdater<K, V> {
270 self.update_receiver.updater()
271 }
272
273 fn latest_map(&self) -> Ref<MapHolder<K, V>> {
274 let latest_version = self.latest_version.load(Ordering::Acquire);
275
276 let received_update = self
278 .get_received_update()
279 .filter(|holder| holder.version == latest_version);
280 if let Some(new_map_holder) = received_update {
281 self.map_holder.replace(new_map_holder);
282 } else {
283 let mut current = self.map_holder.borrow_mut();
284
285 if current.version != latest_version {
286 let latest = self.get_latest();
287
288 current.map = latest.map;
289 current.version = latest.version;
290 }
291 }
292
293 self.map_holder.borrow()
294 }
295
296 fn get_received_update(&self) -> Option<MapHolder<K, V>> {
297 self.update_receiver.get_update()
298 }
299
300 fn get_latest(&self) -> MapHolder<K, V> {
301 let lock_result = self.latest_map_holder.read();
302
303 match lock_result {
304 Err(_) => todo!(),
305 Ok(guard) => {
306 let latest_holder = guard.clone();
307 latest_holder
308 }
309 }
310 }
311
312 fn create_if_necessary(
313 latest_version: &Arc<AtomicU64>,
314 map: &HashMap<K, V>,
315 key: K,
316 factory: Box<dyn FactoryBorrow<K, V>>,
317 result_sender: oneshot::Sender<(V, MapHolder<K, V>)>,
318 ) -> Option<(HashMap<K, V>, u64)> {
319 match map.get(&key) {
320 Some(v) => {
321 if let Err(_) = result_sender.send((
323 v.clone(),
324 MapHolder {
325 version: latest_version.load(Ordering::Acquire),
326 map: map.clone(),
327 },
328 )) {
329 todo!()
330 }
331 None
332 }
333 None => {
334 let value = (*factory).borrow()(&key);
335
336 let updated = map.update(key, value.clone());
338
339 let new_version = latest_version.fetch_add(1, Ordering::AcqRel) + 1;
341
342 if let Err(_) = result_sender.send((
343 value,
344 MapHolder {
345 version: new_version,
346 map: updated.clone(),
347 },
348 )) {
349 todo!()
350 }
351 Some((updated, new_version))
352 }
353 }
354 }
355}
356
357#[cfg(test)]
358mod test {
359
360 use super::VersionedMap;
361 use crate::{AsyncFactory, AsyncMap};
362 #[tokio::test]
363 async fn get_sync() {
364 let map = VersionedMap::<String, String>::new();
365
366 assert_eq!(None, map.get_if_present(&"foo".to_owned()));
367 }
368
369 fn hello_factory(key: &String) -> String {
370 format!("Hello, {}!", key)
371 }
372
373 #[tokio::test]
374 async fn get_sync2() {
375 let map = VersionedMap::<String, String>::new();
376
377 let key = "foo".to_owned();
378
379 let future = map.get(
380 &key,
381 Box::new(hello_factory) as Box<dyn AsyncFactory<String, String>>,
382 );
383
384 assert_eq!(None, map.get_if_present(&key));
385 let value = future.await;
386
387 assert_eq!("Hello, foo!", value);
388 assert_eq!("Hello, foo!", map.get_if_present(&key).unwrap());
389 }
390}