use super::ComputePool;
use std::cell::RefCell;
use std::sync::Arc;
use tokio::sync::Semaphore;
thread_local! {
static COMPUTE_CONTEXT: RefCell<Option<ComputeContext>> = const { RefCell::new(None) };
}
#[derive(Clone)]
pub struct ComputeContext {
pub pool: Arc<ComputePool>,
pub block_in_place_permits: Arc<Semaphore>,
}
pub fn initialize_context(pool: Arc<ComputePool>, permits: Arc<Semaphore>) {
COMPUTE_CONTEXT.with(|ctx| {
*ctx.borrow_mut() = Some(ComputeContext {
pool,
block_in_place_permits: permits,
});
});
}
pub fn with_context<F, R>(f: F) -> Option<R>
where
F: FnOnce(&ComputeContext) -> R,
{
COMPUTE_CONTEXT.with(|ctx| ctx.borrow().as_ref().map(f))
}
pub fn try_acquire_block_permit() -> Result<tokio::sync::OwnedSemaphorePermit, &'static str> {
with_context(|ctx| {
ctx.block_in_place_permits
.clone()
.try_acquire_owned()
.map_err(|_| "No permits available")
})
.ok_or("No compute context on this thread")?
}
pub fn get_pool() -> Option<Arc<ComputePool>> {
with_context(|ctx| ctx.pool.clone())
}
pub fn has_compute_context() -> bool {
with_context(|_| ()).is_some()
}
pub fn assert_compute_context() {
if !has_compute_context() {
panic!(
"Thread-local compute context not initialized! \
Compute macros will fall back to inline execution. \
Call Runtime::initialize_thread_local() on worker threads."
);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_uninitialized_context() {
assert!(get_pool().is_none());
assert!(try_acquire_block_permit().is_err());
assert!(!has_compute_context());
}
#[test]
#[should_panic(expected = "Thread-local compute context not initialized")]
fn test_assert_compute_context_panics() {
assert_compute_context();
}
}