#![forbid(unsafe_code)]
#![warn(missing_docs, missing_debug_implementations, rust_2018_idioms)]
use std::cell::Cell;
use std::future::Future;
use std::marker::PhantomData;
use std::panic::{RefUnwindSafe, UnwindSafe};
use std::pin::Pin;
use std::rc::Rc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex, RwLock};
use std::task::{Context, Poll, Waker};
use concurrent_queue::ConcurrentQueue;
use futures_lite::future;
type Runnable = async_task::Task<()>;
#[must_use = "tasks get canceled when dropped, use `.detach()` to run them in the background"]
#[derive(Debug)]
pub struct Task<T>(Option<async_task::JoinHandle<T, ()>>);
impl<T> Task<T> {
pub fn detach(mut self) {
self.0.take().unwrap();
}
pub async fn cancel(self) -> Option<T> {
let mut task = self;
let handle = task.0.take().unwrap();
handle.cancel();
handle.await
}
}
impl<T> Drop for Task<T> {
fn drop(&mut self) {
if let Some(handle) = &self.0 {
handle.cancel();
}
}
}
impl<T> Future for Task<T> {
type Output = T;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.0.as_mut().unwrap()).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(output) => Poll::Ready(output.expect("task has failed")),
}
}
}
#[derive(Debug)]
struct State {
queue: ConcurrentQueue<Runnable>,
shards: RwLock<Vec<Arc<ConcurrentQueue<Runnable>>>>,
notified: AtomicBool,
sleepers: Mutex<Sleepers>,
}
impl State {
fn new() -> State {
State {
queue: ConcurrentQueue::unbounded(),
shards: RwLock::new(Vec::new()),
notified: AtomicBool::new(true),
sleepers: Mutex::new(Sleepers {
count: 0,
wakers: Vec::new(),
id_gen: 0,
}),
}
}
#[inline]
fn notify(&self) {
if !self
.notified
.compare_and_swap(false, true, Ordering::SeqCst)
{
let waker = self.sleepers.lock().unwrap().notify();
if let Some(w) = waker {
w.wake();
}
}
}
}
#[derive(Debug)]
struct Sleepers {
count: usize,
wakers: Vec<(u64, Waker)>,
id_gen: u64,
}
impl Sleepers {
fn insert(&mut self, waker: &Waker) -> u64 {
let id = self.id_gen;
self.id_gen += 1;
self.count += 1;
self.wakers.push((id, waker.clone()));
id
}
fn update(&mut self, id: u64, waker: &Waker) -> bool {
for item in &mut self.wakers {
if item.0 == id {
if !item.1.will_wake(waker) {
item.1 = waker.clone();
}
return false;
}
}
self.wakers.push((id, waker.clone()));
true
}
fn remove(&mut self, id: u64) {
self.count -= 1;
for i in (0..self.wakers.len()).rev() {
if self.wakers[i].0 == id {
self.wakers.remove(i);
return;
}
}
}
fn is_notified(&self) -> bool {
self.count == 0 || self.count > self.wakers.len()
}
fn notify(&mut self) -> Option<Waker> {
if self.wakers.len() == self.count {
self.wakers.pop().map(|item| item.1)
} else {
None
}
}
}
#[derive(Debug)]
pub struct Executor {
state: once_cell::sync::OnceCell<Arc<State>>,
}
impl UnwindSafe for Executor {}
impl RefUnwindSafe for Executor {}
impl Executor {
pub const fn new() -> Executor {
Executor {
state: once_cell::sync::OnceCell::new(),
}
}
pub fn spawn<T: Send + 'static>(
&self,
future: impl Future<Output = T> + Send + 'static,
) -> Task<T> {
let (runnable, handle) = async_task::spawn(future, self.schedule(), ());
runnable.schedule();
Task(Some(handle))
}
pub async fn run<T>(&self, future: impl Future<Output = T>) -> T {
let ticker = Ticker::new(self.state());
future::race(
future,
future::poll_fn(|cx| {
for _ in 0..200 {
if !ticker.tick(cx.waker()) {
return Poll::Pending;
}
}
cx.waker().wake_by_ref();
Poll::Pending
}),
)
.await
}
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
let state = self.state().clone();
move |runnable| {
state.queue.push(runnable).unwrap();
state.notify();
}
}
fn state(&self) -> &Arc<State> {
self.state.get_or_init(|| Arc::new(State::new()))
}
}
impl Default for Executor {
fn default() -> Executor {
Executor::new()
}
}
#[derive(Debug)]
struct Ticker<'a> {
state: &'a State,
shard: Arc<ConcurrentQueue<Runnable>>,
sleeping: Cell<Option<u64>>,
ticks: Cell<usize>,
}
impl UnwindSafe for Ticker<'_> {}
impl RefUnwindSafe for Ticker<'_> {}
impl Ticker<'_> {
fn new(state: &State) -> Ticker<'_> {
let ticker = Ticker {
state,
shard: Arc::new(ConcurrentQueue::bounded(512)),
sleeping: Cell::new(None),
ticks: Cell::new(0),
};
state.shards.write().unwrap().push(ticker.shard.clone());
ticker
}
fn sleep(&self, waker: &Waker) -> bool {
let mut sleepers = self.state.sleepers.lock().unwrap();
match self.sleeping.get() {
None => self.sleeping.set(Some(sleepers.insert(waker))),
Some(id) => {
if !sleepers.update(id, waker) {
return false;
}
}
}
self.state
.notified
.swap(sleepers.is_notified(), Ordering::SeqCst);
true
}
fn wake(&self) {
if let Some(id) = self.sleeping.take() {
let mut sleepers = self.state.sleepers.lock().unwrap();
sleepers.remove(id);
self.state
.notified
.swap(sleepers.is_notified(), Ordering::SeqCst);
}
}
pub fn tick(&self, waker: &Waker) -> bool {
loop {
match self.search() {
None => {
if !self.sleep(waker) {
return false;
}
}
Some(r) => {
self.wake();
self.state.notify();
let ticks = self.ticks.get();
self.ticks.set(ticks.wrapping_add(1));
if ticks % 64 == 0 {
steal(&self.state.queue, &self.shard);
}
r.run();
return true;
}
}
}
}
fn search(&self) -> Option<Runnable> {
if let Ok(r) = self.shard.pop() {
return Some(r);
}
if let Ok(r) = self.state.queue.pop() {
steal(&self.state.queue, &self.shard);
return Some(r);
}
let shards = self.state.shards.read().unwrap();
let n = shards.len();
let start = fastrand::usize(..n);
let iter = shards.iter().chain(shards.iter()).skip(start).take(n);
let iter = iter.filter(|shard| !Arc::ptr_eq(shard, &self.shard));
for shard in iter {
steal(shard, &self.shard);
if let Ok(r) = self.shard.pop() {
return Some(r);
}
}
None
}
}
impl Drop for Ticker<'_> {
fn drop(&mut self) {
self.wake();
self.state
.shards
.write()
.unwrap()
.retain(|shard| !Arc::ptr_eq(shard, &self.shard));
while let Ok(r) = self.shard.pop() {
r.schedule();
}
self.state.notify();
}
}
fn steal<T>(src: &ConcurrentQueue<T>, dest: &ConcurrentQueue<T>) {
let mut count = (src.len() + 1) / 2;
if count > 0 {
if let Some(cap) = dest.capacity() {
count = count.min(cap - dest.len());
}
for _ in 0..count {
if let Ok(t) = src.pop() {
assert!(dest.push(t).is_ok());
} else {
break;
}
}
}
}
#[derive(Debug)]
pub struct LocalExecutor {
inner: once_cell::unsync::OnceCell<Executor>,
_marker: PhantomData<Rc<()>>,
}
impl LocalExecutor {
pub const fn new() -> LocalExecutor {
LocalExecutor {
inner: once_cell::unsync::OnceCell::new(),
_marker: PhantomData,
}
}
pub fn spawn<T: 'static>(&self, future: impl Future<Output = T> + 'static) -> Task<T> {
let (runnable, handle) = async_task::spawn_local(future, self.schedule(), ());
runnable.schedule();
Task(Some(handle))
}
pub async fn run<T>(&self, future: impl Future<Output = T>) -> T {
self.inner().run(future).await
}
fn schedule(&self) -> impl Fn(Runnable) + Send + Sync + 'static {
let state = self.inner().state().clone();
move |runnable| {
state.queue.push(runnable).unwrap();
state.notify();
}
}
fn inner(&self) -> &Executor {
self.inner.get_or_init(|| Executor::new())
}
}
impl Default for LocalExecutor {
fn default() -> LocalExecutor {
LocalExecutor::new()
}
}