llam 0.1.3

Safe, Go-style Rust bindings for the LLAM runtime
use crate::io;
use crate::sys;
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use std::io::{Read, Result, Write};
use std::net::{
    SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream, ToSocketAddrs,
    UdpSocket as StdUdpSocket,
};

#[cfg(unix)]
use std::os::fd::{AsRawFd, FromRawFd};
#[cfg(unix)]
use std::os::unix::net::{UnixListener as StdUnixListener, UnixStream as StdUnixStream};
#[cfg(windows)]
use std::os::windows::io::{AsRawSocket, FromRawSocket, RawSocket};
#[cfg(unix)]
use std::path::Path;

pub struct TcpListener {
    inner: StdTcpListener,
}

pub struct TcpStream {
    inner: StdTcpStream,
}

pub struct UdpSocket {
    inner: StdUdpSocket,
}

#[cfg(unix)]
pub struct UnixListener {
    inner: StdUnixListener,
}

#[cfg(unix)]
pub struct UnixStream {
    inner: StdUnixStream,
}

impl TcpListener {
    pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
        let inner = StdTcpListener::bind(addr)?;
        inner.set_nonblocking(true)?;
        Ok(Self { inner })
    }

    pub fn accept(&self) -> Result<(TcpStream, Option<SocketAddr>)> {
        let fd = unsafe {
            sys::llam_accept(
                raw_listener_fd(&self.inner),
                std::ptr::null_mut(),
                std::ptr::null_mut(),
            )
        };
        if sys::fd_is_invalid(fd) {
            return Err(crate::Error::last().into());
        }
        let stream = unsafe { stream_from_fd(fd) };
        stream.set_nonblocking(true)?;
        let peer = stream.peer_addr().ok();
        Ok((TcpStream { inner: stream }, peer))
    }

    pub fn local_addr(&self) -> Result<SocketAddr> {
        self.inner.local_addr()
    }

    pub fn into_std(self) -> StdTcpListener {
        self.inner
    }
}

impl TcpStream {
    pub fn connect<A: ToSocketAddrs>(addr: A) -> Result<Self> {
        if unsafe { sys::llam_current_task() }.is_null() {
            let stream = StdTcpStream::connect(addr)?;
            stream.set_nonblocking(true)?;
            return Ok(Self { inner: stream });
        }

        let mut last_error = None;
        for addr in addr.to_socket_addrs()? {
            match connect_one(addr) {
                Ok(stream) => return Ok(stream),
                Err(error) => last_error = Some(error),
            }
        }
        Err(last_error.unwrap_or_else(|| {
            std::io::Error::new(
                std::io::ErrorKind::InvalidInput,
                "no socket address resolved",
            )
        }))
    }

    pub fn into_std(self) -> StdTcpStream {
        self.inner
    }

    pub fn peer_addr(&self) -> Result<SocketAddr> {
        self.inner.peer_addr()
    }

    pub fn local_addr(&self) -> Result<SocketAddr> {
        self.inner.local_addr()
    }

    pub fn try_clone(&self) -> Result<Self> {
        Ok(Self {
            inner: self.inner.try_clone()?,
        })
    }
}

fn connect_one(addr: SocketAddr) -> Result<TcpStream> {
    let domain = if addr.is_ipv4() {
        Domain::IPV4
    } else {
        Domain::IPV6
    };
    let socket = Socket::new(domain, Type::STREAM, Some(Protocol::TCP))?;
    socket.set_nonblocking(true)?;

    let sockaddr = SockAddr::from(addr);
    let rc = unsafe {
        sys::llam_connect(
            raw_socket_fd(&socket),
            sockaddr.as_ptr().cast::<libc::sockaddr>(),
            sockaddr.len(),
        )
    };
    if rc != 0 {
        return Err(crate::Error::last().into());
    }

    let stream: StdTcpStream = socket.into();
    stream.set_nonblocking(true)?;
    Ok(TcpStream { inner: stream })
}

impl Read for TcpStream {
    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
        io::read(raw_stream_fd(&self.inner), buf)
    }
}

impl Write for TcpStream {
    fn write(&mut self, buf: &[u8]) -> Result<usize> {
        io::write(raw_stream_fd(&self.inner), buf)
    }

    fn flush(&mut self) -> Result<()> {
        Ok(())
    }
}

impl UdpSocket {
    pub fn bind<A: ToSocketAddrs>(addr: A) -> Result<Self> {
        let inner = StdUdpSocket::bind(addr)?;
        inner.set_nonblocking(true)?;
        Ok(Self { inner })
    }

    pub fn connect<A: ToSocketAddrs>(&self, addr: A) -> Result<()> {
        self.inner.connect(addr)
    }

    pub fn send(&self, buf: &[u8]) -> Result<usize> {
        io::write(raw_udp_fd(&self.inner), buf)
    }

    pub fn recv(&self, buf: &mut [u8]) -> Result<usize> {
        io::read(raw_udp_fd(&self.inner), buf)
    }

    pub fn local_addr(&self) -> Result<SocketAddr> {
        self.inner.local_addr()
    }

