use rustix::fd::{AsFd, BorrowedFd};
use std::io::{IoSlice, Result};
use std::net::TcpStream;
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, IntoRawFd, OwnedFd, RawFd};
#[cfg(unix)]
use std::os::unix::net::UnixStream;
#[cfg(windows)]
use std::os::windows::io::{
AsRawSocket, AsSocket, BorrowedSocket, IntoRawSocket, OwnedSocket, RawSocket,
};
use crate::utils::RawFdContainer;
use x11rb_protocol::parse_display::ConnectAddress;
use x11rb_protocol::xauth::Family;
#[derive(Debug, Clone, Copy)]
pub enum PollMode {
Readable,
Writable,
ReadAndWritable,
}
impl PollMode {
pub fn readable(self) -> bool {
match self {
PollMode::Readable | PollMode::ReadAndWritable => true,
PollMode::Writable => false,
}
}
pub fn writable(self) -> bool {
match self {
PollMode::Writable | PollMode::ReadAndWritable => true,
PollMode::Readable => false,
}
}
}
pub trait Stream {
fn poll(&self, mode: PollMode) -> Result<()>;
fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize>;
fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize>;
fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
for buf in bufs {
if !buf.is_empty() {
return self.write(buf, fds);
}
}
Ok(0)
}
}
#[derive(Debug)]
pub struct DefaultStream {
inner: DefaultStreamInner,
}
#[cfg(unix)]
type DefaultStreamInner = RawFdContainer;
#[cfg(not(unix))]
type DefaultStreamInner = TcpStream;
type PeerAddr = (Family, Vec<u8>);
impl DefaultStream {
pub fn connect(addr: &ConnectAddress<'_>) -> Result<(Self, PeerAddr)> {
match addr {
ConnectAddress::Hostname(host, port) => {
let stream = TcpStream::connect((*host, *port))?;
Self::from_tcp_stream(stream)
}
#[cfg(unix)]
ConnectAddress::Socket(path) => {
#[cfg(any(target_os = "linux", target_os = "android"))]
if let Ok(stream) = connect_abstract_unix_stream(path.as_bytes()) {
let stream = DefaultStream { inner: stream };
return Ok((stream, peer_addr::local()));
}
let stream = UnixStream::connect(path)?;
Self::from_unix_stream(stream)
}
#[cfg(not(unix))]
ConnectAddress::Socket(_) => {
Err(std::io::Error::new(
std::io::ErrorKind::Other,
"Unix domain sockets are not supported on Windows",
))
}
_ => Err(std::io::Error::new(
std::io::ErrorKind::Other,
"The given address family is not implemented",
)),
}
}
pub fn from_tcp_stream(stream: TcpStream) -> Result<(Self, PeerAddr)> {
let peer_addr = peer_addr::tcp(&stream.peer_addr()?);
stream.set_nonblocking(true)?;
let result = Self {
inner: stream.into(),
};
Ok((result, peer_addr))
}
#[cfg(unix)]
pub fn from_unix_stream(stream: UnixStream) -> Result<(Self, PeerAddr)> {
stream.set_nonblocking(true)?;
let result = Self {
inner: stream.into(),
};
Ok((result, peer_addr::local()))
}
fn as_fd(&self) -> BorrowedFd<'_> {
self.inner.as_fd()
}
}
#[cfg(unix)]
impl AsRawFd for DefaultStream {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
#[cfg(unix)]
impl AsFd for DefaultStream {
fn as_fd(&self) -> BorrowedFd<'_> {
self.inner.as_fd()
}
}
#[cfg(unix)]
impl IntoRawFd for DefaultStream {
fn into_raw_fd(self) -> RawFd {
self.inner.into_raw_fd()
}
}
#[cfg(unix)]
impl From<DefaultStream> for OwnedFd {
fn from(stream: DefaultStream) -> Self {
stream.inner
}
}
#[cfg(windows)]
impl AsRawSocket for DefaultStream {
fn as_raw_socket(&self) -> RawSocket {
self.inner.as_raw_socket()
}
}
#[cfg(windows)]
impl AsSocket for DefaultStream {
fn as_socket(&self) -> BorrowedSocket<'_> {
self.inner.as_socket()
}
}
#[cfg(windows)]
impl IntoRawSocket for DefaultStream {
fn into_raw_socket(self) -> RawSocket {
self.inner.into_raw_socket()
}
}
#[cfg(windows)]
impl From<DefaultStream> for OwnedSocket {
fn from(stream: DefaultStream) -> Self {
stream.inner.into()
}
}
#[cfg(unix)]
fn do_write(
stream: &DefaultStream,
bufs: &[IoSlice<'_>],
fds: &mut Vec<RawFdContainer>,
) -> Result<usize> {
use rustix::io::Errno;
use rustix::net::{sendmsg, SendAncillaryBuffer, SendAncillaryMessage, SendFlags};
use std::mem::MaybeUninit;
fn sendmsg_wrapper(
fd: BorrowedFd<'_>,
iov: &[IoSlice<'_>],
cmsgs: &mut SendAncillaryBuffer<'_, '_, '_>,
flags: SendFlags,
) -> Result<usize> {
loop {
match sendmsg(fd, iov, cmsgs, flags) {
Ok(n) => return Ok(n),
Err(Errno::INTR) => {}
Err(e) => return Err(e.into()),
}
}
}
let fd = stream.as_fd();
let res = if !fds.is_empty() {
let fds = fds.iter().map(|fd| fd.as_fd()).collect::<Vec<_>>();
let rights = SendAncillaryMessage::ScmRights(&fds);
let mut cmsg_space = vec![MaybeUninit::uninit(); rights.size()];
let mut cmsg_buffer = SendAncillaryBuffer::new(&mut cmsg_space);
assert!(cmsg_buffer.push(rights));
sendmsg_wrapper(fd, bufs, &mut cmsg_buffer, SendFlags::empty())?
} else {
sendmsg_wrapper(fd, bufs, &mut Default::default(), SendFlags::empty())?
};
fds.clear();
Ok(res)
}
impl Stream for DefaultStream {
fn poll(&self, mode: PollMode) -> Result<()> {
use rustix::event::{poll, PollFd, PollFlags};
use rustix::io::Errno;
let mut poll_flags = PollFlags::empty();
if mode.readable() {
poll_flags |= PollFlags::IN;
}
if mode.writable() {
poll_flags |= PollFlags::OUT;
}
let fd = self.as_fd();
let mut poll_fds = [PollFd::from_borrowed_fd(fd, poll_flags)];
loop {
match poll(&mut poll_fds, None) {
Ok(_) => break,
Err(Errno::INTR) => {}
Err(e) => return Err(e.into()),
}
}
Ok(())
}
fn read(&self, buf: &mut [u8], fd_storage: &mut Vec<RawFdContainer>) -> Result<usize> {
#[cfg(unix)]
{
use rustix::io::Errno;
use rustix::net::{recvmsg, RecvAncillaryBuffer, RecvAncillaryMessage};
use std::io::IoSliceMut;
use std::mem::MaybeUninit;
let mut cmsg = [MaybeUninit::uninit(); 1024];
let mut iov = [IoSliceMut::new(buf)];
let mut cmsg_buffer = RecvAncillaryBuffer::new(&mut cmsg);
let fd = self.as_fd();
let msg = loop {
match recvmsg(fd, &mut iov, &mut cmsg_buffer, recvmsg::flags()) {
Ok(msg) => break msg,
Err(Errno::INTR) => {}
Err(e) => return Err(e.into()),
}
};
let fds_received = cmsg_buffer
.drain()
.filter_map(|cmsg| match cmsg {
RecvAncillaryMessage::ScmRights(r) => Some(r),
_ => None,
})
.flatten();
let mut cloexec_error = Ok(());
fd_storage.extend(recvmsg::after_recvmsg(fds_received, &mut cloexec_error));
cloexec_error?;
Ok(msg.bytes)
}
#[cfg(not(unix))]
{
use std::io::Read;
let _ = fd_storage;
loop {
match (&mut &self.inner).read(buf) {
Ok(n) => return Ok(n),
Err(ref e) if e.kind() == std::io::ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
}
}
fn write(&self, buf: &[u8], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
#[cfg(unix)]
{
do_write(self, &[IoSlice::new(buf)], fds)
}
#[cfg(not(unix))]
{
use std::io::{Error, ErrorKind, Write};
if !fds.is_empty() {
return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
}
loop {
match (&mut &self.inner).write(buf) {
Ok(n) => return Ok(n),
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
}
}
fn write_vectored(&self, bufs: &[IoSlice<'_>], fds: &mut Vec<RawFdContainer>) -> Result<usize> {
#[cfg(unix)]
{
do_write(self, bufs, fds)
}
#[cfg(not(unix))]
{
use std::io::{Error, ErrorKind, Write};
if !fds.is_empty() {
return Err(Error::new(ErrorKind::Other, "FD passing is unsupported"));
}
loop {
match (&mut &self.inner).write_vectored(bufs) {
Ok(n) => return Ok(n),
Err(ref e) if e.kind() == ErrorKind::Interrupted => {}
Err(e) => return Err(e),
}
}
}
}
}
#[cfg(any(target_os = "linux", target_os = "android"))]
fn connect_abstract_unix_stream(
path: &[u8],
) -> std::result::Result<RawFdContainer, rustix::io::Errno> {
use rustix::fs::{fcntl_getfl, fcntl_setfl, OFlags};
use rustix::net::{
connect, socket_with, AddressFamily, SocketAddrUnix, SocketFlags, SocketType,
};
let socket = socket_with(
AddressFamily::UNIX,
SocketType::STREAM,
SocketFlags::CLOEXEC,
None,
)?;
connect(&socket, &SocketAddrUnix::new_abstract_name(path)?)?;
fcntl_setfl(&socket, fcntl_getfl(&socket)? | OFlags::NONBLOCK)?;
Ok(socket)
}
#[cfg(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "linux",
target_os = "netbsd",
target_os = "openbsd"
))]
mod recvmsg {
use super::RawFdContainer;
use rustix::net::RecvFlags;
pub(crate) fn flags() -> RecvFlags {
RecvFlags::CMSG_CLOEXEC
}
pub(crate) fn after_recvmsg<'a>(
fds: impl Iterator<Item = RawFdContainer> + 'a,
_cloexec_error: &'a mut Result<(), rustix::io::Errno>,
) -> impl Iterator<Item = RawFdContainer> + 'a {
fds
}
}
#[cfg(all(
unix,
not(any(
target_os = "android",
target_os = "dragonfly",
target_os = "freebsd",
target_os = "linux",
target_os = "netbsd",
target_os = "openbsd"
))
))]
mod recvmsg {
use super::RawFdContainer;
use rustix::io::{fcntl_getfd, fcntl_setfd, FdFlags};
use rustix::net::RecvFlags;
pub(crate) fn flags() -> RecvFlags {
RecvFlags::empty()
}
pub(crate) fn after_recvmsg<'a>(
fds: impl Iterator<Item = RawFdContainer> + 'a,
cloexec_error: &'a mut rustix::io::Result<()>,
) -> impl Iterator<Item = RawFdContainer> + 'a {
fds.map(move |fd| {
if let Err(e) =
fcntl_getfd(&fd).and_then(|flags| fcntl_setfd(&fd, flags | FdFlags::CLOEXEC))
{
*cloexec_error = Err(e);
}
fd
})
}
}
mod peer_addr {
use super::{Family, PeerAddr};
use std::net::{Ipv4Addr, SocketAddr};
pub(super) fn local() -> PeerAddr {
let hostname = gethostname::gethostname()
.to_str()
.map_or_else(Vec::new, |s| s.as_bytes().to_vec());
(Family::LOCAL, hostname)
}
pub(super) fn tcp(addr: &SocketAddr) -> PeerAddr {
let ip = match addr {
SocketAddr::V4(addr) => *addr.ip(),
SocketAddr::V6(addr) => {
let ip = addr.ip();
if ip.is_loopback() {
Ipv4Addr::LOCALHOST
} else if let Some(ip) = ip.to_ipv4() {
ip
} else {
return (Family::INTERNET6, ip.octets().to_vec());
}
}
};
if ip.is_loopback() {
local()
} else {
(Family::INTERNET, ip.octets().to_vec())
}
}
}