use crate::common::ready_future::ReadyFuture;
use crate::common::ready_future_state::ReadyFutureResult;
use crate::net::event_listener;
use futures::{AsyncRead, AsyncWrite, FutureExt};
use mio::Token;
use mio::net::UdpSocket as MioUdpSocket;
use std::fmt::{Debug, Error, Formatter};
use std::io;
use std::net::UdpSocket as StdUdpSocket;
use std::net::{SocketAddr, ToSocketAddrs};
use std::pin::Pin;
use std::task::{Context, Poll};
use std::time::{Duration, Instant};
pub struct UdpReadSocket {
udp_socket: MioUdpSocket,
read_token: Token,
read_future: Option<ReadyFuture<()>>,
pub read_timeout: Duration,
}
impl UdpReadSocket {
pub fn new(udp_socket: MioUdpSocket) -> Self {
UdpReadSocket {
udp_socket,
read_token: event_listener().next_token(),
read_future: None,
read_timeout: Duration::from_secs(20),
}
}
pub fn set_read_timeout(&mut self, duration: Duration) {
self.read_timeout = duration;
}
fn wait_read_data(&mut self) -> io::Result<()> {
let future = event_listener().listen_read(
&mut self.udp_socket,
Instant::now() + self.read_timeout,
self.read_token,
)?;
self.read_future = Some(future);
Ok(())
}
fn poll_read_attempt(
&mut self,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<(usize, SocketAddr)>> {
let mut future = match self.read_future.take() {
None => {
match self.udp_socket.recv_from(buf) {
Ok((size, addr)) => return Poll::Ready(Ok((size, addr))),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
Err(err) => return Poll::Ready(Err(err)),
}
if let Err(err) = self.wait_read_data() {
return Poll::Ready(Err(err));
}
self.read_future.take().unwrap()
}
Some(future) => future,
};
match future.poll_unpin(cx) {
Poll::Pending => {
self.read_future = Some(future);
Poll::Pending
}
Poll::Ready(ReadyFutureResult::Timeout) => {
Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
}
Poll::Ready(_) => match self.udp_socket.recv_from(buf) {
Ok((size, addr)) => Poll::Ready(Ok((size, addr))),
Err(err) => Poll::Ready(Err(err)),
},
}
}
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.udp_socket.recv_from(buf)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.udp_socket.local_addr()
}
}
impl AsyncRead for UdpReadSocket {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let me = self.get_mut();
match me.poll_read_attempt(cx, buf) {
Poll::Pending => Poll::Pending,
Poll::Ready(Ok((size, _))) => Poll::Ready(Ok(size)),
Poll::Ready(Err(err)) => Poll::Ready(Err(err)),
}
}
}
impl Drop for UdpReadSocket {
fn drop(&mut self) {
event_listener()
.stop_listening(&mut self.udp_socket, self.read_token)
.ok();
}
}
pub struct UdpWriteSocket {
udp_socket: MioUdpSocket,
write_token: Token,
write_future: Option<ReadyFuture<()>>,
pub write_timeout: Duration,
}
impl UdpWriteSocket {
pub fn new(udp_socket: MioUdpSocket) -> Self {
UdpWriteSocket {
udp_socket,
write_token: event_listener().next_token(),
write_future: None,
write_timeout: Duration::from_secs(2),
}
}
pub fn set_write_timeout(&mut self, duration: Duration) {
self.write_timeout = duration;
}
fn wait_write_ready(&mut self) -> io::Result<()> {
let future = event_listener().listen_write(
&mut self.udp_socket,
Instant::now() + self.write_timeout,
self.write_token,
)?;
self.write_future = Some(future);
Ok(())
}
fn poll_write_attempt(&mut self, cx: &mut Context<'_>, buf: &[u8]) -> Poll<io::Result<usize>> {
let mut future = match self.write_future.take() {
None => {
match self.udp_socket.send(buf) {
Ok(size) => return Poll::Ready(Ok(size)),
Err(err) if err.kind() == io::ErrorKind::WouldBlock => (),
Err(err) => return Poll::Ready(Err(err)),
}
if let Err(err) = self.wait_write_ready() {
return Poll::Ready(Err(err));
}
self.write_future.take().unwrap()
}
Some(future) => future,
};
match future.poll_unpin(cx) {
Poll::Pending => {
self.write_future = Some(future);
Poll::Pending
}
Poll::Ready(ReadyFutureResult::Timeout) => {
Poll::Ready(Err(io::ErrorKind::TimedOut.into()))
}
Poll::Ready(_) => match self.udp_socket.send(buf) {
Ok(size) => Poll::Ready(Ok(size)),
Err(err) => Poll::Ready(Err(err)),
},
}
}
pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
self.udp_socket.send_to(buf, target)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.udp_socket.local_addr()
}
}
impl AsyncWrite for UdpWriteSocket {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let me = self.get_mut();
me.poll_write_attempt(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
impl Drop for UdpWriteSocket {
fn drop(&mut self) {
event_listener()
.stop_listening(&mut self.udp_socket, self.write_token)
.ok();
}
}
pub struct UdpSocket {
read_socket: UdpReadSocket,
write_socket: UdpWriteSocket,
}
impl UdpSocket {
pub fn from(udp_socket: StdUdpSocket) -> io::Result<UdpSocket> {
udp_socket.set_nonblocking(true)?;
Ok(UdpSocket {
read_socket: UdpReadSocket::new(MioUdpSocket::from_std(udp_socket.try_clone()?)),
write_socket: UdpWriteSocket::new(MioUdpSocket::from_std(udp_socket)),
})
}
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<UdpSocket> {
Self::from(StdUdpSocket::bind(addr)?)
}
pub fn connect<A: ToSocketAddrs>(&self, addr: A) -> io::Result<()> {
for addr in addr.to_socket_addrs()? {
self.read_socket.udp_socket.connect(addr)?;
self.write_socket.udp_socket.connect(addr)?;
break;
}
Ok(())
}
pub fn bind_and_connect<A: ToSocketAddrs, B: ToSocketAddrs>(
addr: A,
to_addr: B,
) -> io::Result<UdpSocket> {
let result = Self::bind(addr)?;
result.connect(to_addr)?;
Ok(result)
}
pub fn read_socket(&self) -> &UdpReadSocket {
&self.read_socket
}
pub fn read_socket_mut(&mut self) -> &mut UdpReadSocket {
&mut self.read_socket
}
pub fn write_socket(&self) -> &UdpWriteSocket {
&self.write_socket
}
pub fn write_socket_mut(&mut self) -> &mut UdpWriteSocket {
&mut self.write_socket
}
pub fn set_read_timeout(&mut self, duration: Duration) {
self.read_socket.set_read_timeout(duration);
}
pub fn set_write_timeout(&mut self, duration: Duration) {
self.write_socket.set_write_timeout(duration);
}
pub fn send_to(&self, buf: &[u8], target: SocketAddr) -> io::Result<usize> {
self.write_socket.send_to(buf, target)
}
pub fn recv_from(&self, buf: &mut [u8]) -> io::Result<(usize, SocketAddr)> {
self.read_socket.recv_from(buf)
}
pub fn local_addr(&self) -> io::Result<SocketAddr> {
self.read_socket.local_addr()
}
pub fn split(self) -> (UdpReadSocket, UdpWriteSocket) {
(self.read_socket, self.write_socket)
}
}
impl AsyncRead for UdpSocket {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
let me = self.get_mut();
Pin::new(&mut me.read_socket).poll_read(cx, buf)
}
}
impl AsyncWrite for UdpSocket {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
let me = self.get_mut();
Pin::new(&mut me.write_socket).poll_write(cx, buf)
}
fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let me = self.get_mut();
Pin::new(&mut me.write_socket).poll_flush(cx)
}
fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
let me = self.get_mut();
Pin::new(&mut me.write_socket).poll_close(cx)
}
}
impl Debug for UdpSocket {
fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
write!(f, "{:?}", self.read_socket.udp_socket)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::timer::timer::Timer;
use futures::executor::block_on;
use std::sync::{Arc, Mutex};
use std::thread;
use std::time::Duration;
fn setup_test_sockets() -> (StdUdpSocket, StdUdpSocket) {
let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
(server, client)
}
#[test]
fn test_udp_wrapper_creation() {
let socket = StdUdpSocket::bind("127.0.0.1:0").unwrap();
let addr = socket.local_addr().unwrap();
let wrapper = UdpSocket::from(socket);
assert!(wrapper.is_ok());
let wrapper = wrapper.unwrap();
assert_eq!(wrapper.local_addr().unwrap(), addr);
}
#[test]
fn test_udp_wrapper_bind() {
let wrapper = UdpSocket::bind("127.0.0.1:0");
assert!(wrapper.is_ok());
let wrapper = wrapper.unwrap();
let addr = wrapper.local_addr().unwrap();
assert!(addr.port() > 0);
assert_eq!(addr.ip().to_string(), "127.0.0.1");
}
#[test]
fn test_udp_wrapper_bind_and_connect() {
let (server, _) = setup_test_sockets();
let server_addr = server.local_addr().unwrap();
let wrapper = UdpSocket::bind_and_connect("127.0.0.1:0", server_addr);
assert!(wrapper.is_ok());
}
#[test]
fn test_timeout_setters() {
let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
let mut wrapper = wrapper;
wrapper.set_read_timeout(Duration::from_secs(30));
wrapper.set_write_timeout(Duration::from_secs(5));
assert_eq!(wrapper.read_socket().read_timeout, Duration::from_secs(30));
assert_eq!(wrapper.write_socket().write_timeout, Duration::from_secs(5));
}
#[test]
fn test_socket_accessors() {
let mut wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
let read_socket = wrapper.read_socket();
assert_eq!(read_socket.read_timeout, Duration::from_secs(20));
let read_socket_mut = wrapper.read_socket_mut();
read_socket_mut.set_read_timeout(Duration::from_secs(15));
assert_eq!(read_socket_mut.read_timeout, Duration::from_secs(15));
let write_socket = wrapper.write_socket();
assert_eq!(write_socket.write_timeout, Duration::from_secs(2));
let write_socket_mut = wrapper.write_socket_mut();
write_socket_mut.set_write_timeout(Duration::from_secs(10));
assert_eq!(write_socket_mut.write_timeout, Duration::from_secs(10));
}
#[test]
fn test_sync_send_recv() {
let (server, client) = setup_test_sockets();
let server_addr = server.local_addr().unwrap();
let client_addr = client.local_addr().unwrap();
let server_wrapper = UdpSocket::from(server).unwrap();
let client_wrapper = UdpSocket::from(client).unwrap();
let test_data = b"Hello UDP!";
let sent = client_wrapper.send_to(test_data, server_addr);
assert!(sent.is_ok());
assert_eq!(sent.unwrap(), test_data.len());
thread::sleep(Duration::from_millis(10));
let mut buf = [0u8; 1024];
let received = server_wrapper.recv_from(&mut buf);
assert!(received.is_ok());
let (size, addr) = received.unwrap();
assert_eq!(size, test_data.len());
assert_eq!(&buf[..size], test_data);
assert_eq!(addr, client_addr);
}
#[test]
fn test_async_read_write() {
let (server, client) = setup_test_sockets();
let server_addr = server.local_addr().unwrap();
thread::spawn(move || {
let mut buf = [0u8; 1024];
if let Ok((size, addr)) = server.recv_from(&mut buf) {
let _ = server.send_to(&buf[..size], addr);
}
});
thread::sleep(Duration::from_millis(10));
let test_future = async {
let wrapper = UdpSocket::from(client).unwrap();
let test_data = b"Async UDP test!";
let sent = wrapper.send_to(test_data, server_addr);
assert!(sent.is_ok());
let mut buf = [0u8; 1024];
let read_result = wrapper.recv_from(&mut buf);
if let Ok((size, addr)) = read_result {
assert_eq!(size, test_data.len());
assert_eq!(&buf[..size], test_data);
assert_eq!(addr, server_addr);
}
};
block_on(test_future);
}
#[test]
fn test_async_with_timer() {
let mut timer = Timer::new();
let (server, client) = setup_test_sockets();
let server_addr = server.local_addr().unwrap();
thread::spawn(move || {
let mut buf = [0u8; 1024];
if let Ok((size, addr)) = server.recv_from(&mut buf) {
thread::sleep(Duration::from_millis(50));
let _ = server.send_to(&buf[..size], addr);
}
});
let test_future = async {
let wrapper = UdpSocket::from(client).unwrap();
timer.wait(Duration::from_millis(20)).await;
let test_data = b"Delayed UDP!";
let sent = wrapper.send_to(test_data, server_addr);
assert!(sent.is_ok());
};
block_on(test_future);
}
#[test]
fn test_concurrent_operations() {
let server = StdUdpSocket::bind("127.0.0.1:0").unwrap();
let server_addr = server.local_addr().unwrap();
let response_count = Arc::new(Mutex::new(0));
let response_count_clone = response_count.clone();
thread::spawn(move || {
let mut buf = [0u8; 1024];
for _ in 0..3 {
if let Ok((size, addr)) = server.recv_from(&mut buf) {
let _ = server.send_to(&buf[..size], addr);
let mut count = response_count_clone.lock().unwrap();
*count += 1;
}
}
});
thread::sleep(Duration::from_millis(10));
let test_future = async {
let mut futures = Vec::new();
for i in 0..3 {
let test_data = format!("Message {}", i);
let future = async move {
let client = StdUdpSocket::bind("127.0.0.1:0").unwrap();
let wrapper = UdpSocket::from(client).unwrap();
let sent = wrapper.send_to(test_data.as_bytes(), server_addr);
assert!(sent.is_ok());
};
futures.push(future);
}
futures::future::join_all(futures).await;
};
block_on(test_future);
thread::sleep(Duration::from_millis(100));
let count = response_count.lock().unwrap();
assert_eq!(*count, 3);
}
#[test]
fn test_timeout_behavior() {
let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
let mut wrapper = wrapper;
wrapper.set_read_timeout(Duration::from_millis(50));
let test_future = async {
let mut buf = [0u8; 1024];
let result = wrapper.recv_from(&mut buf);
match result {
Ok(_) => {
panic!("Unexpected data received");
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
}
Err(e) => {
panic!("Unexpected error: {:?}", e);
}
}
};
block_on(test_future);
}
#[test]
fn test_multiple_sends_to_different_addresses() {
let server1 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
let server2 = StdUdpSocket::bind("127.0.0.1:0").unwrap();
let server1_addr = server1.local_addr().unwrap();
let server2_addr = server2.local_addr().unwrap();
let (_, client) = setup_test_sockets();
let wrapper = UdpSocket::from(client).unwrap();
let data1 = b"Hello Server 1";
let sent1 = wrapper.send_to(data1, server1_addr);
assert!(sent1.is_ok());
assert_eq!(sent1.unwrap(), data1.len());
let data2 = b"Hello Server 2";
let sent2 = wrapper.send_to(data2, server2_addr);
assert!(sent2.is_ok());
assert_eq!(sent2.unwrap(), data2.len());
thread::sleep(Duration::from_millis(10));
let mut buf1 = [0u8; 1024];
let received1 = server1.recv_from(&mut buf1);
assert!(received1.is_ok());
let (size1, _) = received1.unwrap();
assert_eq!(&buf1[..size1], data1);
let mut buf2 = [0u8; 1024];
let received2 = server2.recv_from(&mut buf2);
assert!(received2.is_ok());
let (size2, _) = received2.unwrap();
assert_eq!(&buf2[..size2], data2);
}
#[test]
fn test_large_data_transmission() {
let (server, client) = setup_test_sockets();
let server_addr = server.local_addr().unwrap();
thread::spawn(move || {
let mut buf = [0u8; 2048];
if let Ok((size, addr)) = server.recv_from(&mut buf) {
let _ = server.send_to(&buf[..size], addr);
}
});
thread::sleep(Duration::from_millis(10));
let wrapper = UdpSocket::from(client).unwrap();
let large_data = vec![0xAB; 1400];
let sent = wrapper.send_to(&large_data, server_addr);
assert!(sent.is_ok());
assert_eq!(sent.unwrap(), large_data.len());
thread::sleep(Duration::from_millis(20));
let mut buf = [0u8; 2048];
let received = wrapper.recv_from(&mut buf);
assert!(received.is_ok());
let (size, addr) = received.unwrap();
assert_eq!(size, large_data.len());
assert_eq!(&buf[..size], &large_data[..]);
assert_eq!(addr, server_addr);
}
#[test]
fn test_drop_behavior() {
let wrapper = UdpSocket::bind("127.0.0.1:0").unwrap();
let addr = wrapper.local_addr().unwrap();
drop(wrapper);
thread::sleep(Duration::from_millis(10));
let new_wrapper = UdpSocket::bind(addr);
assert!(new_wrapper.is_ok());
}
#[test]
fn test_split_sockets_independently() {
let (server, client) = setup_test_sockets();
let server_addr = server.local_addr().unwrap();
thread::spawn(move || {
let mut buf = [0u8; 1024];
if let Ok((size, addr)) = server.recv_from(&mut buf) {
let _ = server.send_to(&buf[..size], addr);
}
});
thread::sleep(Duration::from_millis(10));
let test_future = async {
let wrapper = UdpSocket::from(client).unwrap();
let (read_socket, write_socket) = wrapper.split();
let test_data = b"Split socket test";
write_socket.send_to(test_data, server_addr).unwrap();
let mut buf = [0u8; 1024];
let received = read_socket.recv_from(&mut buf);
match received {
Ok((size, addr)) => {
assert_eq!(&buf[..size], test_data);
assert_eq!(addr, server_addr);
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {}
Err(e) => panic!("Unexpected error: {:?}", e),
}
};
block_on(test_future);
}
#[test]
fn test_connected_socket_operations() {
let (server, client) = setup_test_sockets();
let server_addr = server.local_addr().unwrap();
let wrapper = UdpSocket::from(client).unwrap();
let test_data = b"Connected test";
let result = wrapper.send_to(test_data, server_addr);
assert!(result.is_ok());
}
}