use std::io;
use std::net::SocketAddr;
use std::pin::Pin;
use std::collections::HashMap;
use futures::channel::mpsc;
use futures::prelude::*;
use futures::task::{Context, Poll};
use futures::{select, Stream};
use async_std::net::UdpSocket;
use async_std::task::{self, JoinHandle};
use tracing::{span, Level};
use tracing_futures::Instrument;
use bytes::Bytes;
pub const UDP_BUFF_SIZE: usize = 4096;
#[derive(Debug, Clone, PartialEq)]
pub enum NetKind {
Udp,
Tcp,
}
#[derive(Debug, Clone, PartialEq)]
pub enum NetCommand {
Bind(NetKind, SocketAddr),
Unbind(NetKind, SocketAddr),
}
#[derive(Debug, Clone, PartialEq)]
pub struct NetMessage {
pub interface: u32,
pub address: SocketAddr,
pub data: Bytes,
}
impl NetMessage {
pub fn new(interface: u32, address: SocketAddr, data: Bytes) -> Self {
Self {
interface,
address,
data,
}
}
}
#[derive(Debug, PartialEq, Clone)]
pub enum NetError {
Io(io::ErrorKind),
Sender(mpsc::SendError),
NoMatchingInterface,
}
impl From<io::Error> for NetError {
fn from(e: io::Error) -> Self {
Self::Io(e.kind())
}
}
impl From<mpsc::SendError> for NetError {
fn from(e: mpsc::SendError) -> Self {
Self::Sender(e)
}
}
impl Unpin for Net {}
pub struct Net {
bindings: HashMap<u32, Binding>,
index: u32,
rx_sink: mpsc::Sender<NetMessage>,
rx_stream: mpsc::Receiver<NetMessage>,
}
#[derive(Debug)]
struct Binding {
handle: JoinHandle<Result<(), NetError>>,
sink: mpsc::Sender<NetMessage>,
exit: mpsc::Sender<()>,
info: NetInfo,
}
#[derive(Debug, Clone, PartialEq)]
pub struct NetInfo {
pub kind: NetKind,
pub addr: SocketAddr,
}
impl NetInfo {
pub fn new(addr: SocketAddr, kind: NetKind) -> Self {
Self { addr, kind }
}
}
impl Net {
pub fn new() -> Self {
let (rx_sink, rx_stream) = mpsc::channel::<NetMessage>(0);
Net {
bindings: HashMap::new(),
index: 0,
rx_sink,
rx_stream,
}
}
pub fn list(&self) -> Vec<NetInfo> {
self.bindings.iter().map(|(_k, b)| b.info.clone()).collect()
}
pub async fn bind(&mut self, kind: NetKind, addr: SocketAddr) -> Result<(), NetError> {
match kind {
NetKind::Udp => self.listen_udp(addr).await?,
NetKind::Tcp => unimplemented!(),
}
Ok(())
}
pub async fn unbind(&mut self, interface: u32) -> Result<(), NetError> {
let mut interface = match self.bindings.remove(&interface) {
Some(v) => v,
None => return Err(NetError::NoMatchingInterface),
};
interface.exit.send(()).await?;
Ok(())
}
pub async fn send(&mut self, msg: NetMessage) -> Result<(), NetError> {
let interface = match self.bindings.get_mut(&msg.interface) {
Some(v) => v,
None => return Err(NetError::NoMatchingInterface),
};
interface.sink.send(msg).await?;
Ok(())
}
async fn listen_udp(&mut self, address: SocketAddr) -> Result<(), NetError> {
let socket = UdpSocket::bind(address).await?;
let interface = self.index;
let mut rx_sink = self.rx_sink.clone();
let (tx_sink, mut tx_stream) = mpsc::channel::<NetMessage>(0);
let (exit_sink, mut exit_stream) = mpsc::channel::<()>(0);
debug!("Starting UDP listener {}: {}", interface, address);
let handle = task::spawn(
async move {
let mut buff = vec![0u8; UDP_BUFF_SIZE];
loop {
select! {
res = socket.recv_from(&mut buff).fuse() => {
match res {
Ok((n, address)) => {
let data = Bytes::copy_from_slice(&buff[..n]);
event!(Level::TRACE, kind="UDP receive", address = %address);
let msg = NetMessage{
interface,
address,
data,
};
rx_sink.send(msg).await?;
},
Err(e) => {
error!("recieve error: {:?}", e);
break
},
}
},
res = tx_stream.next() => {
match res {
Some(d) => {
event!(Level::TRACE, kind="UDP transmit", address = %d.address);
socket.send_to(&d.data, &d.address).await?;
},
None => debug!("tx stream closed"),
}
},
res = exit_stream.next() => {
if let Some(r) = res {
debug!("Received exit");
break;
}
},
}
}
Ok(())
}
.instrument(span!(Level::TRACE, "UDP", interface, address=%address)),
);
let binding = Binding {
handle,
sink: tx_sink,
exit: exit_sink,
info: NetInfo::new(address, NetKind::Udp),
};
self.bindings.insert(interface, binding);
self.index += 1;
Ok(())
}
}
impl Stream for Net {
type Item = NetMessage;
fn poll_next(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<Option<Self::Item>> {
#[cfg(feature = "profile")]
let _fg = ::flame::start_guard("net::poll_next");
Pin::new(&mut self.rx_stream).poll_next(ctx)
}
}
#[cfg(test)]
mod test {
use super::*;
use tracing_subscriber::FmtSubscriber;
#[test]
fn test_udp() {
let addr_a = "127.0.0.1:19993".parse().unwrap();
let addr_b = "127.0.0.1:19994".parse().unwrap();
let _ = FmtSubscriber::builder()
.with_max_level(Level::DEBUG)
.try_init();
task::block_on(async {
let mut net = Net::new();
assert_eq!(net.list().len(), 0);
net.bind(NetKind::Udp, addr_a)
.await
.expect("error binding udp interface 1");
assert_eq!(net.list().len(), 1);
net.bind(NetKind::Udp, addr_b)
.await
.expect("error binding udp interface 2");
assert_eq!(net.list().len(), 2);
let data = Bytes::copy_from_slice(&[0x11, 0x22, 0x33, 0x44]);
net.send(NetMessage::new(0, addr_b, data.clone()))
.await
.expect("Error sending net message");
let res = net.next().await.expect("Error awaiting net message");
assert_eq!(res, NetMessage::new(1, addr_a, data.clone()));
net.send(NetMessage::new(1, addr_a, data.clone()))
.await
.expect("Error sending net message");
let res = net.next().await.expect("Error awaiting net message");
assert_eq!(res, NetMessage::new(0, addr_b, data.clone()));
net.unbind(0).await.unwrap();
net.unbind(1).await.unwrap();
assert_eq!(net.list().len(), 0);
})
}
}