starry-kernel 0.5.11

A Linux-compatible OS kernel built on ArceOS unikernel
Documentation
use alloc::{boxed::Box, vec::Vec};
use core::{net::Ipv4Addr, time::Duration};

use ax_errno::{AxError, AxResult};
use ax_hal::time::wall_time;
use ax_io::prelude::*;
use axnet::{CMsgData, RecvFlags, RecvOptions, SendFlags, SendOptions, SocketAddrEx, SocketOps};
use linux_raw_sys::{
    general::timespec,
    net::{
        MSG_DONTWAIT, MSG_PEEK, MSG_TRUNC, SCM_RIGHTS, SOL_SOCKET, cmsghdr, mmsghdr, msghdr,
        sockaddr, socklen_t,
    },
};

use super::addr::{
    SocketAddrExt, normalize_socket_addr_ex_for_ip_stack, socket_addr_ex_for_user_name,
};
use crate::{
    file::{FileLike, PacketSocket, Socket, add_file_like, get_file_like, netlink::NetlinkSocket},
    mm::{IoVec, IoVectorBuf, UserConstPtr, UserPtr, VmBytes, VmBytesMut},
    syscall::net::{CMsg, CMsgBuilder},
    time::TimeValueLike,
};

// Linux ABI for sendmmsg/recvmmsg limits vlen to UIO_MAXIOV (1024).
const MMSG_MAX_VLEN: u32 = 1024;

fn parse_recvmmsg_timeout(timeout: UserConstPtr<timespec>) -> AxResult<Option<Duration>> {
    if timeout.is_null() {
        return Ok(None);
    }
    let ts = timeout.get_as_ref()?;
    let tv = (*ts).try_into_time_value()?;
    Ok(Some(Duration::new(tv.as_secs(), tv.subsec_nanos())))
}

fn parse_send_cmsgs(control_ptr: usize, control_len: usize) -> AxResult<Vec<CMsgData>> {
    let mut cmsg = Vec::new();
    if control_ptr == 0 || control_len == 0 {
        return Ok(cmsg);
    }

    let mut ptr = control_ptr;
    let ptr_end = ptr.checked_add(control_len).ok_or(AxError::InvalidInput)?;

    while let Some(next) = ptr.checked_add(size_of::<cmsghdr>()) {
        if next > ptr_end {
            break;
        }

        let hdr = UserConstPtr::<cmsghdr>::from(ptr).get_as_ref()?;
        if hdr.cmsg_len < size_of::<cmsghdr>() || ptr_end - ptr < hdr.cmsg_len {
            return Err(AxError::InvalidInput);
        }

        cmsg.push(Box::new(CMsg::parse(hdr)?) as CMsgData);
        ptr += hdr.cmsg_len;
    }

    Ok(cmsg)
}

fn send_impl(
    fd: i32,
    mut src: impl Read + IoBuf,
    flags: u32,
    addr: UserConstPtr<sockaddr>,
    addrlen: socklen_t,
    cmsg: Vec<CMsgData>,
) -> AxResult<isize> {
    if let Ok(packet) = PacketSocket::from_fd(fd) {
        return Ok(packet.send_packet(&mut src)? as isize);
    }

    if let Ok(socket) = Socket::from_fd(fd) {
        let addr = if addr.is_null() {
            // addr == NULL: treat as no address regardless of addrlen.
            // Linux sendto(..., NULL, nonzero) sends to connected peer or
            // returns EDESTADDRREQ on unconnected socket, never EINVAL.
            None
        } else if addrlen == 0 {
            return Err(AxError::InvalidInput);
        } else {
            let mut addr = SocketAddrEx::read_from_user(addr, addrlen)?;
            if socket.ip_domain() == linux_raw_sys::net::AF_INET6 {
                addr = normalize_socket_addr_ex_for_ip_stack(addr, false)?;
            }
            Some(addr)
        };

        let send_flags = SendFlags::from_bits_retain(flags);

        debug!("sys_send <= fd: {fd}, flags: {flags:#x}, addr: {addr:?}");

        let sent = socket.send(
            &mut src,
            SendOptions {
                to: addr,
                flags: send_flags,
                cmsg,
            },
        )?;

        return Ok(sent as isize);
    }

    if let Ok(netlink) = NetlinkSocket::from_fd(fd) {
        let sent = netlink.write(&mut src)?;
        return Ok(sent as isize);
    }

    get_file_like(fd)?;
    Err(AxError::NotASocket)
}

