use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
use std::net::SocketAddr;
use std::sync::Arc;
use tokio::net::{ToSocketAddrs, UdpSocket};
use tokio::runtime::Handle;
use tokio::sync::{mpsc, watch};
use tokio::task::JoinHandle;
pub const DEFAULT_MAX_DATAGRAM_SIZE: usize = 65_536;
pub const DEFAULT_RECEIVE_BUFFER: usize = 64;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct Datagram {
pub payload: Vec<u8>,
pub remote: SocketAddr,
}
impl Datagram {
#[must_use]
pub fn new(payload: impl Into<Vec<u8>>, remote: SocketAddr) -> Self {
Self {
payload: payload.into(),
remote,
}
}
#[must_use]
pub fn payload(&self) -> &[u8] {
&self.payload
}
#[must_use]
pub fn remote(&self) -> SocketAddr {
self.remote
}
#[must_use]
pub fn into_parts(self) -> (Vec<u8>, SocketAddr) {
(self.payload, self.remote)
}
#[must_use]
pub fn into_payload(self) -> Vec<u8> {
self.payload
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UdpBinding {
pub local_addr: SocketAddr,
}
impl UdpBinding {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct UdpConnection {
pub local_addr: SocketAddr,
pub remote_addr: SocketAddr,
}
impl UdpConnection {
#[must_use]
pub fn local_addr(&self) -> SocketAddr {
self.local_addr
}
#[must_use]
pub fn remote_addr(&self) -> SocketAddr {
self.remote_addr
}
}
pub struct TokioUdp;
pub type Udp = TokioUdp;
enum ReceiveResponse<T> {
Item(T),
Error(StreamError),
}
enum QueueOutcome {
Queued,
Dropped,
Closed,
}
struct ReceiveResource<T> {
receiver: mpsc::Receiver<ReceiveResponse<T>>,
cancel: watch::Sender<bool>,
task: JoinHandle<()>,
}
impl<T> Drop for ReceiveResource<T> {
fn drop(&mut self) {
let _ = self.cancel.send(true);
self.task.abort();
}
}
struct SendResource {
socket: Arc<UdpSocket>,
handle: Handle,
}
fn io_error(error: std::io::Error) -> StreamError {
StreamError::Failed(error.to_string())
}
fn abrupt_termination() -> StreamError {
StreamError::AbruptTermination
}
impl TokioUdp {
#[must_use]
pub fn bind<A>(
addr: A,
max_datagram_size: usize,
receive_buffer: usize,
) -> Source<Datagram, StreamCompletion<UdpBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(
max_datagram_size > 0,
"maximum datagram size must be greater than zero"
);
assert!(
receive_buffer > 0,
"receive buffer must be greater than zero"
);
Source::lazy_future_source(move || {
let addr = addr.clone();
async move {
let handle = Handle::current();
let socket = UdpSocket::bind(addr).await.map_err(io_error)?;
let local_addr = socket.local_addr().map_err(io_error)?;
Ok(datagram_source_from_socket(
Arc::new(socket),
local_addr,
handle,
max_datagram_size,
receive_buffer,
))
}
})
}
#[must_use]
pub fn bind_default<A>(addr: A) -> Source<Datagram, StreamCompletion<UdpBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::bind(addr, DEFAULT_MAX_DATAGRAM_SIZE, DEFAULT_RECEIVE_BUFFER)
}
#[must_use]
pub fn send_sink<A>(local_addr: A) -> Sink<Datagram, StreamCompletion<NotUsed>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Flow::<Datagram, NotUsed>::future_flow(move || {
let local_addr = local_addr.clone();
async move {
let handle = Handle::current();
let socket = UdpSocket::bind(local_addr).await.map_err(io_error)?;
Ok(datagram_send_flow_from_socket(Arc::new(socket), handle))
}
})
.to_mat(Sink::ignore(), Keep::right)
}
#[must_use]
pub fn bind_flow<A>(
addr: A,
max_datagram_size: usize,
receive_buffer: usize,
) -> Flow<Datagram, Datagram, StreamCompletion<UdpBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(
max_datagram_size > 0,
"maximum datagram size must be greater than zero"
);
assert!(
receive_buffer > 0,
"receive buffer must be greater than zero"
);
Flow::<Datagram, Datagram>::future_flow(move || {
let addr = addr.clone();
async move {
let handle = Handle::current();
let socket = Arc::new(UdpSocket::bind(addr).await.map_err(io_error)?);
let local_addr = socket.local_addr().map_err(io_error)?;
let sink = datagram_send_flow_from_socket(Arc::clone(&socket), handle.clone())
.to_mat(Sink::ignore(), Keep::right);
let source = datagram_source_from_socket(
Arc::clone(&socket),
local_addr,
handle,
max_datagram_size,
receive_buffer,
);
Ok(Flow::from_sink_and_source(sink, source)
.map_materialized_value(move |_| UdpBinding { local_addr }))
}
})
}
#[must_use]
pub fn bind_flow_default<A>(addr: A) -> Flow<Datagram, Datagram, StreamCompletion<UdpBinding>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::bind_flow(addr, DEFAULT_MAX_DATAGRAM_SIZE, DEFAULT_RECEIVE_BUFFER)
}
#[must_use]
pub fn connect<A, P>(
local_addr: A,
peer: P,
max_datagram_size: usize,
receive_buffer: usize,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<UdpConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
P: ToSocketAddrs + Clone + Send + Sync + 'static,
{
assert!(
max_datagram_size > 0,
"maximum datagram size must be greater than zero"
);
assert!(
receive_buffer > 0,
"receive buffer must be greater than zero"
);
Flow::<Vec<u8>, Vec<u8>>::future_flow(move || {
let local_addr = local_addr.clone();
let peer = peer.clone();
async move {
let handle = Handle::current();
let socket = UdpSocket::bind(local_addr).await.map_err(io_error)?;
socket.connect(peer).await.map_err(io_error)?;
let connection = UdpConnection {
local_addr: socket.local_addr().map_err(io_error)?,
remote_addr: socket.peer_addr().map_err(io_error)?,
};
let socket = Arc::new(socket);
let sink = connected_send_flow_from_socket(Arc::clone(&socket), handle.clone())
.to_mat(Sink::ignore(), Keep::right);
let source = connected_source_from_socket(
Arc::clone(&socket),
handle,
max_datagram_size,
receive_buffer,
);
Ok(Flow::from_sink_and_source(sink, source)
.map_materialized_value(move |_| connection))
}
})
}
#[must_use]
pub fn connect_default<A, P>(
local_addr: A,
peer: P,
) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<UdpConnection>>
where
A: ToSocketAddrs + Clone + Send + Sync + 'static,
P: ToSocketAddrs + Clone + Send + Sync + 'static,
{
Self::connect(
local_addr,
peer,
DEFAULT_MAX_DATAGRAM_SIZE,
DEFAULT_RECEIVE_BUFFER,
)
}
}
fn datagram_source_from_socket(
socket: Arc<UdpSocket>,
local_addr: SocketAddr,
handle: Handle,
max_datagram_size: usize,
receive_buffer: usize,
) -> Source<Datagram, UdpBinding> {
Source::unfold_resource(
move || {
let (sender, receiver) = mpsc::channel(receive_buffer);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let task = handle.spawn(run_datagram_receive_task(
Arc::clone(&socket),
max_datagram_size,
sender,
cancel_receiver,
));
Ok(ReceiveResource {
receiver,
cancel: cancel_sender,
task,
})
},
receive_next_item,
close_receive_resource,
)
.map_materialized_value(move |_| UdpBinding { local_addr })
}
fn connected_source_from_socket(
socket: Arc<UdpSocket>,
handle: Handle,
max_datagram_size: usize,
receive_buffer: usize,
) -> Source<Vec<u8>, NotUsed> {
Source::unfold_resource(
move || {
let (sender, receiver) = mpsc::channel(receive_buffer);
let (cancel_sender, cancel_receiver) = watch::channel(false);
let task = handle.spawn(run_connected_receive_task(
Arc::clone(&socket),
max_datagram_size,
sender,
cancel_receiver,
));
Ok(ReceiveResource {
receiver,
cancel: cancel_sender,
task,
})
},
receive_next_item,
close_receive_resource,
)
}
fn receive_next_item<T>(resource: &mut ReceiveResource<T>) -> StreamResult<Option<T>>
where
T: Send + 'static,
{
match resource.receiver.blocking_recv() {
Some(ReceiveResponse::Item(item)) => Ok(Some(item)),
Some(ReceiveResponse::Error(error)) => Err(error),
None => Err(abrupt_termination()),
}
}
fn close_receive_resource<T>(resource: ReceiveResource<T>) -> StreamResult<()>
where
T: Send + 'static,
{
let _ = resource.cancel.send(true);
resource.task.abort();
Ok(())
}
async fn run_datagram_receive_task(
socket: Arc<UdpSocket>,
max_datagram_size: usize,
sender: mpsc::Sender<ReceiveResponse<Datagram>>,
mut cancel: watch::Receiver<bool>,
) {
let mut buffer = vec![0_u8; max_datagram_size];
loop {
let received = tokio::select! {
received = socket.recv_from(&mut buffer) => received,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
match received {
Ok((read, remote)) => {
let datagram = Datagram::new(buffer[..read].to_vec(), remote);
match try_send_received_item(&sender, datagram) {
QueueOutcome::Queued => {}
QueueOutcome::Dropped => {
if let Err(error) = drain_ready_datagrams(&socket, &mut buffer) {
let _ = send_receive_error(&sender, error, &mut cancel).await;
return;
}
}
QueueOutcome::Closed => return,
}
}
Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
Err(error) => {
let _ = send_receive_error(&sender, io_error(error), &mut cancel).await;
return;
}
}
}
}
async fn run_connected_receive_task(
socket: Arc<UdpSocket>,
max_datagram_size: usize,
sender: mpsc::Sender<ReceiveResponse<Vec<u8>>>,
mut cancel: watch::Receiver<bool>,
) {
let mut buffer = vec![0_u8; max_datagram_size];
loop {
let received = tokio::select! {
received = socket.recv(&mut buffer) => received,
changed = cancel.changed() => {
let _ = changed;
return;
}
};
match received {
Ok(read) => match try_send_received_item(&sender, buffer[..read].to_vec()) {
QueueOutcome::Queued => {}
QueueOutcome::Dropped => {
if let Err(error) = drain_ready_connected_datagrams(&socket, &mut buffer) {
let _ = send_receive_error(&sender, error, &mut cancel).await;
return;
}
}
QueueOutcome::Closed => return,
},
Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
Err(error) => {
let _ = send_receive_error(&sender, io_error(error), &mut cancel).await;
return;
}
}
}
}
fn try_send_received_item<T>(sender: &mpsc::Sender<ReceiveResponse<T>>, item: T) -> QueueOutcome
where
T: Send + 'static,
{
match sender.try_send(ReceiveResponse::Item(item)) {
Ok(()) => QueueOutcome::Queued,
Err(mpsc::error::TrySendError::Full(_)) => QueueOutcome::Dropped,
Err(mpsc::error::TrySendError::Closed(_)) => QueueOutcome::Closed,
}
}
fn drain_ready_datagrams(socket: &UdpSocket, buffer: &mut [u8]) -> StreamResult<()> {
loop {
match socket.try_recv_from(buffer) {
Ok((_read, _remote)) => {}
Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
Err(error) => return Err(io_error(error)),
}
}
}
fn drain_ready_connected_datagrams(socket: &UdpSocket, buffer: &mut [u8]) -> StreamResult<()> {
loop {
match socket.try_recv(buffer) {
Ok(_read) => {}
Err(error) if error.kind() == std::io::ErrorKind::WouldBlock => return Ok(()),
Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
Err(error) => return Err(io_error(error)),
}
}
}
async fn send_receive_error<T>(
sender: &mpsc::Sender<ReceiveResponse<T>>,
error: StreamError,
cancel: &mut watch::Receiver<bool>,
) -> bool
where
T: Send + 'static,
{
tokio::select! {
result = sender.send(ReceiveResponse::Error(error)) => result.is_ok(),
changed = cancel.changed() => {
let _ = changed;
false
}
}
}
fn datagram_send_flow_from_socket(
socket: Arc<UdpSocket>,
handle: Handle,
) -> Flow<Datagram, NotUsed, NotUsed> {
Flow::<Datagram, Datagram>::identity().map_with_resource(
move || {
Ok(SendResource {
socket: Arc::clone(&socket),
handle: handle.clone(),
})
},
|resource, datagram| {
send_datagram(resource, datagram)?;
Ok(NotUsed)
},
|_resource| Ok(None),
)
}
fn connected_send_flow_from_socket(
socket: Arc<UdpSocket>,
handle: Handle,
) -> Flow<Vec<u8>, NotUsed, NotUsed> {
Flow::<Vec<u8>, Vec<u8>>::identity().map_with_resource(
move || {
Ok(SendResource {
socket: Arc::clone(&socket),
handle: handle.clone(),
})
},
|resource, payload| {
send_connected_payload(resource, payload)?;
Ok(NotUsed)
},
|_resource| Ok(None),
)
}
fn send_datagram(resource: &SendResource, datagram: Datagram) -> StreamResult<()> {
let expected = datagram.payload.len();
let sent = resource.handle.block_on(async {
resource
.socket
.send_to(&datagram.payload, datagram.remote)
.await
.map_err(io_error)
})?;
if sent == expected {
Ok(())
} else {
Err(short_send_error(sent, expected))
}
}
fn send_connected_payload(resource: &SendResource, payload: Vec<u8>) -> StreamResult<()> {
let expected = payload.len();
let sent = resource
.handle
.block_on(async { resource.socket.send(&payload).await.map_err(io_error) })?;
if sent == expected {
Ok(())
} else {
Err(short_send_error(sent, expected))
}
}
fn short_send_error(sent: usize, expected: usize) -> StreamError {
StreamError::Failed(format!(
"UDP socket sent {sent} bytes from {expected}-byte datagram"
))
}