use std::{
cell::UnsafeCell,
marker::PhantomData,
mem,
mem::{ManuallyDrop, MaybeUninit},
ops, ptr,
sync::{
atomic::{AtomicPtr, AtomicUsize, Ordering},
Arc,
},
};
use parking_lot::{Mutex, RwLock, RwLockUpgradableReadGuard};
use thiserror::Error;
use crate::{
executor::{AsyncExecutor, Blocking, Executor, Nonblock},
prelude::*,
AsyncHandler,
};
pub trait SchedulerCore<J> {
fn create_node(&self, payload: J, dependencies: usize) -> NodeBuilder<J> {
self.create_node_or_run(payload, dependencies).unwrap()
}
fn create_node_or_run(&self, payload: J, dependencies: usize) -> Option<NodeBuilder<J>>;
fn push_with_dependents(&self, payload: J, dependents: OptRcDependents<J>);
fn push_dependency(
&self,
payload: J,
dependents: impl IntoIterator<Item = Edge<J>>,
) -> Arc<Dependents<J>> {
let deps = Dependents::new(dependents.into_iter().collect());
self.push_with_dependents(payload, Some(Arc::clone(&deps)));
deps
}
}
type NodePayload<T> = ManuallyDrop<UnsafeCell<MaybeUninit<T>>>;
#[derive(Debug)]
struct Node<J> {
payload: NodePayload<J>,
dependents: AtomicPtr<Dependents<J>>,
dependencies: AtomicUsize,
}
#[derive(Debug)]
#[repr(transparent)]
pub struct Edge<J> {
to: Arc<Node<J>>,
}
#[derive(Debug, Clone, Copy, Error)]
#[error("The node associated with this builder can no longer be accessed")]
pub struct NodeDispatched;
#[derive(Debug)]
pub struct NodeBuilder<J> {
node: Option<Arc<Node<J>>>,
remaining: usize,
}
#[derive(Debug)]
pub struct Dependents<J>(RwLock<Option<Vec<Edge<J>>>>);
#[derive(Debug)]
enum AdoptState<J> {
Orphan(Vec<Edge<J>>),
Adopted(Arc<Dependents<J>>),
Abandoned,
Completed,
Poisoned,
}
#[derive(Debug, Clone, Copy, Error)]
#[error("Adoptable dependents have already been adopted or abandoned")]
pub struct BadAdoptState;
#[derive(Debug)]
pub struct AdoptableDependents<J>(AdoptState<J>);
#[derive(Debug)]
#[repr(transparent)]
pub struct RcAdoptableDependents<J>(Arc<Mutex<AdoptableDependents<J>>>);
type OptRcDependents<J> = Option<Arc<Dependents<J>>>;
#[derive(Debug)]
pub struct Job<J> {
payload: J,
dependents: OptRcDependents<J>,
}
#[derive(Debug, Clone, Copy)]
pub struct Handle<H>(H);
#[derive(Debug)]
pub struct Scheduler<J, E> {
executor: E,
_m: PhantomData<fn(J)>,
}
unsafe impl<J> Sync for Node<J> {}
impl<J> Node<J> {
fn decrement<H: SchedulerCore<J>>(&self, handle: &H) {
match self.dependencies.fetch_sub(1, Ordering::SeqCst) {
1 => {
let job = {
let mut taken = MaybeUninit::zeroed();
unsafe {
ptr::swap(self.payload.get(), &mut taken);
taken.assume_init()
}
};
let dependents = {
let ptr = self.dependents.swap(ptr::null_mut(), Ordering::SeqCst);
if ptr.is_null() {
None
} else {
Some(unsafe { Arc::from_raw(ptr) })
}
};
handle.push_with_dependents(job, dependents);
},
0 | usize::MAX => unreachable!(),
_ => (),
}
}
fn set_dependents(&self, dependents: Arc<Dependents<J>>) -> Result<(), Arc<Dependents<J>>> {
let ptr = Arc::into_raw(dependents);
self.dependents
.compare_exchange(
ptr::null_mut(),
ptr.cast_mut(),
Ordering::SeqCst,
Ordering::Relaxed,
)
.map(|_| ())
.map_err(|_| unsafe { Arc::from_raw(ptr) })
}
}
impl<J> Drop for Node<J> {
fn drop(&mut self) {
match mem::replace(self.dependencies.get_mut(), 0) {
0 => (),
usize::MAX => unreachable!(),
_ => unsafe {
mem::drop(
ManuallyDrop::take(&mut self.payload)
.into_inner()
.assume_init(),
);
},
}
}
}
impl<J> Edge<J> {
fn new(to: Arc<Node<J>>) -> Self { Self { to } }
}
impl<J> NodeBuilder<J> {
fn create_or_run(payload: J, dependencies: usize, run: impl FnOnce(J)) -> Option<Self> {
match dependencies {
0 => {
run(payload);
None
},
usize::MAX => panic!("Invalid number of dependencies! (usize::MAX is reserved)"),
_ => {
let node = Arc::new(Node {
payload: ManuallyDrop::new(UnsafeCell::new(MaybeUninit::new(payload))),
dependents: AtomicPtr::new(ptr::null_mut()),
dependencies: AtomicUsize::new(dependencies),
});
Some(NodeBuilder {
node: Some(node),
remaining: dependencies,
})
},
}
}
#[inline]
pub fn get_in_edge(&mut self) -> Edge<J> { self.try_get_in_edge().unwrap() }
pub fn try_get_in_edge(&mut self) -> Option<Edge<J>> {
if (self.remaining == 0) != self.node.is_none() {
unreachable!();
}
let node = match self.remaining {
0 => None,
1 => {
self.remaining = 0;
self.node.take()
},
_ => {
self.remaining -= 1;
self.node.clone()
},
};
node.map(Edge::new)
}
pub fn set_dependents(
&mut self, dependents: Arc<Dependents<J>>,
) -> Result<(), Arc<Dependents<J>>> {
let Some(node) = self.node.as_ref() else {
return Err(dependents);
};
debug_assert!(self.remaining > 0);
debug_assert!(node.dependencies.load(Ordering::SeqCst) >= self.remaining);
node.set_dependents(dependents)?;
Ok(())
}
}
impl<J> Drop for NodeBuilder<J> {
fn drop(&mut self) {
assert!(
self.remaining == 0 || self.node.is_none(),
"Failed to exhaust dependency bag!"
);
}
}
impl<J> Dependents<J> {
#[must_use]
pub fn new(dependents: Vec<Edge<J>>) -> Arc<Self> {
Arc::new(Self(RwLock::new(Some(dependents))))
}
pub fn push<H: SchedulerCore<J>>(&self, handle: &H, dependent: Edge<J>) {
let this = self.0.upgradable_read();
if this.is_some() {
let mut this = RwLockUpgradableReadGuard::upgrade(this);
let this = this.as_mut().unwrap_or_else(|| unreachable!());
this.push(dependent);
} else {
dependent.to.decrement(handle);
}
}
}
impl<J> From<Edge<J>> for Arc<Dependents<J>> {
#[inline]
fn from(edge: Edge<J>) -> Self { Dependents::new(vec![edge]) }
}
impl<J> std::iter::FromIterator<Edge<J>> for Arc<Dependents<J>> {
#[inline]
fn from_iter<I: IntoIterator<Item = Edge<J>>>(it: I) -> Self {
Dependents::new(it.into_iter().collect())
}
}
impl<J> AdoptableDependents<J> {
#[must_use]
pub fn new() -> Self { Self(AdoptState::Orphan(vec![])) }
#[must_use]
pub fn adopted(dependents: Arc<Dependents<J>>) -> Self { Self(AdoptState::Adopted(dependents)) }
#[must_use]
pub fn abandoned() -> Self { Self(AdoptState::Abandoned) }
#[must_use]
pub fn completed() -> Self { Self(AdoptState::Completed) }
#[inline]
#[must_use]
pub fn rc(self) -> RcAdoptableDependents<J> {
RcAdoptableDependents(Arc::new(Mutex::new(self)))
}
pub fn push<H: SchedulerCore<J>>(&mut self, handle: &H, dependent: Edge<J>) {
match self.0 {
AdoptState::Orphan(ref mut deps) => {
deps.push(dependent);
},
AdoptState::Adopted(ref dependents) => dependents.push(handle, dependent),
AdoptState::Abandoned => mem::drop(dependent),
AdoptState::Completed => dependent.to.decrement(handle),
AdoptState::Poisoned => panic!("AdoptableDependents was poisoned"),
}
}
pub fn adopt<H: SchedulerCore<J>>(
&mut self,
handle: &H,
dependents: Arc<Dependents<J>>,
) -> Result<(), BadAdoptState> {
match self.0 {
AdoptState::Orphan(_) => (),
AdoptState::Adopted(_) | AdoptState::Abandoned | AdoptState::Completed => {
return Err(BadAdoptState);
},
AdoptState::Poisoned => panic!("AdoptableDependents was poisoned"),
}
if let AdoptState::Orphan(deps) = mem::replace(&mut self.0, AdoptState::Poisoned) {
for dep in deps {
dependents.push(handle, dep);
}
self.0 = AdoptState::Adopted(dependents);
Ok(())
} else {
unreachable!()
}
}
pub fn abandon(&mut self) -> Result<bool, BadAdoptState> {
match self.0 {
AdoptState::Orphan(_) => (),
AdoptState::Adopted(_) | AdoptState::Completed => return Err(BadAdoptState),
AdoptState::Abandoned => return Ok(false),
AdoptState::Poisoned => panic!("AdoptableDependencies was poisoned"),
}
if let AdoptState::Orphan(jobs) = mem::replace(&mut self.0, AdoptState::Abandoned) {
mem::drop(jobs);
Ok(true)
} else {
unreachable!();
}
}
pub fn complete<H: SchedulerCore<J>>(&mut self, handle: &H) -> Result<bool, BadAdoptState> {
match self.0 {
AdoptState::Orphan(_) => (),
AdoptState::Adopted(_) | AdoptState::Completed => return Ok(false),
AdoptState::Abandoned => return Err(BadAdoptState),
AdoptState::Poisoned => panic!("AdoptableDependents was poisoned"),
}
if let AdoptState::Orphan(edges) = mem::replace(&mut self.0, AdoptState::Completed) {
for edge in edges {
edge.to.decrement(handle);
}
Ok(true)
} else {
unreachable!();
}
}
}
impl<J> Default for AdoptableDependents<J> {
fn default() -> Self { Self::new() }
}
impl<J> ops::Deref for RcAdoptableDependents<J> {
type Target = Mutex<AdoptableDependents<J>>;
fn deref(&self) -> &Self::Target { self.0.as_ref() }
}
impl<J> Clone for RcAdoptableDependents<J> {
fn clone(&self) -> Self { Self(Arc::clone(&self.0)) }
}
impl<J> From<J> for Job<J> {
#[inline]
fn from(payload: J) -> Self {
Self {
payload,
dependents: None,
}
}
}
impl<J, H: ExecutorHandle<Job<J>>> SchedulerCore<J> for Handle<H> {
fn create_node_or_run(&self, payload: J, dependencies: usize) -> Option<NodeBuilder<J>> {
NodeBuilder::create_or_run(payload, dependencies, |j| self.0.push(j.into()))
}
#[inline]
fn push_with_dependents(&self, payload: J, dependents: OptRcDependents<J>) {
self.0.push(Job {
payload,
dependents,
});
}
}
impl<J, H: ExecutorHandle<Job<J>>> ExecutorHandle<J> for Handle<H> {
#[inline]
fn push(&self, job: J) { self.0.push(job.into()); }
}
fn process_result<J, H: ExecutorHandle<Job<J>> + Copy>(
res: Result<(), ()>,
handle: Handle<H>,
dependents: OptRcDependents<J>,
) {
#[allow(clippy::single_match)]
match res {
Ok(()) => {
if let Some(dependents) = dependents {
for dep in mem::take(&mut *dependents.0.write()).into_iter().flatten() {
dep.to.decrement(&handle);
}
}
},
Err(()) => (),
}
}
impl<J, E: ExecutorCore<Job<J>>> Scheduler<J, E> {
fn new<
B: ExecutorBuilderSync<Job<J>, Executor = E>,
F: Fn(J, Handle<E::Handle<'_>>) -> Result<(), ()> + Clone + Send + 'static,
>(
b: B,
f: F,
) -> Result<Self, B::Error> {
b.build(
move |Job {
payload,
dependents,
},
handle| {
let handle = Handle(handle);
let res = f(payload, handle);
process_result(res, handle, dependents);
},
)
.map(|executor| Self {
executor,
_m: PhantomData,
})
}
}
impl<J: Send, E: ExecutorCore<Job<J>>> Scheduler<J, E>
where for<'a> E::Handle<'a>: Send
{
fn new_async<
B: ExecutorBuilderAsync<Job<J>, Executor = E>,
F: for<'h> AsyncHandler<J, Handle<E::Handle<'h>>, Output = Result<(), ()>>
+ Clone
+ Send
+ Sync
+ 'static,
>(
b: B,
f: F,
) -> Result<Self, B::Error> {
#[derive(Clone)]
struct Handler<F>(F);
impl<
J: Send,
H: ExecutorHandle<Job<J>> + Copy + Send,
F: AsyncHandler<J, Handle<H>, Output = Result<(), ()>> + Sync,
> AsyncHandler<Job<J>, H> for Handler<F>
{
type Output = ();
async fn handle(&self, job: Job<J>, handle: H) {
let Job {
payload,
dependents,
} = job;
let handle = Handle(handle);
let res = self.0.handle(payload, handle).await;
process_result(res, handle, dependents);
}
}
b.build_async(Handler(f)).map(|executor| Self {
executor,
_m: PhantomData,
})
}
}
impl<J, E> std::ops::Deref for Scheduler<J, E> {
type Target = E;
fn deref(&self) -> &E { &self.executor }
}
pub trait ExecutorBuilderExt<J>: Sized + ExecutorBuilderCore<Job<J>> {
fn build_graph<
F: Fn(J, Handle<<Self::Executor as ExecutorCore<Job<J>>>::Handle<'_>>) -> Result<(), ()>
+ Clone
+ Send
+ 'static,
>(
self,
work: F,
) -> Result<Scheduler<J, Self::Executor>, Self::Error>
where
Self: ExecutorBuilderSync<Job<J>>,
{
Scheduler::new(self, work)
}
fn build_graph_async<
F: for<'h> AsyncHandler<
J,
Handle<<Self::Executor as ExecutorCore<Job<J>>>::Handle<'h>>,
Output = Result<(), ()>,
> + Clone
+ Send
+ Sync
+ 'static,
>(
self,
work: F,
) -> Result<Scheduler<J, Self::Executor>, Self::Error>
where
J: Send,
Self: ExecutorBuilderAsync<Job<J>>,
for<'a> <Self::Executor as ExecutorCore<Job<J>>>::Handle<'a>: Send,
{
Scheduler::new_async(self, work)
}
}
impl<J, B: ExecutorBuilderCore<Job<J>> + Sized> ExecutorBuilderExt<J> for B {}
impl<J, E: ExecutorCore<Job<J>>> ExecutorHandle<J> for Scheduler<J, E> {
#[inline]
fn push(&self, job: J) { self.executor.push(job.into()); }
}
impl<J, E: ExecutorCore<Job<J>>> ExecutorCore<J> for Scheduler<J, E> {
type Handle<'a> = Handle<E::Handle<'a>>;
}
impl<J: Send + 'static> Scheduler<J, Executor<Job<J>, Blocking>> {
#[inline]
pub fn join(self) { self.executor.join(); }
#[inline]
pub fn abort(self) { self.executor.abort(); }
}
impl<J: Send + 'static, E: AsyncExecutor> Scheduler<J, Executor<Job<J>, Nonblock<E>>> {
#[inline]
pub fn join_async(self) -> impl std::future::Future<Output = ()> + Send {
self.executor.join_async()
}
#[inline]
pub fn abort_async(self) -> impl std::future::Future<Output = ()> + Send {
self.executor.abort_async()
}
}
impl<J, E: ExecutorCore<Job<J>>> SchedulerCore<J> for Scheduler<J, E> {
fn create_node_or_run(&self, payload: J, dependencies: usize) -> Option<NodeBuilder<J>> {
NodeBuilder::create_or_run(payload, dependencies, |j| self.executor.push(j.into()))
}
#[inline]
fn push_with_dependents(&self, payload: J, dependents: OptRcDependents<J>) {
self.executor.push(Job {
payload,
dependents,
});
}
}