rmq_rpc/
reply.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll, Waker};
5use tokio::sync::Mutex;
6
7#[derive(Debug)]
8struct SharedState<T: Clone> {
9    delivery: Option<T>,
10    waker: Option<Waker>,
11}
12
13#[derive(Clone, Debug)]
14pub struct FutureRpcReply<T: Clone + std::fmt::Debug> {
15    shared_state: Arc<Mutex<SharedState<T>>>,
16}
17
18impl<T: Clone + std::fmt::Debug> FutureRpcReply<T> {
19    pub fn new() -> Self {
20        let shared_state = Arc::new(Mutex::new(SharedState {
21            delivery: None,
22            waker: None,
23        }));
24        Self { shared_state }
25    }
26
27    pub async fn resolve(self, delivery: T) {
28        let mut shared_state = self.shared_state.lock().await;
29        shared_state.delivery = Some(delivery);
30
31        match shared_state.waker.clone() {
32            Some(waker) => waker.wake(),
33            None => panic!("Future has never awaited before!"), // TODO: return error
34        }
35    }
36}
37
38impl<T: Clone + std::fmt::Debug> Future for FutureRpcReply<T> {
39    type Output = T;
40
41    fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
42        let mut lock_fut = self.shared_state.lock();
43        let lock = unsafe { Pin::new_unchecked(&mut lock_fut) };
44
45        match lock.poll(cx) {
46            Poll::Pending => Poll::Pending,
47            Poll::Ready(mut v) => {
48                v.waker = Some(cx.waker().clone());
49                match v.delivery.clone() {
50                    Some(v) => Poll::Ready(v),
51                    None => Poll::Pending,
52                }
53            }
54        }
55    }
56}
57
58#[cfg(test)]
59mod tests {
60    use super::*;
61
62    #[tokio::test]
63    async fn resolves_to_specified_result() {
64        let reply = FutureRpcReply::<String>::new();
65        let reply_clone = reply.clone();
66
67        tokio::spawn(async move {
68            tokio::time::sleep(std::time::Duration::from_millis(1)).await;
69            reply_clone.resolve("Hello!".to_owned()).await;
70        });
71        let got = reply.await;
72
73        assert_eq!(got, "Hello!")
74    }
75}