use std::sync::{Arc, Weak};
use takecell::TakeOwnCell;
use crate::ThreadPoolState;
pub struct OwnedTask<T: 'static + Send>(Arc<TypedTaskInner<TakeOwnCell<T>>>);
impl<T: 'static + Send> OwnedTask<T> {
#[inline(always)]
pub(crate) fn spawn(
pool: &Arc<ThreadPoolState>,
f: impl 'static + FnOnce() -> T + Send,
) -> Self {
let inner = Arc::new(TypedTaskInner {
func: TakeOwnCell::new(Box::new(|| TakeOwnCell::new(f()))),
pool: Arc::downgrade(pool),
result: spin::Once::new(),
});
pool.push_task(inner.clone());
Self(inner)
}
#[inline(always)]
pub fn cancel(self) {
if let Some(pool) = self.0.pool.upgrade() {
pool.cancel_task(&self.0);
}
}
#[inline(always)]
pub fn complete(&self) -> bool {
self.0.result.is_completed()
}
#[inline(always)]
pub fn help(&self) {
self.0.run();
}
#[inline(always)]
pub fn join(self) -> T {
self.0.run();
self.0
.result
.wait()
.take()
.expect("Failed to get result of task")
}
#[inline(always)]
pub fn try_join(self) -> Result<T, Self> {
if self.complete() {
Ok(self.join())
} else {
Err(self)
}
}
}
impl<T: 'static + Send> std::fmt::Debug for OwnedTask<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OwnedTask")
.field("complete", &self.complete())
.finish_non_exhaustive()
}
}
pub struct SharedTask<T: 'static + Send + Sync>(Arc<TypedTaskInner<T>>);
impl<T: 'static + Send + Sync> SharedTask<T> {
#[inline(always)]
pub(crate) fn spawn(
pool: &Arc<ThreadPoolState>,
f: impl 'static + FnOnce() -> T + Send,
) -> Self {
let inner = Arc::new(TypedTaskInner {
func: TakeOwnCell::new(Box::new(f)),
pool: Arc::downgrade(pool),
result: spin::Once::new(),
});
pool.push_task(inner.clone());
Self(inner)
}
#[inline(always)]
pub fn cancel(self) {
if Arc::strong_count(&self.0) < 3
&& let Some(pool) = self.0.pool.upgrade()
{
pool.cancel_task(&self.0);
}
}
#[inline(always)]
pub fn complete(&self) -> bool {
self.0.result.is_completed()
}
#[inline(always)]
pub fn help(&self) {
self.0.run();
}
#[inline(always)]
pub fn join(&self) -> &T {
self.0.run();
self.0.result.wait()
}
#[inline(always)]
pub fn try_join(&self) -> Option<&T> {
if self.complete() {
Some(self.join())
} else {
None
}
}
}
impl<T: 'static + Send + Sync> Clone for SharedTask<T> {
fn clone(&self) -> Self {
Self(self.0.clone())
}
}
impl<T: 'static + Send + Sync> std::fmt::Debug for SharedTask<T> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SharedTask")
.field("complete", &self.complete())
.finish_non_exhaustive()
}
}
pub(crate) trait TaskInner: Send {
fn run(&self) -> bool;
}
struct TypedTaskInner<T: Send + Sync> {
func: TakeOwnCell<Box<dyn FnOnce() -> T + Send>>,
pool: Weak<ThreadPoolState>,
result: spin::Once<T>,
}
impl<T: Send + Sync> TaskInner for TypedTaskInner<T> {
#[inline(always)]
fn run(&self) -> bool {
if let Some(f) = self.func.take() {
self.result.call_once(|| f());
true
} else {
false
}
}
}