use core::{
future::Future,
marker::PhantomData,
mem,
pin::Pin,
task::{
Context,
Poll,
Waker,
},
};
use alloc::{
collections::VecDeque,
sync::Arc,
};
use futures::{
future::{
FutureObj,
LocalFutureObj,
UnsafeFutureObj,
},
task::{
LocalSpawn,
Spawn,
SpawnError,
},
};
use lock_api::{
Mutex,
RawMutex,
};
use generational_arena::{
Arena,
Index,
};
use crate::{
future_box,
sleep::*,
wake::{
Wake,
WakeExt,
},
};
const REG_CAP: usize = 16;
const QUEUE_CAP: usize = REG_CAP / 2;
pub struct AllocExecutor<'a, R, S>
where
R: RawMutex,
{
registry: Arena<Task<'a>>,
queue: QueueHandle<'a, R>,
sleep_waker: S,
}
enum SpawnLoc {
Front,
Back,
}
impl<'a, R, S> Default for AllocExecutor<'a, R, S>
where
R: RawMutex,
S: Sleep + Wake + Clone + Default,
{
fn default() -> Self {
Self::new()
}
}
impl<'a, R, S> AllocExecutor<'a, R, S>
where
R: RawMutex,
S: Sleep + Wake + Clone + Default,
{
pub fn new() -> Self {
Self::with_capacity(REG_CAP, QUEUE_CAP)
}
pub fn with_capacity(registry: usize, queue: usize) -> Self {
AllocExecutor {
registry: Arena::with_capacity(registry),
queue: new_queue(queue),
sleep_waker: S::default(),
}
}
pub fn spawner(&self) -> Spawner<'a, R> {
Spawner::new(self.queue.clone())
}
pub fn local_spawner(&self) -> LocalSpawner<'a, R> {
LocalSpawner::new(Spawner::new(self.queue.clone()))
}
fn spawn_local(&mut self, future: LocalFutureObj<'a, ()>, loc: SpawnLoc) {
let id = self.registry.insert(Task::new(future));
let queue_waker = Arc::new(QueueWaker::new(
self.queue.clone(),
id,
self.sleep_waker.clone(),
));
let waker = queue_waker.into_waker();
self.registry.get_mut(id).unwrap().set_waker(waker);
let item = QueueItem::Poll(id);
let mut lock = self.queue.lock();
match loc {
SpawnLoc::Front => lock.push_front(item),
SpawnLoc::Back => lock.push_back(item),
}
}
pub fn spawn_raw<F>(&mut self, future: F)
where
F: UnsafeFutureObj<'a, ()>,
{
self.spawn_local(LocalFutureObj::new(future), SpawnLoc::Back)
}
pub fn spawn<F>(&mut self, future: F)
where
F: Future<Output = ()> + 'a,
{
self.spawn_raw(future_box::make_local(future));
}
fn poll_task(&mut self, id: Index) {
if let Some(Task { future, waker }) = self.registry.get_mut(id) {
let future = Pin::new(future);
let waker = waker
.as_ref()
.expect("waker not set, task spawned incorrectly");
match future.poll(&mut Context::from_waker(waker)) {
Poll::Ready(_) => {
self.registry.remove(id);
}
Poll::Pending => {}
}
}
}
fn dequeue(&self) -> Option<QueueItem<'a>> {
self.queue.lock().pop_front()
}
pub fn run(&mut self) {
'outer: loop {
while let Some(item) = self.dequeue() {
match item {
QueueItem::Poll(id) => {
self.poll_task(id);
}
QueueItem::Spawn(task) => {
self.spawn_local(task.into(), SpawnLoc::Front);
}
}
if self.registry.is_empty() {
break 'outer;
}
self.sleep_waker.sleep();
}
}
}
}
struct Task<'a> {
future: LocalFutureObj<'a, ()>,
waker: Option<Waker>,
}
impl<'a> Task<'a> {
fn new(future: LocalFutureObj<'a, ()>) -> Task<'a> {
Task {
future,
waker: None,
}
}
fn set_waker(&mut self, waker: Waker) {
self.waker = Some(waker);
}
}
type Queue<'a> = VecDeque<QueueItem<'a>>;
type QueueHandle<'a, R> = Arc<Mutex<R, Queue<'a>>>;
fn new_queue<'a, R>(capacity: usize) -> QueueHandle<'a, R>
where
R: RawMutex,
{
Arc::new(Mutex::new(Queue::with_capacity(capacity)))
}
enum QueueItem<'a> {
Poll(Index),
Spawn(FutureObj<'a, ()>),
}
struct QueueWaker<R, W>
where
R: RawMutex,
{
queue: QueueHandle<'static, R>,
id: Index,
waker: W,
}
impl<R, W> QueueWaker<R, W>
where
R: RawMutex,
W: Wake,
{
fn new(queue: QueueHandle<'_, R>, id: Index, waker: W) -> Self {
QueueWaker {
queue: unsafe { mem::transmute(queue) },
id,
waker,
}
}
}
impl<R, W> Wake for QueueWaker<R, W>
where
R: RawMutex,
W: Wake,
{
fn wake(&self) {
self.queue.lock().push_back(QueueItem::Poll(self.id));
self.waker.wake();
}
}
#[derive(Clone)]
pub struct LocalSpawner<'a, R>(Spawner<'a, R>, PhantomData<LocalFutureObj<'a, ()>>)
where
R: RawMutex;
impl<'a, R> LocalSpawner<'a, R>
where
R: RawMutex,
{
fn new(spawner: Spawner<'a, R>) -> Self {
LocalSpawner(spawner, PhantomData)
}
}
impl<'a, R> LocalSpawner<'a, R>
where
R: RawMutex,
{
fn spawn_local(&self, future: LocalFutureObj<'a, ()>) -> Result<(), SpawnError> {
Ok(self
.0
.spawn_obj(unsafe { mem::transmute(future.into_future_obj()) }))
}
pub fn spawn_raw<F>(&mut self, future: F) -> Result<(), SpawnError>
where
F: UnsafeFutureObj<'a, ()>,
{
self.spawn_local(LocalFutureObj::new(future))
}
pub fn spawn<F>(&mut self, future: F) -> Result<(), SpawnError>
where
F: Future<Output = ()> + 'a,
{
self.spawn_raw(future_box::make_local(future))
}
}
impl<'a, R> LocalSpawn for LocalSpawner<'a, R>
where
R: RawMutex,
{
fn spawn_local_obj(&self, future: LocalFutureObj<'a, ()>) -> Result<(), SpawnError> {
self.spawn_local(future)
}
}
pub struct Spawner<'a, R>(QueueHandle<'a, R>)
where
R: RawMutex;
impl<'a, R> Spawner<'a, R>
where
R: RawMutex,
{
fn new(handle: QueueHandle<'a, R>) -> Self {
Spawner(handle)
}
fn spawn_obj(&self, future: FutureObj<'a, ()>) {
self.0.lock().push_back(QueueItem::Spawn(future));
}
pub fn spawn_raw<F>(&self, future: F)
where
F: UnsafeFutureObj<'a, ()> + Send + 'a,
{
Spawner::spawn_obj(self, FutureObj::new(future));
}
pub fn spawn<F>(&mut self, future: F)
where
F: Future<Output = ()> + Send + 'a,
{
self.spawn_raw(future_box::make_obj(future));
}
}
impl<'a, R> Clone for Spawner<'a, R>
where
R: RawMutex,
{
fn clone(&self) -> Self {
Spawner(self.0.clone())
}
}
impl<'a, R> Spawn for Spawner<'a, R>
where
R: RawMutex,
{
fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
Ok(Spawner::spawn_obj(self, future))
}
}
impl<'a, R> From<LocalSpawner<'a, R>> for Spawner<'a, R>
where
R: RawMutex,
{
fn from(other: LocalSpawner<'a, R>) -> Self {
other.0
}
}
#[cfg(test)]
mod test {
use super::*;
use crate::sleep::Sleep;
use core::sync::atomic::{
AtomicBool,
Ordering,
};
use futures::{
future::{
self,
FutureObj,
},
task::Spawn,
};
use lock_api::GuardSend;
pub struct RawSpinlock(AtomicBool);
unsafe impl RawMutex for RawSpinlock {
const INIT: RawSpinlock = RawSpinlock(AtomicBool::new(false));
type GuardMarker = GuardSend;
fn lock(&self) {
while !self.try_lock() {}
}
fn try_lock(&self) -> bool {
self.0.swap(true, Ordering::Acquire)
}
unsafe fn unlock(&self) {
self.0.store(false, Ordering::Release);
}
}
#[derive(Copy, Clone, Default)]
struct NopSleep;
impl Sleep for NopSleep {
fn sleep(&self) {}
}
impl Wake for NopSleep {
fn wake(&self) {}
}
async fn foo() -> i32 {
5
}
async fn bar() -> i32 {
let a = foo().await;
println!("{}", a);
let b = a + 1;
b
}
async fn baz<S: Spawn>(spawner: S) {
let c = bar().await;
for i in c..25 {
let spam = async move {
println!("{}", i);
};
println!("spawning!");
spawner
.spawn_obj(FutureObj::new(future_box::make_obj(spam)))
.unwrap();
}
}
#[test]
fn executor() {
let mut executor = AllocExecutor::<RawSpinlock, NopSleep>::new();
let spawner = executor.spawner();
let entry = future::lazy(move |_| {
for i in 0..10 {
spawner.spawn_raw(future_box::make_obj(future::lazy(move |_| {
println!("{}", i);
})));
}
});
executor.spawn(entry);
executor.spawn(baz(executor.spawner()));
executor.run();
}
}