use crate::error::{Error, Result};
use crate::sys;
use std::marker::PhantomData;
use std::os::raw::c_void;
use std::ptr;
use std::sync::Arc;
struct Inner {
key: sys::llam_task_local_key_t,
}
impl Drop for Inner {
fn drop(&mut self) {
if self.key != sys::LLAM_TASK_LOCAL_INVALID_KEY {
unsafe {
let _ = sys::llam_task_local_key_delete(self.key);
}
self.key = sys::LLAM_TASK_LOCAL_INVALID_KEY;
}
}
}
#[derive(Clone)]
pub struct TaskLocalKey<T> {
inner: Arc<Inner>,
_marker: PhantomData<fn() -> T>,
}
pub struct TaskLocalGuard<'a, T> {
key: &'a TaskLocalKey<T>,
active: bool,
}
unsafe impl<T: Send> Send for TaskLocalKey<T> {}
unsafe impl<T: Send> Sync for TaskLocalKey<T> {}
impl<T> TaskLocalKey<T> {
pub fn new() -> Result<Self> {
let mut key = sys::LLAM_TASK_LOCAL_INVALID_KEY;
let rc = unsafe { sys::llam_task_local_key_create(&mut key) };
if rc != 0 {
return Err(Error::last());
}
Ok(Self {
inner: Arc::new(Inner { key }),
_marker: PhantomData,
})
}
pub fn raw_key(&self) -> sys::llam_task_local_key_t {
self.inner.key
}
pub fn set(&self, value: T) -> Result<()> {
self.ensure_managed_task()?;
let old = unsafe { sys::llam_task_local_get(self.inner.key) };
let new_ptr = Box::into_raw(Box::new(value)) as *mut c_void;
let rc = unsafe { sys::llam_task_local_set(self.inner.key, new_ptr) };
if rc != 0 {
unsafe {
drop(Box::from_raw(new_ptr as *mut T));
}
return Err(Error::last());
}
if !old.is_null() {
unsafe {
drop(Box::from_raw(old as *mut T));
}
}
Ok(())
}
pub fn bind(&self, value: T) -> Result<TaskLocalGuard<'_, T>> {
self.set(value)?;
Ok(TaskLocalGuard {
key: self,
active: true,
})
}
pub fn with<R>(&self, value: T, f: impl FnOnce() -> R) -> Result<R> {
let guard = self.bind(value)?;
let result = f();
drop(guard);
Ok(result)
}
pub fn get_cloned(&self) -> Result<Option<T>>
where
T: Clone,
{
self.ensure_managed_task()?;
let ptr = unsafe { sys::llam_task_local_get(self.inner.key) };
if ptr.is_null() {
Ok(None)
} else {
Ok(Some(unsafe { (*(ptr as *const T)).clone() }))
}
}
pub fn take(&self) -> Result<Option<T>> {
self.ensure_managed_task()?;
let ptr = unsafe { sys::llam_task_local_get(self.inner.key) };
let rc = unsafe { sys::llam_task_local_set(self.inner.key, ptr::null_mut()) };
if rc != 0 {
return Err(Error::last());
}
if ptr.is_null() {
Ok(None)
} else {
Ok(Some(unsafe { *Box::from_raw(ptr as *mut T) }))
}
}
pub fn clear(&self) -> Result<()> {
let _ = self.take()?;
Ok(())
}
fn ensure_managed_task(&self) -> Result<()> {
let task = unsafe { sys::llam_current_task() };
if task.is_null() {
Err(Error::from_errno(libc::ENOTSUP))
} else {
Ok(())
}
}
}
impl<T> TaskLocalGuard<'_, T> {
pub fn clear(mut self) -> Result<()> {
self.active = false;
self.key.clear()
}
}
impl<T> Drop for TaskLocalGuard<'_, T> {
fn drop(&mut self) {
if self.active {
let _ = self.key.clear();
self.active = false;
}
}
}