use core::{
future::Future,
marker::PhantomData,
mem,
pin::Pin,
task::{
LocalWaker,
Poll,
},
};
use alloc::{
collections::VecDeque,
sync::Arc,
task::{
local_waker,
Wake,
},
};
use futures::{
future::{
FutureObj,
LocalFutureObj,
UnsafeFutureObj,
},
task::{
Spawn,
SpawnError,
},
};
use lock_api::{
Mutex,
RawMutex,
};
use generational_arena::{
Arena,
Index,
};
use crate::{
future_box,
sleep::*,
};
const REG_CAP: usize = 16;
const QUEUE_CAP: usize = REG_CAP / 2;
pub struct AllocExecutor<'a, R, S>
where
R: RawMutex + Send + Sync,
S: Sleep,
{
registry: Arena<Task<'a>>,
queue: QueueHandle<'a, R>,
sleeper: S,
}
impl<'a, R, S> AllocExecutor<'a, R, S>
where
R: RawMutex + Send + Sync + 'static,
S: Sleep,
{
pub fn new() -> Self {
Self::with_capacity(REG_CAP, QUEUE_CAP)
}
pub fn with_capacity(registry: usize, queue: usize) -> Self {
AllocExecutor {
registry: Arena::with_capacity(registry),
queue: new_queue(queue),
sleeper: S::default(),
}
}
pub fn spawner(&self) -> Spawner<'a, R> {
Spawner::new(self.queue.clone())
}
pub fn local_spawner(&self) -> LocalSpawner<'a, R> {
LocalSpawner::new(Spawner::new(self.queue.clone()))
}
fn spawn_local(&mut self, future: LocalFutureObj<'a, ()>) {
let id = self.registry.insert(Task::new(future));
let queue_waker = Arc::new(QueueWaker::new(self.queue.clone(), id, self.sleeper.clone()));
let local_waker = queue_waker.into_local_waker();
self.registry.get_mut(id).unwrap().set_waker(local_waker);
self.queue.lock().push_back(QueueItem::Poll(id));
}
pub fn spawn_raw<F>(&mut self, future: F)
where
F: UnsafeFutureObj<'a, ()>,
{
self.spawn_local(LocalFutureObj::new(future))
}
pub fn spawn<F>(&mut self, future: F)
where
F: Future<Output = ()> + 'a,
{
self.spawn_raw(future_box::make_local(future));
}
fn poll_task(&mut self, id: Index) {
if let Some(Task { future, waker }) = self.registry.get_mut(id) {
let future = Pin::new(future);
let waker = waker
.as_ref()
.expect("waker not set, task spawned incorrectly");
match future.poll(waker) {
Poll::Ready(_) => {
self.registry.remove(id);
}
Poll::Pending => {}
}
}
}
fn dequeue(&self) -> Option<QueueItem<'a>> {
self.queue.lock().pop_front()
}
pub fn run(&mut self) {
loop {
while let Some(item) = self.dequeue() {
match item {
QueueItem::Poll(id) => {
self.poll_task(id);
}
QueueItem::Spawn(task) => {
self.spawn_local(task.into());
}
}
}
if self.registry.is_empty() {
break;
}
self.sleeper.sleep();
}
}
}
struct Task<'a> {
future: LocalFutureObj<'a, ()>,
waker: Option<LocalWaker>,
}
impl<'a> Task<'a> {
fn new(future: LocalFutureObj<'a, ()>) -> Task<'a> {
Task {
future,
waker: None,
}
}
fn set_waker(&mut self, waker: LocalWaker) {
self.waker = Some(waker);
}
}
type Queue<'a> = VecDeque<QueueItem<'a>>;
type QueueHandle<'a, R> = Arc<Mutex<R, Queue<'a>>>;
fn new_queue<'a, R>(capacity: usize) -> QueueHandle<'a, R>
where
R: RawMutex + Send + Sync,
{
Arc::new(Mutex::new(Queue::with_capacity(capacity)))
}
enum QueueItem<'a> {
Poll(Index),
Spawn(FutureObj<'a, ()>),
}
struct QueueWaker<R, S>
where
R: RawMutex + Send + Sync,
{
queue: QueueHandle<'static, R>,
id: Index,
sleeper: S,
}
impl<R, S> QueueWaker<R, S>
where
R: RawMutex + Send + Sync + 'static,
S: Sleep,
{
fn new<'a>(queue: QueueHandle<'a, R>, id: Index, sleeper: S) -> Self {
QueueWaker {
queue: unsafe { mem::transmute(queue) },
id,
sleeper,
}
}
fn into_local_waker(self: Arc<Self>) -> LocalWaker {
unsafe { local_waker(self) }
}
}
impl<R, S> Wake for QueueWaker<R, S>
where
R: RawMutex + Send + Sync,
S: Sleep,
{
fn wake(arc_self: &Arc<Self>) {
arc_self
.queue
.lock()
.push_back(QueueItem::Poll(arc_self.id));
arc_self.sleeper.wake();
}
}
#[derive(Clone)]
pub struct LocalSpawner<'a, R>(Spawner<'a, R>, PhantomData<LocalFutureObj<'a, ()>>)
where
R: RawMutex + Send + Sync;
impl<'a, R> LocalSpawner<'a, R>
where
R: RawMutex + Send + Sync,
{
fn new(spawner: Spawner<'a, R>) -> Self {
LocalSpawner(spawner, PhantomData)
}
}
impl<'a, R> LocalSpawner<'a, R>
where
R: RawMutex + Send + Sync,
{
fn spawn_local(&mut self, future: LocalFutureObj<'a, ()>) {
self.0.spawn_obj(unsafe { future.into_future_obj() })
}
pub fn spawn_raw<F>(&mut self, future: F)
where
F: UnsafeFutureObj<'a, ()>,
{
self.spawn_local(LocalFutureObj::new(future));
}
pub fn spawn<F>(&mut self, future: F)
where
F: Future<Output = ()> + 'a,
{
self.spawn_raw(future_box::make_local(future));
}
}
pub struct Spawner<'a, R>(QueueHandle<'a, R>)
where
R: RawMutex + Send + Sync;
impl<'a, R> Spawner<'a, R>
where
R: RawMutex + Send + Sync,
{
fn new(handle: QueueHandle<'a, R>) -> Self {
Spawner(handle)
}
}
impl<'a, R> Spawner<'a, R>
where
R: RawMutex + Send + Sync,
{
fn spawn_obj(&mut self, future: FutureObj<'a, ()>) {
self.0.lock().push_back(QueueItem::Spawn(future));
}
pub fn spawn_raw<F>(&mut self, future: F)
where
F: UnsafeFutureObj<'a, ()> + Send,
{
self.spawn_obj(FutureObj::new(future));
}
pub fn spawn<F>(&mut self, future: F)
where
F: Future<Output = ()> + Send + 'a,
{
self.spawn_raw(future_box::make_obj(future));
}
}
impl<'a, R> Clone for Spawner<'a, R>
where
R: RawMutex + Send + Sync,
{
fn clone(&self) -> Self {
Spawner(self.0.clone())
}
}
impl<'a, R> Spawn for Spawner<'a, R>
where
R: RawMutex + Send + Sync,
{
fn spawn_obj(&mut self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
self.spawn_obj(future);
Ok(())
}
}
impl<'a, R> From<LocalSpawner<'a, R>> for Spawner<'a, R>
where
R: RawMutex + Send + Sync,
{
fn from(other: LocalSpawner<'a, R>) -> Self {
other.0
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::sleep::Sleep;
use core::sync::atomic::{
AtomicBool,
Ordering,
ATOMIC_BOOL_INIT,
};
use futures::{
future::{
self,
FutureExt,
FutureObj,
},
task::Spawn,
};
use lock_api::GuardSend;
pub struct RawSpinlock(AtomicBool);
unsafe impl RawMutex for RawSpinlock {
const INIT: RawSpinlock = RawSpinlock(ATOMIC_BOOL_INIT);
type GuardMarker = GuardSend;
fn lock(&self) {
while !self.try_lock() {}
}
fn try_lock(&self) -> bool {
self.0.swap(true, Ordering::Acquire)
}
fn unlock(&self) {
self.0.store(false, Ordering::Release);
}
}
#[derive(Copy, Clone, Default)]
struct NopSleep;
impl Sleep for NopSleep {
fn sleep(&self) {}
fn wake(&self) {}
}
fn foo() -> impl Future<Output = i32> {
future::ready(5)
}
fn bar() -> impl Future<Output = i32> {
foo().then(|a| {
println!("{}", a);
let b = a + 1;
future::ready(b)
})
}
fn baz<S: Spawn>(mut spawner: S) -> impl Future<Output = ()> {
bar().then(move |c| {
for i in c..25 {
let spam = future::lazy(move |_| println!("{}", i));
spawner
.spawn_obj(FutureObj::new(future_box::make_obj(spam)))
.unwrap();
}
future::ready(())
})
}
#[test]
fn executor() {
let mut executor = AllocExecutor::<RawSpinlock, NopSleep>::new();
let mut spawner = executor.spawner();
let entry = future::lazy(move |_| {
for i in 0..10 {
spawner.spawn_raw(future_box::make_obj(future::lazy(move |_| {
println!("{}", i);
})));
}
});
executor.spawn(entry);
executor.spawn(baz(executor.spawner()));
executor.run();
}
}