#![allow(dead_code)]
use std::cell::RefCell;
use std::collections::{HashMap, VecDeque};
use std::future::Future;
use std::io::{self, Read, Write};
use std::mem::MaybeUninit;
use std::net::{SocketAddr, TcpListener as StdTcpListener, TcpStream as StdTcpStream};
use std::os::fd::{AsRawFd, RawFd};
use std::pin::Pin;
use std::task::Poll;
use std::future::poll_fn;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Wake, Waker};
use std::thread::{self, Thread};
use std::time::Duration;
type Result<T> = std::io::Result<T>;
type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
#[derive(Debug)]
pub enum Error {
Io(io::Error),
Closed,
TimedOut,
WouldBlock,
}
impl From<io::Error> for Error {
fn from(err: io::Error) -> Self {
Error::Io(err)
}
}
#[derive(Clone, Copy, Debug, PartialEq)]
struct Ready(u8);
impl Ready {
const EMPTY: Ready = Ready(0);
const READABLE: Ready = Ready(1);
const WRITABLE: Ready = Ready(2);
fn is_readable(self) -> bool {
self.0 & Self::READABLE.0 != 0
}
fn is_writable(self) -> bool {
self.0 & Self::WRITABLE.0 != 0
}
}
#[derive(Debug)]
struct IoRegistration {
token: usize,
interests: Ready,
waker: Option<Waker>,
}
thread_local! {
static CURRENT: RefCell<Option<Arc<Runtime>>> = RefCell::new(None);
}
#[cfg(target_os = "linux")]
mod sys {
use super::*;
const EPOLLIN: i32 = 0x001;
const EPOLLOUT: i32 = 0x004;
const EPOLLET: i32 = 1 << 31;
const EPOLL_CTL_ADD: i32 = 1;
const EPOLL_CTL_DEL: i32 = 2;
const EPOLL_CTL_MOD: i32 = 3;
#[repr(C, packed)]
struct EpollEvent {
events: u32,
data: u64,
}
extern "C" {
fn epoll_create1(flags: i32) -> i32;
fn epoll_ctl(epfd: i32, op: i32, fd: i32, event: *mut EpollEvent) -> i32;
fn epoll_wait(epfd: i32, events: *mut EpollEvent, maxevents: i32, timeout: i32) -> i32;
}
pub struct Selector {
epoll_fd: RawFd,
}
impl Selector {
pub fn new() -> io::Result<Self> {
let fd = unsafe { epoll_create1(0) };
if fd < 0 {
return Err(io::Error::last_os_error());
}
Ok(Selector { epoll_fd: fd })
}
pub fn register(&self, fd: RawFd, token: usize, interests: Ready) -> io::Result<()> {
let mut event = EpollEvent {
events: interests_to_epoll(interests),
data: token as u64,
};
let op = EPOLL_CTL_ADD;
let ret = unsafe {
epoll_ctl(self.epoll_fd, op, fd, &mut event)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
pub fn reregister(&self, fd: RawFd, token: usize, interests: Ready) -> io::Result<()> {
let mut event = EpollEvent {
events: interests_to_epoll(interests),
data: token as u64,
};
let ret = unsafe {
epoll_ctl(self.epoll_fd, EPOLL_CTL_MOD, fd, &mut event)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
pub fn deregister(&self, fd: RawFd) -> io::Result<()> {
let mut event = EpollEvent {
events: 0,
data: 0,
};
let ret = unsafe {
epoll_ctl(self.epoll_fd, EPOLL_CTL_DEL, fd, &mut event)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
pub fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
let timeout_ms = timeout
.map(|t| t.as_millis() as i32)
.unwrap_or(-1);
let n = unsafe {
epoll_wait(
self.epoll_fd,
events.as_mut_ptr() as *mut EpollEvent,
events.capacity() as i32,
timeout_ms,
)
};
if n < 0 {
return Err(io::Error::last_os_error());
}
unsafe { events.set_len(n as usize) };
Ok(())
}
}
fn interests_to_epoll(interests: Ready) -> u32 {
let mut kind = EPOLLET;
if interests.is_readable() {
kind |= EPOLLIN;
}
if interests.is_writable() {
kind |= EPOLLOUT;
}
kind as u32
}
}
#[cfg(target_os = "macos")]
mod sys {
use super::*;
const EVFILT_READ: i16 = -1;
const EVFILT_WRITE: i16 = -2;
const EV_ADD: u16 = 0x1;
const EV_DELETE: u16 = 0x2;
const EV_ENABLE: u16 = 0x4;
const EV_DISABLE: u16 = 0x8;
const EV_CLEAR: u16 = 0x20;
#[repr(C)]
struct kevent {
ident: usize,
filter: i16,
flags: u16,
fflags: u32,
data: isize,
udata: *mut libc::c_void,
}
extern "C" {
fn kqueue() -> i32;
fn kevent(
kq: i32,
changelist: *const kevent,
nchanges: i32,
eventlist: *mut kevent,
nevents: i32,
timeout: *const timespec,
) -> i32;
}
#[repr(C)]
struct timespec {
tv_sec: isize,
tv_nsec: isize,
}
pub struct Selector {
kq: RawFd,
}
impl Selector {
pub fn new() -> io::Result<Self> {
let fd = unsafe { kqueue() };
if fd < 0 {
return Err(io::Error::last_os_error());
}
Ok(Selector { kq: fd })
}
pub fn register(&self, fd: RawFd, token: usize, interests: Ready) -> io::Result<()> {
let mut changes = Vec::with_capacity(2);
if interests.is_readable() {
changes.push(kevent {
ident: fd as usize,
filter: EVFILT_READ,
flags: EV_ADD | EV_ENABLE | EV_CLEAR,
fflags: 0,
data: 0,
udata: token as *mut _,
});
}
if interests.is_writable() {
changes.push(kevent {
ident: fd as usize,
filter: EVFILT_WRITE,
flags: EV_ADD | EV_ENABLE | EV_CLEAR,
fflags: 0,
data: 0,
udata: token as *mut _,
});
}
let ret = unsafe {
kevent(
self.kq,
changes.as_ptr(),
changes.len() as i32,
std::ptr::null_mut(),
0,
std::ptr::null(),
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
pub fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
let timeout = timeout.map(|t| timespec {
tv_sec: t.as_secs() as isize,
tv_nsec: t.subsec_nanos() as isize,
});
let n = unsafe {
kevent(
self.kq,
std::ptr::null(),
0,
events.as_mut_ptr() as *mut kevent,
events.capacity() as i32,
timeout.as_ref().map(|t| t as *const timespec).unwrap_or(std::ptr::null()),
)
};
if n < 0 {
return Err(io::Error::last_os_error());
}
unsafe { events.set_len(n as usize) };
Ok(())
}
}
}
#[cfg(windows)]
mod sys {
use super::*;
use std::os::windows::io::RawHandle;
#[repr(C)]
struct OVERLAPPED {
internal: usize,
internal_high: usize,
offset: u32,
offset_high: u32,
event: RawHandle,
}
extern "system" {
fn CreateIoCompletionPort(
file_handle: RawHandle,
existing_completion_port: RawHandle,
completion_key: usize,
concurrent_threads: u32,
) -> RawHandle;
fn GetQueuedCompletionStatus(
completion_port: RawHandle,
bytes_transferred: *mut u32,
completion_key: *mut usize,
overlapped: *mut *mut OVERLAPPED,
milliseconds: u32,
) -> i32;
}
pub struct Selector {
iocp: RawHandle,
}
impl Selector {
pub fn new() -> io::Result<Self> {
let iocp = unsafe {
CreateIoCompletionPort(
-1 as _,
0 as _,
0,
0
)
};
if iocp == 0 as _ {
return Err(io::Error::last_os_error());
}
Ok(Selector { iocp })
}
pub fn register(&self, handle: RawHandle, token: usize) -> io::Result<()> {
let ret = unsafe {
CreateIoCompletionPort(
handle,
self.iocp,
token,
0
)
};
if ret == 0 as _ {
return Err(io::Error::last_os_error());
}
Ok(())
}
pub fn select(&self, events: &mut Events, timeout: Option<Duration>) -> io::Result<()> {
let mut bytes_transferred = 0;
let mut completion_key = 0;
let mut overlapped = std::ptr::null_mut();
let timeout_ms = timeout
.map(|t| t.as_millis() as u32)
.unwrap_or(u32::MAX);
let ret = unsafe {
GetQueuedCompletionStatus(
self.iocp,
&mut bytes_transferred,
&mut completion_key,
&mut overlapped,
timeout_ms,
)
};
if ret == 0 {
let error = io::Error::last_os_error();
if overlapped.is_null() && error.kind() == io::ErrorKind::TimedOut {
return Ok(());
}
return Err(error);
}
events.push(Event {
token: completion_key,
ready: Ready(Ready::READABLE.0 | Ready::WRITABLE.0),
});
Ok(())
}
}
}
struct Events {
events: Box<[MaybeUninit<Event>]>,
len: usize,
}
impl Events {
fn with_capacity(capacity: usize) -> Self {
let mut events = Vec::with_capacity(capacity);
events.resize_with(capacity, MaybeUninit::uninit);
Events {
events: events.into_boxed_slice(),
len: 0,
}
}
fn as_mut_ptr(&mut self) -> *mut Event {
self.events.as_mut_ptr() as *mut Event
}
fn capacity(&self) -> usize {
self.events.len()
}
unsafe fn set_len(&mut self, len: usize) {
self.len = len;
}
fn push(&mut self, event: Event) {
if self.len < self.capacity() {
self.events[self.len] = MaybeUninit::new(event);
self.len += 1;
}
}
}
#[derive(Debug)]
struct Event {
token: usize,
ready: Ready,
}
struct Task {
future: Mutex<BoxFuture<()>>,
executor: Arc<Executor>,
}
impl Wake for Task {
fn wake(self: Arc<Self>) {
self.executor.schedule(self.clone());
}
}
pub struct JoinHandle<T> {
task: Arc<Task>,
_phantom: std::marker::PhantomData<T>,
}
#[allow(dead_code)]
impl<T> Future for JoinHandle<T> {
type Output = T;
fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<T> {
Poll::Pending }
}
struct Executor {
tasks: Mutex<VecDeque<Arc<Task>>>,
selector: sys::Selector,
wakers: Mutex<HashMap<usize, Waker>>,
parker: Parker,
}
impl Executor {
fn new() -> io::Result<Self> {
Ok(Executor {
tasks: Mutex::new(VecDeque::new()),
selector: sys::Selector::new()?,
wakers: Mutex::new(HashMap::new()),
parker: Parker::new(),
})
}
fn schedule(&self, task: Arc<Task>) {
self.tasks.lock().unwrap().push_back(task);
self.parker.unpark();
}
fn run(&self) {
let mut events = Events::with_capacity(1024);
loop {
while let Some(task) = self.tasks.lock().unwrap().pop_front() {
let waker = Waker::from(task.clone());
let mut cx = Context::from_waker(&waker);
let _ = task.future.lock().unwrap().as_mut().poll(&mut cx);
}
match self.selector.select(&mut events, Some(Duration::from_millis(100))) {
Ok(()) => {
for i in 0..events.len {
let event = unsafe { events.events[i].assume_init_read() };
if let Some(waker) = self.wakers.lock().unwrap().get(&event.token) {
waker.wake_by_ref();
}
}
}
Err(ref e) if e.kind() == io::ErrorKind::Interrupted => continue,
Err(e) => panic!("polling error: {}", e),
}
if self.tasks.lock().unwrap().is_empty() {
self.parker.park();
}
}
}
fn run_once(&self) {
while let Some(task) = self.tasks.lock().unwrap().pop_front() {
let waker = Waker::from(task.clone());
let mut cx = Context::from_waker(&waker);
let _ = task.future.lock().unwrap().as_mut().poll(&mut cx);
}
let mut events = Events::with_capacity(1024);
if let Ok(()) = self.selector.select(&mut events, Some(Duration::from_millis(10))) {
for i in 0..events.len {
let event = unsafe { events.events[i].assume_init_read() };
if let Some(waker) = self.wakers.lock().unwrap().get(&event.token) {
waker.wake_by_ref();
}
}
}
}
fn register_io(&self, token: usize, waker: Waker) {
self.wakers.lock().unwrap().insert(token, waker);
}
}
struct Parker {
thread: Thread,
unparked: AtomicBool,
}
impl Parker {
fn new() -> Self {
Parker {
thread: thread::current(),
unparked: AtomicBool::new(false),
}
}
fn park(&self) {
if !self.unparked.swap(false, Ordering::SeqCst) {
thread::park();
}
}
fn unpark(&self) {
self.unparked.store(true, Ordering::SeqCst);
self.thread.unpark();
}
}
pub struct Runtime {
executor: Arc<Executor>,
shutdown: Arc<AtomicBool>,
}
impl Runtime {
pub fn new() -> io::Result<Self> {
Ok(Runtime {
executor: Arc::new(Executor::new()?),
shutdown: Arc::new(AtomicBool::new(false)),
})
}
pub fn block_on<F: Future>(&self, future: F) -> F::Output {
let mut future = Box::pin(future);
let waker = Waker::from(Arc::new(Task {
future: Mutex::new(Box::pin(async {})),
executor: self.executor.clone(),
}));
let mut cx = Context::from_waker(&waker);
let executor = self.executor.clone();
let shutdown = self.shutdown.clone();
thread::spawn(move || {
while !shutdown.load(Ordering::SeqCst) {
executor.run_once();
}
});
loop {
match future.as_mut().poll(&mut cx) {
Poll::Ready(output) => {
self.shutdown.store(true, Ordering::SeqCst);
return output;
}
Poll::Pending => thread::park(),
}
}
}
pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let task = Arc::new(Task {
future: Mutex::new(Box::pin(async move {
let _ = future.await;
})),
executor: self.executor.clone(),
});
self.executor.schedule(task.clone());
JoinHandle {
task,
_phantom: std::marker::PhantomData,
}
}
}
pub trait AsyncRead {
fn poll_read(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>>;
}
pub trait AsyncWrite {
fn poll_write(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>>;
fn poll_flush(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>>;
fn poll_close(
self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<io::Result<()>>;
}
pub struct TcpListener {
inner: StdTcpListener,
executor: Arc<Executor>,
token: usize,
}
static NEXT_TOKEN: AtomicUsize = AtomicUsize::new(1);
fn next_token() -> usize {
NEXT_TOKEN.fetch_add(1, Ordering::Relaxed)
}
impl TcpListener {
pub async fn bind(addr: impl Into<SocketAddr>) -> io::Result<Self> {
let inner = StdTcpListener::bind(addr.into())?;
inner.set_nonblocking(true)?;
let executor = CURRENT.with(|cell| {
cell.borrow()
.as_ref()
.map(|rt| rt.executor.clone())
.expect("no runtime in current thread")
});
let token = next_token();
executor.selector.register(
inner.as_raw_fd(),
token,
Ready::READABLE,
)?;
Ok(TcpListener {
inner,
executor,
token,
})
}
pub async fn accept(&self) -> io::Result<(TcpStream, SocketAddr)> {
loop {
match self.inner.accept() {
Ok((stream, addr)) => {
stream.set_nonblocking(true)?;
let token = next_token();
let stream = TcpStream::new(stream, self.executor.clone(), token)?;
return Ok((stream, addr));
}
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
poll_fn(|cx| {
self.executor.register_io(self.token, cx.waker().clone());
Poll::<()>::Pending
}).await;
continue;
}
Err(e) => return Err(e),
}
}
}
}
pub struct TcpStream {
inner: StdTcpStream,
executor: Arc<Executor>,
token: usize,
}
impl TcpStream {
fn new(stream: StdTcpStream, executor: Arc<Executor>, token: usize) -> io::Result<Self> {
executor.selector.register(
stream.as_raw_fd(),
token,
Ready(Ready::READABLE.0 | Ready::WRITABLE.0),
)?;
Ok(TcpStream {
inner: stream,
executor,
token,
})
}
pub async fn connect(addr: impl Into<SocketAddr>) -> io::Result<Self> {
let stream = StdTcpStream::connect(addr.into())?;
stream.set_nonblocking(true)?;
let executor = CURRENT.with(|cell| {
cell.borrow()
.as_ref()
.map(|rt| rt.executor.clone())
.expect("no runtime in current thread")
});
let token = next_token();
Self::new(stream, executor, token)
}
pub async fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
poll_fn(|cx| {
match self.inner.read(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.executor.register_io(self.token, cx.waker().clone());
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
})
.await
}
pub async fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
let mut written = 0;
while written < buf.len() {
written += poll_fn(|cx| {
match self.inner.write(&buf[written..]) {
Ok(n) => Poll::Ready(Ok(n)),
Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
self.executor.register_io(self.token, cx.waker().clone());
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
})
.await?;
}
Ok(())
}
}
impl AsyncRead for TcpStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
match self.inner.read(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
self.executor.register_io(self.token, cx.waker().clone());
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
}
impl AsyncWrite for TcpStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
match self.inner.write(buf) {
Ok(n) => Poll::Ready(Ok(n)),
Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
self.executor.register_io(self.token, cx.waker().clone());
Poll::Pending
}
Err(e) => Poll::Ready(Err(e)),
}
}
fn poll_flush(
mut self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(self.inner.flush())
}
fn poll_close(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
) -> Poll<io::Result<()>> {
Poll::Ready(Ok(()))
}
}
async fn yield_now() {
struct Yield {
yielded: bool,
}
impl Future for Yield {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.yielded {
Poll::Ready(())
} else {
self.yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
Yield { yielded: false }.await
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{SocketAddr, SocketAddrV4, Ipv4Addr};
struct RuntimeGuard {
runtime: Arc<Runtime>,
}
impl RuntimeGuard {
fn new() -> io::Result<Self> {
let runtime = Arc::new(Runtime::new()?);
CURRENT.with(|cell| {
*cell.borrow_mut() = Some(runtime.clone());
});
Ok(Self { runtime })
}
}
impl Drop for RuntimeGuard {
fn drop(&mut self) {
CURRENT.with(|cell| {
*cell.borrow_mut() = None;
});
}
}
#[test]
fn test_tcp_echo() {
let guard = RuntimeGuard::new().unwrap();
guard.runtime.block_on(async {
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
let listener = TcpListener::bind(addr).await.unwrap();
let addr = listener.inner.local_addr().unwrap();
let server = async move {
while let Ok((mut socket, _)) = listener.accept().await {
let mut buf = vec![0; 1024];
while let Ok(n) = socket.read(&mut buf).await {
if n == 0 { break; }
socket.write_all(&buf[..n]).await.unwrap();
}
}
};
guard.runtime.spawn(server);
let mut client = TcpStream::connect(addr).await.unwrap();
client.write_all(b"Hello ZSync!").await.unwrap();
let mut buf = vec![0; 1024];
let n = client.read(&mut buf).await.unwrap();
assert_eq!(&buf[..n], b"Hello ZSync!");
});
}
#[test]
fn test_tcp_large_data() {
let guard = RuntimeGuard::new().unwrap();
let runtime = guard.runtime.clone();
guard.runtime.block_on(async {
let addr = SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0));
let listener = TcpListener::bind(addr).await.unwrap();
let addr = listener.inner.local_addr().unwrap();
runtime.spawn(async move {
while let Ok((mut socket, _)) = listener.accept().await {
let mut buf = vec![0; 8192]; while let Ok(n) = socket.read(&mut buf).await {
if n == 0 { break; }
socket.write_all(&buf[..n]).await.unwrap();
}
}
});
let mut client = TcpStream::connect(addr).await.unwrap();
let data = vec![42u8; 65536];
for chunk in data.chunks(8192) {
client.write_all(chunk).await.unwrap();
}
let mut buf = vec![0; 65536];
let mut total = 0;
while total < data.len() {
let n = client.read(&mut buf[total..]).await.unwrap();
if n == 0 { break; }
total += n;
}
assert_eq!(&buf[..total], &data[..]);
});
}
}