1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156
#![warn(missing_debug_implementations, rust_2018_idioms)] use async_channel::{unbounded, Receiver, Sender, TrySendError}; use async_mutex::Mutex; use std::collections::HashMap; use std::io; use std::net::SocketAddr; use std::sync::Arc; use std::fmt; use tokio::net::{udp, ToSocketAddrs, UdpSocket}; type Packet = Vec<u8>; fn other<E: std::error::Error + Send + Sync + 'static>(e: E) -> io::Error { io::Error::new(io::ErrorKind::Other, e) } struct Inner { sender: Sender<UdpStream>, rx: Mutex<udp::RecvHalf>, tx: Mutex<udp::SendHalf>, children: Mutex<HashMap<SocketAddr, Sender<Packet>>>, } impl Inner { async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> { self.tx.lock().await.send_to(buf, target).await } async fn serve(self: Arc<Inner>) -> io::Result<()> { let socket = &mut self.rx.lock().await; loop { let mut buf = vec![0u8; 65536]; let (size, addr) = socket.recv_from(&mut buf).await?; buf.truncate(size); let mut children = self.children.lock().await; let sender = match children.get(&addr) { Some(sender) => sender.clone(), None => { let (tx, rx) = unbounded(); let stream = UdpStream::new(self.clone(), addr, rx); children.insert(addr, tx.clone()); self.sender.try_send(stream).map_err(other)?; tx } }; match sender.try_send(buf) { Ok(_) => {} Err(TrySendError::Closed(_)) => { children.remove(&addr); } _ => unreachable!(), }; } } } pub struct SendHalf { inner: Arc<Inner>, target: SocketAddr, } impl fmt::Debug for SendHalf { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SendHalf") .field("target", &self.target) .finish() } } #[derive(Debug)] pub struct RecvHalf { receiver: Receiver<Packet>, } impl SendHalf { pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> { self.inner.send_to(buf, &self.target).await } } impl RecvHalf { pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { let p = self.receiver.recv().await.map_err(other)?; let len = std::cmp::min(buf.len(), p.len()); buf.copy_from_slice(&p[..len]); Ok(len) } } pub struct UdpStream { tx: SendHalf, rx: RecvHalf, } impl fmt::Debug for UdpStream { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("UdpStream") .field("target", &self.tx.target) .finish() } } impl UdpStream { fn new(inner: Arc<Inner>, target: SocketAddr, receiver: Receiver<Packet>) -> UdpStream { UdpStream { tx: SendHalf { inner, target }, rx: RecvHalf { receiver }, } } pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> { self.tx.send(buf).await } pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> { self.rx.recv(buf).await } pub fn split(self) -> (RecvHalf, SendHalf) { (self.rx, self.tx) } } pub struct UdpListener { receiver: Receiver<UdpStream>, } impl fmt::Debug for UdpListener { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("UdpListener") .finish() } } impl UdpListener { pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpListener> { Self::from_tokio(UdpSocket::bind(addr).await?) } pub fn from_tokio(udp: UdpSocket) -> io::Result<UdpListener> { let (rx, tx) = udp.split(); let (sender, receiver) = unbounded(); let inner = Arc::new(Inner { sender, rx: Mutex::new(rx), tx: Mutex::new(tx), children: Mutex::new(HashMap::new()), }); tokio::spawn(inner.clone().serve()); Ok(UdpListener { receiver }) } pub fn from_std(socket: std::net::UdpSocket) -> io::Result<UdpListener> { Self::from_tokio(UdpSocket::from_std(socket)?) } pub async fn next(&mut self) -> io::Result<UdpStream> { self.receiver.recv().await.map_err(other) } }