monoio 0.0.9

A thread per core runtime based on iouring.
Documentation
#[cfg(unix)]
use std::os::unix::prelude::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
#[cfg(windows)]
use std::os::windows::prelude::{AsRawHandle, FromRawSocket, RawHandle};
use std::{
    cell::UnsafeCell,
    future::Future,
    io,
    net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
};

use super::stream::TcpStream;
use crate::{
    driver::{op::Op, shared_fd::SharedFd},
    io::stream::Stream,
    net::ListenerConfig,
};

/// TcpListener
pub struct TcpListener {
    fd: SharedFd,
    sys_listener: Option<std::net::TcpListener>,
    meta: UnsafeCell<ListenerMeta>,
}

impl TcpListener {
    pub(crate) fn from_shared_fd(fd: SharedFd) -> Self {
        #[cfg(unix)]
        let sys_listener = unsafe { std::net::TcpListener::from_raw_fd(fd.raw_fd()) };
        #[cfg(windows)]
        let sys_listener = unsafe { std::net::TcpListener::from_raw_socket(todo!()) };
        Self {
            fd,
            sys_listener: Some(sys_listener),
            meta: UnsafeCell::new(ListenerMeta::default()),
        }
    }

    /// Bind to address with config
    pub fn bind_with_config<A: ToSocketAddrs>(
        addr: A,
        config: &ListenerConfig,
    ) -> io::Result<Self> {
        let addr = addr
            .to_socket_addrs()?
            .next()
            .ok_or_else(|| io::Error::new(io::ErrorKind::Other, "empty address"))?;

        let domain = if addr.is_ipv6() {
            socket2::Domain::IPV6
        } else {
            socket2::Domain::IPV4
        };
        let sys_listener =
            socket2::Socket::new(domain, socket2::Type::STREAM, Some(socket2::Protocol::TCP))?;

        #[cfg(all(unix, feature = "legacy"))]
        Self::set_non_blocking(&sys_listener)?;

        let addr = socket2::SockAddr::from(addr);
        #[cfg(unix)]
        if config.reuse_port {
            sys_listener.set_reuse_port(true)?;
        }
        if config.reuse_addr {
            sys_listener.set_reuse_address(true)?;
        }
        if let Some(send_buf_size) = config.send_buf_size {
            sys_listener.set_send_buffer_size(send_buf_size)?;
        }
        if let Some(recv_buf_size) = config.recv_buf_size {
            sys_listener.set_recv_buffer_size(recv_buf_size)?;
        }
        sys_listener.bind(&addr)?;
        sys_listener.listen(config.backlog)?;

        #[cfg(unix)]
        let fd = SharedFd::new(sys_listener.into_raw_fd())?;

        #[cfg(windows)]
        let fd = unimplemented!();

        Ok(Self::from_shared_fd(fd))
    }

    /// Bind to address
    pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
        let cfg = ListenerConfig::default();
        Self::bind_with_config(addr, &cfg)
    }

    #[cfg(unix)]
    /// Accept
    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
        let op = Op::accept(&self.fd)?;

        // Await the completion of the event
        let completion = op.await;

        // Convert fd
        let fd = completion.meta.result?;

        // Construct stream
        let stream = TcpStream::from_shared_fd(SharedFd::new(fd as _)?);

        // Construct SocketAddr
        let storage = completion.data.addr.0.as_ptr() as *const _ as *const libc::sockaddr_storage;
        let addr = unsafe {
            match (*storage).ss_family as libc::c_int {
                libc::AF_INET => {
                    // Safety: if the ss_family field is AF_INET then storage must be a sockaddr_in.
                    let addr: &libc::sockaddr_in = &*(storage as *const libc::sockaddr_in);
                    let ip = Ipv4Addr::from(addr.sin_addr.s_addr.to_ne_bytes());
                    let port = u16::from_be(addr.sin_port);
                    SocketAddr::V4(SocketAddrV4::new(ip, port))
                }
                libc::AF_INET6 => {
                    // Safety: if the ss_family field is AF_INET6 then storage must be a
                    // sockaddr_in6.
                    let addr: &libc::sockaddr_in6 = &*(storage as *const libc::sockaddr_in6);
                    let ip = Ipv6Addr::from(addr.sin6_addr.s6_addr);
                    let port = u16::from_be(addr.sin6_port);
                    SocketAddr::V6(SocketAddrV6::new(
                        ip,
                        port,
                        addr.sin6_flowinfo,
                        addr.sin6_scope_id,
                    ))
                }
                _ => {
                    return Err(io::ErrorKind::InvalidInput.into());
                }
            }
        };

        Ok((stream, addr))
    }

    #[cfg(windows)]
    /// Accept
    pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
        unimplemented!()
    }

    /// Returns the local address that this listener is bound to.
    pub fn local_addr(&self) -> io::Result<SocketAddr> {
        let meta = self.meta.get();
        if let Some(addr) = unsafe { &*meta }.local_addr {
            return Ok(addr);
        }
        self.sys_listener
            .as_ref()
            .unwrap()
            .local_addr()
            .map(|addr| {
                unsafe { &mut *meta }.local_addr = Some(addr);
                addr
            })
    }

    #[cfg(all(unix, feature = "legacy"))]
    fn set_non_blocking(_socket: &socket2::Socket) -> io::Result<()> {
        crate::driver::CURRENT.with(|x| match x {
            // TODO: windows ioring support
            #[cfg(all(target_os = "linux", feature = "iouring"))]
            crate::driver::Inner::Uring(_) => Ok(()),
            crate::driver::Inner::Legacy(_) => _socket.set_nonblocking(true),
        })
    }
}

impl Stream for TcpListener {
    type Item = io::Result<(TcpStream, SocketAddr)>;

    type NextFuture<'a> = impl Future<Output = Option<Self::Item>> + 'a;

    fn next(&mut self) -> Self::NextFuture<'_> {
        async move { Some(self.accept().await) }
    }
}

impl std::fmt::Debug for TcpListener {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("TcpListener").field("fd", &self.fd).finish()
    }
}

#[cfg(unix)]
impl AsRawFd for TcpListener {
    fn as_raw_fd(&self) -> RawFd {
        self.fd.raw_fd()
    }
}

#[cfg(windows)]
impl AsRawHandle for TcpListener {
    fn as_raw_handle(&self) -> RawHandle {
        self.fd.raw_handle()
    }
}

impl Drop for TcpListener {
    fn drop(&mut self) {
        #[cfg(unix)]
        self.sys_listener.take().unwrap().into_raw_fd();
        #[cfg(windows)]
        unimplemented!()
    }
}

#[derive(Debug, Default, Clone)]
struct ListenerMeta {
    local_addr: Option<SocketAddr>,
}