use alloc::{boxed::Box, vec::Vec};
use core::{net::Ipv4Addr, time::Duration};
use ax_errno::{AxError, AxResult};
use ax_hal::time::wall_time;
use ax_io::prelude::*;
use axnet::{CMsgData, RecvFlags, RecvOptions, SendFlags, SendOptions, SocketAddrEx, SocketOps};
use linux_raw_sys::{
general::timespec,
net::{
MSG_DONTWAIT, MSG_PEEK, MSG_TRUNC, SCM_RIGHTS, SOL_SOCKET, cmsghdr, mmsghdr, msghdr,
sockaddr, socklen_t,
},
};
use super::addr::{
SocketAddrExt, normalize_socket_addr_ex_for_ip_stack, socket_addr_ex_for_user_name,
};
use crate::{
file::{FileLike, PacketSocket, Socket, add_file_like, get_file_like, netlink::NetlinkSocket},
mm::{IoVec, IoVectorBuf, UserConstPtr, UserPtr, VmBytes, VmBytesMut},
syscall::net::{CMsg, CMsgBuilder},
time::TimeValueLike,
};
const MMSG_MAX_VLEN: u32 = 1024;
fn parse_recvmmsg_timeout(timeout: UserConstPtr<timespec>) -> AxResult<Option<Duration>> {
if timeout.is_null() {
return Ok(None);
}
let ts = timeout.get_as_ref()?;
let tv = (*ts).try_into_time_value()?;
Ok(Some(Duration::new(tv.as_secs(), tv.subsec_nanos())))
}
fn parse_send_cmsgs(control_ptr: usize, control_len: usize) -> AxResult<Vec<CMsgData>> {
let mut cmsg = Vec::new();
if control_ptr == 0 || control_len == 0 {
return Ok(cmsg);
}
let mut ptr = control_ptr;
let ptr_end = ptr.checked_add(control_len).ok_or(AxError::InvalidInput)?;
while let Some(next) = ptr.checked_add(size_of::<cmsghdr>()) {
if next > ptr_end {
break;
}
let hdr = UserConstPtr::<cmsghdr>::from(ptr).get_as_ref()?;
if hdr.cmsg_len < size_of::<cmsghdr>() || ptr_end - ptr < hdr.cmsg_len {
return Err(AxError::InvalidInput);
}
cmsg.push(Box::new(CMsg::parse(hdr)?) as CMsgData);
ptr += hdr.cmsg_len;
}
Ok(cmsg)
}
fn send_impl(
fd: i32,
mut src: impl Read + IoBuf,
flags: u32,
addr: UserConstPtr<sockaddr>,
addrlen: socklen_t,
cmsg: Vec<CMsgData>,
) -> AxResult<isize> {
if let Ok(packet) = PacketSocket::from_fd(fd) {
return Ok(packet.send_packet(&mut src)? as isize);
}
if let Ok(socket) = Socket::from_fd(fd) {
let addr = if addr.is_null() {
None
} else if addrlen == 0 {
return Err(AxError::InvalidInput);
} else {
let mut addr = SocketAddrEx::read_from_user(addr, addrlen)?;
if socket.ip_domain() == linux_raw_sys::net::AF_INET6 {
addr = normalize_socket_addr_ex_for_ip_stack(addr, false)?;
}
Some(addr)
};
let send_flags = SendFlags::from_bits_retain(flags);
debug!("sys_send <= fd: {fd}, flags: {flags:#x}, addr: {addr:?}");
let sent = socket.send(
&mut src,
SendOptions {
to: addr,
flags: send_flags,
cmsg,
},
)?;
return Ok(sent as isize);
}
if let Ok(netlink) = NetlinkSocket::from_fd(fd) {
let sent = netlink.write(&mut src)?;
return Ok(sent as isize);
}
get_file_like(fd)?;
Err(AxError::NotASocket)
}
pub fn sys_sendto(
fd: i32,
buf: *const u8,
len: usize,
flags: u32,
addr: UserConstPtr<sockaddr>,
addrlen: socklen_t,
) -> AxResult<isize> {
send_impl(fd, VmBytes::new(buf, len), flags, addr, addrlen, Vec::new())
}
pub fn sys_sendmsg(fd: i32, msg: UserConstPtr<msghdr>, flags: u32) -> AxResult<isize> {
let msg = msg.get_as_ref()?;
let cmsg = parse_send_cmsgs(msg.msg_control as usize, msg.msg_controllen)?;
send_impl(
fd,
IoVectorBuf::new(msg.msg_iov as *const IoVec, msg.msg_iovlen)?.into_io(),
flags,
UserConstPtr::from(msg.msg_name as usize),
msg.msg_namelen as socklen_t,
cmsg,
)
}
fn recv_impl(
fd: i32,
mut dst: impl Write + IoBufMut,
flags: u32,
addr: UserPtr<sockaddr>,
addrlen: UserPtr<socklen_t>,
cmsg_builder: Option<CMsgBuilder>,
truncated_out: &mut bool,
) -> AxResult<isize> {
debug!("sys_recv <= fd: {fd}, flags: {flags}");
if let Ok(packet) = PacketSocket::from_fd(fd) {
let (recv, from) = packet.recv_packet(&mut dst)?;
if !addr.is_null() {
from.write_to_user(
addr.address().as_usize() as *mut sockaddr,
addrlen.get_as_mut()?,
)?;
}
return Ok(recv as isize);
}
let Ok(socket) = Socket::from_fd(fd) else {
if let Ok(netlink) = NetlinkSocket::from_fd(fd) {
let recv = netlink.read(&mut dst)?;
if !addr.is_null() {
super::addr::write_netlink_addr(
&netlink.kernel_addr(),
addr,
addrlen.get_as_mut()?,
)?;
}
return Ok(recv as isize);
}
get_file_like(fd)?;
return Err(AxError::NotASocket);
};
let mut recv_flags = RecvFlags::empty();
if flags & MSG_PEEK != 0 {
recv_flags |= RecvFlags::PEEK;
}
if flags & MSG_TRUNC != 0 {
recv_flags |= RecvFlags::TRUNCATE;
}
if flags & MSG_DONTWAIT != 0 {
recv_flags |= RecvFlags::DONTWAIT;
}
let mut cmsg = Vec::new();
let mut remote_addr =
(!addr.is_null()).then(|| SocketAddrEx::Ip((Ipv4Addr::UNSPECIFIED, 0).into()));
let recv = socket.recv(
&mut dst,
RecvOptions {
from: remote_addr.as_mut(),
flags: recv_flags,
cmsg: Some(&mut cmsg),
truncated: Some(truncated_out),
},
)?;
if let Some(remote_addr) = remote_addr {
socket_addr_ex_for_user_name(socket.ip_domain(), remote_addr)
.write_to_user(addr, addrlen.get_as_mut()?)?;
}
if let Some(mut builder) = cmsg_builder {
for cmsg in cmsg {
let Ok(cmsg) = cmsg.downcast::<CMsg>() else {
warn!("received unexpected cmsg");
continue;
};
let pushed = match *cmsg {
CMsg::Rights { fds } => builder.push(SOL_SOCKET, SCM_RIGHTS, |data| {
let mut written = 0;
for (f, chunk) in fds.into_iter().zip(data.chunks_exact_mut(size_of::<i32>())) {
let fd = add_file_like(f, false)?;
chunk.copy_from_slice(&fd.to_ne_bytes());
written += size_of::<i32>();
}
Ok(written)
})?,
};
if !pushed {
break;
}
}
}
debug!("sys_recv => fd: {fd}, recv: {recv}");
Ok(recv as isize)
}
pub fn sys_recvfrom(
fd: i32,
buf: *mut u8,
len: usize,
flags: u32,
addr: UserPtr<sockaddr>,
addrlen: UserPtr<socklen_t>,
) -> AxResult<isize> {
recv_impl(
fd,
VmBytesMut::new(buf, len),
flags,
addr,
addrlen,
None,
&mut false,
)
}
pub fn sys_recvmsg(fd: i32, msg: UserPtr<msghdr>, flags: u32) -> AxResult<isize> {
let msg = msg.get_as_mut()?;
let mut truncated = false;
let recv = recv_impl(
fd,
IoVectorBuf::new(msg.msg_iov as *mut IoVec, msg.msg_iovlen)?.into_io(),
flags,
UserPtr::from(msg.msg_name as usize),
UserPtr::from(&mut msg.msg_namelen as *mut _ as *mut socklen_t),
(!msg.msg_control.is_null()).then(|| {
CMsgBuilder::new(
UserPtr::from(msg.msg_control as *mut cmsghdr),
&mut msg.msg_controllen,
)
}),
&mut truncated,
);
if recv.is_ok() {
msg.msg_flags = if truncated { MSG_TRUNC } else { 0 };
}
recv
}
pub fn sys_sendmmsg(fd: i32, msgvec: UserPtr<mmsghdr>, vlen: u32, flags: u32) -> AxResult<isize> {
if vlen == 0 {
return Ok(0);
}
if vlen > MMSG_MAX_VLEN {
return Err(AxError::InvalidInput);
}
let msgvec = msgvec.get_as_mut_slice(vlen as usize)?;
let mut sent = 0;
for msg in msgvec.iter_mut() {
let cmsg = parse_send_cmsgs(msg.msg_hdr.msg_control as usize, msg.msg_hdr.msg_controllen)?;
match send_impl(
fd,
IoVectorBuf::new(msg.msg_hdr.msg_iov as *const IoVec, msg.msg_hdr.msg_iovlen)?
.into_io(),
flags,
UserConstPtr::from(msg.msg_hdr.msg_name as usize),
msg.msg_hdr.msg_namelen as socklen_t,
cmsg,
) {
Ok(n) => {
msg.msg_len = n as u32;
sent += 1;
}
Err(e) => {
if sent == 0 {
return Err(e);
}
break;
}
}
}
Ok(sent)
}
pub fn sys_recvmmsg(
fd: i32,
msgvec: UserPtr<mmsghdr>,
vlen: u32,
flags: u32,
timeout: UserConstPtr<timespec>,
) -> AxResult<isize> {
if vlen == 0 {
return Ok(0);
}
if vlen > MMSG_MAX_VLEN {
return Err(AxError::InvalidInput);
}
let timeout = parse_recvmmsg_timeout(timeout)?;
let deadline = timeout.map(|t| wall_time() + t);
let _socket = Socket::from_fd(fd)?;
let msgvec = msgvec.get_as_mut_slice(vlen as usize)?;
let mut received = 0;
for msg in msgvec.iter_mut() {
if let Some(deadline) = deadline
&& wall_time() >= deadline
{
if received == 0 {
return Err(AxError::WouldBlock);
}
break;
}
let recv = recv_impl(
fd,
IoVectorBuf::new(msg.msg_hdr.msg_iov as *mut IoVec, msg.msg_hdr.msg_iovlen)?.into_io(),
flags,
UserPtr::from(msg.msg_hdr.msg_name as usize),
UserPtr::from(&mut msg.msg_hdr.msg_namelen as *mut _ as *mut socklen_t),
(!msg.msg_hdr.msg_control.is_null()).then(|| {
CMsgBuilder::new(
UserPtr::from(msg.msg_hdr.msg_control as *mut cmsghdr),
&mut msg.msg_hdr.msg_controllen,
)
}),
&mut false,
);
match recv {
Ok(n) => {
msg.msg_len = n as u32;
received += 1;
}
Err(e) => {
if received == 0 {
return Err(e);
}
break;
}
}
}
Ok(received)
}