use futures::executor::enter;
use futures::future::{FutureObj, LocalFutureObj, RemoteHandle};
use futures::prelude::*;
use futures::stream::FuturesUnordered;
use futures::task::{waker_ref, ArcWake, LocalSpawn, Spawn, SpawnError};
use std::cell::RefCell;
use std::pin::Pin;
use std::rc::{Rc, Weak};
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use std::task::{Context, Poll};
use std::thread;
use std::thread::Thread;
#[must_use = "futures do nothing unless you `.await` or poll them"]
#[derive(Debug)]
struct IndexWrapper<T> {
data: T, index: usize,
}
impl<T> IndexWrapper<T> {
pin_utils::unsafe_pinned!(data: T);
}
impl<T> Future for IndexWrapper<T>
where
T: Future,
{
type Output = IndexWrapper<T::Output>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
self.as_mut()
.data()
.as_mut()
.poll(cx)
.map(|output| IndexWrapper {
data: output,
index: self.index,
})
}
}
#[derive(Debug)]
pub struct LocalPool {
pool: FuturesUnordered<IndexWrapper<LocalFutureObj<'static, ()>>>,
incoming: Rc<Incoming>,
}
#[derive(Clone, Debug)]
pub struct LocalSpawner {
incoming: Weak<Incoming>,
}
#[derive(Debug, Default)]
struct IncomingTracking {
queue: Vec<(usize, LocalFutureObj<'static, ()>)>,
index: usize,
}
type Incoming = RefCell<IncomingTracking>;
pub(crate) struct ThreadNotify {
thread: Thread,
unparked: AtomicBool,
}
thread_local! {
static CURRENT_THREAD_NOTIFY: Arc<ThreadNotify> = Arc::new(ThreadNotify {
thread: thread::current(),
unparked: AtomicBool::new(false),
});
}
impl ArcWake for ThreadNotify {
fn wake_by_ref(arc_self: &Arc<Self>) {
let unparked = arc_self.unparked.swap(true, Ordering::Relaxed);
if !unparked {
arc_self.thread.unpark();
}
}
}
fn run_executor<T, F: FnMut(&mut Context<'_>) -> Poll<T>>(mut f: F) -> T {
let _enter = enter().expect(
"cannot execute `LocalPool` executor from within \
another executor",
);
CURRENT_THREAD_NOTIFY.with(|thread_notify| {
let waker = waker_ref(thread_notify);
let mut cx = Context::from_waker(&waker);
loop {
if let Poll::Ready(t) = f(&mut cx) {
return t;
}
let unparked = thread_notify.unparked.swap(false, Ordering::Acquire);
if !unparked {
thread::park();
thread_notify.unparked.store(false, Ordering::Release);
}
}
})
}
fn poll_executor<T, F: FnMut(&mut Context<'_>) -> T>(mut f: F) -> T {
let _enter = enter().expect(
"cannot execute `LocalPool` executor from within \
another executor",
);
CURRENT_THREAD_NOTIFY.with(|thread_notify| {
let waker = waker_ref(thread_notify);
let mut cx = Context::from_waker(&waker);
f(&mut cx)
})
}
impl LocalPool {
pub fn new() -> LocalPool {
LocalPool {
pool: FuturesUnordered::new(),
incoming: Default::default(),
}
}
pub fn spawner(&self) -> LocalSpawner {
LocalSpawner {
incoming: Rc::downgrade(&self.incoming),
}
}
pub fn run(&mut self) {
run_executor(|cx| self.poll_pool(cx))
}
pub fn run_until<F: Future>(&mut self, future: F) -> F::Output {
pin_utils::pin_mut!(future);
run_executor(|cx| {
{
let result = future.as_mut().poll(cx);
if let Poll::Ready(output) = result {
return Poll::Ready(output);
}
}
let _ = self.poll_pool(cx);
Poll::Pending
})
}
pub fn try_run_one(&mut self) -> Option<usize> {
poll_executor(|ctx| {
loop {
let ret = self.poll_pool_once(ctx);
if let Poll::Ready(Some(key)) = ret {
return Some(key);
}
if self.incoming.borrow().queue.is_empty() {
return None;
}
}
})
}
pub fn run_until_stalled(&mut self) {
poll_executor(|ctx| {
let _ = self.poll_pool(ctx);
});
}
fn poll_pool(&mut self, cx: &mut Context<'_>) -> Poll<()> {
loop {
let ret = self.poll_pool_once(cx);
if !self.incoming.borrow().queue.is_empty() {
continue;
}
match ret {
Poll::Pending => return Poll::Pending,
Poll::Ready(None) => return Poll::Ready(()),
_ => {}
}
}
}
fn poll_pool_once(&mut self, cx: &mut Context<'_>) -> Poll<Option<usize>> {
{
let mut incoming = self.incoming.borrow_mut();
for (key, task) in incoming.queue.drain(..) {
self.pool.push(IndexWrapper {
data: task,
index: key,
})
}
}
self.pool
.poll_next_unpin(cx)
.map(|poll| poll.map(|wrapper| wrapper.index))
}
}
impl Default for LocalPool {
fn default() -> Self {
Self::new()
}
}
impl Spawn for LocalSpawner {
fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
self.spawn_obj_with_id(future).map(|_| ())
}
fn status(&self) -> Result<(), SpawnError> {
if self.incoming.upgrade().is_some() {
Ok(())
} else {
Err(SpawnError::shutdown())
}
}
}
impl LocalSpawn for LocalSpawner {
fn spawn_local_obj(&self, future: LocalFutureObj<'static, ()>) -> Result<(), SpawnError> {
self.spawn_local_obj_with_id(future).map(|_| ())
}
fn status_local(&self) -> Result<(), SpawnError> {
if self.incoming.upgrade().is_some() {
Ok(())
} else {
Err(SpawnError::shutdown())
}
}
}
impl SpawnWithId for LocalSpawner {
fn spawn_obj_with_id(&self, future: FutureObj<'static, ()>) -> Result<usize, SpawnError> {
if let Some(incoming) = self.incoming.upgrade() {
let mut incoming = incoming.borrow_mut();
let id = incoming.index;
incoming.index += 1;
incoming.queue.push((id, future.into()));
Ok(id)
} else {
Err(SpawnError::shutdown())
}
}
}
impl LocalSpawnWithId for LocalSpawner {
fn spawn_local_obj_with_id(
&self,
future: LocalFutureObj<'static, ()>,
) -> Result<usize, SpawnError> {
if let Some(incoming) = self.incoming.upgrade() {
let mut incoming = incoming.borrow_mut();
let id = incoming.index;
incoming.index += 1;
incoming.queue.push((id, future));
Ok(id)
} else {
Err(SpawnError::shutdown())
}
}
}
pub trait SpawnWithId {
fn spawn_obj_with_id(&self, future: FutureObj<'static, ()>) -> Result<usize, SpawnError>;
}
pub trait LocalSpawnWithId {
fn spawn_local_obj_with_id(
&self,
future: LocalFutureObj<'static, ()>,
) -> Result<usize, SpawnError>;
}
impl<Sp: ?Sized> SpawnWithIdExt for Sp where Sp: SpawnWithId {}
impl<Sp: ?Sized> LocalSpawnWithIdExt for Sp where Sp: LocalSpawnWithId {}
pub trait SpawnWithIdExt: SpawnWithId {
fn spawn<Fut>(&self, future: Fut) -> Result<usize, SpawnError>
where
Fut: Future<Output = ()> + Send + 'static,
{
self.spawn_obj_with_id(FutureObj::new(Box::new(future)))
}
fn spawn_with_handle<Fut>(
&self,
future: Fut,
) -> Result<(usize, RemoteHandle<Fut::Output>), SpawnError>
where
Fut: Future + Send + 'static,
Fut::Output: Send,
{
let (future, handle) = future.remote_handle();
let id = self.spawn(future)?;
Ok((id, handle))
}
}
pub trait LocalSpawnWithIdExt: LocalSpawnWithId {
fn spawn_local<Fut>(&self, future: Fut) -> Result<usize, SpawnError>
where
Fut: Future<Output = ()> + 'static,
{
self.spawn_local_obj_with_id(LocalFutureObj::new(Box::new(future)))
}
fn spawn_local_with_handle<Fut>(
&self,
future: Fut,
) -> Result<(usize, RemoteHandle<Fut::Output>), SpawnError>
where
Fut: Future + 'static,
{
let (future, handle) = future.remote_handle();
let id = self.spawn_local(future)?;
Ok((id, handle))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tracking() {
let mut spawned_ids = std::collections::HashSet::new();
let mut pool = LocalPool::new();
let spawner = pool.spawner();
let (id1, handle1) = spawner
.spawn_with_handle(futures::future::ready(1i32))
.unwrap();
let (id2, handle2) = spawner
.spawn_with_handle(futures::future::ready(2u32))
.unwrap();
spawned_ids.insert(id1);
spawned_ids.insert(id2);
while !spawned_ids.is_empty() {
if let Some(completed) = pool.try_run_one() {
assert!(spawned_ids.remove(&completed))
}
}
assert_eq!(handle1.now_or_never().unwrap(), 1);
assert_eq!(handle2.now_or_never().unwrap(), 2);
}
}