use std::{
io::{IoSlice, IoSliceMut},
os::fd::{BorrowedFd, RawFd},
};
use anyhow::anyhow;
use libc::EINVAL;
use nix::{
cmsg_space,
sys::socket::{
ControlMessage, ControlMessageOwned, MsgFlags, getsockopt, recvmsg, sendmsg,
sockopt::SockType,
},
unistd::{read, write},
};
use crate::consts::MEMFD_COUNT;
pub(crate) fn block_read_full(conn_fd: RawFd, data: &mut [u8]) -> Result<(), anyhow::Error> {
let mut read_size = 0;
while read_size < data.len() {
let n = read(
unsafe { BorrowedFd::borrow_raw(conn_fd) },
&mut data[read_size..],
)
.map_err(|e| {
anyhow!(
"read_full failed, had read_size:{read_size}, reason:{}",
e.desc()
)
})?;
read_size += n;
if n == 0 {
return Err(anyhow!("EOF"));
}
}
Ok(())
}
pub(crate) fn block_write_full(conn_fd: RawFd, data: &[u8]) -> Result<(), anyhow::Error> {
let mut written = 0;
while written < data.len() {
let n = write(unsafe { BorrowedFd::borrow_raw(conn_fd) }, &data[written..])?;
written += n;
}
Ok(())
}
pub(crate) fn send_fd(conn_fd: RawFd, fds: &[RawFd]) -> Result<(), anyhow::Error> {
let mut iov = [IoSlice::new(&[0u8; 0])];
let mut cmsgs = Vec::with_capacity(1);
if !fds.is_empty() {
let borrowed_fd = unsafe { BorrowedFd::borrow_raw(conn_fd) };
let sock_type = getsockopt(&borrowed_fd, SockType)?;
if sock_type != nix::sys::socket::SockType::Datagram {
iov[0] = IoSlice::new(&[0u8; 1]);
}
cmsgs.push(ControlMessage::ScmRights(fds))
}
Ok(sendmsg::<()>(
conn_fd,
iov.as_slice(),
cmsgs.as_slice(),
MsgFlags::empty(),
None,
)
.map(|_| ())?)
}
pub(crate) fn block_read_out_of_bound_for_fd(conn_fd: RawFd) -> Result<Vec<RawFd>, anyhow::Error> {
let mut iov = [IoSliceMut::new(&mut [0u8; 0])];
let borrowed_fd = unsafe { BorrowedFd::borrow_raw(conn_fd) };
let sock_type = getsockopt(&borrowed_fd, SockType)?;
let mut buf = [0u8; 1];
if sock_type != nix::sys::socket::SockType::Datagram {
iov[0] = IoSliceMut::new(&mut buf);
}
let mut cmsg_buffer = cmsg_space!([RawFd; MEMFD_COUNT]);
let recv_msg = recvmsg::<()>(
conn_fd,
iov.as_mut_slice(),
Some(cmsg_buffer.as_mut()),
MsgFlags::empty(),
)
.map_err(|err| anyhow!("try recv fd from peer failed, reason:{}", err))?;
tracing::info!("recvmsg finished");
if let Some(msgs) = recv_msg.cmsgs()?.next() {
if let ControlMessageOwned::ScmRights(fds) = msgs {
Ok(fds)
} else {
Err(anyhow!(
"parse fd from unix domain failed, reason errno:{}",
EINVAL
))
}
} else {
Err(anyhow!("parse socket control message ret is nil"))
}
}