callback_result/
lib.rs

1use std::collections::HashMap;
2use std::fmt::{Display, Formatter};
3use std::future::Future;
4use std::hash::Hash;
5use std::pin::Pin;
6use std::sync::{Mutex};
7use notify_future::{Notify};
8
9#[derive(Debug, Eq, PartialEq, Copy, Clone)]
10pub enum WaiterError {
11    AlreadyExist,
12    Timeout,
13    NoWaiter,
14}
15
16impl Display for WaiterError {
17    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
18        match self {
19            WaiterError::AlreadyExist => write!(f, "AlreadyExist"),
20            WaiterError::Timeout => write!(f, "Timeout"),
21            WaiterError::NoWaiter => write!(f, "NoWaiter"),
22        }
23    }
24}
25
26impl std::error::Error for WaiterError {
27
28}
29pub type WaiterResult<T> = Result<T, WaiterError>;
30
31pub struct ResultFuture<'a, R> {
32    future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>,
33}
34
35impl <'a, R> ResultFuture<'a, R> {
36    pub fn new(future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>) -> Self {
37        Self {
38            future,
39        }
40    }
41}
42
43impl <'a, R> Future for ResultFuture<'a, R> {
44    type Output = Result<R, WaiterError>;
45
46    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
47        self.get_mut().future.as_mut().poll(cx)
48    }
49}
50
51struct CallbackWaiterState<K, R> {
52    result_notifies: HashMap<K, Option<Notify<R>>>,
53    result_cache: HashMap<K, Vec<R>>,
54}
55pub struct CallbackWaiter<K, R> {
56    state: Mutex<CallbackWaiterState<K, R>>,
57}
58
59impl <K: Hash + Eq + Clone + 'static + Send, R: 'static + Send> CallbackWaiter<K, R> {
60    pub fn new() -> Self {
61        Self {
62            state: Mutex::new(CallbackWaiterState {
63                result_notifies: HashMap::new(),
64                result_cache: HashMap::new(),
65            })
66        }
67    }
68
69    pub fn create_result_future(&self, callback_id: K) -> WaiterResult<ResultFuture<R>> {
70        let waiter = {
71            let mut state = self.state.lock().unwrap();
72            let notifies = state.result_notifies.get(&callback_id);
73            if let Some(notifies) = notifies {
74                if let Some(notifies) = notifies {
75                    if!notifies.is_canceled() {
76                        return Err(WaiterError::AlreadyExist);
77                    }
78                }
79            }
80            if let Some(result) = state.result_cache.get_mut(&callback_id) {
81                if result.len() > 0 {
82                    let ret = result.remove(0);
83                    return Ok(ResultFuture::new(Box::pin(async move {
84                        Ok(ret)
85                    })));
86                }
87            }
88
89            let (notify, waiter) = Notify::new();
90            state.result_notifies.insert(callback_id.clone(), Some(notify));
91            waiter
92        };
93
94        Ok(ResultFuture::new(Box::pin(async move {
95            let ret = waiter.await;
96            {
97                let mut state = self.state.lock().unwrap();
98                state.result_notifies.remove(&callback_id);
99            }
100            Ok(ret)
101        })))
102    }
103
104    pub fn create_timeout_result_future(&self, callback_id: K, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
105        let waiter = {
106            let mut state = self.state.lock().unwrap();
107            let notifies = state.result_notifies.get(&callback_id);
108            if let Some(notifies) = notifies {
109                if let Some(notifies) = notifies {
110                    if!notifies.is_canceled() {
111                        return Err(WaiterError::AlreadyExist);
112                    }
113                }
114            }
115
116            if let Some(result) = state.result_cache.get_mut(&callback_id) {
117                if result.len() > 0 {
118                    let ret = result.remove(0);
119                    return Ok(ResultFuture::new(Box::pin(async move {
120                        Ok(ret)
121                    })));
122                }
123            }
124
125            let (notify, waiter) = Notify::new();
126            state.result_notifies.insert(callback_id.clone(), Some(notify));
127            waiter
128        };
129        Ok(ResultFuture::new(Box::pin(async move {
130            let ret = async_std::future::timeout(timeout, waiter).await;
131            {
132                let mut state = self.state.lock().unwrap();
133                state.result_notifies.remove(&callback_id);
134            }
135            match ret {
136                Ok(ret) => Ok(ret),
137                Err(_) => Err(WaiterError::Timeout)
138            }
139        })))
140    }
141
142    pub fn set_result(&self, callback_id: K, result: R) -> Result<(), WaiterError> {
143        let mut state = self.state.lock().unwrap();
144        if let Some(future) = state.result_notifies.get_mut(&callback_id) {
145            if let Some(future) = future.take() {
146                if !future.is_canceled() {
147                    future.notify(result);
148                    return Ok(());
149                }
150            }
151        }
152        Err(WaiterError::NoWaiter)
153    }
154
155    pub fn set_result_with_cache(&self, callback_id: K, result: R) {
156        let mut state = self.state.lock().unwrap();
157        if let Some(future) = state.result_notifies.get_mut(&callback_id) {
158            if let Some(future) = future.take() {
159                if !future.is_canceled() {
160                    future.notify(result);
161                    return;
162                }
163            }
164        }
165        if let Some(cache) = state.result_cache.get_mut(&callback_id) {
166            cache.push(result);
167        } else {
168            state.result_cache.insert(callback_id, vec![result]);
169        }
170    }
171}
172
173struct SingleCallbackWaiterState<R> {
174    result_notify: Option<Option<Notify<R>>>,
175    result_cache: Vec<R>,
176}
177
178pub struct SingleCallbackWaiter<R> {
179    state: Mutex<SingleCallbackWaiterState<R>>,
180}
181
182impl <R: 'static + Send> SingleCallbackWaiter<R> {
183    pub fn new() -> Self {
184        Self {
185            state: Mutex::new(SingleCallbackWaiterState {
186                result_notify: None,
187                result_cache: Vec::new(),
188            })
189        }
190    }
191
192    pub fn create_result_future(&self) -> WaiterResult<ResultFuture<R>> {
193        let waiter = {
194            let mut state = self.state.lock().unwrap();
195            if let Some(notify) = state.result_notify.as_ref() {
196                if let Some(notify) = notify {
197                    if !notify.is_canceled() {
198                        return Err(WaiterError::AlreadyExist);
199                    }
200                }
201            }
202
203            if state.result_cache.len() > 0 {
204                let ret = state.result_cache.remove(0);
205                return Ok(ResultFuture::new(Box::pin(async move {
206                    Ok(ret)
207                })));
208            }
209            let (notify, waiter) = Notify::new();
210            state.result_notify = Some(Some(notify));
211            waiter
212        };
213        Ok(ResultFuture::new(Box::pin(async move {
214            let ret = waiter.await;
215            {
216                let mut state = self.state.lock().unwrap();
217                state.result_notify = None;
218            }
219            Ok(ret)
220        })))
221    }
222
223    pub fn create_timeout_result_future(&self, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
224        let waiter = {
225            let mut state = self.state.lock().unwrap();
226            if let Some(notify) = state.result_notify.as_ref() {
227                if let Some(notify) = notify {
228                    if !notify.is_canceled() {
229                        return Err(WaiterError::AlreadyExist);
230                    }
231                }
232            }
233
234            if state.result_cache.len() > 0 {
235                let ret = state.result_cache.remove(0);
236                return Ok(ResultFuture::new(Box::pin(async move {
237                    Ok(ret)
238                })));
239            }
240
241            let (notify, waiter) = Notify::new();
242            state.result_notify = Some(Some(notify));
243            waiter
244        };
245        Ok(ResultFuture::new(Box::pin(async move {
246            let ret = async_std::future::timeout(timeout, waiter).await;
247            {
248                let mut state = self.state.lock().unwrap();
249                state.result_notify = None;
250            }
251            match ret {
252                Ok(ret) => Ok(ret),
253                Err(_) => {
254                    Err(WaiterError::Timeout)
255                }
256            }
257        })))
258    }
259
260    pub fn set_result(&self, result: R) -> Result<(), WaiterError> {
261        let mut state = self.state.lock().unwrap();
262        if let Some(future) = state.result_notify.as_mut() {
263            if let Some(future) = future.take() {
264                if !future.is_canceled() {
265                    future.notify(result);
266                    return Ok(());
267                }
268            }
269        }
270        Err(WaiterError::NoWaiter)
271    }
272
273    pub fn set_result_with_cache(&self, result: R) {
274        let mut state = self.state.lock().unwrap();
275        if let Some(future) = state.result_notify.as_mut() {
276            if let Some(future) = future.take() {
277                if !future.is_canceled() {
278                    future.notify(result);
279                    return;
280                }
281            }
282        }
283        state.result_cache.push(result);
284    }
285}
286#[cfg(test)]
287mod test {
288    use std::sync::Arc;
289
290    #[test]
291    fn test_waiter() {
292        use async_std::task;
293        use std::time::Duration;
294        use super::*;
295        task::block_on(async {
296            let waiter = Arc::new(CallbackWaiter::new());
297            let callback_id = 1;
298            let result_future = waiter.create_result_future(callback_id).unwrap();
299            assert!(waiter.create_result_future(callback_id).is_err());
300            let tmp = waiter.clone();
301            async_std::task::spawn(async move {
302                async_std::task::sleep(Duration::from_millis(1000)).await;
303                let ret = tmp.set_result(callback_id, 1);
304                assert!(ret.is_ok());
305            });
306            let ret = result_future.await.unwrap();
307            assert_eq!(ret, 1);
308        });
309    }
310
311    #[test]
312    fn test_waiter1() {
313        use async_std::task;
314        use super::*;
315        task::block_on(async {
316            let waiter = Arc::new(CallbackWaiter::new());
317            let callback_id = 1;
318            let tmp = waiter.clone();
319            async_std::task::spawn(async move {
320                tmp.set_result_with_cache(callback_id, 1);
321            });
322            let result_future = waiter.create_result_future(callback_id).unwrap();
323            let ret = result_future.await.unwrap();
324            assert_eq!(ret, 1);
325        });
326    }
327
328    #[test]
329    fn test_waiter_timout() {
330        use async_std::task;
331        use std::time::Duration;
332        use super::*;
333        task::block_on(async {
334            let waiter = Arc::new(CallbackWaiter::new());
335            let callback_id = 1;
336            let result_future = waiter.create_timeout_result_future(callback_id, Duration::from_secs(2)).unwrap();
337            let tmp = waiter.clone();
338            async_std::task::spawn(async move {
339                async_std::task::sleep(Duration::from_millis(1000)).await;
340                let ret = tmp.set_result(callback_id, 1);
341                assert!(ret.is_ok());
342            });
343            let ret = result_future.await.unwrap();
344            assert_eq!(ret, 1);
345        });
346    }
347
348    #[test]
349    fn test_waiter_timout2() {
350        use async_std::task;
351        use std::time::Duration;
352        use super::*;
353        task::block_on(async {
354            let waiter = Arc::new(CallbackWaiter::new());
355            let callback_id = 1;
356            let result_future = waiter.create_timeout_result_future(callback_id, Duration::from_secs(2)).unwrap();
357            let tmp = waiter.clone();
358            async_std::task::spawn(async move {
359                async_std::task::sleep(Duration::from_secs(3)).await;
360                let ret = tmp.set_result(callback_id, 1);
361                assert!(ret.is_err());
362            });
363            match result_future.await {
364                Ok(_) => {}
365                Err(e) => {
366                    assert_eq!(e, WaiterError::Timeout);
367                }
368            }
369        });
370    }
371
372    #[test]
373    fn test_waiter_timout3() {
374        use async_std::task;
375        use std::time::Duration;
376        use super::*;
377        task::block_on(async {
378            let waiter = Arc::new(CallbackWaiter::new());
379            let callback_id = 1;
380            let tmp = waiter.clone();
381            async_std::task::spawn(async move {
382                let ret = tmp.set_result(callback_id, 1);
383                assert!(ret.is_err());
384            }).await;
385            let result_future = waiter.create_timeout_result_future(callback_id, Duration::from_secs(2)).unwrap();
386            assert!(waiter.create_timeout_result_future(callback_id, Duration::from_secs(2)).is_err());
387            match result_future.await {
388                Ok(_) => {}
389                Err(e) => {
390                    assert_eq!(e, WaiterError::Timeout);
391                }
392            }
393        });
394    }
395
396    #[test]
397    fn test_signle_waiter() {
398        use async_std::task;
399        use std::time::Duration;
400        use super::*;
401        task::block_on(async {
402            let waiter = Arc::new(SingleCallbackWaiter::new());
403            let result_future = waiter.create_result_future().unwrap();
404            assert!(waiter.create_result_future().is_err());
405            let tmp = waiter.clone();
406            async_std::task::spawn(async move {
407                async_std::task::sleep(Duration::from_millis(1000)).await;
408                let ret = tmp.set_result(1);
409                assert!(ret.is_ok());
410            });
411            let ret = result_future.await.unwrap();
412            assert_eq!(ret, 1);
413        });
414    }
415
416    #[test]
417    fn test_single_waiter1() {
418        use async_std::task;
419        use super::*;
420        task::block_on(async {
421            let waiter = Arc::new(SingleCallbackWaiter::new());
422            let tmp = waiter.clone();
423            async_std::task::spawn(async move {
424                tmp.set_result_with_cache(1);
425            });
426            let result_future = waiter.create_result_future().unwrap();
427            let ret = result_future.await.unwrap();
428            assert_eq!(ret, 1);
429        });
430    }
431
432    #[test]
433    fn test_single_waiter_timout() {
434        use async_std::task;
435        use std::time::Duration;
436        use super::*;
437        task::block_on(async {
438            let waiter = Arc::new(SingleCallbackWaiter::new());
439            let result_future = waiter.create_timeout_result_future(Duration::from_secs(2)).unwrap();
440            assert!(waiter.create_timeout_result_future(Duration::from_secs(2)).is_err());
441            let tmp = waiter.clone();
442            async_std::task::spawn(async move {
443                async_std::task::sleep(Duration::from_millis(1000)).await;
444                let ret = tmp.set_result(1);
445                assert!(ret.is_ok());
446            });
447            let ret = result_future.await.unwrap();
448            assert_eq!(ret, 1);
449        });
450    }
451
452    #[test]
453    fn test_single_waiter_timout2() {
454        use async_std::task;
455        use std::time::Duration;
456        use super::*;
457        task::block_on(async {
458            let waiter = Arc::new(SingleCallbackWaiter::new());
459            let result_future = waiter.create_timeout_result_future(Duration::from_secs(2)).unwrap();
460            let tmp = waiter.clone();
461            async_std::task::spawn(async move {
462                async_std::task::sleep(Duration::from_secs(3)).await;
463                let ret = tmp.set_result(1);
464                assert!(ret.is_err());
465            });
466            match result_future.await {
467                Ok(_) => {}
468                Err(e) => {
469                    assert_eq!(e, WaiterError::Timeout);
470                }
471            }
472        });
473    }
474
475    #[test]
476    fn test_single_waiter_timout3() {
477        use async_std::task;
478        use std::time::Duration;
479        use super::*;
480        task::block_on(async {
481            let waiter = Arc::new(SingleCallbackWaiter::new());
482            let tmp = waiter.clone();
483            async_std::task::spawn(async move {
484                let ret = tmp.set_result(1);
485                assert!(ret.is_err());
486            }).await;
487            let result_future = waiter.create_timeout_result_future(Duration::from_secs(2)).unwrap();
488            match result_future.await {
489                Ok(_) => {}
490                Err(e) => {
491                    assert_eq!(e, WaiterError::Timeout);
492                }
493            }
494        });
495    }
496}