use async_trait::async_trait;
use bytes::Bytes;
use mio::net::UdpSocket as MIOUdpStream;
use parking_lot::Mutex;
use std::io;
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::UdpSocket as TokioUdpStream;
use tokio::sync::mpsc;
use tokio::task::JoinHandle;
use super::pool::{AddressedStreamFactory, NonBlockingStream};
use super::TransportError;
use crate::hub::event::IOSource;
use crate::hub::utils::error;
const MAX_UDP_PAYLOAD_SIZE: usize = 65507;
const INIT_MESSAGE: &[u8] = b"init";
impl NonBlockingStream for MIOUdpStream {
fn try_recv(&mut self) -> Result<Bytes, TransportError> {
let mut buffer = [0; MAX_UDP_PAYLOAD_SIZE];
match self.recv(&mut buffer) {
Ok(n) => Ok(Bytes::copy_from_slice(&buffer[..n])),
Err(_) => Err(TransportError::NotReady),
}
}
fn try_send(&mut self, data: Option<Bytes>) -> Result<bool, TransportError> {
let data = match data {
Some(d) => d,
None => return Ok(false),
};
match self.send(&data) {
Ok(_) => Ok(true),
Err(_) => Err(TransportError::NotReady),
}
}
fn source(&mut self) -> IOSource {
IOSource::MIO(self)
}
fn shutdown(&mut self, _how: std::net::Shutdown) -> io::Result<()> {
Ok(())
}
}
struct FactoryInner {
listen_handle: Option<JoinHandle<()>>,
accepted_stream: Mutex<mpsc::Receiver<MIOUdpStream>>,
}
impl Drop for FactoryInner {
fn drop(&mut self) {
if let Some(h) = self.listen_handle.take() {
h.abort()
}
}
}
#[derive(Clone)]
pub struct Factory(Arc<FactoryInner>);
impl Factory {
pub fn new(listen_addr: Option<SocketAddr>) -> Self {
let (tx, accepted_stream) = mpsc::channel(1);
let listen_handle = listen_addr.map(|listen_addr| {
tokio::spawn(async move {
let listener = match TokioUdpStream::bind(listen_addr).await {
Ok(l) => l,
Err(e) => {
error!("[Udp] failed to bind to address {}: {}", listen_addr, e);
return
}
};
loop {
let mut buf = [0u8; INIT_MESSAGE.len()];
if let Ok((_, client_addr)) = listener.recv_from(&mut buf).await {
{
if INIT_MESSAGE != &buf[..] {
continue
}
let udp_socket = match create_udp_socket().await {
Ok(s) => s,
Err(e) => {
error!("[Udp] failed to create udp socket: {}", e);
return
}
};
match udp_socket.connect(client_addr).await {
Ok(_) => match udp_socket.send(INIT_MESSAGE).await {
Ok(_) => {
tx.send(tokio_to_mio_stream(udp_socket)).await.ok();
}
Err(e) => error!("[Udp] failed to send Init message: {}", e),
},
Err(e) => {
error!("[Udp] failed to connect to the client {}: {}", client_addr, e)
}
}
}
}
}
})
});
Self(Arc::new(FactoryInner {
accepted_stream: Mutex::new(accepted_stream),
listen_handle,
}))
}
}
#[async_trait]
impl AddressedStreamFactory for Factory {
async fn create_stream(&self, addr: &str) -> Option<Box<dyn NonBlockingStream>> {
let addr: SocketAddr = addr.parse().ok()?;
let udp_socket = match create_udp_socket().await {
Ok(s) => s,
Err(e) => {
error!("[Udp] failed to create udp socket: {}", e);
return None
}
};
let mut buf = [0u8; INIT_MESSAGE.len()];
if let Ok(_) = udp_socket.send_to(INIT_MESSAGE, addr).await {
match udp_socket.recv_from(&mut buf).await {
Ok((_, server_socket)) => {
if INIT_MESSAGE == &buf[..] {
match udp_socket.connect(server_socket).await {
Ok(_) => Some(Box::new(tokio_to_mio_stream(udp_socket))),
Err(_) => None,
}
} else {
None
}
}
Err(_) => None,
}
} else {
None
}
}
async fn discover_stream(&self) -> Box<dyn NonBlockingStream> {
match self.0.accepted_stream.lock().recv().await {
None => futures::future::pending().await,
Some(s) => Box::new(s),
}
}
}
pub(crate) fn tokio_to_mio_stream(src: TokioUdpStream) -> MIOUdpStream {
let stream = src.into_std().unwrap();
MIOUdpStream::from_std(stream)
}
pub(crate) async fn create_udp_socket() -> io::Result<TokioUdpStream> {
TokioUdpStream::bind("127.0.0.1:0").await
}