1use std::collections::{HashMap, VecDeque};
2use std::fmt::{Display, Formatter};
3use std::future::Future;
4use std::hash::Hash;
5use std::pin::Pin;
6use std::sync::{Mutex};
7use notify_future::{Notify};
8
9#[derive(Debug, Eq, PartialEq, Copy, Clone)]
10pub enum WaiterError {
11 AlreadyExist,
12 Timeout,
13 NoWaiter,
14}
15
16impl Display for WaiterError {
17 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
18 match self {
19 WaiterError::AlreadyExist => write!(f, "AlreadyExist"),
20 WaiterError::Timeout => write!(f, "Timeout"),
21 WaiterError::NoWaiter => write!(f, "NoWaiter"),
22 }
23 }
24}
25
26impl std::error::Error for WaiterError {
27
28}
29pub type WaiterResult<T> = Result<T, WaiterError>;
30
31pub struct ResultFuture<'a, R> {
32 future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>,
33}
34
35impl <'a, R> ResultFuture<'a, R> {
36 pub fn new(future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>) -> Self {
37 Self {
38 future,
39 }
40 }
41}
42
43impl <'a, R> Future for ResultFuture<'a, R> {
44 type Output = Result<R, WaiterError>;
45
46 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
47 self.get_mut().future.as_mut().poll(cx)
48 }
49}
50
51struct CallbackWaiterState<K, R> {
52 result_notifies: HashMap<K, Option<Notify<R>>>,
53 result_cache: HashMap<K, VecDeque<R>>,
54}
55pub struct CallbackWaiter<K, R> {
56 state: Mutex<CallbackWaiterState<K, R>>,
57}
58
59impl <K: Hash + Eq + Clone + 'static + Send, R: 'static + Send> CallbackWaiter<K, R> {
60 pub fn new() -> Self {
61 Self {
62 state: Mutex::new(CallbackWaiterState {
63 result_notifies: HashMap::new(),
64 result_cache: HashMap::new(),
65 })
66 }
67 }
68
69 pub fn create_result_future(&self, callback_id: K) -> WaiterResult<ResultFuture<R>> {
70 let waiter = {
71 let mut state = self.state.lock().unwrap();
72 let notifies = state.result_notifies.get(&callback_id);
73 if let Some(notifies) = notifies {
74 if let Some(notifies) = notifies {
75 if!notifies.is_canceled() {
76 return Err(WaiterError::AlreadyExist);
77 }
78 }
79 }
80 if let Some(result) = state.result_cache.get_mut(&callback_id) {
81 if let Some(ret) = result.pop_front() {
82 return Ok(ResultFuture::new(Box::pin(async move {
83 Ok(ret)
84 })));
85 }
86 }
87
88 let (notify, waiter) = Notify::new();
89 state.result_notifies.insert(callback_id.clone(), Some(notify));
90 waiter
91 };
92
93 Ok(ResultFuture::new(Box::pin(async move {
94 let ret = waiter.await;
95 {
96 let mut state = self.state.lock().unwrap();
97 state.result_notifies.remove(&callback_id);
98 }
99 Ok(ret)
100 })))
101 }
102
103 pub fn create_timeout_result_future(&self, callback_id: K, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
104 let waiter = {
105 let mut state = self.state.lock().unwrap();
106 let notifies = state.result_notifies.get(&callback_id);
107 if let Some(notifies) = notifies {
108 if let Some(notifies) = notifies {
109 if!notifies.is_canceled() {
110 return Err(WaiterError::AlreadyExist);
111 }
112 }
113 }
114
115 if let Some(result) = state.result_cache.get_mut(&callback_id) {
116 if let Some(ret) = result.pop_front() {
117 return Ok(ResultFuture::new(Box::pin(async move {
118 Ok(ret)
119 })));
120 }
121 }
122
123 let (notify, waiter) = Notify::new();
124 state.result_notifies.insert(callback_id.clone(), Some(notify));
125 waiter
126 };
127 Ok(ResultFuture::new(Box::pin(async move {
128 let ret = tokio::time::timeout(timeout, waiter).await;
129 {
130 let mut state = self.state.lock().unwrap();
131 state.result_notifies.remove(&callback_id);
132 }
133 match ret {
134 Ok(ret) => Ok(ret),
135 Err(_) => Err(WaiterError::Timeout)
136 }
137 })))
138 }
139
140 pub fn set_result(&self, callback_id: K, result: R) -> Result<(), WaiterError> {
141 let mut state = self.state.lock().unwrap();
142 if let Some(future) = state.result_notifies.get_mut(&callback_id) {
143 if let Some(future) = future.take() {
144 if !future.is_canceled() {
145 future.notify(result);
146 return Ok(());
147 }
148 }
149 }
150 Err(WaiterError::NoWaiter)
151 }
152
153 pub fn set_result_with_cache(&self, callback_id: K, result: R) {
154 let mut state = self.state.lock().unwrap();
155 if let Some(future) = state.result_notifies.get_mut(&callback_id) {
156 if let Some(future) = future.take() {
157 if !future.is_canceled() {
158 future.notify(result);
159 return;
160 }
161 }
162 }
163 if let Some(cache) = state.result_cache.get_mut(&callback_id) {
164 cache.push_back(result);
165 } else {
166 let mut cache = VecDeque::new();
167 cache.push_back(result);
168 state.result_cache.insert(callback_id, cache);
169 }
170 }
171}
172
173struct SingleCallbackWaiterState<R> {
174 result_notify: Option<Option<Notify<R>>>,
175 result_cache: VecDeque<R>,
176}
177
178pub struct SingleCallbackWaiter<R> {
179 state: Mutex<SingleCallbackWaiterState<R>>,
180}
181
182impl <R: 'static + Send> SingleCallbackWaiter<R> {
183 pub fn new() -> Self {
184 Self {
185 state: Mutex::new(SingleCallbackWaiterState {
186 result_notify: None,
187 result_cache: VecDeque::new(),
188 })
189 }
190 }
191
192 pub fn create_result_future(&self) -> WaiterResult<ResultFuture<R>> {
193 let waiter = {
194 let mut state = self.state.lock().unwrap();
195 if let Some(notify) = state.result_notify.as_ref() {
196 if let Some(notify) = notify {
197 if !notify.is_canceled() {
198 return Err(WaiterError::AlreadyExist);
199 }
200 }
201 }
202
203 if let Some(ret) = state.result_cache.pop_front() {
204 return Ok(ResultFuture::new(Box::pin(async move {
205 Ok(ret)
206 })));
207 }
208 let (notify, waiter) = Notify::new();
209 state.result_notify = Some(Some(notify));
210 waiter
211 };
212 Ok(ResultFuture::new(Box::pin(async move {
213 let ret = waiter.await;
214 {
215 let mut state = self.state.lock().unwrap();
216 state.result_notify = None;
217 }
218 Ok(ret)
219 })))
220 }
221
222 pub fn create_timeout_result_future(&self, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
223 let waiter = {
224 let mut state = self.state.lock().unwrap();
225 if let Some(notify) = state.result_notify.as_ref() {
226 if let Some(notify) = notify {
227 if !notify.is_canceled() {
228 return Err(WaiterError::AlreadyExist);
229 }
230 }
231 }
232
233 if let Some(ret) = state.result_cache.pop_front() {
234 return Ok(ResultFuture::new(Box::pin(async move {
235 Ok(ret)
236 })));
237 }
238
239 let (notify, waiter) = Notify::new();
240 state.result_notify = Some(Some(notify));
241 waiter
242 };
243 Ok(ResultFuture::new(Box::pin(async move {
244 let ret = tokio::time::timeout(timeout, waiter).await;
245 {
246 let mut state = self.state.lock().unwrap();
247 state.result_notify = None;
248 }
249 match ret {
250 Ok(ret) => Ok(ret),
251 Err(_) => {
252 Err(WaiterError::Timeout)
253 }
254 }
255 })))
256 }
257
258 pub fn set_result(&self, result: R) -> Result<(), WaiterError> {
259 let mut state = self.state.lock().unwrap();
260 if let Some(future) = state.result_notify.as_mut() {
261 if let Some(future) = future.take() {
262 if !future.is_canceled() {
263 future.notify(result);
264 return Ok(());
265 }
266 }
267 }
268 Err(WaiterError::NoWaiter)
269 }
270
271 pub fn set_result_with_cache(&self, result: R) {
272 let mut state = self.state.lock().unwrap();
273 if let Some(future) = state.result_notify.as_mut() {
274 if let Some(future) = future.take() {
275 if !future.is_canceled() {
276 future.notify(result);
277 return;
278 }
279 }
280 }
281 state.result_cache.push_back(result);
282 }
283}
284#[cfg(test)]
285mod test {
286 use super::*;
287 use std::sync::Arc;
288 use tokio::time::{sleep, Duration};
289
290 #[tokio::test]
291 async fn test_waiter() {
292 let waiter = Arc::new(CallbackWaiter::new());
293 let callback_id = 1;
294 let result_future = waiter.create_result_future(callback_id).unwrap();
295 assert!(waiter.create_result_future(callback_id).is_err());
296 let tmp = waiter.clone();
297 tokio::spawn(async move {
298 sleep(Duration::from_millis(1000)).await;
299 let ret = tmp.set_result(callback_id, 1);
300 assert!(ret.is_ok());
301 });
302 let ret = result_future.await.unwrap();
303 assert_eq!(ret, 1);
304 }
305
306 #[tokio::test]
307 async fn test_waiter1() {
308 let waiter = Arc::new(CallbackWaiter::new());
309 let callback_id = 1;
310 let tmp = waiter.clone();
311 tokio::spawn(async move {
312 tmp.set_result_with_cache(callback_id, 1);
313 });
314 let result_future = waiter.create_result_future(callback_id).unwrap();
315 let ret = result_future.await.unwrap();
316 assert_eq!(ret, 1);
317 }
318
319 #[tokio::test]
320 async fn test_waiter_timout() {
321 let waiter = Arc::new(CallbackWaiter::new());
322 let callback_id = 1;
323 let result_future = waiter
324 .create_timeout_result_future(callback_id, Duration::from_secs(2))
325 .unwrap();
326 let tmp = waiter.clone();
327 tokio::spawn(async move {
328 sleep(Duration::from_millis(1000)).await;
329 let ret = tmp.set_result(callback_id, 1);
330 assert!(ret.is_ok());
331 });
332 let ret = result_future.await.unwrap();
333 assert_eq!(ret, 1);
334 }
335
336 #[tokio::test]
337 async fn test_waiter_timout2() {
338 let waiter = Arc::new(CallbackWaiter::new());
339 let callback_id = 1;
340 let result_future = waiter
341 .create_timeout_result_future(callback_id, Duration::from_secs(2))
342 .unwrap();
343 let tmp = waiter.clone();
344 tokio::spawn(async move {
345 sleep(Duration::from_secs(3)).await;
346 let ret = tmp.set_result(callback_id, 1);
347 assert!(ret.is_err());
348 });
349 match result_future.await {
350 Ok(_) => {}
351 Err(e) => {
352 assert_eq!(e, WaiterError::Timeout);
353 }
354 }
355 }
356
357 #[tokio::test]
358 async fn test_waiter_timout3() {
359 let waiter = Arc::new(CallbackWaiter::new());
360 let callback_id = 1;
361 let tmp = waiter.clone();
362 tokio::spawn(async move {
363 let ret = tmp.set_result(callback_id, 1);
364 assert!(ret.is_err());
365 })
366 .await
367 .unwrap();
368 let result_future = waiter
369 .create_timeout_result_future(callback_id, Duration::from_secs(2))
370 .unwrap();
371 assert!(waiter
372 .create_timeout_result_future(callback_id, Duration::from_secs(2))
373 .is_err());
374 match result_future.await {
375 Ok(_) => {}
376 Err(e) => {
377 assert_eq!(e, WaiterError::Timeout);
378 }
379 }
380 }
381
382 #[tokio::test]
383 async fn test_signle_waiter() {
384 let waiter = Arc::new(SingleCallbackWaiter::new());
385 let result_future = waiter.create_result_future().unwrap();
386 assert!(waiter.create_result_future().is_err());
387 let tmp = waiter.clone();
388 tokio::spawn(async move {
389 sleep(Duration::from_millis(1000)).await;
390 let ret = tmp.set_result(1);
391 assert!(ret.is_ok());
392 });
393 let ret = result_future.await.unwrap();
394 assert_eq!(ret, 1);
395 }
396
397 #[tokio::test]
398 async fn test_single_waiter1() {
399 let waiter = Arc::new(SingleCallbackWaiter::new());
400 let tmp = waiter.clone();
401 tokio::spawn(async move {
402 tmp.set_result_with_cache(1);
403 });
404 let result_future = waiter.create_result_future().unwrap();
405 let ret = result_future.await.unwrap();
406 assert_eq!(ret, 1);
407 }
408
409 #[tokio::test]
410 async fn test_single_waiter_timout() {
411 let waiter = Arc::new(SingleCallbackWaiter::new());
412 let result_future = waiter
413 .create_timeout_result_future(Duration::from_secs(2))
414 .unwrap();
415 assert!(waiter
416 .create_timeout_result_future(Duration::from_secs(2))
417 .is_err());
418 let tmp = waiter.clone();
419 tokio::spawn(async move {
420 sleep(Duration::from_millis(1000)).await;
421 let ret = tmp.set_result(1);
422 assert!(ret.is_ok());
423 });
424 let ret = result_future.await.unwrap();
425 assert_eq!(ret, 1);
426 }
427
428 #[tokio::test]
429 async fn test_single_waiter_timout2() {
430 let waiter = Arc::new(SingleCallbackWaiter::new());
431 let result_future = waiter
432 .create_timeout_result_future(Duration::from_secs(2))
433 .unwrap();
434 let tmp = waiter.clone();
435 tokio::spawn(async move {
436 sleep(Duration::from_secs(3)).await;
437 let ret = tmp.set_result(1);
438 assert!(ret.is_err());
439 });
440 match result_future.await {
441 Ok(_) => {}
442 Err(e) => {
443 assert_eq!(e, WaiterError::Timeout);
444 }
445 }
446 }
447
448 #[tokio::test]
449 async fn test_single_waiter_timout3() {
450 let waiter = Arc::new(SingleCallbackWaiter::new());
451 let tmp = waiter.clone();
452 tokio::spawn(async move {
453 let ret = tmp.set_result(1);
454 assert!(ret.is_err());
455 })
456 .await
457 .unwrap();
458 let result_future = waiter
459 .create_timeout_result_future(Duration::from_secs(2))
460 .unwrap();
461 match result_future.await {
462 Ok(_) => {}
463 Err(e) => {
464 assert_eq!(e, WaiterError::Timeout);
465 }
466 }
467 }
468
469 #[tokio::test]
470 async fn test_waiter_reregister_after_future_drop() {
471 let waiter = Arc::new(CallbackWaiter::new());
472 let callback_id = 42;
473 let dropped_future = waiter.create_result_future(callback_id).unwrap();
474 drop(dropped_future);
475
476 sleep(Duration::from_millis(10)).await;
477
478 let result_future = waiter.create_result_future(callback_id).unwrap();
479 let tmp = waiter.clone();
480 tokio::spawn(async move {
481 tmp.set_result(callback_id, 7).unwrap();
482 });
483
484 let ret = result_future.await.unwrap();
485 assert_eq!(ret, 7);
486 }
487
488 #[tokio::test]
489 async fn test_waiter_cache_fifo_under_load() {
490 let waiter = CallbackWaiter::new();
491 let callback_id = 1;
492 let total = 200;
493
494 for i in 0..total {
495 waiter.set_result_with_cache(callback_id, i);
496 }
497
498 for expected in 0..total {
499 let ret = waiter
500 .create_result_future(callback_id)
501 .unwrap()
502 .await
503 .unwrap();
504 assert_eq!(ret, expected);
505 }
506 }
507
508 #[tokio::test]
509 async fn test_waiter_timeout_set_result_race() {
510 for callback_id in 0..50 {
511 let waiter = Arc::new(CallbackWaiter::new());
512 let result_future = waiter
513 .create_timeout_result_future(callback_id, Duration::from_millis(50))
514 .unwrap();
515
516 let tmp = waiter.clone();
517 let set_task = tokio::spawn(async move {
518 sleep(Duration::from_millis(50)).await;
519 tmp.set_result(callback_id, 1)
520 });
521
522 let future_result = result_future.await;
523 let set_result = set_task.await.unwrap();
524
525 match (future_result, set_result) {
526 (Ok(1), Ok(())) => {}
527 (Err(WaiterError::Timeout), Err(WaiterError::NoWaiter)) => {}
528 (other_future, other_set) => {
529 panic!("unexpected race outcome: {:?}, {:?}", other_future, other_set);
530 }
531 }
532 }
533 }
534}