use filedesc::FileDesc;
use std::io::{IoSlice, IoSliceMut};
use std::os::raw::{c_int, c_void};
use std::os::unix::io::{RawFd, AsRawFd, IntoRawFd, FromRawFd};
use crate::AsSocketAddress;
use crate::ancillary::SocketAncillary;
pub struct Socket<Address> {
fd: FileDesc,
_address: std::marker::PhantomData<fn() -> Address>,
}
#[cfg(not(any(target_os = "apple", target_os = "solaris")))]
mod extra_flags {
pub const SENDMSG: std::os::raw::c_int = libc::MSG_NOSIGNAL;
pub const RECVMSG: std::os::raw::c_int = libc::MSG_CMSG_CLOEXEC;
}
#[cfg(any(target_os = "apple", target_os = "solaris"))]
mod extra_flags {
pub const SENDMSG: std::os::raw::c_int = 0;
pub const RECVMSG: std::os::raw::c_int = 0;
}
impl<Address: AsSocketAddress> Socket<Address> {
fn wrap(fd: FileDesc) -> std::io::Result<Self> {
let wrapped = Self {
fd,
_address: std::marker::PhantomData,
};
#[cfg(target_os = "apple")]
wrapped.set_option(libc::SOL_SOCKET, libc::SO_NOSIGPIPE, 1 as c_int)?;
Ok(wrapped)
}
pub fn new(kind: c_int, protocol: c_int) -> std::io::Result<Self>
where
Address: crate::SpecificSocketAddress,
{
Self::new_generic(Address::static_family() as c_int, kind, protocol)
}
pub fn new_generic(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<Self> {
socket(domain, kind | libc::SOCK_CLOEXEC, protocol)
.or_else(|e| {
if e.raw_os_error() == Some(libc::EINVAL) {
let fd = socket(domain, kind, protocol)?;
fd.set_close_on_exec(true)?;
Ok(fd)
} else {
Err(e)
}
})
.and_then(Self::wrap)
}
pub fn pair(kind: c_int, protocol: c_int) -> std::io::Result<(Self, Self)>
where
Address: crate::SpecificSocketAddress,
{
Self::pair_generic(Address::static_family() as c_int, kind, protocol)
}
pub fn pair_generic(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<(Self, Self)> {
socketpair(domain, kind | libc::SOCK_CLOEXEC, protocol)
.or_else(|e| {
if e.raw_os_error() == Some(libc::EINVAL) {
let (a, b) = socketpair(domain, kind, protocol)?;
a.set_close_on_exec(true)?;
b.set_close_on_exec(true)?;
Ok((a, b))
} else {
Err(e)
}
})
.and_then(|(a, b)| {
Ok((Self::wrap(a)?, Self::wrap(b)?))
})
}
pub fn try_clone(&self) -> std::io::Result<Self> {
Ok(Self {
fd: self.fd.duplicate()?,
_address: std::marker::PhantomData,
})
}
pub unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self {
fd: FileDesc::from_raw_fd(fd),
_address: std::marker::PhantomData,
}
}
pub fn as_raw_fd(&self) -> RawFd {
self.fd.as_raw_fd()
}
pub fn into_raw_fd(self) -> RawFd {
self.fd.into_raw_fd()
}
fn set_option<T: Copy>(&self, level: c_int, option: c_int, value: T) -> std::io::Result<()> {
unsafe {
let value = &value as *const T as *const c_void;
let length = std::mem::size_of::<T>() as libc::socklen_t;
check_ret(libc::setsockopt(self.as_raw_fd(), level, option, value, length))?;
Ok(())
}
}
fn get_option<T: Copy>(&self, level: c_int, option: c_int) -> std::io::Result<T> {
unsafe {
let mut output = std::mem::MaybeUninit::zeroed();
let output_ptr = output.as_mut_ptr() as *mut c_void;
let mut length = std::mem::size_of::<T>() as libc::socklen_t;
check_ret(libc::getsockopt(self.as_raw_fd(), level, option, output_ptr, &mut length))?;
assert_eq!(length, std::mem::size_of::<T>() as libc::socklen_t);
Ok(output.assume_init())
}
}
pub fn set_nonblocking(&self, non_blocking: bool) -> std::io::Result<()> {
self.set_option(libc::SOL_SOCKET, libc::O_NONBLOCK, bool_to_c_int(non_blocking))
}
pub fn get_nonblocking(&self) -> std::io::Result<bool> {
let raw: c_int = self.get_option(libc::SOL_SOCKET, libc::O_NONBLOCK)?;
Ok(raw != 0)
}
pub fn take_error(&self) -> std::io::Result<Option<std::io::Error>> {
let raw: c_int = self.get_option(libc::SOL_SOCKET, libc::SO_ERROR)?;
if raw == 0 {
Ok(None)
} else {
Ok(Some(std::io::Error::from_raw_os_error(raw)))
}
}
pub fn local_addr(&self) -> std::io::Result<Address> {
unsafe {
let mut address = std::mem::MaybeUninit::<Address>::zeroed();
let mut len = Address::max_len();
check_ret(libc::getsockname(self.as_raw_fd(), Address::as_sockaddr_mut(&mut address), &mut len))?;
Address::finalize(address, len)
}
}
pub fn peer_addr(&self) -> std::io::Result<Address> {
unsafe {
let mut address = std::mem::MaybeUninit::<Address>::zeroed();
let mut len = Address::max_len();
check_ret(libc::getpeername(self.as_raw_fd(), Address::as_sockaddr_mut(&mut address), &mut len))?;
Address::finalize(address, len)
}
}
pub fn connect(&self, address: &Address) -> std::io::Result<()> {
unsafe {
check_ret(libc::connect(self.as_raw_fd(), address.as_sockaddr(), address.len()))?;
Ok(())
}
}
pub fn bind(&self, address: &Address) -> std::io::Result<()> {
unsafe {
check_ret(libc::bind(self.as_raw_fd(), address.as_sockaddr(), address.len()))?;
Ok(())
}
}
pub fn listen(&self, backlog: c_int) -> std::io::Result<()> {
unsafe {
check_ret(libc::listen(self.as_raw_fd(), backlog))?;
Ok(())
}
}
pub fn accept(&self) -> std::io::Result<(Self, Address)> {
unsafe {
let mut address = std::mem::MaybeUninit::zeroed();
let mut len = Address::max_len();
let fd = check_ret(libc::accept4(self.as_raw_fd(), Address::as_sockaddr_mut(&mut address), &mut len, libc::SOCK_CLOEXEC))?;
let socket = Self::wrap(FileDesc::from_raw_fd(fd))?;
let address = Address::finalize(address, len)?;
Ok((socket, address))
}
}
pub fn send(&self, data: &[u8], flags: c_int) -> std::io::Result<usize> {
unsafe {
let data_ptr = data.as_ptr() as *const c_void;
let transferred = check_ret_isize(libc::send(self.as_raw_fd(), data_ptr, data.len(), flags | extra_flags::SENDMSG))?;
Ok(transferred as usize)
}
}
pub fn send_to(&self, data: &[u8], address: &Address, flags: c_int) -> std::io::Result<usize> {
unsafe {
let data_ptr = data.as_ptr() as *const c_void;
let transferred = check_ret_isize(libc::sendto(
self.as_raw_fd(),
data_ptr,
data.len(),
flags | extra_flags::SENDMSG,
address.as_sockaddr(), address.len()
))?;
Ok(transferred as usize)
}
}
pub fn send_msg(&self, data: &[IoSlice], cdata: Option<&[u8]>, flags: c_int) -> std::io::Result<usize> {
unsafe {
let mut header = std::mem::zeroed::<libc::msghdr>();
header.msg_iov = data.as_ptr() as *mut libc::iovec;
header.msg_iovlen = data.len();
header.msg_control = cdata.map(|x| x.as_ptr()).unwrap_or(std::ptr::null()) as *mut c_void;
header.msg_controllen = cdata.map(|x| x.len()).unwrap_or(0);
let ret = check_ret_isize(libc::sendmsg(self.as_raw_fd(), &header, flags | extra_flags::SENDMSG))?;
Ok(ret as usize)
}
}
pub fn send_msg_to(&self, address: &Address, data: &[IoSlice], cdata: Option<&[u8]>, flags: c_int) -> std::io::Result<usize> {
unsafe {
let mut header = std::mem::zeroed::<libc::msghdr>();
header.msg_name = address.as_sockaddr() as *mut c_void;
header.msg_namelen = address.len();
header.msg_iov = data.as_ptr() as *mut libc::iovec;
header.msg_iovlen = data.len();
header.msg_control = cdata.map(|x| x.as_ptr()).unwrap_or(std::ptr::null()) as *mut c_void;
header.msg_controllen = cdata.map(|x| x.len()).unwrap_or(0);
let ret = check_ret_isize(libc::sendmsg(self.as_raw_fd(), &header, flags | extra_flags::SENDMSG))?;
Ok(ret as usize)
}
}
pub fn recv(&self, buffer: &mut [u8], flags: c_int) -> std::io::Result<usize> {
unsafe {
let buffer_ptr = buffer.as_mut_ptr() as *mut c_void;
let transferred = check_ret_isize(libc::recv(self.as_raw_fd(), buffer_ptr, buffer.len(), flags | extra_flags::RECVMSG))?;
Ok(transferred as usize)
}
}
pub fn recv_from(&self, buffer: &mut [u8], flags: c_int) -> std::io::Result<(Address, usize)> {
unsafe {
let buffer_ptr = buffer.as_mut_ptr() as *mut c_void;
let mut address = std::mem::MaybeUninit::zeroed();
let mut address_len = Address::max_len();
let transferred = check_ret_isize(libc::recvfrom(
self.as_raw_fd(),
buffer_ptr,
buffer.len(),
flags,
Address::as_sockaddr_mut(&mut address),
&mut address_len
))?;
let address = Address::finalize(address, address_len)?;
Ok((address, transferred as usize))
}
}
pub fn recv_msg(&self, data: &[IoSliceMut], cdata: &mut SocketAncillary, flags: c_int) -> std::io::Result<(usize, c_int)> {
let (cdata_buf, cdata_len) = if cdata.capacity() == 0 {
(std::ptr::null_mut(), 0)
} else {
(cdata.buffer.as_mut_ptr(), cdata.capacity())
};
unsafe {
let mut header = std::mem::zeroed::<libc::msghdr>();
header.msg_iov = data.as_ptr() as *mut libc::iovec;
header.msg_iovlen = data.len();
header.msg_control = cdata_buf as *mut c_void;
header.msg_controllen = cdata_len;
let ret = check_ret_isize(libc::recvmsg(self.as_raw_fd(), &mut header, flags | extra_flags::RECVMSG))?;
cdata.length = header.msg_controllen as usize;
cdata.truncated = header.msg_flags & libc::MSG_CTRUNC != 0;
Ok((ret as usize, header.msg_flags))
}
}
pub fn recv_msg_from(&self, data: &[IoSliceMut], cdata: &mut SocketAncillary, flags: c_int) -> std::io::Result<(Address, usize, c_int)> {
let (cdata_buf, cdata_len) = if cdata.capacity() == 0 {
(std::ptr::null_mut(), 0)
} else {
(cdata.buffer.as_mut_ptr(), cdata.capacity())
};
unsafe {
let mut address = std::mem::MaybeUninit::zeroed();
let mut header = std::mem::zeroed::<libc::msghdr>();
header.msg_name = Address::as_sockaddr_mut(&mut address) as *mut c_void;
header.msg_namelen = Address::max_len();
header.msg_iov = data.as_ptr() as *mut libc::iovec;
header.msg_iovlen = data.len();
header.msg_control = cdata_buf as *mut c_void;
header.msg_controllen = cdata_len;
let ret = check_ret_isize(libc::recvmsg(self.as_raw_fd(), &mut header, flags | extra_flags::RECVMSG))?;
let address = Address::finalize(address, header.msg_namelen)?;
cdata.length = header.msg_controllen as usize;
cdata.truncated = header.msg_flags & libc::MSG_CTRUNC != 0;
Ok((address, ret as usize, header.msg_flags))
}
}
}
impl<Address: AsSocketAddress> FromRawFd for Socket<Address> {
unsafe fn from_raw_fd(fd: RawFd) -> Self {
Self::from_raw_fd(fd)
}
}
impl<Address: AsSocketAddress> AsRawFd for Socket<Address> {
fn as_raw_fd(&self) -> RawFd {
self.as_raw_fd()
}
}
impl<Address: AsSocketAddress> AsRawFd for &'_ Socket<Address> {
fn as_raw_fd(&self) -> RawFd {
(*self).as_raw_fd()
}
}
impl<Address: AsSocketAddress> IntoRawFd for Socket<Address> {
fn into_raw_fd(self) -> RawFd {
self.into_raw_fd()
}
}
fn check_ret(ret: c_int) -> std::io::Result<c_int> {
if ret == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(ret)
}
}
fn check_ret_isize(ret: isize) -> std::io::Result<isize> {
if ret == -1 {
Err(std::io::Error::last_os_error())
} else {
Ok(ret)
}
}
fn socket(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<FileDesc> {
unsafe {
let fd = check_ret(libc::socket(domain, kind, protocol))?;
Ok(FileDesc::from_raw_fd(fd))
}
}
fn socketpair(domain: c_int, kind: c_int, protocol: c_int) -> std::io::Result<(FileDesc, FileDesc)> {
unsafe {
let mut fds = [0; 2];
check_ret(libc::socketpair(domain, kind, protocol, fds.as_mut_ptr()))?;
Ok((
FileDesc::from_raw_fd(fds[0]),
FileDesc::from_raw_fd(fds[1]),
))
}
}
fn bool_to_c_int(value: bool) -> c_int {
if value {
1
} else {
0
}
}