llam 0.1.2

Safe, Go-style Rust bindings for the LLAM runtime
use crate::error::{Error, Result};
use crate::sys;
use std::marker::PhantomData;
use std::os::raw::c_void;
use std::ptr;
use std::sync::Arc;

struct Inner {
    key: sys::llam_task_local_key_t,
}

impl Drop for Inner {
    fn drop(&mut self) {
        if self.key != sys::LLAM_TASK_LOCAL_INVALID_KEY {
            unsafe {
                let _ = sys::llam_task_local_key_delete(self.key);
            }
            self.key = sys::LLAM_TASK_LOCAL_INVALID_KEY;
        }
    }
}

/// Typed LLAM task-local key.
///
/// Values are local to the current managed LLAM task. The C runtime stores raw
/// pointers and has no value destructor hook, so callers should `take()` or
/// `clear()` task-local values before task exit when the value owns resources.
#[derive(Clone)]
pub struct TaskLocalKey<T> {
    inner: Arc<Inner>,
    _marker: PhantomData<fn() -> T>,
}

pub struct TaskLocalGuard<'a, T> {
    key: &'a TaskLocalKey<T>,
    active: bool,
}

unsafe impl<T: Send> Send for TaskLocalKey<T> {}
unsafe impl<T: Send> Sync for TaskLocalKey<T> {}

impl<T> TaskLocalKey<T> {
    pub fn new() -> Result<Self> {
        let mut key = sys::LLAM_TASK_LOCAL_INVALID_KEY;
        let rc = unsafe { sys::llam_task_local_key_create(&mut key) };
        if rc != 0 {
            return Err(Error::last());
        }
        Ok(Self {
            inner: Arc::new(Inner { key }),
            _marker: PhantomData,
        })
    }

    pub fn raw_key(&self) -> sys::llam_task_local_key_t {
        self.inner.key
    }

    pub fn set(&self, value: T) -> Result<()> {
        self.ensure_managed_task()?;
        let old = unsafe { sys::llam_task_local_get(self.inner.key) };
        let new_ptr = Box::into_raw(Box::new(value)) as *mut c_void;
        let rc = unsafe { sys::llam_task_local_set(self.inner.key, new_ptr) };
        if rc != 0 {
            unsafe {
                drop(Box::from_raw(new_ptr as *mut T));
            }
            return Err(Error::last());
        }
        if !old.is_null() {
            unsafe {
                drop(Box::from_raw(old as *mut T));
            }
        }
        Ok(())
    }

    pub fn bind(&self, value: T) -> Result<TaskLocalGuard<'_, T>> {
        self.set(value)?;
        Ok(TaskLocalGuard {
            key: self,
            active: true,
        })
    }

    pub fn with<R>(&self, value: T, f: impl FnOnce() -> R) -> Result<R> {
        let guard = self.bind(value)?;
        let result = f();
        drop(guard);
        Ok(result)
    }

    pub fn get_cloned(&self) -> Result<Option<T>>
    where
        T: Clone,
    {
        self.ensure_managed_task()?;
        let ptr = unsafe { sys::llam_task_local_get(self.inner.key) };
        if ptr.is_null() {
            Ok(None)
        } else {
            Ok(Some(unsafe { (*(ptr as *const T)).clone() }))
        }
    }

    pub fn take(&self) -> Result<Option<T>> {
        self.ensure_managed_task()?;
        let ptr = unsafe { sys::llam_task_local_get(self.inner.key) };
        let rc = unsafe { sys::llam_task_local_set(self.inner.key, ptr::null_mut()) };
        if rc != 0 {
            return Err(Error::last());
        }
        if ptr.is_null() {
            Ok(None)
        } else {
            Ok(Some(unsafe { *Box::from_raw(ptr as *mut T) }))
        }
    }

    pub fn clear(&self) -> Result<()> {
        let _ = self.take()?;
        Ok(())
    }

    fn ensure_managed_task(&self) -> Result<()> {
        let task = unsafe { sys::llam_current_task() };
        if task.is_null() {
            Err(Error::from_errno(libc::ENOTSUP))
        } else {
            Ok(())
        }
    }
}

impl<T> TaskLocalGuard<'_, T> {
    pub fn clear(mut self) -> Result<()> {
        self.active = false;
        self.key.clear()
    }
}

impl<T> Drop for TaskLocalGuard<'_, T> {
    fn drop(&mut self) {
        if self.active {
            let _ = self.key.clear();
            self.active = false;
        }
    }
}