#![deny(missing_docs)]
#![deny(clippy::all)]
mod awaiting;
mod errors;
mod message;
mod result;
#[cfg(test)]
mod tests;
use async_std::{
net::UdpSocket,
sync::{Arc, Mutex},
task::{self, JoinHandle},
};
use futures::{
channel::{mpsc, oneshot},
future::FutureExt,
sink::SinkExt,
stream::StreamExt,
};
use std::{
future::Future,
net::{SocketAddr, ToSocketAddrs},
ops::Drop,
pin::Pin,
task::{Context, Poll},
};
use awaiting::AwaitingRequestMap;
use message::{RpcHeader, RpcMessage};
use result::Result;
pub struct RpcSocket {
udp: Arc<UdpSocket>,
awaiting_map: Arc<AwaitingRequestMap>,
_handle: JoinHandle<()>,
receiver: Mutex<mpsc::UnboundedReceiver<(RpcMessage, SocketAddr)>>,
}
async fn rpc_loop(
udp: Arc<UdpSocket>,
awaiting_map: Arc<AwaitingRequestMap>,
mut sender: mpsc::UnboundedSender<(RpcMessage, SocketAddr)>,
) {
let (msg_sender, mut msg_receiver) = mpsc::unbounded();
let receiver_handle = task::spawn(receiver_loop(udp, msg_sender));
while let Some((msg, addr)) = msg_receiver.next().await {
if msg.is_request() {
if sender.send((msg, addr)).await.is_err() {
break;
}
} else if let Some(rsp_sender) =
awaiting_map.pop(addr, msg.request_id()).await
{
let _ = rsp_sender.send(msg);
}
}
drop(msg_receiver);
receiver_handle.await;
}
async fn receiver_loop(
udp: Arc<UdpSocket>,
mut msg_sender: mpsc::UnboundedSender<(RpcMessage, SocketAddr)>,
) {
while let Ok((msg, addr)) = RpcMessage::read_from_socket(&udp).await {
if msg_sender.send((msg, addr)).await.is_err() {
break;
}
}
}
impl RpcSocket {
pub async fn bind<A: ToSocketAddrs>(addrs: A) -> Result<Self> {
let addr = get_addr(addrs)?;
let udp = Arc::new(UdpSocket::bind(addr).await?);
let awaiting_map = Arc::new(AwaitingRequestMap::default());
let (sender, receiver) = mpsc::unbounded();
let receiver = Mutex::new(receiver);
let _handle =
task::spawn(rpc_loop(udp.clone(), awaiting_map.clone(), sender));
Ok(Self {
udp,
awaiting_map,
receiver,
_handle,
})
}
pub fn local_addr(&self) -> Result<SocketAddr> {
Ok(self.udp.local_addr()?)
}
#[allow(clippy::needless_lifetimes)]
pub async fn send_to<'a, A: ToSocketAddrs>(
&self,
buf: &[u8],
rsp_buf: &'a mut [u8],
addrs: A,
) -> Result<(usize, ResponseFuture<'a>)> {
let addr = get_addr(addrs)?;
let (sender, receiver) = oneshot::channel();
let rid = self.awaiting_map.put(addr, sender).await;
let header = RpcHeader::request_from_rid(rid);
let written =
match RpcMessage::write_to_socket(&self.udp, addr, header, buf)
.await
{
Ok(written) => written,
Err(err) => {
self.awaiting_map.pop(addr, rid).await;
return Err(err);
}
};
Ok((
written,
ResponseFuture {
rsp_buf,
addr,
rid,
awaiting_map: self.awaiting_map.clone(),
receiver,
},
))
}
pub async fn recv_from(
&self,
buf: &mut [u8],
) -> Result<(usize, RpcResponder)> {
match self.receiver.lock().await.next().await {
Some((msg, addr)) => {
let read = msg.write_to_buffer(buf);
let header = msg.split();
Ok((
read,
RpcResponder {
origin: addr,
udp: self.udp.clone(),
header,
},
))
}
None => Err(errors::other("unexpected channel close")),
}
}
pub fn ttl(&self) -> Result<u32> {
Ok(self.udp.ttl()?)
}
pub fn set_ttl(&self, ttl: u32) -> Result<()> {
Ok(self.udp.set_ttl(ttl)?)
}
}
pub struct ResponseFuture<'a> {
rsp_buf: &'a mut [u8],
addr: SocketAddr,
rid: u16,
awaiting_map: Arc<AwaitingRequestMap>,
receiver: oneshot::Receiver<RpcMessage>,
}
impl<'a> ResponseFuture<'a> {
pub fn remote_addr(&self) -> SocketAddr {
self.addr
}
}
impl<'a> Future for ResponseFuture<'a> {
type Output = Result<usize>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Self::Output> {
let this = &mut *self;
this.receiver.poll_unpin(cx).map(|res| match res {
Ok(rsp) => {
let read = rsp.write_to_buffer(this.rsp_buf);
Ok(read)
}
Err(_) => Err(errors::other("unexpected channel cancel")),
})
}
}
impl<'a> Drop for ResponseFuture<'a> {
fn drop(&mut self) {
task::block_on(self.awaiting_map.pop(self.addr, self.rid));
}
}
pub struct RpcResponder {
origin: SocketAddr,
udp: Arc<UdpSocket>,
header: RpcHeader,
}
impl RpcResponder {
pub fn origin(&self) -> &SocketAddr {
&self.origin
}
pub async fn respond(mut self, buf: &[u8]) -> Result<usize> {
self.header.flip_request();
let written = RpcMessage::write_to_socket(
&self.udp,
self.origin,
self.header,
buf,
)
.await?;
Ok(written)
}
}
fn get_addr<A: ToSocketAddrs>(addrs: A) -> Result<SocketAddr> {
match addrs.to_socket_addrs()?.next() {
Some(addr) => Ok(addr),
None => Err(errors::invalid_input("no addresses to send data to")),
}
}