1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll, Waker};
5use tokio::sync::Mutex;
6
7#[derive(Debug)]
8struct SharedState<T: Clone> {
9 delivery: Option<T>,
10 waker: Option<Waker>,
11}
12
13#[derive(Clone, Debug)]
14pub struct FutureRpcReply<T: Clone + std::fmt::Debug> {
15 shared_state: Arc<Mutex<SharedState<T>>>,
16}
17
18impl<T: Clone + std::fmt::Debug> FutureRpcReply<T> {
19 pub fn new() -> Self {
20 let shared_state = Arc::new(Mutex::new(SharedState {
21 delivery: None,
22 waker: None,
23 }));
24 Self { shared_state }
25 }
26
27 pub async fn resolve(self, delivery: T) {
28 let mut shared_state = self.shared_state.lock().await;
29 shared_state.delivery = Some(delivery);
30
31 match shared_state.waker.clone() {
32 Some(waker) => waker.wake(),
33 None => panic!("Future has never awaited before!"), }
35 }
36}
37
38impl<T: Clone + std::fmt::Debug> Future for FutureRpcReply<T> {
39 type Output = T;
40
41 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
42 let mut lock_fut = self.shared_state.lock();
43 let lock = unsafe { Pin::new_unchecked(&mut lock_fut) };
44
45 match lock.poll(cx) {
46 Poll::Pending => Poll::Pending,
47 Poll::Ready(mut v) => {
48 v.waker = Some(cx.waker().clone());
49 match v.delivery.clone() {
50 Some(v) => Poll::Ready(v),
51 None => Poll::Pending,
52 }
53 }
54 }
55 }
56}
57
58#[cfg(test)]
59mod tests {
60 use super::*;
61
62 #[tokio::test]
63 async fn resolves_to_specified_result() {
64 let reply = FutureRpcReply::<String>::new();
65 let reply_clone = reply.clone();
66
67 tokio::spawn(async move {
68 tokio::time::sleep(std::time::Duration::from_millis(1)).await;
69 reply_clone.resolve("Hello!".to_owned()).await;
70 });
71 let got = reply.await;
72
73 assert_eq!(got, "Hello!")
74 }
75}