use std::net;
use std::sync::Arc;
use std::time;
const SOCKET_POLLING_KEY: usize = 0;
pub struct SocketTx {
socket: Arc<net::UdpSocket>,
}
pub struct SocketRx {
socket: Arc<net::UdpSocket>,
local_addr: net::SocketAddr,
poller: polling::Poller,
poller_events: polling::Events,
recv_buffer: Box<[u8]>,
}
pub struct ConnectedSocketTx {
socket: Arc<net::UdpSocket>,
}
pub struct ConnectedSocketRx {
socket: Arc<net::UdpSocket>,
local_addr: net::SocketAddr,
peer_addr: net::SocketAddr,
poller: polling::Poller,
poller_events: polling::Events,
recv_buffer: Box<[u8]>,
}
impl SocketTx {
pub fn send(&self, frame: &[u8], addr: &net::SocketAddr) {
let _ = self.socket.send_to(frame, addr);
}
}
impl SocketRx {
pub fn try_read_frame(
&mut self,
) -> std::io::Result<Option<(&[u8], net::SocketAddr)>> {
match self.socket.recv_from(&mut self.recv_buffer) {
Ok((frame_len, sender_addr)) => {
let frame_bytes = &self.recv_buffer[..frame_len];
Ok(Some((frame_bytes, sender_addr)))
}
Err(err) => match err.kind() {
std::io::ErrorKind::WouldBlock => Ok(None),
_ => Err(err),
},
}
}
pub fn wait_for_frame(
&mut self,
timeout: Option<time::Duration>,
) -> std::io::Result<Option<(&[u8], net::SocketAddr)>> {
self.poller
.modify(&*self.socket, polling::Event::readable(SOCKET_POLLING_KEY))?;
self.poller_events.clear();
let n = self.poller.wait(&mut self.poller_events, timeout)?;
if n > 0 {
self.try_read_frame()
} else {
Ok(None)
}
}
pub fn local_addr(&self) -> net::SocketAddr {
self.local_addr
}
}
pub fn new<A>(bind_address: A, frame_size_max: usize) -> std::io::Result<(SocketTx, SocketRx)>
where
A: net::ToSocketAddrs,
{
let socket = net::UdpSocket::bind(bind_address)?;
socket.set_nonblocking(true)?;
let local_addr = socket.local_addr()?;
let poller = polling::Poller::new()?;
unsafe {
poller.add(&socket, polling::Event::readable(SOCKET_POLLING_KEY))?;
}
let socket_rc = Arc::new(socket);
let tx = SocketTx {
socket: Arc::clone(&socket_rc),
};
let rx = SocketRx {
socket: socket_rc,
local_addr,
poller,
poller_events: polling::Events::new(),
recv_buffer: vec![0; frame_size_max].into_boxed_slice(),
};
Ok((tx, rx))
}
impl ConnectedSocketTx {
pub fn send(&self, frame: &[u8]) {
let _ = self.socket.send(frame);
}
}
impl ConnectedSocketRx {
pub fn try_read_frame(&mut self) -> std::io::Result<Option<&[u8]>> {
match self.socket.recv(&mut self.recv_buffer) {
Ok(frame_len) => {
let frame_bytes = &self.recv_buffer[..frame_len];
Ok(Some(frame_bytes))
}
Err(err) => match err.kind() {
std::io::ErrorKind::WouldBlock => Ok(None),
_ => Err(err),
},
}
}
pub fn wait_for_frame(
&mut self,
timeout: Option<time::Duration>,
) -> std::io::Result<Option<&[u8]>> {
self.poller
.modify(&*self.socket, polling::Event::readable(SOCKET_POLLING_KEY))?;
self.poller_events.clear();
let n = self.poller.wait(&mut self.poller_events, timeout)?;
if n > 0 {
self.try_read_frame()
} else {
Ok(None)
}
}
pub fn local_addr(&self) -> net::SocketAddr {
self.local_addr
}
pub fn peer_addr(&self) -> net::SocketAddr {
self.peer_addr
}
}
pub fn new_connected<A, B>(
bind_address: A,
connect_address: B,
frame_size_max: usize,
) -> std::io::Result<(ConnectedSocketTx, ConnectedSocketRx)>
where
A: net::ToSocketAddrs,
B: net::ToSocketAddrs,
{
let socket = net::UdpSocket::bind(bind_address)?;
socket.set_nonblocking(true)?;
socket.connect(connect_address)?;
let local_addr = socket.local_addr()?;
let peer_addr = socket.peer_addr()?;
let poller = polling::Poller::new()?;
unsafe {
poller.add(&socket, polling::Event::readable(SOCKET_POLLING_KEY))?;
}
let socket_rc = Arc::new(socket);
let tx = ConnectedSocketTx {
socket: Arc::clone(&socket_rc),
};
let rx = ConnectedSocketRx {
socket: socket_rc,
local_addr,
peer_addr,
poller,
poller_events: polling::Events::new(),
recv_buffer: vec![0; frame_size_max].into_boxed_slice(),
};
Ok((tx, rx))
}