llam 0.1.4

Safe, Go-style Rust bindings for the LLAM runtime
use crate::cancel::CancelToken;
use crate::error::{Error, Result};
use crate::sys;
use crate::task::{JoinError, SpawnOptions};
use std::any::Any;
use std::cell::Cell;
use std::marker::PhantomData;
use std::mem;
use std::os::raw::c_void;
use std::panic::{self, AssertUnwindSafe};
use std::ptr;
use std::sync::{Arc, Mutex};

type TaskOutput<T> = std::thread::Result<T>;

struct ScopedTaskRecord {
    task: Mutex<*mut sys::llam_task_t>,
}

unsafe impl Send for ScopedTaskRecord {}
unsafe impl Sync for ScopedTaskRecord {}

impl ScopedTaskRecord {
    fn new(task: *mut sys::llam_task_t) -> Self {
        Self {
            task: Mutex::new(task),
        }
    }

    fn join_raw(&self) -> Result<()> {
        let mut task = self.task.lock().expect("scoped LLAM task mutex poisoned");
        if task.is_null() {
            return Ok(());
        }

        let rc = unsafe { sys::llam_join(*task) };
        if rc == 0 {
            *task = ptr::null_mut();
            Ok(())
        } else {
            Err(Error::last())
        }
    }
}

struct ScopedTaskEntry<F, T> {
    f: Option<F>,
    slot: Arc<Mutex<Option<TaskOutput<T>>>>,
    _cancel: Option<CancelToken>,
}

/// A join handle tied to a structured LLAM scope.
///
/// Dropping the handle does not detach the task. The enclosing [`Scope`] or
/// [`Nursery`] will still join it before returning.
pub struct ScopedJoinHandle<'scope, T> {
    record: Arc<ScopedTaskRecord>,
    slot: Arc<Mutex<Option<TaskOutput<T>>>>,
    _scope: PhantomData<&'scope mut &'scope ()>,
}

unsafe impl<T: Send> Send for ScopedJoinHandle<'_, T> {}

impl<T> ScopedJoinHandle<'_, T> {
    pub fn join(self) -> std::result::Result<T, JoinError> {
        self.record.join_raw().map_err(JoinError::Runtime)?;
        match self
            .slot
            .lock()
            .expect("scoped task result mutex poisoned")
            .take()
        {
            Some(Ok(value)) => Ok(value),
            Some(Err(panic)) => Err(JoinError::Panic(panic)),
            None => Err(JoinError::MissingResult),
        }
    }
}

/// Structured task scope.
///
/// Tasks spawned in this scope may borrow from the caller, and any task whose
/// handle is not explicitly joined is joined automatically when the scope ends.
pub struct Scope<'scope, 'env: 'scope> {
    records: Mutex<Vec<Arc<ScopedTaskRecord>>>,
    _scope: PhantomData<&'scope mut &'scope ()>,
    _env: PhantomData<&'env mut &'env ()>,
}

impl<'scope, 'env: 'scope> Scope<'scope, 'env> {
    fn new() -> Self {
        Self {
            records: Mutex::new(Vec::new()),
            _scope: PhantomData,
            _env: PhantomData,
        }
    }

    pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
    where
        F: FnOnce() -> T + Send + 'scope,
        T: Send + 'scope,
    {
        self.try_spawn(f).expect("llam scoped task spawn failed")
    }

    pub fn try_spawn<F, T>(&'scope self, f: F) -> Result<ScopedJoinHandle<'scope, T>>
    where
        F: FnOnce() -> T + Send + 'scope,
        T: Send + 'scope,
    {
        self.try_spawn_with(SpawnOptions::new(), f)
    }

    pub fn spawn_with<F, T>(&'scope self, opts: SpawnOptions, f: F) -> ScopedJoinHandle<'scope, T>
    where
        F: FnOnce() -> T + Send + 'scope,
        T: Send + 'scope,
    {
        self.try_spawn_with(opts, f)
            .expect("llam scoped task spawn failed")
    }

    pub fn try_spawn_with<F, T>(
        &'scope self,
        opts: SpawnOptions,
        f: F,
    ) -> Result<ScopedJoinHandle<'scope, T>>
    where
        F: FnOnce() -> T + Send + 'scope,
        T: Send + 'scope,
    {
        spawn_scoped(self, opts, f)
    }

    pub fn join_all(&self) -> Result<()> {
        let records = self
            .records
            .lock()
            .expect("scoped LLAM task list mutex poisoned")
            .clone();
        for record in records {
            record.join_raw()?;
        }
        Ok(())
    }
}

/// Run a structured LLAM task scope.
///
/// This mirrors `std::thread::scope`: spawned tasks may borrow from the caller,
/// and unjoined tasks are joined before this function returns.
pub fn scope<'env, F, R>(f: F) -> R
where
    F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R,
{
    try_scope(f).expect("llam scoped task join failed")
}

/// Run a structured LLAM task scope and report automatic-join failures.
pub fn try_scope<'env, F, R>(f: F) -> std::result::Result<R, JoinError>
where
    F: for<'scope> FnOnce(&'scope Scope<'scope, 'env>) -> R,
{
    let scope = Scope::new();
    let result = panic::catch_unwind(AssertUnwindSafe(|| f(&scope)));
    let join_result = scope.join_all();
    match result {
        Ok(value) => {
            join_result.map_err(JoinError::Runtime)?;
            Ok(value)
        }
        Err(payload) => panic::resume_unwind(payload),
    }
}

