#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
use std::any::Any;
use std::cell::RefCell;
use std::collections::VecDeque;
use std::fmt;
use std::io::{self, Read, Write};
use std::mem;
use std::panic;
use std::pin::Pin;
use std::slice;
use std::sync::atomic::{self, AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Condvar, Mutex, MutexGuard};
use std::task::{Context, Poll, Waker};
use std::thread;
use std::time::Duration;
use futures_channel::{mpsc, oneshot};
use futures_util::future::{self, Future};
use futures_util::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
use futures_util::sink::SinkExt;
use futures_util::stream::{self, Stream, StreamExt};
use futures_util::task::{waker_ref, ArcWake, AtomicWaker};
use futures_util::{pin_mut, ready};
use once_cell::sync::Lazy;
use parking::Parker;
use waker_fn::waker_fn;
type Task<T> = Pin<Box<dyn Future<Output = T> + Send>>;
struct Runnable {
state: AtomicUsize,
future: Mutex<Pin<Box<dyn Future<Output = ()> + Send>>>,
}
impl Runnable {
fn run(self: Arc<Runnable>) {
const WOKEN: usize = 0b01;
const RUNNING: usize = 0b10;
impl ArcWake for Runnable {
fn wake_by_ref(runnable: &Arc<Self>) {
if runnable.state.fetch_or(WOKEN, Ordering::SeqCst) == 0 {
EXECUTOR.schedule(runnable.clone());
}
}
}
self.state.store(RUNNING, Ordering::SeqCst);
let waker = waker_ref(&self);
let cx = &mut Context::from_waker(&waker);
let poll = self.future.try_lock().unwrap().as_mut().poll(cx);
if poll.is_pending() {
if self.state.fetch_and(!RUNNING, Ordering::SeqCst) == WOKEN | RUNNING {
EXECUTOR.schedule(self);
}
}
}
}
static EXECUTOR: Lazy<Executor> = Lazy::new(|| Executor {
inner: Mutex::new(Inner {
idle_count: 0,
thread_count: 0,
queue: VecDeque::new(),
}),
cvar: Condvar::new(),
});
struct Executor {
inner: Mutex<Inner>,
cvar: Condvar,
}
struct Inner {
idle_count: usize,
thread_count: usize,
queue: VecDeque<Arc<Runnable>>,
}
impl Executor {
fn spawn<T: Send + 'static>(future: impl Future<Output = T> + Send + 'static) -> Task<T> {
let (s, r) = oneshot::channel();
let future = async move {
let _ = s.send(future.await);
};
let runnable = Arc::new(Runnable {
state: AtomicUsize::new(0),
future: Mutex::new(Box::pin(future)),
});
EXECUTOR.schedule(runnable);
Box::pin(async { r.await.expect("future has panicked") })
}
fn main_loop(&'static self) {
let mut inner = self.inner.lock().unwrap();
loop {
inner.idle_count -= 1;
while let Some(runnable) = inner.queue.pop_front() {
self.grow_pool(inner);
let _ = panic::catch_unwind(|| runnable.run());
inner = self.inner.lock().unwrap();
}
inner.idle_count += 1;
let timeout = Duration::from_millis(500);
let (lock, res) = self.cvar.wait_timeout(inner, timeout).unwrap();
inner = lock;
if res.timed_out() && inner.queue.is_empty() {
inner.idle_count -= 1;
inner.thread_count -= 1;
break;
}
}
}
fn schedule(&'static self, runnable: Arc<Runnable>) {
let mut inner = self.inner.lock().unwrap();
inner.queue.push_back(runnable);
self.cvar.notify_one();
self.grow_pool(inner);
}
fn grow_pool(&'static self, mut inner: MutexGuard<'static, Inner>) {
while inner.queue.len() > inner.idle_count * 5 && inner.thread_count < 500 {
inner.idle_count += 1;
inner.thread_count += 1;
self.cvar.notify_all();
static ID: AtomicUsize = AtomicUsize::new(1);
let id = ID.fetch_add(1, Ordering::Relaxed);
thread::Builder::new()
.name(format!("blocking-{}", id))
.spawn(move || self.main_loop())
.unwrap();
}
}
}
pub fn block_on<T>(future: impl Future<Output = T>) -> T {
pin_mut!(future);
let cx = &mut Context::from_waker(futures_util::task::noop_waker_ref());
if let Poll::Ready(output) = future.as_mut().poll(cx) {
return output;
}
fn parker_and_waker() -> (Parker, Waker) {
let parker = Parker::new();
let unparker = parker.unparker();
let waker = waker_fn(move || unparker.unpark());
(parker, waker)
}
thread_local! {
static CACHE: RefCell<(Parker, Waker)> = RefCell::new(parker_and_waker());
}
CACHE.with(|cache| {
match cache.try_borrow_mut() {
Ok(cache) => {
let (parker, waker) = &*cache;
let cx = &mut Context::from_waker(&waker);
loop {
match future.as_mut().poll(cx) {
Poll::Ready(output) => return output,
Poll::Pending => parker.park(),
}
}
}
Err(_) => {
let (parker, waker) = parker_and_waker();
let cx = &mut Context::from_waker(&waker);
loop {
match future.as_mut().poll(cx) {
Poll::Ready(output) => return output,
Poll::Pending => parker.park(),
}
}
}
}
})
}
#[macro_export]
macro_rules! block_on {
($($code:tt)*) => {
$crate::block_on(async move { $($code)* })
};
}
#[derive(Debug)]
pub struct BlockOn<T>(T);
impl<T> BlockOn<T> {
pub fn new(io: T) -> BlockOn<T> {
BlockOn(io)
}
pub fn get_ref(&self) -> &T {
&self.0
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.0
}
pub fn into_inner(self) -> T {
self.0
}
}
impl<T: Stream + Unpin> Iterator for BlockOn<T> {
type Item = T::Item;
fn next(&mut self) -> Option<Self::Item> {
block_on(self.0.next())
}
}
impl<T: AsyncRead + Unpin> Read for BlockOn<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
block_on(self.0.read(buf))
}
}
impl<T: AsyncWrite + Unpin> Write for BlockOn<T> {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
block_on(self.0.write(buf))
}
fn flush(&mut self) -> io::Result<()> {
block_on(self.0.flush())
}
}
pub async fn unblock<T, F>(f: F) -> T
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let (sender, receiver) = oneshot::channel();
let task = Executor::spawn(async move {
let _ = sender.send(f());
});
task.await;
receiver.await.expect("future has panicked")
}
#[macro_export]
macro_rules! unblock {
($($code:tt)*) => {
$crate::unblock(move || { $($code)* }).await
};
}
pub struct Unblock<T>(State<T>);
impl<T> Unblock<T> {
pub fn new(io: T) -> Unblock<T> {
Unblock(State::Idle(Some(Box::new(io))))
}
pub async fn get_mut(&mut self) -> &mut T {
let _ = future::poll_fn(|cx| self.poll_stop(cx)).await;
match &mut self.0 {
State::Idle(t) => t.as_mut().expect("inner value was taken out"),
State::WithMut(..) | State::Streaming(..) | State::Reading(..) | State::Writing(..) => {
unreachable!("when stopped, the state machine must be in idle state");
}
}
}
pub async fn with_mut<R, F>(&mut self, op: F) -> R
where
F: FnOnce(&mut T) -> R + Send + 'static,
R: Send + 'static,
T: Send + 'static,
{
let _ = future::poll_fn(|cx| self.poll_stop(cx)).await;
let mut t = match &mut self.0 {
State::Idle(t) => t.take().expect("inner value was taken out"),
State::WithMut(..) | State::Streaming(..) | State::Reading(..) | State::Writing(..) => {
unreachable!("when stopped, the state machine must be in idle state");
}
};
let (sender, receiver) = oneshot::channel();
let task = Executor::spawn(async move {
let _ = sender.send(op(&mut t));
t
});
self.0 = State::WithMut(task);
receiver.await.expect("`with_mut()` operation has panicked")
}
pub async fn into_inner(self) -> T {
let mut this = self;
let _ = future::poll_fn(|cx| this.poll_stop(cx)).await;
match &mut this.0 {
State::Idle(t) => *t.take().expect("inner value was taken out"),
State::WithMut(..) | State::Streaming(..) | State::Reading(..) | State::Writing(..) => {
unreachable!("when stopped, the state machine must be in idle state");
}
}
}
fn poll_stop(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
match &mut self.0 {
State::Idle(_) => return Poll::Ready(Ok(())),
State::WithMut(task) => {
let t = ready!(Pin::new(task).poll(cx));
self.0 = State::Idle(Some(t));
}
State::Streaming(any, task) => {
any.take();
let iter = ready!(Pin::new(task).poll(cx));
self.0 = State::Idle(Some(iter));
}
State::Reading(reader, task) => {
reader.take();
let (res, io) = ready!(Pin::new(task).poll(cx));
self.0 = State::Idle(Some(io));
res?;
}
State::Writing(writer, task) => {
writer.take();
let (res, io) = ready!(Pin::new(task).poll(cx));
self.0 = State::Idle(Some(io));
res?;
}
}
}
}
}
impl<T: fmt::Debug> fmt::Debug for Unblock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Closed;
impl fmt::Debug for Closed {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<closed>")
}
}
struct Blocked;
impl fmt::Debug for Blocked {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<blocked>")
}
}
match &self.0 {
State::Idle(None) => f.debug_struct("Unblock").field("io", &Closed).finish(),
State::Idle(Some(io)) => {
let io: &T = &*io;
f.debug_struct("Unblock").field("io", io).finish()
}
State::WithMut(..) | State::Streaming(..) | State::Reading(..) | State::Writing(..) => {
f.debug_struct("Unblock").field("io", &Blocked).finish()
}
}
}
}
enum State<T> {
Idle(Option<Box<T>>),
WithMut(Task<Box<T>>),
Streaming(Option<Box<dyn Any + Send>>, Task<Box<T>>),
Reading(Option<Reader>, Task<(io::Result<()>, Box<T>)>),
Writing(Option<Writer>, Task<(io::Result<()>, Box<T>)>),
}
impl<T: Iterator + Send + 'static> Stream for Unblock<T>
where
T::Item: Send + 'static,
{
type Item = T::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
loop {
match &mut self.0 {
State::WithMut(..)
| State::Streaming(None, _)
| State::Reading(..)
| State::Writing(..) => {
let _ = ready!(self.poll_stop(cx));
}
State::Idle(iter) => {
let mut iter = iter.take().expect("inner iterator was taken out");
let (mut sender, receiver) = mpsc::channel(8 * 1024);
let task = Executor::spawn(async move {
for item in &mut iter {
if sender.send(item).await.is_err() {
break;
}
}
iter
});
self.0 = State::Streaming(Some(Box::new(receiver.fuse())), task);
}
State::Streaming(Some(any), task) => {
let receiver = any
.downcast_mut::<stream::Fuse<mpsc::Receiver<T::Item>>>()
.unwrap();
let opt = ready!(Pin::new(receiver).poll_next(cx));
if opt.is_none() {
let iter = ready!(Pin::new(task).poll(cx));
self.0 = State::Idle(Some(iter));
}
return Poll::Ready(opt);
}
}
}
}
}
impl<T: Read + Send + 'static> AsyncRead for Unblock<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
loop {
match &mut self.0 {
State::WithMut(..)
| State::Reading(None, _)
| State::Streaming(..)
| State::Writing(..) => {
ready!(self.poll_stop(cx))?;
}
State::Idle(io) => {
let mut io = io.take().expect("inner value was taken out");
let (reader, mut writer) = pipe(8 * 1024 * 1024);
let task = Executor::spawn(async move {
loop {
match future::poll_fn(|cx| writer.fill(cx, &mut io)).await {
Ok(0) => return (Ok(()), io),
Ok(_) => {}
Err(err) => return (Err(err), io),
}
}
});
self.0 = State::Reading(Some(reader), task);
}
State::Reading(Some(reader), task) => {
let n = ready!(reader.drain(cx, buf))?;
if n == 0 {
let (res, io) = ready!(Pin::new(task).poll(cx));
self.0 = State::Idle(Some(io));
res?;
}
return Poll::Ready(Ok(n));
}
}
}
}
}
impl<T: Write + Send + 'static> AsyncWrite for Unblock<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
match &mut self.0 {
State::WithMut(..)
| State::Writing(None, _)
| State::Streaming(..)
| State::Reading(..) => {
ready!(self.poll_stop(cx))?;
}
State::Idle(io) => {
let mut io = io.take().expect("inner value was taken out");
let (mut reader, writer) = pipe(8 * 1024 * 1024);
let task = Executor::spawn(async move {
loop {
match future::poll_fn(|cx| reader.drain(cx, &mut io)).await {
Ok(0) => return (io.flush(), io),
Ok(_) => {}
Err(err) => {
let _ = io.flush();
return (Err(err), io);
}
}
}
});
self.0 = State::Writing(Some(writer), task);
}
State::Writing(Some(writer), _) => return writer.fill(cx, buf),
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
match &mut self.0 {
State::WithMut(..)
| State::Streaming(..)
| State::Writing(..)
| State::Reading(..) => {
ready!(self.poll_stop(cx))?;
}
State::Idle(_) => return Poll::Ready(Ok(())),
}
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(Pin::new(&mut self).poll_flush(cx))?;
self.0 = State::Idle(None);
Poll::Ready(Ok(()))
}
}
fn pipe(cap: usize) -> (Reader, Writer) {
assert!(cap > 0, "capacity must be positive");
assert!(cap.checked_mul(2).is_some(), "capacity is too large");
let mut v = Vec::with_capacity(cap);
let buffer = v.as_mut_ptr();
mem::forget(v);
let inner = Arc::new(Pipe {
head: AtomicUsize::new(0),
tail: AtomicUsize::new(0),
reader: AtomicWaker::new(),
writer: AtomicWaker::new(),
closed: AtomicBool::new(false),
buffer,
cap,
});
let r = Reader {
inner: inner.clone(),
head: 0,
tail: 0,
};
let w = Writer {
inner,
head: 0,
tail: 0,
zeroed_until: 0,
};
(r, w)
}
struct Reader {
inner: Arc<Pipe>,
head: usize,
tail: usize,
}
struct Writer {
inner: Arc<Pipe>,
head: usize,
tail: usize,
zeroed_until: usize,
}
unsafe impl Send for Reader {}
unsafe impl Send for Writer {}
struct Pipe {
head: AtomicUsize,
tail: AtomicUsize,
reader: AtomicWaker,
writer: AtomicWaker,
closed: AtomicBool,
buffer: *mut u8,
cap: usize,
}
impl Drop for Pipe {
fn drop(&mut self) {
unsafe {
Vec::from_raw_parts(self.buffer, 0, self.cap);
}
}
}
impl Drop for Reader {
fn drop(&mut self) {
self.inner.closed.store(true, Ordering::SeqCst);
self.inner.writer.wake();
}
}
impl Drop for Writer {
fn drop(&mut self) {
self.inner.closed.store(true, Ordering::SeqCst);
self.inner.reader.wake();
}
}
impl Reader {
fn drain(&mut self, cx: &mut Context<'_>, mut dest: impl Write) -> Poll<io::Result<usize>> {
let cap = self.inner.cap;
let distance = |a: usize, b: usize| {
if a <= b {
b - a
} else {
2 * cap - (a - b)
}
};
if distance(self.head, self.tail) == 0 {
self.tail = self.inner.tail.load(Ordering::Acquire);
if distance(self.head, self.tail) == 0 {
self.inner.reader.register(cx.waker());
atomic::fence(Ordering::SeqCst);
self.tail = self.inner.tail.load(Ordering::Acquire);
if distance(self.head, self.tail) == 0 {
if self.inner.closed.load(Ordering::Relaxed) {
return Poll::Ready(Ok(0));
} else {
return Poll::Pending;
}
}
}
}
self.inner.reader.take();
let real_index = |i: usize| {
if i < cap {
i
} else {
i - cap
}
};
let mut count = 0;
loop {
let n = (128 * 1024) .min(distance(self.head, self.tail)) .min(cap - real_index(self.head));
let pipe_slice =
unsafe { slice::from_raw_parts(self.inner.buffer.add(real_index(self.head)), n) };
let n = dest.write(pipe_slice)?;
count += n;
if n == 0 {
return Poll::Ready(Ok(count));
}
if self.head + n < 2 * cap {
self.head += n;
} else {
self.head = 0;
}
self.inner.head.store(self.head, Ordering::Release);
self.inner.writer.wake();
}
}
}
impl Writer {
fn fill(&mut self, cx: &mut Context<'_>, mut src: impl Read) -> Poll<io::Result<usize>> {
if self.inner.closed.load(Ordering::Relaxed) {
return Poll::Ready(Ok(0));
}
let cap = self.inner.cap;
let distance = |a: usize, b: usize| {
if a <= b {
b - a
} else {
2 * cap - (a - b)
}
};
if distance(self.head, self.tail) == cap {
self.head = self.inner.head.load(Ordering::Acquire);
if distance(self.head, self.tail) == cap {
self.inner.writer.register(cx.waker());
atomic::fence(Ordering::SeqCst);
self.head = self.inner.head.load(Ordering::Acquire);
if distance(self.head, self.tail) == cap {
if self.inner.closed.load(Ordering::Relaxed) {
return Poll::Ready(Ok(0));
} else {
return Poll::Pending;
}
}
}
}
self.inner.writer.take();
let real_index = |i: usize| {
if i < cap {
i
} else {
i - cap
}
};
let mut count = 0;
loop {
let n = (128 * 1024) .min(self.zeroed_until * 2 + 4096) .min(cap - distance(self.head, self.tail)) .min(cap - real_index(self.tail));
let pipe_slice_mut = unsafe {
let from = real_index(self.tail);
let to = from + n;
if self.zeroed_until < to {
self.inner
.buffer
.add(self.zeroed_until)
.write_bytes(0u8, to - self.zeroed_until);
self.zeroed_until = to;
}
slice::from_raw_parts_mut(self.inner.buffer.add(from), n)
};
let n = src.read(pipe_slice_mut)?;
count += n;
if n == 0 {
return Poll::Ready(Ok(count));
}
if self.tail + n < 2 * cap {
self.tail += n;
} else {
self.tail = 0;
}
self.inner.tail.store(self.tail, Ordering::Release);
self.inner.reader.wake();
}
}
}