pub fn sys_sendto(
    fd: i32,
    buf: *const u8,
    len: usize,
    flags: u32,
    addr: UserConstPtr<sockaddr>,
    addrlen: socklen_t,
) -> AxResult<isize> {
    send_impl(fd, VmBytes::new(buf, len), flags, addr, addrlen, Vec::new())
}

pub fn sys_sendmsg(fd: i32, msg: UserConstPtr<msghdr>, flags: u32) -> AxResult<isize> {
    let msg = msg.get_as_ref()?;
    let cmsg = parse_send_cmsgs(msg.msg_control as usize, msg.msg_controllen)?;
    send_impl(
        fd,
        IoVectorBuf::new(msg.msg_iov as *const IoVec, msg.msg_iovlen)?.into_io(),
        flags,
        UserConstPtr::from(msg.msg_name as usize),
        msg.msg_namelen as socklen_t,
        cmsg,
    )
}

fn recv_impl(
    fd: i32,
    mut dst: impl Write + IoBufMut,
    flags: u32,
    addr: UserPtr<sockaddr>,
    addrlen: UserPtr<socklen_t>,
    cmsg_builder: Option<CMsgBuilder>,
    truncated_out: &mut bool,
) -> AxResult<isize> {
    debug!("sys_recv <= fd: {fd}, flags: {flags}");

    if let Ok(packet) = PacketSocket::from_fd(fd) {
        let (recv, from) = packet.recv_packet(&mut dst)?;
        if !addr.is_null() {
            from.write_to_user(
                addr.address().as_usize() as *mut sockaddr,
                addrlen.get_as_mut()?,
            )?;
        }
        return Ok(recv as isize);
    }

    let Ok(socket) = Socket::from_fd(fd) else {
        if let Ok(netlink) = NetlinkSocket::from_fd(fd) {
            let recv = netlink.read(&mut dst)?;
            if !addr.is_null() {
                super::addr::write_netlink_addr(
                    &netlink.kernel_addr(),
                    addr,
                    addrlen.get_as_mut()?,
                )?;
            }
            return Ok(recv as isize);
        }

        get_file_like(fd)?;
        return Err(AxError::NotASocket);
    };
    let mut recv_flags = RecvFlags::empty();
    if flags & MSG_PEEK != 0 {
        recv_flags |= RecvFlags::PEEK;
    }
    if flags & MSG_TRUNC != 0 {
        recv_flags |= RecvFlags::TRUNCATE;
    }
    if flags & MSG_DONTWAIT != 0 {
        recv_flags |= RecvFlags::DONTWAIT;
    }

    let mut cmsg = Vec::new();

    let mut remote_addr =
        (!addr.is_null()).then(|| SocketAddrEx::Ip((Ipv4Addr::UNSPECIFIED, 0).into()));
    let recv = socket.recv(
        &mut dst,
        RecvOptions {
            from: remote_addr.as_mut(),
            flags: recv_flags,
            cmsg: Some(&mut cmsg),
            truncated: Some(truncated_out),
        },
    )?;

    if let Some(remote_addr) = remote_addr {
        socket_addr_ex_for_user_name(socket.ip_domain(), remote_addr)
            .write_to_user(addr, addrlen.get_as_mut()?)?;
    }

    if let Some(mut builder) = cmsg_builder {
        for cmsg in cmsg {
            let Ok(cmsg) = cmsg.downcast::<CMsg>() else {
                warn!("received unexpected cmsg");
                continue;
            };

            let pushed = match *cmsg {
                CMsg::Rights { fds } => builder.push(SOL_SOCKET, SCM_RIGHTS, |data| {
                    let mut written = 0;
                    for (f, chunk) in fds.into_iter().zip(data.chunks_exact_mut(size_of::<i32>())) {
                        let fd = add_file_like(f, false)?;
                        chunk.copy_from_slice(&fd.to_ne_bytes());
                        written += size_of::<i32>();
                    }
                    Ok(written)
                })?,
            };
            if !pushed {
                break;
            }
        }
    }

    debug!("sys_recv => fd: {fd}, recv: {recv}");
    Ok(recv as isize)
}

pub fn sys_recvfrom(
    fd: i32,
    buf: *mut u8,
    len: usize,
    flags: u32,
    addr: UserPtr<sockaddr>,
    addrlen: UserPtr<socklen_t>,
) -> AxResult<isize> {
    recv_impl(
        fd,
        VmBytesMut::new(buf, len),
        flags,
        addr,
        addrlen,
        None,
        &mut false,
    )
}

