use alloc::boxed::Box;
use core::{future::Future, marker::PhantomData, ptr::NonNull};
use async_task::{Runnable, Task};
use crate::{
job::{HeapJob, JobRef},
latch::{CountLatch, Latch},
thread_pool::{ThreadPool, WorkerThread},
util::CallOnDrop,
};
pub struct Scope<'scope> {
thread_pool: &'static ThreadPool,
job_completed_latch: CountLatch,
#[allow(clippy::type_complexity)]
marker: PhantomData<Box<dyn FnOnce(&Scope<'scope>) + Send + Sync + 'scope>>,
}
impl<'scope> Scope<'scope> {
pub unsafe fn new(owner: &WorkerThread) -> Scope<'scope> {
Scope {
thread_pool: owner.thread_pool(),
job_completed_latch: CountLatch::with_count(1, owner),
marker: PhantomData,
}
}
pub fn spawn<F>(&self, f: F)
where
F: FnOnce(&Scope<'scope>) + Send + 'scope,
{
self.job_completed_latch.increment();
let scope_ptr = ScopePtr(self);
let job = HeapJob::new(move || {
unsafe {
let scope = scope_ptr.as_ref();
f(scope);
Latch::set(&self.job_completed_latch);
}
});
let job_ref = unsafe { job.into_job_ref() };
self.thread_pool.inject_or_push(job_ref);
}
pub fn spawn_future<F, T>(&self, future: F) -> Task<T>
where
F: Future<Output = T> + Send + 'scope,
T: Send + 'scope,
{
self.job_completed_latch.increment();
let scope_ptr = ScopePtr(self);
let future = async move {
let _guard = CallOnDrop(move || {
unsafe {
let scope = scope_ptr.as_ref();
Latch::set(&scope.job_completed_latch);
}
});
future.await
};
let scope_ptr = ScopePtr(self);
let schedule = move |runnable: Runnable| {
let scope = unsafe { scope_ptr.as_ref() };
let job_ref = unsafe {
JobRef::new_raw(runnable.into_raw().as_ptr(), |this| {
let this = NonNull::new_unchecked(this.cast_mut());
let runnable = Runnable::<()>::from_raw(this);
runnable.run();
})
};
scope.thread_pool.inject_or_push(job_ref);
};
let (runnable, task) = unsafe { async_task::spawn_unchecked(future, schedule) };
runnable.schedule();
task
}
pub fn spawn_async<Fn, Fut, T>(&self, f: Fn) -> Task<T>
where
Fn: FnOnce(&Scope<'scope>) -> Fut + Send + 'static,
Fut: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let scope_ptr = ScopePtr(self);
let future = async move {
let scope = unsafe { scope_ptr.as_ref() };
f(scope).await
};
self.spawn_future(future)
}
pub fn complete(self, owner: &WorkerThread) {
unsafe { Latch::set(&self.job_completed_latch) };
owner.run_until(&self.job_completed_latch);
}
}
struct ScopePtr<T>(*const T);
unsafe impl<T: Sync> Send for ScopePtr<T> {}
unsafe impl<T: Sync> Sync for ScopePtr<T> {}
impl<T> ScopePtr<T> {
unsafe fn as_ref(&self) -> &T {
unsafe { &*self.0 }
}
}