use crate::park::{Park, Unpark};
use crate::task::{self, queue::MpscQueues, JoinHandle, Schedule, ScheduleSendOnly, Task};
use std::cell::Cell;
use std::fmt;
use std::future::Future;
use std::mem::ManuallyDrop;
use std::ptr;
use std::sync::Arc;
use std::task::{RawWaker, RawWakerVTable, Waker};
use std::time::Duration;
#[derive(Debug)]
pub(crate) struct BasicScheduler<P>
where
P: Park,
{
scheduler: Arc<SchedulerPriv>,
local: LocalState<P>,
}
#[derive(Debug, Clone)]
pub(crate) struct Spawner {
scheduler: Arc<SchedulerPriv>,
}
pub(super) struct SchedulerPriv {
queues: MpscQueues<Self>,
unpark: Box<dyn Unpark>,
}
unsafe impl Send for SchedulerPriv {}
unsafe impl Sync for SchedulerPriv {}
#[derive(Debug)]
struct LocalState<P> {
tick: u8,
park: P,
}
const MAX_TASKS_PER_TICK: usize = 61;
thread_local! {
static ACTIVE: Cell<*const SchedulerPriv> = Cell::new(ptr::null())
}
impl<P> BasicScheduler<P>
where
P: Park,
{
pub(crate) fn new(park: P) -> BasicScheduler<P> {
let unpark = park.unpark();
BasicScheduler {
scheduler: Arc::new(SchedulerPriv {
queues: MpscQueues::new(),
unpark: Box::new(unpark),
}),
local: LocalState { tick: 0, park },
}
}
pub(crate) fn spawner(&self) -> Spawner {
Spawner {
scheduler: self.scheduler.clone(),
}
}
pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (task, handle) = task::joinable(future);
self.scheduler.schedule(task, true);
handle
}
pub(crate) fn block_on<F>(&mut self, mut future: F) -> F::Output
where
F: Future,
{
use crate::runtime;
use std::pin::Pin;
use std::task::Context;
use std::task::Poll::Ready;
let local = &mut self.local;
let scheduler = &*self.scheduler;
struct Guard {
old: *const SchedulerPriv,
}
impl Drop for Guard {
fn drop(&mut self) {
ACTIVE.with(|cell| cell.set(self.old));
}
}
let _guard = ACTIVE.with(|cell| {
let guard = Guard { old: cell.get() };
cell.set(scheduler as *const SchedulerPriv);
guard
});
let mut _enter = runtime::enter();
let raw_waker = RawWaker::new(
scheduler as *const SchedulerPriv as *const (),
&RawWakerVTable::new(sched_clone_waker, sched_noop, sched_wake_by_ref, sched_noop),
);
let waker = ManuallyDrop::new(unsafe { Waker::from_raw(raw_waker) });
let mut cx = Context::from_waker(&waker);
let mut future = unsafe { Pin::new_unchecked(&mut future) };
loop {
if let Ready(v) = future.as_mut().poll(&mut cx) {
return v;
}
scheduler.tick(local);
unsafe {
scheduler.queues.drain_pending_drop();
}
}
}
}
impl Spawner {
pub(crate) fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let (task, handle) = task::joinable(future);
self.scheduler.schedule(task, true);
handle
}
}
impl SchedulerPriv {
fn tick(&self, local: &mut LocalState<impl Park>) {
for _ in 0..MAX_TASKS_PER_TICK {
let tick = local.tick;
local.tick = tick.wrapping_add(1);
let next = unsafe {
self.queues.next_task(tick)
};
let task = match next {
Some(task) => task,
None => {
local.park.park().ok().expect("failed to park");
return;
}
};
if let Some(task) = task.run(&mut || Some(self.into())) {
unsafe {
self.queues.push_local(task);
}
}
}
local
.park
.park_timeout(Duration::from_millis(0))
.ok()
.expect("failed to park");
}
fn schedule(&self, task: Task<Self>, spawn: bool) {
let is_current = ACTIVE.with(|cell| cell.get() == self as *const SchedulerPriv);
if is_current {
unsafe {
self.queues.push_local(task)
};
} else {
let mut lock = self.queues.remote();
lock.schedule(task, spawn);
self.unpark.unpark();
drop(lock);
}
}
}
impl Schedule for SchedulerPriv {
fn bind(&self, task: &Task<Self>) {
unsafe {
self.queues.add_task(task);
}
}
fn release(&self, task: Task<Self>) {
self.queues.release_remote(task);
}
fn release_local(&self, task: &Task<Self>) {
unsafe {
self.queues.release_local(task);
}
}
fn schedule(&self, task: Task<Self>) {
SchedulerPriv::schedule(self, task, false);
}
}
impl ScheduleSendOnly for SchedulerPriv {}
impl<P> Drop for BasicScheduler<P>
where
P: Park,
{
fn drop(&mut self) {
unsafe {
self.scheduler.queues.shutdown();
}
loop {
unsafe {
self.scheduler.queues.drain_pending_drop();
self.scheduler.queues.drain_queues();
if !self.scheduler.queues.has_tasks_remaining() {
break;
}
self.local.park.park().ok().expect("park failed");
}
}
}
}
impl fmt::Debug for SchedulerPriv {
fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
fmt.debug_struct("Scheduler")
.field("queues", &self.queues)
.finish()
}
}
unsafe fn sched_clone_waker(ptr: *const ()) -> RawWaker {
let s1 = ManuallyDrop::new(Arc::from_raw(ptr as *const SchedulerPriv));
#[allow(clippy::redundant_clone)]
let s2 = s1.clone();
RawWaker::new(
&**s2 as *const SchedulerPriv as *const (),
&RawWakerVTable::new(sched_clone_waker, sched_wake, sched_wake_by_ref, sched_drop),
)
}
unsafe fn sched_wake(ptr: *const ()) {
let scheduler = Arc::from_raw(ptr as *const SchedulerPriv);
scheduler.unpark.unpark();
}
unsafe fn sched_wake_by_ref(ptr: *const ()) {
let scheduler = ManuallyDrop::new(Arc::from_raw(ptr as *const SchedulerPriv));
scheduler.unpark.unpark();
}
unsafe fn sched_drop(ptr: *const ()) {
let _ = Arc::from_raw(ptr as *const SchedulerPriv);
}
unsafe fn sched_noop(_ptr: *const ()) {
unreachable!();
}