pub fn sys_recvmsg(fd: i32, msg: UserPtr<msghdr>, flags: u32) -> AxResult<isize> {
    let msg = msg.get_as_mut()?;
    let mut truncated = false;
    let recv = recv_impl(
        fd,
        IoVectorBuf::new(msg.msg_iov as *mut IoVec, msg.msg_iovlen)?.into_io(),
        flags,
        UserPtr::from(msg.msg_name as usize),
        UserPtr::from(&mut msg.msg_namelen as *mut _ as *mut socklen_t),
        (!msg.msg_control.is_null()).then(|| {
            CMsgBuilder::new(
                UserPtr::from(msg.msg_control as *mut cmsghdr),
                &mut msg.msg_controllen,
            )
        }),
        &mut truncated,
    );
    // Linux: on success, set msg.msg_flags to indicate truncation etc.
    if recv.is_ok() {
        msg.msg_flags = if truncated { MSG_TRUNC } else { 0 };
    }
    recv
}

/// Send multiple datagrams in one syscall.
pub fn sys_sendmmsg(fd: i32, msgvec: UserPtr<mmsghdr>, vlen: u32, flags: u32) -> AxResult<isize> {
    if vlen == 0 {
        return Ok(0);
    }
    if vlen > MMSG_MAX_VLEN {
        return Err(AxError::InvalidInput);
    }

    let msgvec = msgvec.get_as_mut_slice(vlen as usize)?;
    let mut sent = 0;
    for msg in msgvec.iter_mut() {
        let cmsg = parse_send_cmsgs(msg.msg_hdr.msg_control as usize, msg.msg_hdr.msg_controllen)?;
        match send_impl(
            fd,
            IoVectorBuf::new(msg.msg_hdr.msg_iov as *const IoVec, msg.msg_hdr.msg_iovlen)?
                .into_io(),
            flags,
            UserConstPtr::from(msg.msg_hdr.msg_name as usize),
            msg.msg_hdr.msg_namelen as socklen_t,
            cmsg,
        ) {
            Ok(n) => {
                msg.msg_len = n as u32;
                sent += 1;
            }
            Err(e) => {
                if sent == 0 {
                    return Err(e);
                }
                break;
            }
        }
    }
    Ok(sent)
}

/// Receive multiple datagrams in one syscall.
pub fn sys_recvmmsg(
    fd: i32,
    msgvec: UserPtr<mmsghdr>,
    vlen: u32,
    flags: u32,
    timeout: UserConstPtr<timespec>,
) -> AxResult<isize> {
    if vlen == 0 {
        return Ok(0);
    }
    if vlen > MMSG_MAX_VLEN {
        return Err(AxError::InvalidInput);
    }

    let timeout = parse_recvmmsg_timeout(timeout)?;
    // TODO: deadline is only checked between recv_impl calls. If a single
    // recv_impl blocks waiting for data (socket has nothing to read), the
    // deadline cannot interrupt it. Needs a non-blocking recv path or
    // SO_RCVTIMEO support at the socket layer to fix.
    let deadline = timeout.map(|t| wall_time() + t);
    let _socket = Socket::from_fd(fd)?;
    let msgvec = msgvec.get_as_mut_slice(vlen as usize)?;
    let mut received = 0;
    for msg in msgvec.iter_mut() {
        if let Some(deadline) = deadline
            && wall_time() >= deadline
        {
            if received == 0 {
                return Err(AxError::WouldBlock);
            }
            break;
        }

        let recv = recv_impl(
            fd,
            IoVectorBuf::new(msg.msg_hdr.msg_iov as *mut IoVec, msg.msg_hdr.msg_iovlen)?.into_io(),
            flags,
            UserPtr::from(msg.msg_hdr.msg_name as usize),
            UserPtr::from(&mut msg.msg_hdr.msg_namelen as *mut _ as *mut socklen_t),
            (!msg.msg_hdr.msg_control.is_null()).then(|| {
                CMsgBuilder::new(
                    UserPtr::from(msg.msg_hdr.msg_control as *mut cmsghdr),
                    &mut msg.msg_hdr.msg_controllen,
                )
            }),
            &mut false,
        );

        match recv {
            Ok(n) => {
                msg.msg_len = n as u32;
                received += 1;
            }
            Err(e) => {
                if received == 0 {
                    return Err(e);
                }
                break;
            }
        }
    }

    Ok(received)
}