use crate::error::{Error, Result};
use crate::sys;
use crate::CancelToken;
use std::any::Any;
use std::ffi::CStr;
use std::mem;
use std::os::raw::c_void;
use std::panic::{self, AssertUnwindSafe};
use std::ptr;
use std::sync::{Arc, Mutex};
use std::time::Duration;
#[derive(Clone, Copy, Debug)]
pub enum TaskClass {
Latency,
Default,
Batch,
}
impl TaskClass {
pub(crate) fn raw(self) -> u32 {
match self {
Self::Latency => sys::LLAM_TASK_CLASS_LATENCY,
Self::Default => sys::LLAM_TASK_CLASS_DEFAULT,
Self::Batch => sys::LLAM_TASK_CLASS_BATCH,
}
}
pub fn from_raw(raw: u32) -> Option<Self> {
match raw {
sys::LLAM_TASK_CLASS_LATENCY => Some(Self::Latency),
sys::LLAM_TASK_CLASS_DEFAULT => Some(Self::Default),
sys::LLAM_TASK_CLASS_BATCH => Some(Self::Batch),
_ => None,
}
}
}
#[derive(Clone, Copy, Debug)]
pub enum StackClass {
Default,
Large,
Huge,
}
impl StackClass {
pub(crate) fn raw(self) -> u32 {
match self {
Self::Default => sys::LLAM_STACK_CLASS_DEFAULT,
Self::Large => sys::LLAM_STACK_CLASS_LARGE,
Self::Huge => sys::LLAM_STACK_CLASS_HUGE,
}
}
}
#[derive(Clone)]
pub struct SpawnOptions {
raw: sys::llam_spawn_opts_t,
cancel: Option<CancelToken>,
}
impl SpawnOptions {
pub fn new() -> Self {
let mut raw = unsafe { mem::zeroed::<sys::llam_spawn_opts_t>() };
let rc = unsafe { sys::llam_spawn_opts_init(&mut raw, mem::size_of_val(&raw)) };
assert_eq!(rc, 0, "llam_spawn_opts_init failed");
Self { raw, cancel: None }
}
pub fn class(mut self, class: TaskClass) -> Self {
self.raw.task_class = class.raw();
self
}
pub fn stack(mut self, stack: StackClass) -> Self {
self.raw.stack_class = stack.raw();
self
}
pub fn deadline_after(mut self, duration: Duration) -> Self {
self.raw.deadline_ns = crate::time::deadline_after(duration);
self
}
pub fn cancel(mut self, token: CancelToken) -> Self {
self.raw.cancel_token = token.raw();
self.cancel = Some(token);
self
}
pub fn pinned(mut self, enabled: bool) -> Self {
self.set_flag(sys::LLAM_SPAWN_F_PINNED, enabled);
self
}
pub fn latency_critical(mut self, enabled: bool) -> Self {
self.set_flag(sys::LLAM_SPAWN_F_LATENCY_CRITICAL, enabled);
self
}
pub fn no_preempt(mut self, enabled: bool) -> Self {
self.set_flag(sys::LLAM_SPAWN_F_NO_PREEMPT, enabled);
self
}
pub fn system_task(mut self, enabled: bool) -> Self {
self.set_flag(sys::LLAM_SPAWN_F_SYS_TASK, enabled);
self
}
pub fn raw_options(&self) -> &sys::llam_spawn_opts_t {
&self.raw
}
fn set_flag(&mut self, flag: u32, enabled: bool) {
if enabled {
self.raw.flags |= flag;
} else {
self.raw.flags &= !flag;
}
}
}
impl Default for SpawnOptions {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub enum JoinError {
Runtime(Error),
Panic(Box<dyn Any + Send + 'static>),
MissingResult,
}
impl std::fmt::Display for JoinError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Runtime(error) => write!(f, "{error}"),
Self::Panic(_) => write!(f, "task panicked"),
Self::MissingResult => write!(f, "task completed without publishing a result"),
}
}
}
impl std::error::Error for JoinError {}
impl From<Error> for JoinError {
fn from(value: Error) -> Self {
Self::Runtime(value)
}
}
type TaskOutput<T> = std::thread::Result<T>;
struct TaskEntry<F, T> {
f: Option<F>,
slot: Arc<Mutex<Option<TaskOutput<T>>>>,
_cancel: Option<CancelToken>,
}
pub struct JoinHandle<T> {
task: *mut sys::llam_task_t,
slot: Arc<Mutex<Option<TaskOutput<T>>>>,
}
unsafe impl<T: Send> Send for JoinHandle<T> {}
impl<T> JoinHandle<T> {
pub fn join(mut self) -> std::result::Result<T, JoinError> {
let task = self.task;
self.task = ptr::null_mut();
let rc = unsafe { sys::llam_join(task) };
if rc != 0 {
return Err(JoinError::Runtime(Error::last()));
}
match self.slot.lock().expect("task result mutex poisoned").take() {
Some(Ok(value)) => Ok(value),
Some(Err(panic)) => Err(JoinError::Panic(panic)),
None => Err(JoinError::MissingResult),
}
}
pub fn join_until(&mut self, deadline_ns: u64) -> std::result::Result<Option<T>, JoinError> {
if self.task.is_null() {
return Err(JoinError::MissingResult);
}
let task = self.task;
let rc = unsafe { sys::llam_join_until(task, deadline_ns) };
if rc != 0 {
let error = Error::last();
if error.is_timed_out() {
return Ok(None);
}
return Err(JoinError::Runtime(error));
}
self.task = ptr::null_mut();
match self.slot.lock().expect("task result mutex poisoned").take() {
Some(Ok(value)) => Ok(Some(value)),
Some(Err(panic)) => Err(JoinError::Panic(panic)),
None => Err(JoinError::MissingResult),
}
}
pub fn join_timeout(&mut self, timeout: Duration) -> std::result::Result<Option<T>, JoinError> {
self.join_until(crate::time::deadline_after(timeout))
}
pub fn detach(mut self) -> Result<()> {
let task = self.task;
self.task = ptr::null_mut();
let rc = unsafe { sys::llam_detach(task) };
if rc == 0 {
Ok(())
} else {
Err(Error::last())
}
}
pub fn raw(&self) -> *mut sys::llam_task_t {
self.task
}
pub fn id(&self) -> Option<u64> {
if self.task.is_null() {
None
} else {
Some(unsafe { sys::llam_task_id(self.task) })
}
}
pub fn flags(&self) -> u32 {
if self.task.is_null() {
0
} else {
unsafe { sys::llam_task_flags(self.task) }
}
}
pub fn class(&self) -> Option<TaskClass> {
if self.task.is_null() {
None
} else {
TaskClass::from_raw(unsafe { sys::llam_task_class(self.task) })
}
}
pub fn state_name(&self) -> Option<String> {
if self.task.is_null() {
return None;
}
let ptr = unsafe { sys::llam_task_state_name(self.task) };
if ptr.is_null() {
None
} else {
Some(
unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned(),
)
}
}
}
impl<T> Drop for JoinHandle<T> {
fn drop(&mut self) {
if !self.task.is_null() {
let _ = unsafe { sys::llam_detach(self.task) };
self.task = ptr::null_mut();
}
}
}
pub fn spawn<F, T>(f: F) -> JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
try_spawn(f).expect("llam task spawn failed")
}
pub fn try_spawn<F, T>(f: F) -> Result<JoinHandle<T>>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
try_spawn_with(SpawnOptions::new(), f)
}
pub fn spawn_with<F, T>(opts: SpawnOptions, f: F) -> JoinHandle<T>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
try_spawn_with(opts, f).expect("llam task spawn failed")
}
pub fn try_spawn_with<F, T>(opts: SpawnOptions, f: F) -> Result<JoinHandle<T>>
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let slot = Arc::new(Mutex::new(None));
let entry = Box::new(TaskEntry {
f: Some(f),
slot: Arc::clone(&slot),
_cancel: opts.cancel.clone(),
});
let arg = Box::into_raw(entry) as *mut c_void;
let task = unsafe {
sys::llam_spawn_ex(
trampoline::<F, T>,
arg,
&opts.raw,
mem::size_of_val(&opts.raw),
)
};
if task.is_null() {
unsafe {
drop(Box::from_raw(arg as *mut TaskEntry<F, T>));
}
return Err(Error::last());
}
Ok(JoinHandle { task, slot })
}
unsafe extern "C" fn trampoline<F, T>(arg: *mut c_void)
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
let mut entry = Box::from_raw(arg as *mut TaskEntry<F, T>);
let f = entry.f.take().expect("LLAM task closure missing");
let result = panic::catch_unwind(AssertUnwindSafe(f));
*entry.slot.lock().expect("task result mutex poisoned") = Some(result);
}
pub fn yield_now() {
unsafe { sys::llam_yield() };
}
pub fn safepoint() {
unsafe { sys::llam_task_safepoint() };
}
pub fn set_class(class: TaskClass) -> Result<()> {
let rc = unsafe { sys::llam_task_set_class(class.raw()) };
if rc == 0 {
Ok(())
} else {
Err(Error::last())
}
}
pub fn current_class() -> Option<TaskClass> {
let task = unsafe { sys::llam_current_task() };
if task.is_null() {
None
} else {
TaskClass::from_raw(unsafe { sys::llam_task_class(task) })
}
}
pub fn current_flags() -> u32 {
let task = unsafe { sys::llam_current_task() };
if task.is_null() {
0
} else {
unsafe { sys::llam_task_flags(task) }
}
}
pub fn current_id() -> Option<u64> {
let task = unsafe { sys::llam_current_task() };
if task.is_null() {
None
} else {
Some(unsafe { sys::llam_task_id(task) })
}
}
pub fn current_state_name() -> Option<String> {
let task = unsafe { sys::llam_current_task() };
if task.is_null() {
return None;
}
let ptr = unsafe { sys::llam_task_state_name(task) };
if ptr.is_null() {
None
} else {
Some(
unsafe { CStr::from_ptr(ptr) }
.to_string_lossy()
.into_owned(),
)
}
}
struct GroupEntry<F> {
f: Option<F>,
}
pub struct TaskGroup {
raw: *mut sys::llam_task_group_t,
}
unsafe impl Send for TaskGroup {}
impl TaskGroup {
pub fn new() -> Result<Self> {
let raw = unsafe { sys::llam_task_group_create() };
if raw.is_null() {
Err(Error::last())
} else {
Ok(Self { raw })
}
}
pub fn spawn<F>(&mut self, f: F) -> Result<()>
where
F: FnOnce() + Send + 'static,
{
self.spawn_with(SpawnOptions::new(), f)
}
pub fn spawn_with<F>(&mut self, opts: SpawnOptions, f: F) -> Result<()>
where
F: FnOnce() + Send + 'static,
{
let entry = Box::new(GroupEntry { f: Some(f) });
let arg = Box::into_raw(entry) as *mut c_void;
let task = unsafe {
sys::llam_task_group_spawn_ex(
self.raw,
group_trampoline::<F>,
arg,
&opts.raw,
mem::size_of_val(&opts.raw),
)
};
if task.is_null() {
unsafe {
drop(Box::from_raw(arg as *mut GroupEntry<F>));
}
Err(Error::last())
} else {
Ok(())
}
}
pub fn cancel(&self) -> Result<()> {
let rc = unsafe { sys::llam_task_group_cancel(self.raw) };
if rc == 0 {
Ok(())
} else {
Err(Error::last())
}
}
pub fn join(&mut self) -> Result<()> {
let rc = unsafe { sys::llam_task_group_join(self.raw) };
if rc == 0 {
Ok(())
} else {
Err(Error::last())
}
}
pub fn join_timeout(&mut self, timeout: Duration) -> Result<()> {
let rc = unsafe {
sys::llam_task_group_join_until(self.raw, crate::time::deadline_after(timeout))
};
if rc == 0 {
Ok(())
} else {
Err(Error::last())
}
}
}
impl Drop for TaskGroup {
fn drop(&mut self) {
if !self.raw.is_null() {
unsafe {
let _ = sys::llam_task_group_destroy(self.raw);
}
self.raw = ptr::null_mut();
}
}
}
pub struct TaskBatch<T> {
handles: Vec<JoinHandle<T>>,
}
impl<T> TaskBatch<T>
where
T: Send + 'static,
{
pub fn new() -> Self {
Self {
handles: Vec::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
handles: Vec::with_capacity(capacity),
}
}
pub fn spawn<F>(&mut self, f: F) -> Result<()>
where
F: FnOnce() -> T + Send + 'static,
{
self.spawn_with(SpawnOptions::new(), f)
}
pub fn spawn_with<F>(&mut self, opts: SpawnOptions, f: F) -> Result<()>
where
F: FnOnce() -> T + Send + 'static,
{
self.handles.push(try_spawn_with(opts, f)?);
Ok(())
}
pub fn push(&mut self, handle: JoinHandle<T>) {
self.handles.push(handle);
}
pub fn len(&self) -> usize {
self.handles.len()
}
pub fn is_empty(&self) -> bool {
self.handles.is_empty()
}
pub fn join(mut self) -> std::result::Result<Vec<T>, JoinError> {
let mut results = Vec::with_capacity(self.handles.len());
for handle in self.handles.drain(..) {
results.push(handle.join()?);
}
Ok(results)
}
}
impl<T> Default for TaskBatch<T>
where
T: Send + 'static,
{
fn default() -> Self {
Self::new()
}
}
unsafe extern "C" fn group_trampoline<F>(arg: *mut c_void)
where
F: FnOnce() + Send + 'static,
{
let mut entry = Box::from_raw(arg as *mut GroupEntry<F>);
if let Some(f) = entry.f.take() {
let _ = panic::catch_unwind(AssertUnwindSafe(f));
}
}