/// Cancellation-aware structured task scope.
pub struct Nursery<'scope, 'env: 'scope> {
    scope: Scope<'scope, 'env>,
    token: CancelToken,
    cancel_on_error: Cell<bool>,
}

impl<'scope, 'env: 'scope> Nursery<'scope, 'env> {
    fn new(token: CancelToken) -> Self {
        Self {
            scope: Scope::new(),
            token,
            cancel_on_error: Cell::new(true),
        }
    }

    pub fn token(&self) -> CancelToken {
        self.token.clone()
    }

    pub fn cancel(&self) -> Result<()> {
        self.token.cancel()
    }

    pub fn is_cancelled(&self) -> bool {
        self.token.is_cancelled()
    }

    pub fn set_cancel_on_error(&self, enabled: bool) {
        self.cancel_on_error.set(enabled);
    }

    pub fn spawn<F, T>(&'scope self, f: F) -> ScopedJoinHandle<'scope, T>
    where
        F: FnOnce() -> T + Send + 'scope,
        T: Send + 'scope,
    {
        self.try_spawn(f).expect("llam nursery task spawn failed")
    }

    pub fn try_spawn<F, T>(&'scope self, f: F) -> Result<ScopedJoinHandle<'scope, T>>
    where
        F: FnOnce() -> T + Send + 'scope,
        T: Send + 'scope,
    {
        self.try_spawn_with(SpawnOptions::new(), f)
    }

    pub fn spawn_with<F, T>(&'scope self, opts: SpawnOptions, f: F) -> ScopedJoinHandle<'scope, T>
    where
        F: FnOnce() -> T + Send + 'scope,
        T: Send + 'scope,
    {
        self.try_spawn_with(opts, f)
            .expect("llam nursery task spawn failed")
    }

    pub fn try_spawn_with<F, T>(
        &'scope self,
        opts: SpawnOptions,
        f: F,
    ) -> Result<ScopedJoinHandle<'scope, T>>
    where
        F: FnOnce() -> T + Send + 'scope,
        T: Send + 'scope,
    {
        self.scope
            .try_spawn_with(opts.cancel(self.token.clone()), f)
    }

    pub fn join_all(&self) -> Result<()> {
        self.scope.join_all()
    }
}

/// Run a cancellation-aware structured LLAM task scope.
///
/// If the closure returns `Err` or panics, the nursery requests cancellation
/// before waiting for unjoined children.
pub fn nursery<'env, F, R>(f: F) -> Result<R>
where
    F: for<'scope> FnOnce(&'scope Nursery<'scope, 'env>) -> Result<R>,
{
    let nursery = Nursery::new(CancelToken::new()?);
    let result = panic::catch_unwind(AssertUnwindSafe(|| f(&nursery)));

    if matches!(result, Ok(Err(_))) && nursery.cancel_on_error.get() {
        let _ = nursery.cancel();
    }
    if result.is_err() {
        let _ = nursery.cancel();
    }

    let join_result = nursery.join_all();
    match result {
        Ok(Ok(value)) => {
            join_result?;
            Ok(value)
        }
        Ok(Err(error)) => Err(error),
        Err(payload) => {
            let _ = join_result;
            panic::resume_unwind(payload);
        }
    }
}

fn spawn_scoped<'scope, 'env, F, T>(
    scope: &Scope<'scope, 'env>,
    opts: SpawnOptions,
    f: F,
) -> Result<ScopedJoinHandle<'scope, T>>
where
    'env: 'scope,
    F: FnOnce() -> T + Send + 'scope,
    T: Send + 'scope,
{
    let slot = Arc::new(Mutex::new(None));
    let entry = Box::new(ScopedTaskEntry {
        f: Some(f),
        slot: Arc::clone(&slot),
        _cancel: opts.retained_cancel(),
    });
    let arg = Box::into_raw(entry) as *mut c_void;
    let task = unsafe {
        sys::llam_spawn_ex(
            scoped_trampoline::<F, T>,
            arg,
            opts.raw_for_spawn(),
            mem::size_of_val(opts.raw_for_spawn()),
        )
    };
    if task.is_null() {
        unsafe {
            drop(Box::from_raw(arg as *mut ScopedTaskEntry<F, T>));
        }
        return Err(Error::last());
    }

    let record = Arc::new(ScopedTaskRecord::new(task));
    scope
        .records
        .lock()
        .expect("scoped LLAM task list mutex poisoned")
        .push(Arc::clone(&record));

    Ok(ScopedJoinHandle {
        record,
        slot,
        _scope: PhantomData,
    })
}

unsafe extern "C" fn scoped_trampoline<F, T>(arg: *mut c_void)
where
    F: FnOnce() -> T + Send,
    T: Send,
{
    let mut entry = Box::from_raw(arg as *mut ScopedTaskEntry<F, T>);
    let f = entry.f.take().expect("LLAM scoped task closure missing");
    let result: std::result::Result<T, Box<dyn Any + Send + 'static>> =
        panic::catch_unwind(AssertUnwindSafe(f));
    *entry
        .slot
        .lock()
        .expect("scoped task result mutex poisoned") = Some(result);
}