uniudp 1.0.0

Unidirectional UDP transport with chunking, redundancy, and Reed-Solomon FEC.
Documentation
use std::io;
use std::net::{SocketAddr, UdpSocket};
use std::time::{Duration, Instant};

#[cfg(unix)]
use mio::unix::SourceFd;
#[cfg(windows)]
use mio::windows::SourceSocket;
use mio::{Events, Interest, Poll, Token};
#[cfg(unix)]
use std::os::fd::{AsRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, RawSocket};

use crate::error::{Result, UniUdpError};
#[cfg(feature = "tokio")]
use tokio::net::UdpSocket as TokioUdpSocket;
#[cfg(feature = "tokio")]
use tokio::time::timeout as tokio_timeout;

const SOCKET_TOKEN: Token = Token(0);

pub(super) struct SocketReadTimeoutGuard<'a> {
    socket: &'a UdpSocket,
    previous: Option<Duration>,
}

impl<'a> SocketReadTimeoutGuard<'a> {
    pub(super) fn capture(socket: &'a UdpSocket) -> Result<Self> {
        Ok(Self {
            socket,
            previous: socket.read_timeout()?,
        })
    }
}

impl Drop for SocketReadTimeoutGuard<'_> {
    fn drop(&mut self) {
        let _ = self.socket.set_read_timeout(self.previous);
    }
}

#[cfg(any(unix, windows))]
pub(super) struct SocketReadinessWaiter {
    poll: Poll,
    events: Events,
    #[cfg(unix)]
    raw_fd: RawFd,
    #[cfg(windows)]
    raw_socket: RawSocket,
}

#[cfg(not(any(unix, windows)))]
pub(super) struct SocketReadinessWaiter;

#[cfg(unix)]
impl SocketReadinessWaiter {
    pub(super) fn new(socket: &UdpSocket) -> Result<Self> {
        let poll = Poll::new()?;
        let raw_fd = socket.as_raw_fd();
        let mut source = SourceFd(&raw_fd);
        poll.registry()
            .register(&mut source, SOCKET_TOKEN, Interest::READABLE)?;
        Ok(Self {
            poll,
            events: Events::with_capacity(4),
            raw_fd,
        })
    }

    fn wait_until_readable(&mut self, timeout: Duration) -> io::Result<bool> {
        let mut source = SourceFd(&self.raw_fd);
        self.poll
            .registry()
            .reregister(&mut source, SOCKET_TOKEN, Interest::READABLE)?;
        poll_until_readable(&mut self.poll, &mut self.events, timeout)
    }
}

#[cfg(unix)]
impl Drop for SocketReadinessWaiter {
    fn drop(&mut self) {
        let mut source = SourceFd(&self.raw_fd);
        let _ = self.poll.registry().deregister(&mut source);
    }
}

#[cfg(windows)]
impl SocketReadinessWaiter {
    pub(super) fn new(socket: &UdpSocket) -> Result<Self> {
        let poll = Poll::new()?;
        let raw_socket = socket.as_raw_socket();
        let mut source = SourceSocket(&raw_socket);
        poll.registry()
            .register(&mut source, SOCKET_TOKEN, Interest::READABLE)?;
        Ok(Self {
            poll,
            events: Events::with_capacity(4),
            raw_socket,
        })
    }

    fn wait_until_readable(&mut self, timeout: Duration) -> io::Result<bool> {
        let mut source = SourceSocket(&self.raw_socket);
        self.poll
            .registry()
            .reregister(&mut source, SOCKET_TOKEN, Interest::READABLE)?;
        poll_until_readable(&mut self.poll, &mut self.events, timeout)
    }
}

#[cfg(windows)]
impl Drop for SocketReadinessWaiter {
    fn drop(&mut self) {
        let mut source = SourceSocket(&self.raw_socket);
        let _ = self.poll.registry().deregister(&mut source);
    }
}

