maelstrom/
rpc.rs

1use crate::protocol::{ErrorMessageBody, Message};
2use crate::Error;
3use crate::Result;
4use crate::Runtime;
5use std::future::Future;
6use std::pin::{pin, Pin};
7use std::task::Poll;
8use tokio::select;
9use tokio::sync::oneshot::Receiver;
10use tokio::sync::{oneshot, OnceCell};
11use tokio_context::context::Context;
12
13/// Represents a result of a RPC call. Can be awaited with or without timeout.
14///
15/// Example:
16///
17/// ```
18/// use maelstrom::protocol::Message;
19/// use maelstrom::{RPCResult, Runtime, Result};
20/// use serde::Serialize;
21/// use tokio_context::context::Context;
22///
23/// async fn call<T>(ctx: Context, runtime: Runtime, node: String, msg: T) -> Result<Message>
24/// where
25///     T: Serialize,
26/// {
27///     let mut res: RPCResult = runtime.rpc(node, msg).await?;
28///     return res.done_with(ctx).await;
29/// }
30/// ```
31pub struct RPCResult {
32    runtime: Runtime,
33    rx: OnceCell<Receiver<Message>>,
34    msg_id: u64,
35}
36
37impl RPCResult {
38    #[must_use]
39    pub fn new(msg_id: u64, rx: Receiver<Message>, runtime: Runtime) -> RPCResult {
40        RPCResult {
41            runtime,
42            rx: OnceCell::new_with(Some(rx)),
43            msg_id,
44        }
45    }
46
47    /// Releases RPC call resources. Drop calls `Self::done`().
48    ///
49    /// Example:
50    ///
51    /// ```
52    /// use maelstrom::protocol::Message;
53    /// use maelstrom::{RPCResult, Runtime, Result};
54    /// use serde::Serialize;
55    ///
56    /// async fn call<T>(runtime: Runtime, node: String, msg: T)
57    /// where
58    ///     T: Serialize,
59    /// {
60    ///     let _ = runtime.rpc(node, msg).await;
61    /// }
62    /// ```
63    pub fn done(&mut self) {
64        drop(self.rx.take());
65        drop(self.runtime.release_rpc_sender(self.msg_id));
66    }
67
68    /// Acquires a RPC call response within specific timeout.
69    ///
70    /// Example:
71    ///
72    /// ```
73    /// use maelstrom::protocol::Message;
74    /// use maelstrom::{RPCResult, Runtime, Result};
75    /// use serde::Serialize;
76    /// use tokio_context::context::Context;
77    ///
78    /// async fn call<T>(ctx: Context, runtime: Runtime, node: String, msg: T) -> Result<Message>
79    /// where
80    ///     T: Serialize,
81    /// {
82    ///     let mut res: RPCResult = runtime.rpc(node, msg).await?;
83    ///     return res.done_with(ctx).await;
84    /// }
85    /// ```
86    pub async fn done_with(&mut self, mut ctx: Context) -> Result<Message> {
87        let result: Result<Message>;
88        let rx = match self.rx.take() {
89            Some(x) => x,
90            None => return Err(Box::new(Error::Abort)),
91        };
92
93        select! {
94            data = rx => match data {
95                Ok(resp) => result = rpc_msg_type(resp),
96                Err(err) => result = Err(Box::new(err)),
97            },
98            _ = ctx.done() => result = Err(Box::new(Error::Timeout)),
99        }
100
101        drop(self.runtime.release_rpc_sender(self.msg_id));
102
103        result
104    }
105}
106
107impl Drop for RPCResult {
108    fn drop(&mut self) {
109        self.done();
110    }
111}
112
113/// Makes `RPCResult` an awaitable future.
114///
115/// Example:
116///
117/// ```
118/// use maelstrom::protocol::Message;
119/// use maelstrom::{RPCResult, Runtime, Result};
120/// use serde::Serialize;
121///
122/// async fn call<T>(runtime: Runtime, node: String, msg: T) -> Result<Message>
123/// where
124///     T: Serialize,
125/// {
126///     let mut res: RPCResult = runtime.rpc(node, msg).await?;
127///     return res.await;
128/// }
129/// ```
130impl Future for RPCResult {
131    type Output = Result<Message>;
132
133    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
134        let rx = pin!(match self.rx.get_mut() {
135            Some(x) => x,
136            None => return Poll::Ready(Err(Box::new(Error::Abort))),
137        });
138
139        match rx.poll(cx) {
140            Poll::Ready(t) => {
141                let _ = self.rx.take();
142                match t {
143                    Err(e) => Poll::Ready(Err(Box::new(e))),
144                    Ok(m) => Poll::Ready(rpc_msg_type(m)),
145                }
146            }
147            Poll::Pending => Poll::Pending,
148        }
149    }
150}
151
152pub(crate) async fn rpc(runtime: Runtime, msg_id: u64, req: Result<String>) -> Result<RPCResult> {
153    let req_str = req?;
154
155    let (tx, rx) = oneshot::channel::<Message>();
156
157    let _ = runtime.insert_rpc_sender(msg_id, tx).await;
158
159    if let Err(err) = runtime.send_raw(req_str.as_str()).await {
160        let _ = runtime.release_rpc_sender(msg_id).await;
161        return Err(err);
162    }
163
164    Ok(RPCResult::new(msg_id, rx, runtime))
165}
166
167fn rpc_msg_type(m: Message) -> Result<Message> {
168    if m.body.is_error() {
169        Err(Box::new(Error::from(&m.body)))
170    } else {
171        Ok(m)
172    }
173}
174
175pub fn is_rpc_error<T>(t: &Result<T>) -> bool {
176    match t {
177        Ok(_) => false,
178        Err(e) => e.downcast_ref::<Error>().is_some(),
179    }
180}
181
182pub fn rpc_err_to_response<T>(t: &Result<T>) -> Option<ErrorMessageBody> {
183    match t {
184        Ok(_) => None,
185        Err(e) => e
186            .downcast_ref::<Error>()
187            .map(|t| ErrorMessageBody::from_error(t.clone())),
188    }
189}