use crate::cancel::CancelToken;
use crate::error::{Error, Result};
use crate::sys;
use crate::task::{JoinError, SpawnOptions};
use std::any::Any;
use std::cell::Cell;
use std::marker::PhantomData;
use std::mem;
use std::os::raw::c_void;
use std::panic::{self, AssertUnwindSafe};
use std::ptr;
use std::sync::{Arc, Mutex};
type TaskOutput<T> = std::thread::Result<T>;
struct ScopedTaskRecord {
task: Mutex<*mut sys::llam_task_t>,
}
unsafe impl Send for ScopedTaskRecord {}
unsafe impl Sync for ScopedTaskRecord {}
impl ScopedTaskRecord {
fn new(task: *mut sys::llam_task_t) -> Self {
Self {
task: Mutex::new(task),
}
}
fn join_raw(&self) -> Result<()> {
let mut task = self.task.lock().expect("scoped LLAM task mutex poisoned");
if task.is_null() {
return Ok(());
}
let rc = unsafe { sys::llam_join(*task) };
if rc == 0 {
*task = ptr::null_mut();
Ok(())
} else {
Err(Error::last())
}
}
}
struct ScopedTaskEntry<F, T> {
f: Option<F>,
slot: Arc<Mutex<Option<TaskOutput<T>>>>,
_cancel: Option<CancelToken>,
}
pub struct ScopedJoinHandle<'scope, T> {
record: Arc<ScopedTaskRecord>,
slot: Arc<Mutex<Option<TaskOutput<T>>>>,
_scope: PhantomData<&'scope mut &'scope ()>,
}
unsafe impl<T: Send> Send for ScopedJoinHandle<'_, T> {}
impl<T> ScopedJoinHandle<'_, T> {
pub fn join(self) -> std::result::Result<T, JoinError> {
self.record.join_raw().map_err(JoinError::Runtime)?;
match self
.slot
.lock()
.expect("scoped task result mutex poisoned")
.take()
{
Some(Ok(value)) => Ok(value),
Some(Err(panic)) => Err(JoinError::Panic(panic)),
None => Err(JoinError::MissingResult),
}
}
}
pub struct Scope<'scope, 'env: 'scope> {
records: Mutex<Vec<Arc<ScopedTaskRecord>>>,
_scope: PhantomData<&'scope mut &'scope ()>,
_env: PhantomData<&'env mut &'env ()>,
}
impl<'scope, 'env: 'scope> Scope<'scope, 'env> {
fn new() -> Self {
Self {
records: Mutex::new(Vec::new()),
_scope: PhantomData,
_env: PhantomData,
}
}
pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
self.try_spawn(f).expect("llam scoped task spawn failed")
}
pub fn try_spawn<F, T>(&'scope self, f: F) -> Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
self.try_spawn_with(SpawnOptions::new(), f)
}
pub fn spawn_with<F, T>(&'scope self, opts: SpawnOptions, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
self.try_spawn_with(opts, f)
.expect("llam scoped task spawn failed")
}
pub fn try_spawn_with<F, T>(
&'scope self,
opts: SpawnOptions,
f: F,
) -> Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
spawn_scoped(self, opts, f)
}
pub fn join_all(&self) -> Result<()> {
let records = self
.records
.lock()
.expect("scoped LLAM task list mutex poisoned")
.clone();
for record in records {
record.join_raw()?;
}
Ok(())
}
}
pub fn scope<'env, F, R>(f: F) -> R
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R,
{
try_scope(f).expect("llam scoped task join failed")
}
pub fn try_scope<'env, F, R>(f: F) -> std::result::Result<R, JoinError>
where
F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R,
{
let scope = Scope::new();
let result = panic::catch_unwind(AssertUnwindSafe(|| f(&scope)));
let join_result = scope.join_all();
match result {
Ok(value) => {
join_result.map_err(JoinError::Runtime)?;
Ok(value)
}
Err(payload) => panic::resume_unwind(payload),
}
}
pub struct Nursery<'scope, 'env: 'scope> {
scope: Scope<'scope, 'env>,
token: CancelToken,
cancel_on_error: Cell<bool>,
}
impl<'scope, 'env: 'scope> Nursery<'scope, 'env> {
fn new(token: CancelToken) -> Self {
Self {
scope: Scope::new(),
token,
cancel_on_error: Cell::new(true),
}
}
pub fn token(&self) -> CancelToken {
self.token.clone()
}
pub fn cancel(&self) -> Result<()> {
self.token.cancel()
}
pub fn is_cancelled(&self) -> bool {
self.token.is_cancelled()
}
pub fn set_cancel_on_error(&self, enabled: bool) {
self.cancel_on_error.set(enabled);
}
pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
self.try_spawn(f).expect("llam nursery task spawn failed")
}
pub fn try_spawn<F, T>(&'scope self, f: F) -> Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
self.try_spawn_with(SpawnOptions::new(), f)
}
pub fn spawn_with<F, T>(&'scope self, opts: SpawnOptions, f: F) -> ScopedJoinHandle<'scope, T>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
self.try_spawn_with(opts, f)
.expect("llam nursery task spawn failed")
}
pub fn try_spawn_with<F, T>(
&'scope self,
opts: SpawnOptions,
f: F,
) -> Result<ScopedJoinHandle<'scope, T>>
where
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
self.scope
.try_spawn_with(opts.cancel(self.token.clone()), f)
}
pub fn join_all(&self) -> Result<()> {
self.scope.join_all()
}
}
pub fn nursery<'env, F, R>(f: F) -> Result<R>
where
F: for<'scope> FnOnce(&'scope Nursery<'scope, 'env>) -> Result<R>,
{
let nursery = Nursery::new(CancelToken::new()?);
let result = panic::catch_unwind(AssertUnwindSafe(|| f(&nursery)));
if matches!(result, Ok(Err(_))) && nursery.cancel_on_error.get() {
let _ = nursery.cancel();
}
if result.is_err() {
let _ = nursery.cancel();
}
let join_result = nursery.join_all();
match result {
Ok(Ok(value)) => {
join_result?;
Ok(value)
}
Ok(Err(error)) => Err(error),
Err(payload) => {
let _ = join_result;
panic::resume_unwind(payload);
}
}
}
fn spawn_scoped<'scope, 'env, F, T>(
scope: &Scope<'scope, 'env>,
opts: SpawnOptions,
f: F,
) -> Result<ScopedJoinHandle<'scope, T>>
where
'env: 'scope,
F: FnOnce() -> T + Send + 'scope,
T: Send + 'scope,
{
let slot = Arc::new(Mutex::new(None));
let entry = Box::new(ScopedTaskEntry {
f: Some(f),
slot: Arc::clone(&slot),
_cancel: opts.retained_cancel(),
});
let arg = Box::into_raw(entry) as *mut c_void;
let task = unsafe {
sys::llam_spawn_ex(
scoped_trampoline::<F, T>,
arg,
opts.raw_for_spawn(),
mem::size_of_val(opts.raw_for_spawn()),
)
};
if task.is_null() {
unsafe {
drop(Box::from_raw(arg as *mut ScopedTaskEntry<F, T>));
}
return Err(Error::last());
}
let record = Arc::new(ScopedTaskRecord::new(task));
scope
.records
.lock()
.expect("scoped LLAM task list mutex poisoned")
.push(Arc::clone(&record));
Ok(ScopedJoinHandle {
record,
slot,
_scope: PhantomData,
})
}
unsafe extern "C" fn scoped_trampoline<F, T>(arg: *mut c_void)
where
F: FnOnce() -> T + Send,
T: Send,
{
let mut entry = Box::from_raw(arg as *mut ScopedTaskEntry<F, T>);
let f = entry.f.take().expect("LLAM scoped task closure missing");
let result: std::result::Result<T, Box<dyn Any + Send + 'static>> =
panic::catch_unwind(AssertUnwindSafe(f));
*entry
.slot
.lock()
.expect("scoped task result mutex poisoned") = Some(result);
}