#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
#![forbid(unsafe_code)]
#![doc(
html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
)]
#![doc(
html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
)]
use std::any::Any;
use std::collections::VecDeque;
use std::fmt;
use std::io::{self, Read, Seek, SeekFrom, Write};
use std::num::NonZeroUsize;
use std::panic;
use std::pin::Pin;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Condvar, Mutex, MutexGuard};
use std::task::{Context, Poll};
use std::thread;
use std::time::Duration;
#[cfg(not(target_family = "wasm"))]
use std::env;
use async_channel::{bounded, Receiver};
use async_task::Runnable;
use futures_io::{AsyncRead, AsyncSeek, AsyncWrite};
use futures_lite::{
future::{self, Future},
ready,
stream::Stream,
};
use piper::{pipe, Reader, Writer};
#[doc(no_inline)]
pub use async_task::Task;
#[cfg(not(target_family = "wasm"))]
const DEFAULT_MAX_THREADS: usize = 500;
#[cfg(not(target_family = "wasm"))]
const MIN_MAX_THREADS: usize = 1;
#[cfg(not(target_family = "wasm"))]
const MAX_MAX_THREADS: usize = 10000;
#[cfg(not(target_family = "wasm"))]
const MAX_THREADS_ENV: &str = "BLOCKING_MAX_THREADS";
struct Executor {
inner: Mutex<Inner>,
cvar: Condvar,
}
struct Inner {
idle_count: usize,
thread_count: usize,
queue: VecDeque<Runnable>,
thread_limit: NonZeroUsize,
}
impl Executor {
#[cfg(not(target_family = "wasm"))]
fn max_threads() -> usize {
match env::var(MAX_THREADS_ENV) {
Ok(v) => v
.parse::<usize>()
.map(|v| v.clamp(MIN_MAX_THREADS, MAX_MAX_THREADS))
.unwrap_or(DEFAULT_MAX_THREADS),
Err(_) => DEFAULT_MAX_THREADS,
}
}
#[inline]
fn get() -> &'static Self {
#[cfg(not(target_family = "wasm"))]
{
use async_lock::OnceCell;
static EXECUTOR: OnceCell<Executor> = OnceCell::new();
return EXECUTOR.get_or_init_blocking(|| {
let thread_limit = Self::max_threads();
Executor {
inner: Mutex::new(Inner {
idle_count: 0,
thread_count: 0,
queue: VecDeque::new(),
thread_limit: NonZeroUsize::new(thread_limit).unwrap(),
}),
cvar: Condvar::new(),
}
});
}
#[cfg(target_family = "wasm")]
panic!("cannot spawn a blocking task on WASM")
}
fn spawn<T: Send + 'static>(future: impl Future<Output = T> + Send + 'static) -> Task<T> {
let (runnable, task) = async_task::Builder::new().propagate_panic(true).spawn(
move |()| future,
|r| {
let executor = Self::get();
executor.schedule(r)
},
);
runnable.schedule();
task
}
fn main_loop(&'static self) {
#[cfg(feature = "tracing")]
let _span = tracing::trace_span!("blocking::main_loop").entered();
let mut inner = self.inner.lock().unwrap();
loop {
inner.idle_count -= 1;
while let Some(runnable) = inner.queue.pop_front() {
self.grow_pool(inner);
panic::catch_unwind(|| runnable.run()).ok();
inner = self.inner.lock().unwrap();
}
inner.idle_count += 1;
let timeout = Duration::from_millis(500);
#[cfg(feature = "tracing")]
tracing::trace!(?timeout, "going to sleep");
let (lock, res) = self.cvar.wait_timeout(inner, timeout).unwrap();
inner = lock;
if res.timed_out() && inner.queue.is_empty() {
inner.idle_count -= 1;
inner.thread_count -= 1;
break;
}
#[cfg(feature = "tracing")]
tracing::trace!("notified");
}
#[cfg(feature = "tracing")]
tracing::trace!("shutting down due to lack of tasks");
}
fn schedule(&'static self, runnable: Runnable) {
let mut inner = self.inner.lock().unwrap();
inner.queue.push_back(runnable);
self.cvar.notify_one();
self.grow_pool(inner);
}
fn grow_pool(&'static self, mut inner: MutexGuard<'static, Inner>) {
#[cfg(feature = "tracing")]
let _span = tracing::trace_span!(
"grow_pool",
queue_len = inner.queue.len(),
idle_count = inner.idle_count,
thread_count = inner.thread_count,
)
.entered();
while inner.queue.len() > inner.idle_count * 5
&& inner.thread_count < inner.thread_limit.get()
{
#[cfg(feature = "tracing")]
tracing::trace!("spawning a new thread to handle blocking tasks");
inner.idle_count += 1;
inner.thread_count += 1;
self.cvar.notify_all();
static ID: AtomicUsize = AtomicUsize::new(1);
let id = ID.fetch_add(1, Ordering::Relaxed);
if let Err(_e) = thread::Builder::new()
.name(format!("blocking-{}", id))
.spawn(move || self.main_loop())
{
#[cfg(feature = "tracing")]
tracing::error!("failed to spawn a blocking thread: {}", _e);
inner.idle_count -= 1;
inner.thread_count -= 1;
inner.thread_limit = {
let new_limit = inner.thread_count;
NonZeroUsize::new(new_limit).unwrap_or_else(|| {
#[cfg(feature = "tracing")]
tracing::warn!(
"attempted to lower thread_limit to zero; setting to one instead"
);
NonZeroUsize::new(1).unwrap()
})
};
}
}
}
}
pub fn unblock<T, F>(f: F) -> Task<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
Executor::spawn(async move { f() })
}
pub struct Unblock<T> {
state: State<T>,
cap: Option<usize>,
}
impl<T> Unblock<T> {
pub fn new(io: T) -> Unblock<T> {
Unblock {
state: State::Idle(Some(Box::new(io))),
cap: None,
}
}
pub fn with_capacity(cap: usize, io: T) -> Unblock<T> {
Unblock {
state: State::Idle(Some(Box::new(io))),
cap: Some(cap),
}
}
pub async fn get_mut(&mut self) -> &mut T {
future::poll_fn(|cx| self.poll_stop(cx)).await.ok();
match &mut self.state {
State::Idle(t) => t.as_mut().expect("inner value was taken out"),
State::WithMut(..)
| State::Streaming(..)
| State::Reading(..)
| State::Writing(..)
| State::Seeking(..) => {
unreachable!("when stopped, the state machine must be in idle state");
}
}
}
pub async fn with_mut<R, F>(&mut self, op: F) -> R
where
F: FnOnce(&mut T) -> R + Send + 'static,
R: Send + 'static,
T: Send + 'static,
{
future::poll_fn(|cx| self.poll_stop(cx)).await.ok();
let mut t = match &mut self.state {
State::Idle(t) => t.take().expect("inner value was taken out"),
State::WithMut(..)
| State::Streaming(..)
| State::Reading(..)
| State::Writing(..)
| State::Seeking(..) => {
unreachable!("when stopped, the state machine must be in idle state");
}
};
let (sender, receiver) = bounded(1);
let task = Executor::spawn(async move {
sender.try_send(op(&mut t)).ok();
t
});
self.state = State::WithMut(task);
receiver
.recv()
.await
.expect("`Unblock::with_mut()` operation has panicked")
}
pub async fn into_inner(self) -> T {
let mut this = self;
future::poll_fn(|cx| this.poll_stop(cx)).await.ok();
match &mut this.state {
State::Idle(t) => *t.take().expect("inner value was taken out"),
State::WithMut(..)
| State::Streaming(..)
| State::Reading(..)
| State::Writing(..)
| State::Seeking(..) => {
unreachable!("when stopped, the state machine must be in idle state");
}
}
}
fn poll_stop(&mut self, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
match &mut self.state {
State::Idle(_) => return Poll::Ready(Ok(())),
State::WithMut(task) => {
let io = ready!(Pin::new(task).poll(cx));
self.state = State::Idle(Some(io));
}
State::Streaming(any, task) => {
any.take();
let iter = ready!(Pin::new(task).poll(cx));
self.state = State::Idle(Some(iter));
}
State::Reading(reader, task) => {
reader.take();
let (res, io) = ready!(Pin::new(task).poll(cx));
self.state = State::Idle(Some(io));
res?;
}
State::Writing(writer, task) => {
writer.take();
let (res, io) = ready!(Pin::new(task).poll(cx));
self.state = State::Idle(Some(io));
res?;
}
State::Seeking(task) => {
let (_, res, io) = ready!(Pin::new(task).poll(cx));
self.state = State::Idle(Some(io));
res?;
}
}
}
}
}
impl<T: fmt::Debug> fmt::Debug for Unblock<T> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
struct Closed;
impl fmt::Debug for Closed {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<closed>")
}
}
struct Blocked;
impl fmt::Debug for Blocked {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("<blocked>")
}
}
match &self.state {
State::Idle(None) => f.debug_struct("Unblock").field("io", &Closed).finish(),
State::Idle(Some(io)) => {
let io: &T = io;
f.debug_struct("Unblock").field("io", io).finish()
}
State::WithMut(..)
| State::Streaming(..)
| State::Reading(..)
| State::Writing(..)
| State::Seeking(..) => f.debug_struct("Unblock").field("io", &Blocked).finish(),
}
}
}
enum State<T> {
Idle(Option<Box<T>>),
WithMut(Task<Box<T>>),
Streaming(Option<Box<dyn Any + Send + Sync>>, Task<Box<T>>),
Reading(Option<Reader>, Task<(io::Result<()>, Box<T>)>),
Writing(Option<Writer>, Task<(io::Result<()>, Box<T>)>),
Seeking(Task<(SeekFrom, io::Result<u64>, Box<T>)>),
}
impl<T: Iterator + Send + 'static> Stream for Unblock<T>
where
T::Item: Send + 'static,
{
type Item = T::Item;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<T::Item>> {
loop {
match &mut self.state {
State::WithMut(..)
| State::Streaming(None, _)
| State::Reading(..)
| State::Writing(..)
| State::Seeking(..) => {
ready!(self.poll_stop(cx)).ok();
}
State::Idle(iter) => {
let mut iter = iter.take().expect("inner iterator was taken out");
let (sender, receiver) = bounded(self.cap.unwrap_or(8 * 1024));
let task = Executor::spawn(async move {
for item in &mut iter {
if sender.send(item).await.is_err() {
break;
}
}
iter
});
self.state = State::Streaming(Some(Box::new(Box::pin(receiver))), task);
}
State::Streaming(Some(any), task) => {
let receiver = any.downcast_mut::<Pin<Box<Receiver<T::Item>>>>().unwrap();
let opt = ready!(receiver.as_mut().poll_next(cx));
if opt.is_none() {
let iter = ready!(Pin::new(task).poll(cx));
self.state = State::Idle(Some(iter));
}
return Poll::Ready(opt);
}
}
}
}
}
impl<T: Read + Send + 'static> AsyncRead for Unblock<T> {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut [u8],
) -> Poll<io::Result<usize>> {
loop {
match &mut self.state {
State::WithMut(..)
| State::Reading(None, _)
| State::Streaming(..)
| State::Writing(..)
| State::Seeking(..) => {
ready!(self.poll_stop(cx))?;
}
State::Idle(io) => {
let mut io = io.take().expect("inner value was taken out");
let (reader, mut writer) = pipe(self.cap.unwrap_or(8 * 1024 * 1024));
let task = Executor::spawn(async move {
loop {
match future::poll_fn(|cx| writer.poll_fill(cx, &mut io)).await {
Ok(0) => return (Ok(()), io),
Ok(_) => {}
Err(err) => return (Err(err), io),
}
}
});
self.state = State::Reading(Some(reader), task);
}
State::Reading(Some(reader), task) => {
let n = ready!(reader.poll_drain(cx, buf))?;
if n == 0 {
let (res, io) = ready!(Pin::new(task).poll(cx));
self.state = State::Idle(Some(io));
res?;
}
return Poll::Ready(Ok(n));
}
}
}
}
}
impl<T: Write + Send + 'static> AsyncWrite for Unblock<T> {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
loop {
match &mut self.state {
State::WithMut(..)
| State::Writing(None, _)
| State::Streaming(..)
| State::Reading(..)
| State::Seeking(..) => {
ready!(self.poll_stop(cx))?;
}
State::Idle(io) => {
let mut io = io.take().expect("inner value was taken out");
let (mut reader, writer) = pipe(self.cap.unwrap_or(8 * 1024 * 1024));
let task = Executor::spawn(async move {
loop {
match future::poll_fn(|cx| reader.poll_drain(cx, &mut io)).await {
Ok(0) => return (io.flush(), io),
Ok(_) => {}
Err(err) => {
io.flush().ok();
return (Err(err), io);
}
}
}
});
self.state = State::Writing(Some(writer), task);
}
State::Writing(Some(writer), _) => return writer.poll_fill(cx, buf),
}
}
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
loop {
match &mut self.state {
State::WithMut(..)
| State::Streaming(..)
| State::Writing(..)
| State::Reading(..)
| State::Seeking(..) => {
ready!(self.poll_stop(cx))?;
}
State::Idle(_) => return Poll::Ready(Ok(())),
}
}
}
fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
ready!(Pin::new(&mut self).poll_flush(cx))?;
self.state = State::Idle(None);
Poll::Ready(Ok(()))
}
}
impl<T: Seek + Send + 'static> AsyncSeek for Unblock<T> {
fn poll_seek(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
pos: SeekFrom,
) -> Poll<io::Result<u64>> {
loop {
match &mut self.state {
State::WithMut(..)
| State::Streaming(..)
| State::Reading(..)
| State::Writing(..) => {
ready!(self.poll_stop(cx))?;
}
State::Idle(io) => {
let mut io = io.take().expect("inner value was taken out");
let task = Executor::spawn(async move {
let res = io.seek(pos);
(pos, res, io)
});
self.state = State::Seeking(task);
}
State::Seeking(task) => {
let (original_pos, res, io) = ready!(Pin::new(task).poll(cx));
self.state = State::Idle(Some(io));
let current = res?;
if original_pos == pos {
return Poll::Ready(Ok(current));
}
}
}
}
}
}
#[cfg(all(test, not(target_family = "wasm")))]
mod tests {
use super::*;
#[test]
fn test_max_threads() {
env::set_var(MAX_THREADS_ENV, "100");
assert_eq!(100, Executor::max_threads());
env::set_var(MAX_THREADS_ENV, "0");
assert_eq!(1, Executor::max_threads());
env::set_var(MAX_THREADS_ENV, "50000");
assert_eq!(10000, Executor::max_threads());
env::set_var(MAX_THREADS_ENV, "");
assert_eq!(500, Executor::max_threads());
env::set_var(MAX_THREADS_ENV, "NOTINT");
assert_eq!(500, Executor::max_threads());
}
}