1use std::collections::HashMap;
2use std::fmt::{Display, Formatter};
3use std::future::Future;
4use std::hash::Hash;
5use std::pin::Pin;
6use std::sync::{Mutex};
7use notify_future::{Notify};
8
9#[derive(Debug, Eq, PartialEq, Copy, Clone)]
10pub enum WaiterError {
11 AlreadyExist,
12 Timeout,
13 NoWaiter,
14}
15
16impl Display for WaiterError {
17 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
18 match self {
19 WaiterError::AlreadyExist => write!(f, "AlreadyExist"),
20 WaiterError::Timeout => write!(f, "Timeout"),
21 WaiterError::NoWaiter => write!(f, "NoWaiter"),
22 }
23 }
24}
25
26impl std::error::Error for WaiterError {
27
28}
29pub type WaiterResult<T> = Result<T, WaiterError>;
30
31pub struct ResultFuture<'a, R> {
32 future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>,
33}
34
35impl <'a, R> ResultFuture<'a, R> {
36 pub fn new(future: Pin<Box<dyn Future<Output = Result<R, WaiterError>> + 'a + Send>>) -> Self {
37 Self {
38 future,
39 }
40 }
41}
42
43impl <'a, R> Future for ResultFuture<'a, R> {
44 type Output = Result<R, WaiterError>;
45
46 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
47 self.get_mut().future.as_mut().poll(cx)
48 }
49}
50
51struct CallbackWaiterState<K, R> {
52 result_notifies: HashMap<K, Option<Notify<R>>>,
53 result_cache: HashMap<K, Vec<R>>,
54}
55pub struct CallbackWaiter<K, R> {
56 state: Mutex<CallbackWaiterState<K, R>>,
57}
58
59impl <K: Hash + Eq + Clone + 'static + Send, R: 'static + Send> CallbackWaiter<K, R> {
60 pub fn new() -> Self {
61 Self {
62 state: Mutex::new(CallbackWaiterState {
63 result_notifies: HashMap::new(),
64 result_cache: HashMap::new(),
65 })
66 }
67 }
68
69 pub fn create_result_future(&self, callback_id: K) -> WaiterResult<ResultFuture<R>> {
70 let waiter = {
71 let mut state = self.state.lock().unwrap();
72 let notifies = state.result_notifies.get(&callback_id);
73 if let Some(notifies) = notifies {
74 if let Some(notifies) = notifies {
75 if!notifies.is_canceled() {
76 return Err(WaiterError::AlreadyExist);
77 }
78 }
79 }
80 if let Some(result) = state.result_cache.get_mut(&callback_id) {
81 if result.len() > 0 {
82 let ret = result.remove(0);
83 return Ok(ResultFuture::new(Box::pin(async move {
84 Ok(ret)
85 })));
86 }
87 }
88
89 let (notify, waiter) = Notify::new();
90 state.result_notifies.insert(callback_id.clone(), Some(notify));
91 waiter
92 };
93
94 Ok(ResultFuture::new(Box::pin(async move {
95 let ret = waiter.await;
96 {
97 let mut state = self.state.lock().unwrap();
98 state.result_notifies.remove(&callback_id);
99 }
100 Ok(ret)
101 })))
102 }
103
104 pub fn create_timeout_result_future(&self, callback_id: K, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
105 let waiter = {
106 let mut state = self.state.lock().unwrap();
107 let notifies = state.result_notifies.get(&callback_id);
108 if let Some(notifies) = notifies {
109 if let Some(notifies) = notifies {
110 if!notifies.is_canceled() {
111 return Err(WaiterError::AlreadyExist);
112 }
113 }
114 }
115
116 if let Some(result) = state.result_cache.get_mut(&callback_id) {
117 if result.len() > 0 {
118 let ret = result.remove(0);
119 return Ok(ResultFuture::new(Box::pin(async move {
120 Ok(ret)
121 })));
122 }
123 }
124
125 let (notify, waiter) = Notify::new();
126 state.result_notifies.insert(callback_id.clone(), Some(notify));
127 waiter
128 };
129 Ok(ResultFuture::new(Box::pin(async move {
130 let ret = async_std::future::timeout(timeout, waiter).await;
131 {
132 let mut state = self.state.lock().unwrap();
133 state.result_notifies.remove(&callback_id);
134 }
135 match ret {
136 Ok(ret) => Ok(ret),
137 Err(_) => Err(WaiterError::Timeout)
138 }
139 })))
140 }
141
142 pub fn set_result(&self, callback_id: K, result: R) -> Result<(), WaiterError> {
143 let mut state = self.state.lock().unwrap();
144 if let Some(future) = state.result_notifies.get_mut(&callback_id) {
145 if let Some(future) = future.take() {
146 if !future.is_canceled() {
147 future.notify(result);
148 return Ok(());
149 }
150 }
151 }
152 Err(WaiterError::NoWaiter)
153 }
154
155 pub fn set_result_with_cache(&self, callback_id: K, result: R) {
156 let mut state = self.state.lock().unwrap();
157 if let Some(future) = state.result_notifies.get_mut(&callback_id) {
158 if let Some(future) = future.take() {
159 if !future.is_canceled() {
160 future.notify(result);
161 return;
162 }
163 }
164 }
165 if let Some(cache) = state.result_cache.get_mut(&callback_id) {
166 cache.push(result);
167 } else {
168 state.result_cache.insert(callback_id, vec![result]);
169 }
170 }
171}
172
173struct SingleCallbackWaiterState<R> {
174 result_notify: Option<Option<Notify<R>>>,
175 result_cache: Vec<R>,
176}
177
178pub struct SingleCallbackWaiter<R> {
179 state: Mutex<SingleCallbackWaiterState<R>>,
180}
181
182impl <R: 'static + Send> SingleCallbackWaiter<R> {
183 pub fn new() -> Self {
184 Self {
185 state: Mutex::new(SingleCallbackWaiterState {
186 result_notify: None,
187 result_cache: Vec::new(),
188 })
189 }
190 }
191
192 pub fn create_result_future(&self) -> WaiterResult<ResultFuture<R>> {
193 let waiter = {
194 let mut state = self.state.lock().unwrap();
195 if let Some(notify) = state.result_notify.as_ref() {
196 if let Some(notify) = notify {
197 if !notify.is_canceled() {
198 return Err(WaiterError::AlreadyExist);
199 }
200 }
201 }
202
203 if state.result_cache.len() > 0 {
204 let ret = state.result_cache.remove(0);
205 return Ok(ResultFuture::new(Box::pin(async move {
206 Ok(ret)
207 })));
208 }
209 let (notify, waiter) = Notify::new();
210 state.result_notify = Some(Some(notify));
211 waiter
212 };
213 Ok(ResultFuture::new(Box::pin(async move {
214 let ret = waiter.await;
215 {
216 let mut state = self.state.lock().unwrap();
217 state.result_notify = None;
218 }
219 Ok(ret)
220 })))
221 }
222
223 pub fn create_timeout_result_future(&self, timeout: std::time::Duration) -> WaiterResult<ResultFuture<R>> {
224 let waiter = {
225 let mut state = self.state.lock().unwrap();
226 if let Some(notify) = state.result_notify.as_ref() {
227 if let Some(notify) = notify {
228 if !notify.is_canceled() {
229 return Err(WaiterError::AlreadyExist);
230 }
231 }
232 }
233
234 if state.result_cache.len() > 0 {
235 let ret = state.result_cache.remove(0);
236 return Ok(ResultFuture::new(Box::pin(async move {
237 Ok(ret)
238 })));
239 }
240
241 let (notify, waiter) = Notify::new();
242 state.result_notify = Some(Some(notify));
243 waiter
244 };
245 Ok(ResultFuture::new(Box::pin(async move {
246 let ret = async_std::future::timeout(timeout, waiter).await;
247 {
248 let mut state = self.state.lock().unwrap();
249 state.result_notify = None;
250 }
251 match ret {
252 Ok(ret) => Ok(ret),
253 Err(_) => {
254 Err(WaiterError::Timeout)
255 }
256 }
257 })))
258 }
259
260 pub fn set_result(&self, result: R) -> Result<(), WaiterError> {
261 let mut state = self.state.lock().unwrap();
262 if let Some(future) = state.result_notify.as_mut() {
263 if let Some(future) = future.take() {
264 if !future.is_canceled() {
265 future.notify(result);
266 return Ok(());
267 }
268 }
269 }
270 Err(WaiterError::NoWaiter)
271 }
272
273 pub fn set_result_with_cache(&self, result: R) {
274 let mut state = self.state.lock().unwrap();
275 if let Some(future) = state.result_notify.as_mut() {
276 if let Some(future) = future.take() {
277 if !future.is_canceled() {
278 future.notify(result);
279 return;
280 }
281 }
282 }
283 state.result_cache.push(result);
284 }
285}
286#[cfg(test)]
287mod test {
288 use std::sync::Arc;
289
290 #[test]
291 fn test_waiter() {
292 use async_std::task;
293 use std::time::Duration;
294 use super::*;
295 task::block_on(async {
296 let waiter = Arc::new(CallbackWaiter::new());
297 let callback_id = 1;
298 let result_future = waiter.create_result_future(callback_id).unwrap();
299 assert!(waiter.create_result_future(callback_id).is_err());
300 let tmp = waiter.clone();
301 async_std::task::spawn(async move {
302 async_std::task::sleep(Duration::from_millis(1000)).await;
303 let ret = tmp.set_result(callback_id, 1);
304 assert!(ret.is_ok());
305 });
306 let ret = result_future.await.unwrap();
307 assert_eq!(ret, 1);
308 });
309 }
310
311 #[test]
312 fn test_waiter1() {
313 use async_std::task;
314 use super::*;
315 task::block_on(async {
316 let waiter = Arc::new(CallbackWaiter::new());
317 let callback_id = 1;
318 let tmp = waiter.clone();
319 async_std::task::spawn(async move {
320 tmp.set_result_with_cache(callback_id, 1);
321 });
322 let result_future = waiter.create_result_future(callback_id).unwrap();
323 let ret = result_future.await.unwrap();
324 assert_eq!(ret, 1);
325 });
326 }
327
328 #[test]
329 fn test_waiter_timout() {
330 use async_std::task;
331 use std::time::Duration;
332 use super::*;
333 task::block_on(async {
334 let waiter = Arc::new(CallbackWaiter::new());
335 let callback_id = 1;
336 let result_future = waiter.create_timeout_result_future(callback_id, Duration::from_secs(2)).unwrap();
337 let tmp = waiter.clone();
338 async_std::task::spawn(async move {
339 async_std::task::sleep(Duration::from_millis(1000)).await;
340 let ret = tmp.set_result(callback_id, 1);
341 assert!(ret.is_ok());
342 });
343 let ret = result_future.await.unwrap();
344 assert_eq!(ret, 1);
345 });
346 }
347
348 #[test]
349 fn test_waiter_timout2() {
350 use async_std::task;
351 use std::time::Duration;
352 use super::*;
353 task::block_on(async {
354 let waiter = Arc::new(CallbackWaiter::new());
355 let callback_id = 1;
356 let result_future = waiter.create_timeout_result_future(callback_id, Duration::from_secs(2)).unwrap();
357 let tmp = waiter.clone();
358 async_std::task::spawn(async move {
359 async_std::task::sleep(Duration::from_secs(3)).await;
360 let ret = tmp.set_result(callback_id, 1);
361 assert!(ret.is_err());
362 });
363 match result_future.await {
364 Ok(_) => {}
365 Err(e) => {
366 assert_eq!(e, WaiterError::Timeout);
367 }
368 }
369 });
370 }
371
372 #[test]
373 fn test_waiter_timout3() {
374 use async_std::task;
375 use std::time::Duration;
376 use super::*;
377 task::block_on(async {
378 let waiter = Arc::new(CallbackWaiter::new());
379 let callback_id = 1;
380 let tmp = waiter.clone();
381 async_std::task::spawn(async move {
382 let ret = tmp.set_result(callback_id, 1);
383 assert!(ret.is_err());
384 }).await;
385 let result_future = waiter.create_timeout_result_future(callback_id, Duration::from_secs(2)).unwrap();
386 assert!(waiter.create_timeout_result_future(callback_id, Duration::from_secs(2)).is_err());
387 match result_future.await {
388 Ok(_) => {}
389 Err(e) => {
390 assert_eq!(e, WaiterError::Timeout);
391 }
392 }
393 });
394 }
395
396 #[test]
397 fn test_signle_waiter() {
398 use async_std::task;
399 use std::time::Duration;
400 use super::*;
401 task::block_on(async {
402 let waiter = Arc::new(SingleCallbackWaiter::new());
403 let result_future = waiter.create_result_future().unwrap();
404 assert!(waiter.create_result_future().is_err());
405 let tmp = waiter.clone();
406 async_std::task::spawn(async move {
407 async_std::task::sleep(Duration::from_millis(1000)).await;
408 let ret = tmp.set_result(1);
409 assert!(ret.is_ok());
410 });
411 let ret = result_future.await.unwrap();
412 assert_eq!(ret, 1);
413 });
414 }
415
416 #[test]
417 fn test_single_waiter1() {
418 use async_std::task;
419 use super::*;
420 task::block_on(async {
421 let waiter = Arc::new(SingleCallbackWaiter::new());
422 let tmp = waiter.clone();
423 async_std::task::spawn(async move {
424 tmp.set_result_with_cache(1);
425 });
426 let result_future = waiter.create_result_future().unwrap();
427 let ret = result_future.await.unwrap();
428 assert_eq!(ret, 1);
429 });
430 }
431
432 #[test]
433 fn test_single_waiter_timout() {
434 use async_std::task;
435 use std::time::Duration;
436 use super::*;
437 task::block_on(async {
438 let waiter = Arc::new(SingleCallbackWaiter::new());
439 let result_future = waiter.create_timeout_result_future(Duration::from_secs(2)).unwrap();
440 assert!(waiter.create_timeout_result_future(Duration::from_secs(2)).is_err());
441 let tmp = waiter.clone();
442 async_std::task::spawn(async move {
443 async_std::task::sleep(Duration::from_millis(1000)).await;
444 let ret = tmp.set_result(1);
445 assert!(ret.is_ok());
446 });
447 let ret = result_future.await.unwrap();
448 assert_eq!(ret, 1);
449 });
450 }
451
452 #[test]
453 fn test_single_waiter_timout2() {
454 use async_std::task;
455 use std::time::Duration;
456 use super::*;
457 task::block_on(async {
458 let waiter = Arc::new(SingleCallbackWaiter::new());
459 let result_future = waiter.create_timeout_result_future(Duration::from_secs(2)).unwrap();
460 let tmp = waiter.clone();
461 async_std::task::spawn(async move {
462 async_std::task::sleep(Duration::from_secs(3)).await;
463 let ret = tmp.set_result(1);
464 assert!(ret.is_err());
465 });
466 match result_future.await {
467 Ok(_) => {}
468 Err(e) => {
469 assert_eq!(e, WaiterError::Timeout);
470 }
471 }
472 });
473 }
474
475 #[test]
476 fn test_single_waiter_timout3() {
477 use async_std::task;
478 use std::time::Duration;
479 use super::*;
480 task::block_on(async {
481 let waiter = Arc::new(SingleCallbackWaiter::new());
482 let tmp = waiter.clone();
483 async_std::task::spawn(async move {
484 let ret = tmp.set_result(1);
485 assert!(ret.is_err());
486 }).await;
487 let result_future = waiter.create_timeout_result_future(Duration::from_secs(2)).unwrap();
488 match result_future.await {
489 Ok(_) => {}
490 Err(e) => {
491 assert_eq!(e, WaiterError::Timeout);
492 }
493 }
494 });
495 }
496}