Skip to main content

callback_result/
lib.rs

1use std::collections::{HashMap, VecDeque};
2use std::fmt::{Display, Formatter};
3use std::future::Future;
4use std::hash::Hash;
5use std::pin::Pin;
6use std::sync::{Mutex};
7use notify_future::{Notify};
8
9#[derive(Debug, Eq, PartialEq, Copy, Clone)]
10pub enum WaiterError {
11    AlreadyExist,
12    Timeout,
13    NoWaiter,
14}
15
16impl Display for WaiterError {
17    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
18        match self {
19            WaiterError::AlreadyExist => write!(f, "AlreadyExist"),
20            WaiterError::Timeout => write!(f, "Timeout"),
21            WaiterError::NoWaiter => write!(f, "NoWaiter"),
22        }
23    }
24}
25
26impl std::error::Error for WaiterError {
27
28}
29pub type WaiterResult<T> = Result<T, WaiterError>;
30
31pub struct ResultFuture<'a, R> {
32    future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>,
33}
34
35impl <'a, R> ResultFuture<'a, R> {
36    pub fn new(future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>) -> Self {
37        Self {
38            future,
39        }
40    }
41}
42
43impl <'a, R> Future for ResultFuture<'a, R> {
44    type Output = Result<R, WaiterError>;
45
46    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
47        self.get_mut().future.as_mut().poll(cx)
48    }
49}
50
51struct CallbackWaiterState<K, R> {
52    result_notifies: HashMap<K, Option<Notify<R>>>,
53    result_cache: HashMap<K, VecDeque<R>>,
54}
55pub struct CallbackWaiter<K, R> {
56    state: Mutex<CallbackWaiterState<K, R>>,
57}
58
59impl <K: Hash + Eq + Clone + 'static + Send, R: 'static + Send> CallbackWaiter<K, R> {
60    pub fn new() -> Self {
61        Self {
62            state: Mutex::new(CallbackWaiterState {
63                result_notifies: HashMap::new(),
64                result_cache: HashMap::new(),
65            })
66        }
67    }
68
69    pub fn create_result_future(&self, callback_id: K) -> WaiterResult<ResultFuture<R>> {
70        let waiter = {
71            let mut state = self.state.lock().unwrap();
72            let notifies = state.result_notifies.get(&callback_id);
73            if let Some(notifies) = notifies {
74                if let Some(notifies) = notifies {
75                    if!notifies.is_canceled() {
76                        return Err(WaiterError::AlreadyExist);
77                    }
78                }
79            }
80            if let Some(result) = state.result_cache.get_mut(&callback_id) {
81                if let Some(ret) = result.pop_front() {
82                    return Ok(ResultFuture::new(Box::pin(async move {
83                        Ok(ret)
84                    })));
85                }
86            }
87
88            let (notify, waiter) = Notify::new();
89            state.result_notifies.insert(callback_id.clone(), Some(notify));
90            waiter
91        };
92
93        Ok(ResultFuture::new(Box::pin(async move {
94            let ret = waiter.await;
95            {
96                let mut state = self.state.lock().unwrap();
97                state.result_notifies.remove(&callback_id);
98            }
99            Ok(ret)
100        })))
101    }
102
103    pub fn create_timeout_result_future(&self, callback_id: K, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
104        let waiter = {
105            let mut state = self.state.lock().unwrap();
106            let notifies = state.result_notifies.get(&callback_id);
107            if let Some(notifies) = notifies {
108                if let Some(notifies) = notifies {
109                    if!notifies.is_canceled() {
110                        return Err(WaiterError::AlreadyExist);
111                    }
112                }
113            }
114
115            if let Some(result) = state.result_cache.get_mut(&callback_id) {
116                if let Some(ret) = result.pop_front() {
117                    return Ok(ResultFuture::new(Box::pin(async move {
118                        Ok(ret)
119                    })));
120                }
121            }
122
123            let (notify, waiter) = Notify::new();
124            state.result_notifies.insert(callback_id.clone(), Some(notify));
125            waiter
126        };
127        Ok(ResultFuture::new(Box::pin(async move {
128            let ret = tokio::time::timeout(timeout, waiter).await;
129            {
130                let mut state = self.state.lock().unwrap();
131                state.result_notifies.remove(&callback_id);
132            }
133            match ret {
134                Ok(ret) => Ok(ret),
135                Err(_) => Err(WaiterError::Timeout)
136            }
137        })))
138    }
139
140    pub fn set_result(&self, callback_id: K, result: R) -> Result<(), WaiterError> {
141        let mut state = self.state.lock().unwrap();
142        if let Some(future) = state.result_notifies.get_mut(&callback_id) {
143            if let Some(future) = future.take() {
144                if !future.is_canceled() {
145                    future.notify(result);
146                    return Ok(());
147                }
148            }
149        }
150        Err(WaiterError::NoWaiter)
151    }
152
153    pub fn set_result_with_cache(&self, callback_id: K, result: R) {
154        let mut state = self.state.lock().unwrap();
155        if let Some(future) = state.result_notifies.get_mut(&callback_id) {
156            if let Some(future) = future.take() {
157                if !future.is_canceled() {
158                    future.notify(result);
159                    return;
160                }
161            }
162        }
163        if let Some(cache) = state.result_cache.get_mut(&callback_id) {
164            cache.push_back(result);
165        } else {
166            let mut cache = VecDeque::new();
167            cache.push_back(result);
168            state.result_cache.insert(callback_id, cache);
169        }
170    }
171}
172
173struct SingleCallbackWaiterState<R> {
174    result_notify: Option<Option<Notify<R>>>,
175    result_cache: VecDeque<R>,
176}
177
178pub struct SingleCallbackWaiter<R> {
179    state: Mutex<SingleCallbackWaiterState<R>>,
180}
181
182impl <R: 'static + Send> SingleCallbackWaiter<R> {
183    pub fn new() -> Self {
184        Self {
185            state: Mutex::new(SingleCallbackWaiterState {
186                result_notify: None,
187                result_cache: VecDeque::new(),
188            })
189        }
190    }
191
192    pub fn create_result_future(&self) -> WaiterResult<ResultFuture<R>> {
193        let waiter = {
194            let mut state = self.state.lock().unwrap();
195            if let Some(notify) = state.result_notify.as_ref() {
196                if let Some(notify) = notify {
197                    if !notify.is_canceled() {
198                        return Err(WaiterError::AlreadyExist);
199                    }
200                }
201            }
202
203            if let Some(ret) = state.result_cache.pop_front() {
204                return Ok(ResultFuture::new(Box::pin(async move {
205                    Ok(ret)
206                })));
207            }
208            let (notify, waiter) = Notify::new();
209            state.result_notify = Some(Some(notify));
210            waiter
211        };
212        Ok(ResultFuture::new(Box::pin(async move {
213            let ret = waiter.await;
214            {
215                let mut state = self.state.lock().unwrap();
216                state.result_notify = None;
217            }
218            Ok(ret)
219        })))
220    }
221
222    pub fn create_timeout_result_future(&self, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
223        let waiter = {
224            let mut state = self.state.lock().unwrap();
225            if let Some(notify) = state.result_notify.as_ref() {
226                if let Some(notify) = notify {
227                    if !notify.is_canceled() {
228                        return Err(WaiterError::AlreadyExist);
229                    }
230                }
231            }
232
233            if let Some(ret) = state.result_cache.pop_front() {
234                return Ok(ResultFuture::new(Box::pin(async move {
235                    Ok(ret)
236                })));
237            }
238
239            let (notify, waiter) = Notify::new();
240            state.result_notify = Some(Some(notify));
241            waiter
242        };
243        Ok(ResultFuture::new(Box::pin(async move {
244            let ret = tokio::time::timeout(timeout, waiter).await;
245            {
246                let mut state = self.state.lock().unwrap();
247                state.result_notify = None;
248            }
249            match ret {
250                Ok(ret) => Ok(ret),
251                Err(_) => {
252                    Err(WaiterError::Timeout)
253                }
254            }
255        })))
256    }
257
258    pub fn set_result(&self, result: R) -> Result<(), WaiterError> {
259        let mut state = self.state.lock().unwrap();
260        if let Some(future) = state.result_notify.as_mut() {
261            if let Some(future) = future.take() {
262                if !future.is_canceled() {
263                    future.notify(result);
264                    return Ok(());
265                }
266            }
267        }
268        Err(WaiterError::NoWaiter)
269    }
270
271    pub fn set_result_with_cache(&self, result: R) {
272        let mut state = self.state.lock().unwrap();
273        if let Some(future) = state.result_notify.as_mut() {
274            if let Some(future) = future.take() {
275                if !future.is_canceled() {
276                    future.notify(result);
277                    return;
278                }
279            }
280        }
281        state.result_cache.push_back(result);
282    }
283}
284#[cfg(test)]
285mod test {
286    use super::*;
287    use std::sync::Arc;
288    use tokio::time::{sleep, Duration};
289
290    #[tokio::test]
291    async fn test_waiter() {
292        let waiter = Arc::new(CallbackWaiter::new());
293        let callback_id = 1;
294        let result_future = waiter.create_result_future(callback_id).unwrap();
295        assert!(waiter.create_result_future(callback_id).is_err());
296        let tmp = waiter.clone();
297        tokio::spawn(async move {
298            sleep(Duration::from_millis(1000)).await;
299            let ret = tmp.set_result(callback_id, 1);
300            assert!(ret.is_ok());
301        });
302        let ret = result_future.await.unwrap();
303        assert_eq!(ret, 1);
304    }
305
306    #[tokio::test]
307    async fn test_waiter1() {
308        let waiter = Arc::new(CallbackWaiter::new());
309        let callback_id = 1;
310        let tmp = waiter.clone();
311        tokio::spawn(async move {
312            tmp.set_result_with_cache(callback_id, 1);
313        });
314        let result_future = waiter.create_result_future(callback_id).unwrap();
315        let ret = result_future.await.unwrap();
316        assert_eq!(ret, 1);
317    }
318
319    #[tokio::test]
320    async fn test_waiter_timout() {
321        let waiter = Arc::new(CallbackWaiter::new());
322        let callback_id = 1;
323        let result_future = waiter
324            .create_timeout_result_future(callback_id, Duration::from_secs(2))
325            .unwrap();
326        let tmp = waiter.clone();
327        tokio::spawn(async move {
328            sleep(Duration::from_millis(1000)).await;
329            let ret = tmp.set_result(callback_id, 1);
330            assert!(ret.is_ok());
331        });
332        let ret = result_future.await.unwrap();
333        assert_eq!(ret, 1);
334    }
335
336    #[tokio::test]
337    async fn test_waiter_timout2() {
338        let waiter = Arc::new(CallbackWaiter::new());
339        let callback_id = 1;
340        let result_future = waiter
341            .create_timeout_result_future(callback_id, Duration::from_secs(2))
342            .unwrap();
343        let tmp = waiter.clone();
344        tokio::spawn(async move {
345            sleep(Duration::from_secs(3)).await;
346            let ret = tmp.set_result(callback_id, 1);
347            assert!(ret.is_err());
348        });
349        match result_future.await {
350            Ok(_) => {}
351            Err(e) => {
352                assert_eq!(e, WaiterError::Timeout);
353            }
354        }
355    }
356
357    #[tokio::test]
358    async fn test_waiter_timout3() {
359        let waiter = Arc::new(CallbackWaiter::new());
360        let callback_id = 1;
361        let tmp = waiter.clone();
362        tokio::spawn(async move {
363            let ret = tmp.set_result(callback_id, 1);
364            assert!(ret.is_err());
365        })
366        .await
367        .unwrap();
368        let result_future = waiter
369            .create_timeout_result_future(callback_id, Duration::from_secs(2))
370            .unwrap();
371        assert!(waiter
372            .create_timeout_result_future(callback_id, Duration::from_secs(2))
373            .is_err());
374        match result_future.await {
375            Ok(_) => {}
376            Err(e) => {
377                assert_eq!(e, WaiterError::Timeout);
378            }
379        }
380    }
381
382    #[tokio::test]
383    async fn test_signle_waiter() {
384        let waiter = Arc::new(SingleCallbackWaiter::new());
385        let result_future = waiter.create_result_future().unwrap();
386        assert!(waiter.create_result_future().is_err());
387        let tmp = waiter.clone();
388        tokio::spawn(async move {
389            sleep(Duration::from_millis(1000)).await;
390            let ret = tmp.set_result(1);
391            assert!(ret.is_ok());
392        });
393        let ret = result_future.await.unwrap();
394        assert_eq!(ret, 1);
395    }
396
397    #[tokio::test]
398    async fn test_single_waiter1() {
399        let waiter = Arc::new(SingleCallbackWaiter::new());
400        let tmp = waiter.clone();
401        tokio::spawn(async move {
402            tmp.set_result_with_cache(1);
403        });
404        let result_future = waiter.create_result_future().unwrap();
405        let ret = result_future.await.unwrap();
406        assert_eq!(ret, 1);
407    }
408
409    #[tokio::test]
410    async fn test_single_waiter_timout() {
411        let waiter = Arc::new(SingleCallbackWaiter::new());
412        let result_future = waiter
413            .create_timeout_result_future(Duration::from_secs(2))
414            .unwrap();
415        assert!(waiter
416            .create_timeout_result_future(Duration::from_secs(2))
417            .is_err());
418        let tmp = waiter.clone();
419        tokio::spawn(async move {
420            sleep(Duration::from_millis(1000)).await;
421            let ret = tmp.set_result(1);
422            assert!(ret.is_ok());
423        });
424        let ret = result_future.await.unwrap();
425        assert_eq!(ret, 1);
426    }
427
428    #[tokio::test]
429    async fn test_single_waiter_timout2() {
430        let waiter = Arc::new(SingleCallbackWaiter::new());
431        let result_future = waiter
432            .create_timeout_result_future(Duration::from_secs(2))
433            .unwrap();
434        let tmp = waiter.clone();
435        tokio::spawn(async move {
436            sleep(Duration::from_secs(3)).await;
437            let ret = tmp.set_result(1);
438            assert!(ret.is_err());
439        });
440        match result_future.await {
441            Ok(_) => {}
442            Err(e) => {
443                assert_eq!(e, WaiterError::Timeout);
444            }
445        }
446    }
447
448    #[tokio::test]
449    async fn test_single_waiter_timout3() {
450        let waiter = Arc::new(SingleCallbackWaiter::new());
451        let tmp = waiter.clone();
452        tokio::spawn(async move {
453            let ret = tmp.set_result(1);
454            assert!(ret.is_err());
455        })
456        .await
457        .unwrap();
458        let result_future = waiter
459            .create_timeout_result_future(Duration::from_secs(2))
460            .unwrap();
461        match result_future.await {
462            Ok(_) => {}
463            Err(e) => {
464                assert_eq!(e, WaiterError::Timeout);
465            }
466        }
467    }
468
469    #[tokio::test]
470    async fn test_waiter_reregister_after_future_drop() {
471        let waiter = Arc::new(CallbackWaiter::new());
472        let callback_id = 42;
473        let dropped_future = waiter.create_result_future(callback_id).unwrap();
474        drop(dropped_future);
475
476        sleep(Duration::from_millis(10)).await;
477
478        let result_future = waiter.create_result_future(callback_id).unwrap();
479        let tmp = waiter.clone();
480        tokio::spawn(async move {
481            tmp.set_result(callback_id, 7).unwrap();
482        });
483
484        let ret = result_future.await.unwrap();
485        assert_eq!(ret, 7);
486    }
487
488    #[tokio::test]
489    async fn test_waiter_cache_fifo_under_load() {
490        let waiter = CallbackWaiter::new();
491        let callback_id = 1;
492        let total = 200;
493
494        for i in 0..total {
495            waiter.set_result_with_cache(callback_id, i);
496        }
497
498        for expected in 0..total {
499            let ret = waiter
500                .create_result_future(callback_id)
501                .unwrap()
502                .await
503                .unwrap();
504            assert_eq!(ret, expected);
505        }
506    }
507
508    #[tokio::test]
509    async fn test_waiter_timeout_set_result_race() {
510        for callback_id in 0..50 {
511            let waiter = Arc::new(CallbackWaiter::new());
512            let result_future = waiter
513                .create_timeout_result_future(callback_id, Duration::from_millis(50))
514                .unwrap();
515
516            let tmp = waiter.clone();
517            let set_task = tokio::spawn(async move {
518                sleep(Duration::from_millis(50)).await;
519                tmp.set_result(callback_id, 1)
520            });
521
522            let future_result = result_future.await;
523            let set_result = set_task.await.unwrap();
524
525            match (future_result, set_result) {
526                (Ok(1), Ok(())) => {}
527                (Err(WaiterError::Timeout), Err(WaiterError::NoWaiter)) => {}
528                (other_future, other_set) => {
529                    panic!("unexpected race outcome: {:?}, {:?}", other_future, other_set);
530                }
531            }
532        }
533    }
534}