Skip to main content

aurpc/
lib.rs

1//! Asynchronous UDP RPCs.
2//!
3//! Exposes a socket-like interface allowing for sending requests and awaiting a
4//! response as well as listening to requests, with UDP as transport.
5//!
6//! This is achieved by implementing an 24-bit protocol header on top of UDP
7//! containing 8-bit flags and a 16-bit request id.
8//!
9//! ```text
10//!                     1                   2
11//! 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4
12//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
13//! |     Flags     |           Request Id          |
14//! +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
15//! ```
16//!
17//! Since a UDP datagram can carry a maximum of 65507 data bytes. This means
18//! that, with the added overhead, each message can be a maximum of `65504
19//! bytes`.
20//!
21//! # Examples
22//!
23//! ```no_run
24//! # fn main() -> std::io::Result<()> { async_std::task::block_on(async {
25//! use aurpc::RpcSocket;
26//!
27//! let socket = RpcSocket::bind("127.0.0.1:8080").await?;
28//! let mut buf = vec![0u8; 1024];
29//!
30//! loop {
31//!     let (n, responder) = socket.recv_from(&mut buf).await?;
32//!     responder.respond(&buf[..n]).await?;
33//! }
34//! # }) }
35//! ```
36#![deny(missing_docs)]
37#![deny(clippy::all)]
38
39mod awaiting;
40mod errors;
41mod message;
42mod result;
43
44#[cfg(test)]
45mod tests;
46
47use async_std::{
48    net::UdpSocket,
49    sync::{Arc, Mutex},
50    task::{self, JoinHandle},
51};
52use futures::{
53    channel::{mpsc, oneshot},
54    future::FutureExt,
55    sink::SinkExt,
56    stream::StreamExt,
57};
58use std::{
59    future::Future,
60    net::{SocketAddr, ToSocketAddrs},
61    ops::Drop,
62    pin::Pin,
63    task::{Context, Poll},
64};
65
66use awaiting::AwaitingRequestMap;
67use message::{RpcHeader, RpcMessage};
68use result::Result;
69
70/// A RPC socket.
71///
72/// After creating a `RpcSocket` by [`bind`]ing it to a socket address, RPCs
73/// can be [`sent to`] and [`received from`] any other socket address.
74///
75/// [`bind`]: #method.bind
76/// [`sent to`]: #method.send_to
77/// [`received from`]: #method.recv_from
78pub struct RpcSocket {
79    udp: Arc<UdpSocket>,
80    awaiting_map: Arc<AwaitingRequestMap>,
81
82    _handle: JoinHandle<()>,
83    receiver: Mutex<mpsc::UnboundedReceiver<(RpcMessage, SocketAddr)>>,
84}
85
86async fn rpc_loop(
87    udp: Arc<UdpSocket>,
88    awaiting_map: Arc<AwaitingRequestMap>,
89    mut sender: mpsc::UnboundedSender<(RpcMessage, SocketAddr)>,
90) {
91    let (msg_sender, mut msg_receiver) = mpsc::unbounded();
92    let receiver_handle = task::spawn(receiver_loop(udp, msg_sender));
93
94    while let Some((msg, addr)) = msg_receiver.next().await {
95        if msg.is_request() {
96            if sender.send((msg, addr)).await.is_err() {
97                break;
98            }
99        } else if let Some(rsp_sender) =
100            awaiting_map.pop(addr, msg.request_id()).await
101        {
102            let _ = rsp_sender.send(msg);
103        }
104    }
105
106    drop(msg_receiver);
107    receiver_handle.await;
108}
109
110async fn receiver_loop(
111    udp: Arc<UdpSocket>,
112    mut msg_sender: mpsc::UnboundedSender<(RpcMessage, SocketAddr)>,
113) {
114    // TODO Handle the possibility of errors better
115    while let Ok((msg, addr)) = RpcMessage::read_from_socket(&udp).await {
116        if msg_sender.send((msg, addr)).await.is_err() {
117            break;
118        }
119    }
120}
121
122impl RpcSocket {
123    /// Creates a RPC socket from the given address.
124    ///
125    /// Binding with a port number of 0 will request that the OS assigns a port
126    /// to this socket. THe port allocated can be queried via the [`local_addr`]
127    /// method.
128    ///
129    /// [`local_addr`]: #method.local_addr
130    pub async fn bind<A: ToSocketAddrs>(addrs: A) -> Result<Self> {
131        let addr = get_addr(addrs)?;
132        let udp = Arc::new(UdpSocket::bind(addr).await?);
133        let awaiting_map = Arc::new(AwaitingRequestMap::default());
134
135        let (sender, receiver) = mpsc::unbounded();
136        let receiver = Mutex::new(receiver);
137
138        let _handle =
139            task::spawn(rpc_loop(udp.clone(), awaiting_map.clone(), sender));
140
141        Ok(Self {
142            udp,
143            awaiting_map,
144            receiver,
145            _handle,
146        })
147    }
148
149    /// Returns the local address that this listener is bound to.
150    ///
151    /// This can be useful, for example, when binding to port 0 to figure out
152    /// which port was actually bound.
153    pub fn local_addr(&self) -> Result<SocketAddr> {
154        Ok(self.udp.local_addr()?)
155    }
156
157    /// Sends an RPC on the socket to the given address.
158    ///
159    /// On success, returns the number of bytes written to and read from the
160    /// socket.
161    #[allow(clippy::needless_lifetimes)]
162    pub async fn send_to<'a, A: ToSocketAddrs>(
163        &self,
164        buf: &[u8],
165        rsp_buf: &'a mut [u8],
166        addrs: A,
167    ) -> Result<(usize, ResponseFuture<'a>)> {
168        let addr = get_addr(addrs)?;
169
170        let (sender, receiver) = oneshot::channel();
171
172        let rid = self.awaiting_map.put(addr, sender).await;
173        let header = RpcHeader::request_from_rid(rid);
174
175        let written =
176            match RpcMessage::write_to_socket(&self.udp, addr, header, buf)
177                .await
178            {
179                Ok(written) => written,
180                Err(err) => {
181                    self.awaiting_map.pop(addr, rid).await;
182                    return Err(err);
183                }
184            };
185
186        Ok((
187            written,
188            ResponseFuture {
189                rsp_buf,
190                addr,
191                rid,
192                awaiting_map: self.awaiting_map.clone(),
193                receiver,
194            },
195        ))
196    }
197
198    /// Receives RPC from the socket.
199    ///
200    /// On success, returns the number of bytes read and the origin.
201    pub async fn recv_from(
202        &self,
203        buf: &mut [u8],
204    ) -> Result<(usize, RpcResponder)> {
205        match self.receiver.lock().await.next().await {
206            Some((msg, addr)) => {
207                let read = msg.write_to_buffer(buf);
208                let header = msg.split();
209                Ok((
210                    read,
211                    RpcResponder {
212                        origin: addr,
213                        udp: self.udp.clone(),
214                        header,
215                    },
216                ))
217            }
218            None => Err(errors::other("unexpected channel close")),
219        }
220    }
221
222    /// Gets the value of the `IP_TTL` option for this socket.
223    ///
224    /// For more information about this option, see [`set_ttl`].
225    ///
226    /// [`set_ttl`]: #method.set_ttl
227    pub fn ttl(&self) -> Result<u32> {
228        Ok(self.udp.ttl()?)
229    }
230
231    /// Sets the value for the `IP_TTL` option on this socket.
232    ///
233    /// This value sets the time-to-live field that is used in every packet sent
234    /// from this socket.
235    pub fn set_ttl(&self, ttl: u32) -> Result<()> {
236        Ok(self.udp.set_ttl(ttl)?)
237    }
238}
239
240/// Future returned by [`send_to`].
241///
242/// Allows for awaiting in two steps - for sending the request first, and then
243/// for receiving the response.
244///
245/// Dropping signals disinterest in the response.
246///
247/// [`send_to`]: struct.RpcSocket.html#method.send_to
248pub struct ResponseFuture<'a> {
249    rsp_buf: &'a mut [u8],
250    addr: SocketAddr,
251    rid: u16,
252    awaiting_map: Arc<AwaitingRequestMap>,
253    receiver: oneshot::Receiver<RpcMessage>,
254}
255
256impl<'a> ResponseFuture<'a> {
257    /// Address one is expecting to receive a response from.
258    ///
259    /// Useful for situations where one doesn't know what `ToSocketAddrs`
260    /// resolves into.
261    pub fn remote_addr(&self) -> SocketAddr {
262        self.addr
263    }
264}
265
266impl<'a> Future for ResponseFuture<'a> {
267    type Output = Result<usize>;
268
269    fn poll(
270        mut self: Pin<&mut Self>,
271        cx: &mut Context<'_>,
272    ) -> Poll<Self::Output> {
273        let this = &mut *self;
274        this.receiver.poll_unpin(cx).map(|res| match res {
275            Ok(rsp) => {
276                let read = rsp.write_to_buffer(this.rsp_buf);
277                Ok(read)
278            }
279            Err(_) => Err(errors::other("unexpected channel cancel")),
280        })
281    }
282}
283
284/// Explicitly remove from map when dropped.
285///
286/// This allows for timeouts to be implemented outside this crate, for instance.
287impl<'a> Drop for ResponseFuture<'a> {
288    fn drop(&mut self) {
289        task::block_on(self.awaiting_map.pop(self.addr, self.rid));
290    }
291}
292
293/// Allows [`respond`]ing to RPCs.
294///
295/// [`respond`]: struct.RpcResponder.html#method.respond
296pub struct RpcResponder {
297    origin: SocketAddr,
298    udp: Arc<UdpSocket>,
299
300    header: RpcHeader,
301}
302
303impl RpcResponder {
304    /// The endpoint the RPC originates from.
305    pub fn origin(&self) -> &SocketAddr {
306        &self.origin
307    }
308
309    /// Responds to the received RPC and consumes the responder.
310    ///
311    /// On success, returns the number of bytes written.
312    pub async fn respond(mut self, buf: &[u8]) -> Result<usize> {
313        self.header.flip_request();
314        let written = RpcMessage::write_to_socket(
315            &self.udp,
316            self.origin,
317            self.header,
318            buf,
319        )
320        .await?;
321        Ok(written)
322    }
323}
324
325fn get_addr<A: ToSocketAddrs>(addrs: A) -> Result<SocketAddr> {
326    match addrs.to_socket_addrs()?.next() {
327        Some(addr) => Ok(addr),
328        None => Err(errors::invalid_input("no addresses to send data to")),
329    }
330}