#![cfg(feature = "std")]
use super::Connection;
use crate::{Error, Fd, Result, ResultExt, Unsupported};
use alloc::vec::Vec;
use core::any::type_name;
use core::borrow::{Borrow, BorrowMut};
use core::fmt;
use core::ops::{Deref, DerefMut};
use std::io::{IoSlice, IoSliceMut, Read, Write};
cfg_std_unix! {
use nix::sys::socket;
use std::os::unix::io::{AsRawFd, RawFd};
}
cfg_std_windows! {
use std::{io, net, os::windows::io::{AsRawSocket, RawSocket}};
}
#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
#[repr(transparent)]
pub struct StdConnection<C: ?Sized> {
inner: C,
}
impl<C: fmt::Debug> fmt::Debug for StdConnection<C> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt::Debug::fmt(&self.inner, f)
}
}
impl<C> StdConnection<C> {
pub fn new(inner: C) -> Self {
Self { inner }
}
pub fn into_inner(self) -> C {
self.inner
}
}
impl<C: ?Sized> StdConnection<C> {
pub fn get_ref(&self) -> &C {
&self.inner
}
pub fn get_mut(&mut self) -> &mut C {
&mut self.inner
}
}
impl<C> From<C> for StdConnection<C> {
fn from(inner: C) -> Self {
Self { inner }
}
}
impl<C: ?Sized> AsRef<C> for StdConnection<C> {
fn as_ref(&self) -> &C {
&self.inner
}
}
impl<C: ?Sized> AsMut<C> for StdConnection<C> {
fn as_mut(&mut self) -> &mut C {
&mut self.inner
}
}
impl<C> Borrow<C> for StdConnection<C> {
fn borrow(&self) -> &C {
&self.inner
}
}
impl<C> BorrowMut<C> for StdConnection<C> {
fn borrow_mut(&mut self) -> &mut C {
&mut self.inner
}
}
impl<C> Deref for StdConnection<C> {
type Target = C;
fn deref(&self) -> &C {
&self.inner
}
}
impl<C> DerefMut for StdConnection<C> {
fn deref_mut(&mut self) -> &mut C {
&mut self.inner
}
}
cfg_std_unix! {
impl<C: AsRawFd + ?Sized> AsRawFd for StdConnection<C> {
fn as_raw_fd(&self) -> RawFd {
self.inner.as_raw_fd()
}
}
}
cfg_std_windows! {
impl<C: AsRawSocket + ?Sized> AsRawSocket for StdConnection<C> {
fn as_raw_socket(&self) -> RawSocket {
self.inner.as_raw_socket()
}
}
}
macro_rules! impl_non_os_specific_items {
($($inner: tt)*) => {
fn send_slices_and_fds(
&mut self,
slices: &[IoSlice<'_>],
fds: &mut Vec<Fd>,
) -> Result<usize> {
let span = tracing::trace_span!(
"{tyname}::send_slices_and_fds",
tyname = type_name::<Self>()
);
let _enter = span.enter();
if !fds.is_empty() {
tracing::error!("Attempted to send fds with non-unix connection");
return Err(Error::make_unsupported(Unsupported::Fds));
}
($($inner)* self.inner)
.write_vectored(slices)
.map_err(Error::io)
.trace(|amt| {
tracing::trace!("Sent {} bytes", amt);
})
}
fn send_slices(&mut self, slices: &[IoSlice<'_>]) -> Result<usize> {
let span = tracing::trace_span!(
"{tyname}::send_slices",
tyname = type_name::<Self>()
);
let _enter = span.enter();
($($inner)* self.inner)
.write_vectored(slices)
.map_err(Error::io)
.trace(|amt| {
tracing::trace!("Sent {} bytes", amt);
})
}
fn send_slice(&mut self, slice: &[u8]) -> Result<usize> {
let span = tracing::trace_span!(
"{tyname}::send_slice",
tyname = type_name::<Self>()
);
let _enter = span.enter();
($($inner)* self.inner)
.write(slice)
.map_err(Error::io)
.trace(|amt| {
tracing::trace!("Sent {} bytes", amt);
})
}
fn recv_slices_and_fds(
&mut self,
slices: &mut [IoSliceMut<'_>],
_fds: &mut Vec<Fd>,
) -> Result<usize> {
let span = tracing::trace_span!(
"{tyname}::recv_slices_and_fds",
tyname = type_name::<Self>()
);
let _enter = span.enter();
($($inner)* self.inner)
.read_vectored(slices)
.map_err(Error::io)
.trace(|amt| {
tracing::trace!("Received {} bytes", amt);
})
}
fn recv_slice_and_fds(&mut self, slice: &mut [u8], _fds: &mut Vec<Fd>) -> Result<usize> {
let span = tracing::trace_span!(
"{tyname}::recv_slice_and_fds",
tyname = type_name::<Self>()
);
let _enter = span.enter();
($($inner)* self.inner)
.read(slice)
.map_err(Error::io)
.trace(|amt| {
tracing::trace!("Received {} bytes", amt);
})
}
fn recv_slice(&mut self, slice: &mut [u8]) -> Result<usize> {
let span = tracing::trace_span!(
"{tyname}::recv_slice",
tyname = type_name::<Self>()
);
let _enter = span.enter();
($($inner)* self.inner)
.read(slice)
.map_err(Error::io)
.trace(|amt| {
tracing::trace!("Received {} bytes", amt);
})
}
fn flush(&mut self) -> Result<()> {
let span = tracing::trace_span!(
"{tyname}::flush",
tyname = type_name::<Self>()
);
let _enter = span.enter();
($($inner)* self.inner).flush().map_err(Error::io)
}
};
}
cfg_std_unix! {
macro_rules! impl_items_unix {
($($inner: tt)*) => {
impl_non_os_specific_items! { $($inner)* }
fn non_blocking_recv_slices_and_fds(
&mut self,
slices: &mut [IoSliceMut<'_>],
_fds: &mut Vec<Fd>,
) -> Result<usize> {
let span = tracing::trace_span!(
"{tyname}::non_blocking_recv_slices_and_fds",
tyname = type_name::<Self>()
);
let _enter = span.enter();
let raw_fd = self.inner.as_raw_fd();
let msg = socket::recvmsg::<()>(
raw_fd,
slices,
None,
socket::MsgFlags::MSG_DONTWAIT,
).map_err(Error::nix)?;
tracing::trace!("Received {} bytes", msg.bytes);
Ok(msg.bytes)
}
fn non_blocking_recv_slice_and_fds(
&mut self,
slice: &mut [u8],
_fds: &mut Vec<Fd>,
) -> Result<usize> {
let span = tracing::trace_span!(
"{tyname}::non_blocking_recv_slice_and_fds",
tyname = type_name::<Self>()
);
let _enter = span.enter();
let raw_fd = self.inner.as_raw_fd();
socket::recv(
raw_fd,
slice,
socket::MsgFlags::MSG_DONTWAIT
).map_err(Error::nix)
.trace(|amt| {
tracing::trace!("Received {} bytes", amt);
})
}
fn shutdown(&self) -> Result<()> {
let span = tracing::trace_span!(
"{tyname}::shutdown",
tyname = type_name::<Self>()
);
let _enter = span.enter();
let raw_fd = self.inner.as_raw_fd();
socket::shutdown(raw_fd, socket::Shutdown::Both).map_err(Error::nix)
}
}
}
impl<C: Read + Write + AsRawFd + ?Sized> Connection for StdConnection<C> {
impl_items_unix! { &mut }
}
impl<'a, C: AsRawFd + ?Sized> Connection for &'a StdConnection<C>
where &'a C: Read + Write
{
impl_items_unix! { & }
}
}
cfg_std_windows! {
impl<C: Read + Write + AsRawSocket> Connection for StdConnection<C> {
impl_non_os_specific_items! { &mut }
fn non_blocking_recv_slices_and_fds(
&mut self,
slices: &mut [IoSliceMut<'_>],
fds: &mut Vec<Fd>,
) -> Result<usize> {
if fionread::fionread(&*self).map_err(Error::io)? == 0 {
Err(Error::io(io::ErrorKind::WouldBlock.into()))
} else {
self.recv_slices_and_fds(slices, fds)
}
}
fn shutdown(&self) -> Result<()> {
socket2::SockRef::from(self).shutdown(net::Shutdown::Both).map_err(Error::io)
}
}
impl<'a, C: AsRawSocket> Connection for &'a StdConnection<C>
where &'a C: Read + Write
{
impl_non_os_specific_items! { & }
fn non_blocking_recv_slices_and_fds(
&mut self,
slices: &mut [IoSliceMut<'_>],
fds: &mut Vec<Fd>,
) -> Result<usize> {
if fionread::fionread(*self).map_err(Error::io)? == 0 {
Err(Error::io(io::ErrorKind::WouldBlock.into()))
} else {
self.recv_slices_and_fds(slices, fds)
}
}
fn shutdown(&self) -> Result<()> {
socket2::SockRef::from(*self).shutdown(net::Shutdown::Both).map_err(Error::io)
}
}
}
#[cfg(test)]
mod tests {
use super::StdConnection;
use crate::connection::Connection;
use std::io::{Read, Write};
#[cfg(unix)]
use std::os::unix::io::AsRawFd;
#[cfg(unix)]
fn pair() -> (impl Read + Write + AsRawFd, impl Read + Write + AsRawFd) {
std::os::unix::net::UnixStream::pair().unwrap()
}
#[cfg(windows)]
use std::os::windows::io::AsRawSocket;
#[cfg(windows)]
fn pair() -> (
impl Read + Write + AsRawSocket,
impl Read + Write + AsRawSocket,
) {
uds_windows::UnixStream::pair().unwrap()
}
#[test]
fn basic() {
let (left, right) = pair();
let mut writer = StdConnection::new(left);
let mut reader = StdConnection::new(right);
let data = b"Hello, world!";
writer.send_slice(data).unwrap();
let mut buffer = [0u8; 13];
reader.recv_slice(&mut buffer).unwrap();
assert_eq!(buffer, *data);
}
}