1#![deny(missing_docs)]
37#![deny(clippy::all)]
38
39mod awaiting;
40mod errors;
41mod message;
42mod result;
43
44#[cfg(test)]
45mod tests;
46
47use async_std::{
48 net::UdpSocket,
49 sync::{Arc, Mutex},
50 task::{self, JoinHandle},
51};
52use futures::{
53 channel::{mpsc, oneshot},
54 future::FutureExt,
55 sink::SinkExt,
56 stream::StreamExt,
57};
58use std::{
59 future::Future,
60 net::{SocketAddr, ToSocketAddrs},
61 ops::Drop,
62 pin::Pin,
63 task::{Context, Poll},
64};
65
66use awaiting::AwaitingRequestMap;
67use message::{RpcHeader, RpcMessage};
68use result::Result;
69
70pub struct RpcSocket {
79 udp: Arc<UdpSocket>,
80 awaiting_map: Arc<AwaitingRequestMap>,
81
82 _handle: JoinHandle<()>,
83 receiver: Mutex<mpsc::UnboundedReceiver<(RpcMessage, SocketAddr)>>,
84}
85
86async fn rpc_loop(
87 udp: Arc<UdpSocket>,
88 awaiting_map: Arc<AwaitingRequestMap>,
89 mut sender: mpsc::UnboundedSender<(RpcMessage, SocketAddr)>,
90) {
91 let (msg_sender, mut msg_receiver) = mpsc::unbounded();
92 let receiver_handle = task::spawn(receiver_loop(udp, msg_sender));
93
94 while let Some((msg, addr)) = msg_receiver.next().await {
95 if msg.is_request() {
96 if sender.send((msg, addr)).await.is_err() {
97 break;
98 }
99 } else if let Some(rsp_sender) =
100 awaiting_map.pop(addr, msg.request_id()).await
101 {
102 let _ = rsp_sender.send(msg);
103 }
104 }
105
106 drop(msg_receiver);
107 receiver_handle.await;
108}
109
110async fn receiver_loop(
111 udp: Arc<UdpSocket>,
112 mut msg_sender: mpsc::UnboundedSender<(RpcMessage, SocketAddr)>,
113) {
114 while let Ok((msg, addr)) = RpcMessage::read_from_socket(&udp).await {
116 if msg_sender.send((msg, addr)).await.is_err() {
117 break;
118 }
119 }
120}
121
122impl RpcSocket {
123 pub async fn bind<A: ToSocketAddrs>(addrs: A) -> Result<Self> {
131 let addr = get_addr(addrs)?;
132 let udp = Arc::new(UdpSocket::bind(addr).await?);
133 let awaiting_map = Arc::new(AwaitingRequestMap::default());
134
135 let (sender, receiver) = mpsc::unbounded();
136 let receiver = Mutex::new(receiver);
137
138 let _handle =
139 task::spawn(rpc_loop(udp.clone(), awaiting_map.clone(), sender));
140
141 Ok(Self {
142 udp,
143 awaiting_map,
144 receiver,
145 _handle,
146 })
147 }
148
149 pub fn local_addr(&self) -> Result<SocketAddr> {
154 Ok(self.udp.local_addr()?)
155 }
156
157 #[allow(clippy::needless_lifetimes)]
162 pub async fn send_to<'a, A: ToSocketAddrs>(
163 &self,
164 buf: &[u8],
165 rsp_buf: &'a mut [u8],
166 addrs: A,
167 ) -> Result<(usize, ResponseFuture<'a>)> {
168 let addr = get_addr(addrs)?;
169
170 let (sender, receiver) = oneshot::channel();
171
172 let rid = self.awaiting_map.put(addr, sender).await;
173 let header = RpcHeader::request_from_rid(rid);
174
175 let written =
176 match RpcMessage::write_to_socket(&self.udp, addr, header, buf)
177 .await
178 {
179 Ok(written) => written,
180 Err(err) => {
181 self.awaiting_map.pop(addr, rid).await;
182 return Err(err);
183 }
184 };
185
186 Ok((
187 written,
188 ResponseFuture {
189 rsp_buf,
190 addr,
191 rid,
192 awaiting_map: self.awaiting_map.clone(),
193 receiver,
194 },
195 ))
196 }
197
198 pub async fn recv_from(
202 &self,
203 buf: &mut [u8],
204 ) -> Result<(usize, RpcResponder)> {
205 match self.receiver.lock().await.next().await {
206 Some((msg, addr)) => {
207 let read = msg.write_to_buffer(buf);
208 let header = msg.split();
209 Ok((
210 read,
211 RpcResponder {
212 origin: addr,
213 udp: self.udp.clone(),
214 header,
215 },
216 ))
217 }
218 None => Err(errors::other("unexpected channel close")),
219 }
220 }
221
222 pub fn ttl(&self) -> Result<u32> {
228 Ok(self.udp.ttl()?)
229 }
230
231 pub fn set_ttl(&self, ttl: u32) -> Result<()> {
236 Ok(self.udp.set_ttl(ttl)?)
237 }
238}
239
240pub struct ResponseFuture<'a> {
249 rsp_buf: &'a mut [u8],
250 addr: SocketAddr,
251 rid: u16,
252 awaiting_map: Arc<AwaitingRequestMap>,
253 receiver: oneshot::Receiver<RpcMessage>,
254}
255
256impl<'a> ResponseFuture<'a> {
257 pub fn remote_addr(&self) -> SocketAddr {
262 self.addr
263 }
264}
265
266impl<'a> Future for ResponseFuture<'a> {
267 type Output = Result<usize>;
268
269 fn poll(
270 mut self: Pin<&mut Self>,
271 cx: &mut Context<'_>,
272 ) -> Poll<Self::Output> {
273 let this = &mut *self;
274 this.receiver.poll_unpin(cx).map(|res| match res {
275 Ok(rsp) => {
276 let read = rsp.write_to_buffer(this.rsp_buf);
277 Ok(read)
278 }
279 Err(_) => Err(errors::other("unexpected channel cancel")),
280 })
281 }
282}
283
284impl<'a> Drop for ResponseFuture<'a> {
288 fn drop(&mut self) {
289 task::block_on(self.awaiting_map.pop(self.addr, self.rid));
290 }
291}
292
293pub struct RpcResponder {
297 origin: SocketAddr,
298 udp: Arc<UdpSocket>,
299
300 header: RpcHeader,
301}
302
303impl RpcResponder {
304 pub fn origin(&self) -> &SocketAddr {
306 &self.origin
307 }
308
309 pub async fn respond(mut self, buf: &[u8]) -> Result<usize> {
313 self.header.flip_request();
314 let written = RpcMessage::write_to_socket(
315 &self.udp,
316 self.origin,
317 self.header,
318 buf,
319 )
320 .await?;
321 Ok(written)
322 }
323}
324
325fn get_addr<A: ToSocketAddrs>(addrs: A) -> Result<SocketAddr> {
326 match addrs.to_socket_addrs()?.next() {
327 Some(addr) => Ok(addr),
328 None => Err(errors::invalid_input("no addresses to send data to")),
329 }
330}