use std::io;
use std::net::{SocketAddr, ToSocketAddrs};
use std::os::unix::io::RawFd;
use libc;
use crate::runtime::g::WaitReason;
use crate::runtime::netpoll::{netpoll_arm, netpoll_unarm, POLL_READ, POLL_WRITE};
use crate::runtime::park::gopark;
fn nonblocking_tcp_socket(family: libc::c_int) -> io::Result<RawFd> {
#[cfg(target_os = "linux")]
let fd = unsafe { libc::socket(family, libc::SOCK_STREAM | libc::SOCK_NONBLOCK, 0) };
#[cfg(not(target_os = "linux"))]
let fd = unsafe { libc::socket(family, libc::SOCK_STREAM, 0) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
#[cfg(not(target_os = "linux"))]
{
let flags = unsafe { libc::fcntl(fd, libc::F_GETFL) };
if flags < 0
|| unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) } < 0
{
unsafe { libc::close(fd) };
return Err(io::Error::last_os_error());
}
}
Ok(fd)
}
fn set_reuseaddr(fd: RawFd) -> io::Result<()> {
let one: libc::c_int = 1;
let ret = unsafe {
libc::setsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_REUSEADDR,
&one as *const _ as *const libc::c_void,
std::mem::size_of::<libc::c_int>() as libc::socklen_t,
)
};
if ret < 0 {
Err(io::Error::last_os_error())
} else {
Ok(())
}
}
fn to_sockaddr(addr: SocketAddr) -> (libc::sockaddr_storage, libc::socklen_t) {
let mut storage: libc::sockaddr_storage = unsafe { std::mem::zeroed() };
match addr {
SocketAddr::V4(v4) => {
let sa: &mut libc::sockaddr_in =
unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in) };
sa.sin_family = libc::AF_INET as libc::sa_family_t;
sa.sin_port = v4.port().to_be();
sa.sin_addr.s_addr = u32::from_ne_bytes(v4.ip().octets());
(storage, std::mem::size_of::<libc::sockaddr_in>() as libc::socklen_t)
}
SocketAddr::V6(v6) => {
let sa: &mut libc::sockaddr_in6 =
unsafe { &mut *(&mut storage as *mut _ as *mut libc::sockaddr_in6) };
sa.sin6_family = libc::AF_INET6 as libc::sa_family_t;
sa.sin6_port = v6.port().to_be();
sa.sin6_addr.s6_addr = v6.ip().octets();
(storage, std::mem::size_of::<libc::sockaddr_in6>() as libc::socklen_t)
}
}
}
fn addr_family(addr: SocketAddr) -> libc::c_int {
match addr {
SocketAddr::V4(_) => libc::AF_INET,
SocketAddr::V6(_) => libc::AF_INET6,
}
}
unsafe fn park_on_fd(fd: RawFd, mode: u32) {
let gp = crate::runtime::g::current_g();
debug_assert!(!gp.is_null(), "park_on_fd: not running on a goroutine");
unsafe {
netpoll_arm(fd, mode, gp);
gopark(WaitReason::IOWait);
}
}
pub struct TcpListener {
fd: RawFd,
}
impl TcpListener {
pub fn bind<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let addr = addr
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
let fd = nonblocking_tcp_socket(addr_family(addr))?;
set_reuseaddr(fd)?;
let (sa, sa_len) = to_sockaddr(addr);
let ret = unsafe {
libc::bind(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
};
if ret < 0 {
unsafe { libc::close(fd) };
return Err(io::Error::last_os_error());
}
let ret = unsafe { libc::listen(fd, 128) };
if ret < 0 {
unsafe { libc::close(fd) };
return Err(io::Error::last_os_error());
}
Ok(TcpListener { fd })
}
pub fn accept(&self) -> io::Result<TcpStream> {
loop {
let cfd = unsafe {
libc::accept(self.fd, std::ptr::null_mut(), std::ptr::null_mut())
};
if cfd >= 0 {
let flags = unsafe { libc::fcntl(cfd, libc::F_GETFL) };
if flags >= 0 {
unsafe { libc::fcntl(cfd, libc::F_SETFL, flags | libc::O_NONBLOCK) };
}
return Ok(TcpStream { fd: cfd });
}
let err = io::Error::last_os_error();
match err.raw_os_error().unwrap_or(0) {
libc::EAGAIN => {
unsafe { park_on_fd(self.fd, POLL_READ) };
}
_ => return Err(err),
}
}
}
pub fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for TcpListener {
fn drop(&mut self) {
netpoll_unarm(self.fd);
unsafe { libc::close(self.fd) };
}
}
pub struct TcpStream {
fd: RawFd,
}
impl TcpStream {
pub fn connect<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
let addr = addr
.to_socket_addrs()?
.next()
.ok_or_else(|| io::Error::new(io::ErrorKind::InvalidInput, "no address given"))?;
let fd = nonblocking_tcp_socket(addr_family(addr))?;
let (sa, sa_len) = to_sockaddr(addr);
let ret = unsafe {
libc::connect(fd, &sa as *const _ as *const libc::sockaddr, sa_len)
};
if ret < 0 {
let err = io::Error::last_os_error();
match err.raw_os_error().unwrap_or(0) {
libc::EINPROGRESS | libc::EAGAIN => {
unsafe { park_on_fd(fd, POLL_WRITE) };
let mut so_err: libc::c_int = 0;
let mut len = std::mem::size_of::<libc::c_int>() as libc::socklen_t;
unsafe {
libc::getsockopt(
fd,
libc::SOL_SOCKET,
libc::SO_ERROR,
&mut so_err as *mut _ as *mut libc::c_void,
&mut len,
)
};
if so_err != 0 {
unsafe { libc::close(fd) };
return Err(io::Error::from_raw_os_error(so_err));
}
}
_ => {
unsafe { libc::close(fd) };
return Err(err);
}
}
}
Ok(TcpStream { fd })
}
pub fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
loop {
let n = unsafe {
libc::read(self.fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len())
};
if n >= 0 {
return Ok(n as usize);
}
let err = io::Error::last_os_error();
match err.raw_os_error().unwrap_or(0) {
libc::EAGAIN => {
unsafe { park_on_fd(self.fd, POLL_READ) };
}
_ => return Err(err),
}
}
}
pub fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
loop {
let n = unsafe {
libc::write(self.fd, buf.as_ptr() as *const libc::c_void, buf.len())
};
if n >= 0 {
return Ok(n as usize);
}
let err = io::Error::last_os_error();
match err.raw_os_error().unwrap_or(0) {
libc::EAGAIN => {
unsafe { park_on_fd(self.fd, POLL_WRITE) };
}
_ => return Err(err),
}
}
}
pub fn as_raw_fd(&self) -> RawFd {
self.fd
}
}
impl Drop for TcpStream {
fn drop(&mut self) {
netpoll_unarm(self.fd);
unsafe { libc::close(self.fd) };
}
}