    pub fn peer_addr(&self) -> Result<SocketAddr> {
        self.inner.peer_addr()
    }

    pub fn into_std(self) -> StdUdpSocket {
        self.inner
    }
}

#[cfg(unix)]
impl UnixListener {
    pub fn bind<P: AsRef<Path>>(path: P) -> Result<Self> {
        Ok(Self {
            inner: StdUnixListener::bind(path)?,
        })
    }

    pub fn accept(&self) -> Result<UnixStream> {
        let fd = unsafe {
            sys::llam_accept(
                self.inner.as_raw_fd(),
                std::ptr::null_mut(),
                std::ptr::null_mut(),
            )
        };
        if sys::fd_is_invalid(fd) {
            return Err(crate::Error::last().into());
        }
        let stream = unsafe { StdUnixStream::from_raw_fd(fd) };
        stream.set_nonblocking(true)?;
        Ok(UnixStream { inner: stream })
    }

    pub fn into_std(self) -> StdUnixListener {
        self.inner
    }
}

#[cfg(unix)]
impl UnixStream {
    pub fn connect<P: AsRef<Path>>(path: P) -> Result<Self> {
        let stream = StdUnixStream::connect(path)?;
        stream.set_nonblocking(true)?;
        Ok(Self { inner: stream })
    }

    pub fn into_std(self) -> StdUnixStream {
        self.inner
    }
}

#[cfg(unix)]
impl Read for UnixStream {
    fn read(&mut self, buf: &mut [u8]) -> Result<usize> {
        io::read(self.inner.as_raw_fd(), buf)
    }
}

#[cfg(unix)]
impl Write for UnixStream {
    fn write(&mut self, buf: &[u8]) -> Result<usize> {
        io::write(self.inner.as_raw_fd(), buf)
    }

    fn flush(&mut self) -> Result<()> {
        Ok(())
    }
}

#[cfg(unix)]
fn raw_listener_fd(listener: &StdTcpListener) -> sys::llam_fd_t {
    listener.as_raw_fd()
}

#[cfg(windows)]
fn raw_listener_fd(listener: &StdTcpListener) -> sys::llam_fd_t {
    listener.as_raw_socket() as sys::llam_fd_t
}

#[cfg(unix)]
fn raw_stream_fd(stream: &StdTcpStream) -> sys::llam_fd_t {
    stream.as_raw_fd()
}

#[cfg(unix)]
fn raw_udp_fd(socket: &StdUdpSocket) -> sys::llam_fd_t {
    socket.as_raw_fd()
}

#[cfg(unix)]
fn raw_socket_fd(socket: &Socket) -> sys::llam_fd_t {
    socket.as_raw_fd()
}

#[cfg(windows)]
fn raw_stream_fd(stream: &StdTcpStream) -> sys::llam_fd_t {
    stream.as_raw_socket() as sys::llam_fd_t
}

#[cfg(windows)]
fn raw_udp_fd(socket: &StdUdpSocket) -> sys::llam_fd_t {
    socket.as_raw_socket() as sys::llam_fd_t
}

#[cfg(windows)]
fn raw_socket_fd(socket: &Socket) -> sys::llam_fd_t {
    let raw: RawSocket = socket.as_raw_socket();
    raw as sys::llam_fd_t
}

#[cfg(unix)]
unsafe fn stream_from_fd(fd: sys::llam_fd_t) -> StdTcpStream {
    StdTcpStream::from_raw_fd(fd)
}

#[cfg(windows)]
unsafe fn stream_from_fd(fd: sys::llam_fd_t) -> StdTcpStream {
    StdTcpStream::from_raw_socket(fd as _)
}

#[cfg(unix)]
impl AsRawFd for TcpListener {
    fn as_raw_fd(&self) -> std::os::fd::RawFd {
        self.inner.as_raw_fd()
    }
}

#[cfg(unix)]
impl AsRawFd for TcpStream {
    fn as_raw_fd(&self) -> std::os::fd::RawFd {
        self.inner.as_raw_fd()
    }
}

#[cfg(windows)]
impl AsRawSocket for TcpListener {
    fn as_raw_socket(&self) -> RawSocket {
        self.inner.as_raw_socket()
    }
}

#[cfg(windows)]
impl AsRawSocket for TcpStream {
    fn as_raw_socket(&self) -> RawSocket {
        self.inner.as_raw_socket()
    }
}

#[cfg(unix)]
impl AsRawFd for UdpSocket {
    fn as_raw_fd(&self) -> std::os::fd::RawFd {
        self.inner.as_raw_fd()
    }
}

#[cfg(windows)]
impl AsRawSocket for UdpSocket {
    fn as_raw_socket(&self) -> RawSocket {
        self.inner.as_raw_socket()
    }
}

#[cfg(unix)]
impl AsRawFd for UnixListener {
    fn as_raw_fd(&self) -> std::os::fd::RawFd {
        self.inner.as_raw_fd()
    }
}

#[cfg(unix)]
impl AsRawFd for UnixStream {
    fn as_raw_fd(&self) -> std::os::fd::RawFd {
        self.inner.as_raw_fd()
    }
}