#[cfg(not(any(unix, windows)))]
impl SocketReadinessWaiter {
    pub(super) fn new(_socket: &UdpSocket) -> Result<Self> {
        Ok(Self)
    }
}

#[cfg(any(unix, windows))]
pub(super) fn recv_from_timeout(
    socket: &UdpSocket,
    timeout: Duration,
    buffer: &mut [u8],
    readiness: &mut SocketReadinessWaiter,
) -> Result<Option<(SocketAddr, usize)>> {
    if timeout.is_zero() {
        return Ok(None);
    }

    let start = Instant::now();
    loop {
        let elapsed = start.elapsed();
        if elapsed >= timeout {
            return Ok(None);
        }
        let remaining = timeout - elapsed;
        if !readiness.wait_until_readable(remaining)? {
            return Ok(None);
        }

        match socket.recv_from(buffer) {
            Ok((len, source)) => return Ok(Some((source, len))),
            Err(err)
                if matches!(
                    err.kind(),
                    io::ErrorKind::WouldBlock
                        | io::ErrorKind::TimedOut
                        | io::ErrorKind::Interrupted
                ) =>
            {
                continue;
            }
            Err(err) => return Err(UniUdpError::Io(err)),
        }
    }
}

#[cfg(not(any(unix, windows)))]
pub(super) fn recv_from_timeout(
    socket: &UdpSocket,
    timeout: Duration,
    buffer: &mut [u8],
    _readiness: &mut SocketReadinessWaiter,
) -> Result<Option<(SocketAddr, usize)>> {
    if timeout.is_zero() {
        return Ok(None);
    }

    let start = Instant::now();
    loop {
        let elapsed = start.elapsed();
        if elapsed >= timeout {
            return Ok(None);
        }
        let remaining = timeout - elapsed;
        socket.set_read_timeout(Some(remaining))?;

        match socket.recv_from(buffer) {
            Ok((len, source)) => return Ok(Some((source, len))),
            Err(err)
                if matches!(
                    err.kind(),
                    io::ErrorKind::WouldBlock
                        | io::ErrorKind::TimedOut
                        | io::ErrorKind::Interrupted
                ) =>
            {
                continue;
            }
            Err(err) => return Err(UniUdpError::Io(err)),
        }
    }
}

#[cfg(feature = "tokio")]
pub(super) async fn recv_from_timeout_async(
    socket: &TokioUdpSocket,
    timeout: Duration,
    buffer: &mut [u8],
) -> Result<Option<(SocketAddr, usize)>> {
    if timeout.is_zero() {
        return Ok(None);
    }

    let start = Instant::now();
    loop {
        let elapsed = start.elapsed();
        if elapsed >= timeout {
            return Ok(None);
        }
        let remaining = timeout - elapsed;

        match tokio_timeout(remaining, socket.recv_from(buffer)).await {
            Ok(Ok((len, source))) => return Ok(Some((source, len))),
            Ok(Err(err)) if err.kind() == io::ErrorKind::WouldBlock => continue,
            Ok(Err(err)) => return Err(UniUdpError::Io(err)),
            Err(_) => return Ok(None),
        }
    }
}

fn poll_until_readable(
    poll: &mut Poll,
    events: &mut Events,
    timeout: Duration,
) -> io::Result<bool> {
    if timeout.is_zero() {
        return Ok(false);
    }

    let start = Instant::now();
    loop {
        let elapsed = start.elapsed();
        if elapsed >= timeout {
            return Ok(false);
        }
        let remaining = timeout - elapsed;

        events.clear();
        match poll.poll(events, Some(remaining)) {
            Ok(()) => {
                if events.is_empty() {
                    return Ok(false);
                }
                if events
                    .iter()
                    .any(|event| event.token() == SOCKET_TOKEN && event.is_readable())
                {
                    return Ok(true);
                }
            }
            Err(err) if err.kind() == io::ErrorKind::Interrupted => continue,
            Err(err) => return Err(err),
        }
    }
}