use core::ffi::{c_int, c_void};
use core::mem::size_of;
use core::ptr;
use std::io;
pub(crate) mod ffi;
#[cfg(any(target_os = "macos", target_os = "ios"))]
mod poller_kq;
#[cfg(target_os = "linux")]
mod poller_ep;
#[cfg(any(target_os = "macos", target_os = "ios"))]
pub use poller_kq::Poller;
#[cfg(target_os = "linux")]
pub use poller_ep::Poller;
const AF_INET: c_int = 2;
const SOCK_STREAM: c_int = 1;
const IPPROTO_TCP: c_int = 6;
const TCP_NODELAY: c_int = 1;
const F_GETFL: c_int = 3;
const F_SETFL: c_int = 4;
#[cfg(target_os = "linux")]
const SOL_SOCKET: c_int = 1;
#[cfg(target_os = "linux")]
const SO_REUSEADDR: c_int = 2;
#[cfg(target_os = "linux")]
const SO_REUSEPORT: c_int = 15;
#[cfg(target_os = "linux")]
const O_NONBLOCK: c_int = 0x800;
#[cfg(any(target_os = "macos", target_os = "ios"))]
const SOL_SOCKET: c_int = 0xffff;
#[cfg(any(target_os = "macos", target_os = "ios"))]
const SO_REUSEADDR: c_int = 0x0004;
#[cfg(any(target_os = "macos", target_os = "ios"))]
const SO_REUSEPORT: c_int = 0x0200;
#[cfg(any(target_os = "macos", target_os = "ios"))]
const O_NONBLOCK: c_int = 0x0004;
#[cfg(target_os = "linux")]
#[repr(C)]
struct SockaddrIn {
sin_family: u16,
sin_port: u16,
sin_addr: u32,
sin_zero: [u8; 8],
}
#[cfg(any(target_os = "macos", target_os = "ios"))]
#[repr(C)]
struct SockaddrIn {
sin_len: u8,
sin_family: u8,
sin_port: u16,
sin_addr: u32,
sin_zero: [u8; 8],
}
impl SockaddrIn {
fn new(ip: [u8; 4], port: u16) -> Self {
#[cfg(target_os = "linux")]
return SockaddrIn {
sin_family: AF_INET as u16,
sin_port: port.to_be(),
sin_addr: u32::from_ne_bytes(ip),
sin_zero: [0; 8],
};
#[cfg(any(target_os = "macos", target_os = "ios"))]
return SockaddrIn {
sin_len: size_of::<SockaddrIn>() as u8,
sin_family: AF_INET as u8,
sin_port: port.to_be(),
sin_addr: u32::from_ne_bytes(ip),
sin_zero: [0; 8],
};
}
fn zeroed() -> Self {
unsafe { core::mem::zeroed() }
}
}
pub struct Socket {
fd: c_int,
}
impl Socket {
#[inline]
pub fn raw(&self) -> i32 {
self.fd
}
#[inline]
pub unsafe fn from_raw_fd(fd: i32) -> Socket {
Socket { fd }
}
pub fn accept(&self) -> io::Result<Socket> {
let fd = unsafe { ffi::accept(self.fd, ptr::null_mut(), ptr::null_mut()) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
Ok(Socket { fd })
}
pub fn read(&self, buf: &mut [u8]) -> io::Result<usize> {
loop {
let n = unsafe { ffi::read(self.fd, buf.as_mut_ptr() as *mut c_void, buf.len()) };
if n < 0 {
let e = io::Error::last_os_error();
if e.kind() == io::ErrorKind::Interrupted {
continue;
}
return Err(e);
}
return Ok(n as usize);
}
}
pub fn write(&self, buf: &[u8]) -> io::Result<usize> {
loop {
let n = unsafe { ffi::write(self.fd, buf.as_ptr() as *const c_void, buf.len()) };
if n < 0 {
let e = io::Error::last_os_error();
if e.kind() == io::ErrorKind::Interrupted {
continue;
}
return Err(e);
}
return Ok(n as usize);
}
}
pub fn write_all(&self, mut buf: &[u8]) -> io::Result<()> {
while !buf.is_empty() {
let n = self.write(buf)?;
if n == 0 {
return Err(io::Error::new(io::ErrorKind::WriteZero, "write returned 0"));
}
buf = &buf[n..];
}
Ok(())
}
pub fn set_nonblocking(&self) -> io::Result<()> {
set_fd_nonblocking(self.fd)
}
pub fn set_nodelay(&self) -> io::Result<()> {
let one: c_int = 1;
let r = unsafe {
ffi::setsockopt(
self.fd,
IPPROTO_TCP,
TCP_NODELAY,
&one as *const c_int as *const c_void,
size_of::<c_int>() as u32,
)
};
if r < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
pub fn local_port(&self) -> io::Result<u16> {
let mut addr = SockaddrIn::zeroed();
let mut len = size_of::<SockaddrIn>() as u32;
let r = unsafe {
ffi::getsockname(
self.fd,
&mut addr as *mut SockaddrIn as *mut c_void,
&mut len,
)
};
if r < 0 {
return Err(io::Error::last_os_error());
}
Ok(u16::from_be(addr.sin_port))
}
}
impl Drop for Socket {
fn drop(&mut self) {
unsafe {
ffi::close(self.fd);
}
}
}
fn set_fd_nonblocking(fd: c_int) -> io::Result<()> {
let flags = unsafe { ffi::fcntl(fd, F_GETFL, 0) };
if flags < 0 {
return Err(io::Error::last_os_error());
}
if unsafe { ffi::fcntl(fd, F_SETFL, flags | O_NONBLOCK) } < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
fn setsockopt_int(fd: c_int, level: c_int, name: c_int, val: c_int) -> io::Result<()> {
let r = unsafe {
ffi::setsockopt(
fd,
level,
name,
&val as *const c_int as *const c_void,
size_of::<c_int>() as u32,
)
};
if r < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
fn listen_inner(ip: [u8; 4], port: u16, backlog: i32, reuseport: bool) -> io::Result<Socket> {
let fd = unsafe { ffi::socket(AF_INET, SOCK_STREAM, 0) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
let sock = Socket { fd };
setsockopt_int(fd, SOL_SOCKET, SO_REUSEADDR, 1)?;
if reuseport {
setsockopt_int(fd, SOL_SOCKET, SO_REUSEPORT, 1)?;
}
let addr = SockaddrIn::new(ip, port);
let r = unsafe {
ffi::bind(
fd,
&addr as *const SockaddrIn as *const c_void,
size_of::<SockaddrIn>() as u32,
)
};
if r < 0 {
return Err(io::Error::last_os_error());
}
if unsafe { ffi::listen(fd, backlog) } < 0 {
return Err(io::Error::last_os_error());
}
Ok(sock)
}
pub fn tcp_listen(ip: [u8; 4], port: u16, backlog: i32) -> io::Result<Socket> {
listen_inner(ip, port, backlog, false)
}
pub fn tcp_listen_reuseport(ip: [u8; 4], port: u16, backlog: i32) -> io::Result<Socket> {
listen_inner(ip, port, backlog, true)
}
pub struct Waker {
read_fd: c_int,
write_fd: c_int,
}
pub fn waker() -> io::Result<Waker> {
let mut fds = [0 as c_int; 2];
if unsafe { ffi::pipe(fds.as_mut_ptr()) } < 0 {
return Err(io::Error::last_os_error());
}
let w = Waker {
read_fd: fds[0],
write_fd: fds[1],
};
set_fd_nonblocking(w.read_fd)?;
set_fd_nonblocking(w.write_fd)?;
Ok(w)
}
impl Waker {
#[inline]
pub fn read_fd(&self) -> i32 {
self.read_fd
}
pub fn wake(&self) -> io::Result<()> {
let byte = [1u8];
loop {
let n = unsafe { ffi::write(self.write_fd, byte.as_ptr() as *const c_void, 1) };
if n < 0 {
let e = io::Error::last_os_error();
match e.kind() {
io::ErrorKind::Interrupted => continue,
io::ErrorKind::WouldBlock => return Ok(()),
_ => return Err(e),
}
}
return Ok(());
}
}
pub fn drain(&self) {
let mut buf = [0u8; 64];
loop {
let n = unsafe { ffi::read(self.read_fd, buf.as_mut_ptr() as *mut c_void, buf.len()) };
if n <= 0 {
break; }
}
}
}
impl Drop for Waker {
fn drop(&mut self) {
unsafe {
ffi::close(self.read_fd);
ffi::close(self.write_fd);
}
}
}
unsafe impl Send for Waker {}
unsafe impl Sync for Waker {}
#[derive(Debug, Clone, Copy)]
pub struct Event {
pub fd: i32,
pub readable: bool,
pub writable: bool,
pub hup: bool,
}
const WAIT_CAPACITY: usize = 1024;
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Read, Write};
#[test]
fn listen_accept_roundtrip() {
let listener = tcp_listen([127, 0, 0, 1], 0, 16).unwrap();
let port = listener.local_port().unwrap();
assert_ne!(port, 0);
let server = std::thread::spawn(move || {
let conn = listener.accept().unwrap();
let mut b = [0u8; 1];
assert_eq!(conn.read(&mut b).unwrap(), 1);
conn.write_all(&b).unwrap();
});
let mut client = std::net::TcpStream::connect(("127.0.0.1", port)).unwrap();
client.write_all(b"Z").unwrap();
let mut got = [0u8; 1];
assert_eq!(client.read(&mut got).unwrap(), 1);
assert_eq!(&got, b"Z");
server.join().unwrap();
}
#[test]
fn poller_signals_listener_readable() {
let listener = tcp_listen([127, 0, 0, 1], 0, 16).unwrap();
listener.set_nonblocking().unwrap();
let port = listener.local_port().unwrap();
let poller = Poller::new().unwrap();
poller.add(listener.raw(), true, false).unwrap();
let _client = std::net::TcpStream::connect(("127.0.0.1", port)).unwrap();
let mut events = Vec::new();
let n = poller.wait(&mut events, Some(2000)).unwrap();
assert!(n >= 1, "expected a readiness event");
assert!(events.iter().any(|e| e.fd == listener.raw() && e.readable));
listener.accept().unwrap();
}
#[test]
fn waker_wakes_poller() {
let w = std::sync::Arc::new(waker().unwrap());
let poller = Poller::new().unwrap();
poller.add(w.read_fd(), true, false).unwrap();
let w2 = w.clone();
std::thread::spawn(move || w2.wake().unwrap());
let mut events = Vec::new();
let n = poller.wait(&mut events, Some(2000)).unwrap();
assert!(n >= 1, "waker should have woken the poller");
assert!(events.iter().any(|e| e.fd == w.read_fd() && e.readable));
w.drain();
}
#[test]
fn reuseport_allows_shared_port() {
let l1 = tcp_listen_reuseport([127, 0, 0, 1], 0, 16).unwrap();
let port = l1.local_port().unwrap();
let l2 = tcp_listen_reuseport([127, 0, 0, 1], port, 16).unwrap();
assert_eq!(l2.local_port().unwrap(), port);
}
}