1use crate::protocol::{ErrorMessageBody, Message};
2use crate::Error;
3use crate::Result;
4use crate::Runtime;
5use std::future::Future;
6use std::pin::{pin, Pin};
7use std::task::Poll;
8use tokio::select;
9use tokio::sync::oneshot::Receiver;
10use tokio::sync::{oneshot, OnceCell};
11use tokio_context::context::Context;
12
13pub struct RPCResult {
32 runtime: Runtime,
33 rx: OnceCell<Receiver<Message>>,
34 msg_id: u64,
35}
36
37impl RPCResult {
38 #[must_use]
39 pub fn new(msg_id: u64, rx: Receiver<Message>, runtime: Runtime) -> RPCResult {
40 RPCResult {
41 runtime,
42 rx: OnceCell::new_with(Some(rx)),
43 msg_id,
44 }
45 }
46
47 pub fn done(&mut self) {
64 drop(self.rx.take());
65 drop(self.runtime.release_rpc_sender(self.msg_id));
66 }
67
68 pub async fn done_with(&mut self, mut ctx: Context) -> Result<Message> {
87 let result: Result<Message>;
88 let rx = match self.rx.take() {
89 Some(x) => x,
90 None => return Err(Box::new(Error::Abort)),
91 };
92
93 select! {
94 data = rx => match data {
95 Ok(resp) => result = rpc_msg_type(resp),
96 Err(err) => result = Err(Box::new(err)),
97 },
98 _ = ctx.done() => result = Err(Box::new(Error::Timeout)),
99 }
100
101 drop(self.runtime.release_rpc_sender(self.msg_id));
102
103 result
104 }
105}
106
107impl Drop for RPCResult {
108 fn drop(&mut self) {
109 self.done();
110 }
111}
112
113impl Future for RPCResult {
131 type Output = Result<Message>;
132
133 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
134 let rx = pin!(match self.rx.get_mut() {
135 Some(x) => x,
136 None => return Poll::Ready(Err(Box::new(Error::Abort))),
137 });
138
139 match rx.poll(cx) {
140 Poll::Ready(t) => {
141 let _ = self.rx.take();
142 match t {
143 Err(e) => Poll::Ready(Err(Box::new(e))),
144 Ok(m) => Poll::Ready(rpc_msg_type(m)),
145 }
146 }
147 Poll::Pending => Poll::Pending,
148 }
149 }
150}
151
152pub(crate) async fn rpc(runtime: Runtime, msg_id: u64, req: Result<String>) -> Result<RPCResult> {
153 let req_str = req?;
154
155 let (tx, rx) = oneshot::channel::<Message>();
156
157 let _ = runtime.insert_rpc_sender(msg_id, tx).await;
158
159 if let Err(err) = runtime.send_raw(req_str.as_str()).await {
160 let _ = runtime.release_rpc_sender(msg_id).await;
161 return Err(err);
162 }
163
164 Ok(RPCResult::new(msg_id, rx, runtime))
165}
166
167fn rpc_msg_type(m: Message) -> Result<Message> {
168 if m.body.is_error() {
169 Err(Box::new(Error::from(&m.body)))
170 } else {
171 Ok(m)
172 }
173}
174
175pub fn is_rpc_error<T>(t: &Result<T>) -> bool {
176 match t {
177 Ok(_) => false,
178 Err(e) => e.downcast_ref::<Error>().is_some(),
179 }
180}
181
182pub fn rpc_err_to_response<T>(t: &Result<T>) -> Option<ErrorMessageBody> {
183 match t {
184 Ok(_) => None,
185 Err(e) => e
186 .downcast_ref::<Error>()
187 .map(|t| ErrorMessageBody::from_error(t.clone())),
188 }
189}