use alloc::{boxed::Box, vec::Vec};
use core::net::Ipv4Addr;
use ax_errno::{AxError, AxResult};
use ax_io::prelude::*;
use axnet::{CMsgData, RecvFlags, RecvOptions, SendFlags, SendOptions, SocketAddrEx, SocketOps};
use linux_raw_sys::net::{
MSG_PEEK, MSG_TRUNC, SCM_RIGHTS, SOL_SOCKET, cmsghdr, msghdr, sockaddr, socklen_t,
};
use super::addr::SocketAddrExt;
use crate::{
file::{FileLike, Socket, add_file_like},
mm::{IoVec, IoVectorBuf, UserConstPtr, UserPtr, VmBytes, VmBytesMut},
syscall::net::{CMsg, CMsgBuilder},
};
fn send_impl(
fd: i32,
mut src: impl Read + IoBuf,
flags: u32,
addr: UserConstPtr<sockaddr>,
addrlen: socklen_t,
cmsg: Vec<CMsgData>,
) -> AxResult<isize> {
let addr = if addr.is_null() || addrlen == 0 {
None
} else {
Some(SocketAddrEx::read_from_user(addr, addrlen)?)
};
debug!("sys_send <= fd: {fd}, flags: {flags}, addr: {addr:?}");
let socket = Socket::from_fd(fd)?;
let sent = socket.send(
&mut src,
SendOptions {
to: addr,
flags: SendFlags::default(),
cmsg,
},
)?;
Ok(sent as isize)
}
pub fn sys_sendto(
fd: i32,
buf: *const u8,
len: usize,
flags: u32,
addr: UserConstPtr<sockaddr>,
addrlen: socklen_t,
) -> AxResult<isize> {
send_impl(fd, VmBytes::new(buf, len), flags, addr, addrlen, Vec::new())
}
pub fn sys_sendmsg(fd: i32, msg: UserConstPtr<msghdr>, flags: u32) -> AxResult<isize> {
let msg = msg.get_as_ref()?;
let mut cmsg = Vec::new();
if !msg.msg_control.is_null() {
let mut ptr = msg.msg_control as usize;
let ptr_end = ptr + msg.msg_controllen;
while ptr + size_of::<cmsghdr>() <= ptr_end {
let hdr = UserConstPtr::<cmsghdr>::from(ptr).get_as_ref()?;
if ptr_end - ptr < hdr.cmsg_len {
return Err(AxError::InvalidInput);
}
cmsg.push(Box::new(CMsg::parse(hdr)?) as CMsgData);
ptr += hdr.cmsg_len;
}
}
send_impl(
fd,
IoVectorBuf::new(msg.msg_iov as *const IoVec, msg.msg_iovlen)?.into_io(),
flags,
UserConstPtr::from(msg.msg_name as usize),
msg.msg_namelen as socklen_t,
cmsg,
)
}
fn recv_impl(
fd: i32,
mut dst: impl Write + IoBufMut,
flags: u32,
addr: UserPtr<sockaddr>,
addrlen: UserPtr<socklen_t>,
cmsg_builder: Option<CMsgBuilder>,
) -> AxResult<isize> {
debug!("sys_recv <= fd: {fd}, flags: {flags}");
let socket = Socket::from_fd(fd)?;
let mut recv_flags = RecvFlags::empty();
if flags & MSG_PEEK != 0 {
recv_flags |= RecvFlags::PEEK;
}
if flags & MSG_TRUNC != 0 {
recv_flags |= RecvFlags::TRUNCATE;
}
let mut cmsg = Vec::new();
let mut remote_addr =
(!addr.is_null()).then(|| SocketAddrEx::Ip((Ipv4Addr::UNSPECIFIED, 0).into()));
let recv = socket.recv(
&mut dst,
RecvOptions {
from: remote_addr.as_mut(),
flags: recv_flags,
cmsg: Some(&mut cmsg),
},
)?;
if let Some(remote_addr) = remote_addr {
remote_addr.write_to_user(addr, addrlen.get_as_mut()?)?;
}
if let Some(mut builder) = cmsg_builder {
for cmsg in cmsg {
let Ok(cmsg) = cmsg.downcast::<CMsg>() else {
warn!("received unexpected cmsg");
continue;
};
let pushed = match *cmsg {
CMsg::Rights { fds } => builder.push(SOL_SOCKET, SCM_RIGHTS, |data| {
let mut written = 0;
for (f, chunk) in fds.into_iter().zip(data.chunks_exact_mut(size_of::<i32>())) {
let fd = add_file_like(f, false)?;
chunk.copy_from_slice(&fd.to_ne_bytes());
written += size_of::<i32>();
}
Ok(written)
})?,
};
if !pushed {
break;
}
}
}
debug!("sys_recv => fd: {fd}, recv: {recv}");
Ok(recv as isize)
}
pub fn sys_recvfrom(
fd: i32,
buf: *mut u8,
len: usize,
flags: u32,
addr: UserPtr<sockaddr>,
addrlen: UserPtr<socklen_t>,
) -> AxResult<isize> {
recv_impl(fd, VmBytesMut::new(buf, len), flags, addr, addrlen, None)
}
pub fn sys_recvmsg(fd: i32, msg: UserPtr<msghdr>, flags: u32) -> AxResult<isize> {
let msg = msg.get_as_mut()?;
recv_impl(
fd,
IoVectorBuf::new(msg.msg_iov as *mut IoVec, msg.msg_iovlen)?.into_io(),
flags,
UserPtr::from(msg.msg_name as usize),
UserPtr::from(&mut msg.msg_namelen as *mut _ as *mut socklen_t),
(!msg.msg_control.is_null()).then(|| {
CMsgBuilder::new(
UserPtr::from(msg.msg_control as *mut cmsghdr),
&mut msg.msg_controllen,
)
}),
)
}