1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#![warn(missing_debug_implementations, rust_2018_idioms)]

use async_channel::{unbounded, Receiver, Sender, TrySendError};
use async_mutex::Mutex;
use std::collections::HashMap;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use std::fmt;
use tokio::net::{udp, ToSocketAddrs, UdpSocket};

type Packet = Vec<u8>;

fn other<E: std::error::Error + Send + Sync + 'static>(e: E) -> io::Error {
    io::Error::new(io::ErrorKind::Other, e)
}

struct Inner {
    sender: Sender<UdpStream>,
    rx: Mutex<udp::RecvHalf>,
    tx: Mutex<udp::SendHalf>,
    children: Mutex<HashMap<SocketAddr, Sender<Packet>>>,
}

impl Inner {
    async fn send_to(&self, buf: &[u8], target: &SocketAddr) -> io::Result<usize> {
        self.tx.lock().await.send_to(buf, target).await
    }
    async fn serve(self: Arc<Inner>) -> io::Result<()> {
        let socket = &mut self.rx.lock().await;
        loop {
            let mut buf = vec![0u8; 65536];
            let (size, addr) = socket.recv_from(&mut buf).await?;
            buf.truncate(size);

            let mut children = self.children.lock().await;
            let sender = match children.get(&addr) {
                Some(sender) => sender.clone(),
                None => {
                    let (tx, rx) = unbounded();
                    let stream = UdpStream::new(self.clone(), addr, rx);
                    children.insert(addr, tx.clone());
                    self.sender.try_send(stream).map_err(other)?;
                    tx
                }
            };
            match sender.try_send(buf) {
                Ok(_) => {}
                Err(TrySendError::Closed(_)) => {
                    children.remove(&addr);
                }
                _ => unreachable!(),
            };
        }
    }
}

pub struct SendHalf {
    inner: Arc<Inner>,
    target: SocketAddr,
}

impl fmt::Debug for SendHalf {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("SendHalf")
            .field("target", &self.target)
            .finish()
    }
}

#[derive(Debug)]
pub struct RecvHalf {
    receiver: Receiver<Packet>,
}

impl SendHalf {
    pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
        self.inner.send_to(buf, &self.target).await
    }
}

impl RecvHalf {
    pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        let p = self.receiver.recv().await.map_err(other)?;
        let len = std::cmp::min(buf.len(), p.len());
        buf.copy_from_slice(&p[..len]);
        Ok(len)
    }
}

pub struct UdpStream {
    tx: SendHalf,
    rx: RecvHalf,
}

impl fmt::Debug for UdpStream {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("UdpStream")
            .field("target", &self.tx.target)
            .finish()
    }
}

impl UdpStream {
    fn new(inner: Arc<Inner>, target: SocketAddr, receiver: Receiver<Packet>) -> UdpStream {
        UdpStream {
            tx: SendHalf { inner, target },
            rx: RecvHalf { receiver },
        }
    }
    pub async fn send(&mut self, buf: &[u8]) -> io::Result<usize> {
        self.tx.send(buf).await
    }
    pub async fn recv(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        self.rx.recv(buf).await
    }
    pub fn split(self) -> (RecvHalf, SendHalf) {
        (self.rx, self.tx)
    }
}


pub struct UdpListener {
    receiver: Receiver<UdpStream>,
}

impl fmt::Debug for UdpListener {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("UdpListener")
            .finish()
    }
}

impl UdpListener {
    pub async fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpListener> {
        Self::from_tokio(UdpSocket::bind(addr).await?)
    }
    pub fn from_tokio(udp: UdpSocket) -> io::Result<UdpListener> {
        let (rx, tx) = udp.split();
        let (sender, receiver) = unbounded();
        let inner = Arc::new(Inner {
            sender,
            rx: Mutex::new(rx),
            tx: Mutex::new(tx),
            children: Mutex::new(HashMap::new()),
        });
        tokio::spawn(inner.clone().serve());
        Ok(UdpListener { receiver })
    }
    pub fn from_std(socket: std::net::UdpSocket) -> io::Result<UdpListener> {
        Self::from_tokio(UdpSocket::from_std(socket)?)
    }
    pub async fn next(&mut self) -> io::Result<UdpStream> {
        self.receiver.recv().await.map_err(other)
    }
}