use crate::shape_graph::{ShapeId, ShapeTransitionTable};
use std::cell::RefCell;
use std::future::Future;
use std::sync::{Arc, LazyLock, Mutex};
pub struct ShapeTableHandle {
table: Mutex<ShapeTransitionTable>,
transition_log: Mutex<Vec<(ShapeId, ShapeId)>>,
}
impl ShapeTableHandle {
pub fn new() -> Arc<Self> {
Arc::new(Self {
table: Mutex::new(ShapeTransitionTable::new()),
transition_log: Mutex::new(Vec::new()),
})
}
#[inline]
pub fn table(&self) -> &Mutex<ShapeTransitionTable> {
&self.table
}
#[inline]
pub fn transition_log(&self) -> &Mutex<Vec<(ShapeId, ShapeId)>> {
&self.transition_log
}
}
impl Default for ShapeTableHandle {
fn default() -> Self {
Self {
table: Mutex::new(ShapeTransitionTable::new()),
transition_log: Mutex::new(Vec::new()),
}
}
}
tokio::task_local! {
static CURRENT_SHAPE_TABLE: Arc<ShapeTableHandle>;
}
thread_local! {
static SYNC_CURRENT_SHAPE_TABLE: RefCell<Option<Arc<ShapeTableHandle>>> =
const { RefCell::new(None) };
}
static DEFAULT_SHAPE_TABLE: LazyLock<Arc<ShapeTableHandle>> =
LazyLock::new(ShapeTableHandle::new);
#[must_use = "the scope only lives as long as the guard is held"]
pub struct SyncShapeTableScope {
prev: Option<Arc<ShapeTableHandle>>,
}
impl SyncShapeTableScope {
pub fn enter(handle: Arc<ShapeTableHandle>) -> Self {
let prev = SYNC_CURRENT_SHAPE_TABLE
.with(|cell| cell.borrow_mut().replace(handle));
Self { prev }
}
}
impl Drop for SyncShapeTableScope {
fn drop(&mut self) {
SYNC_CURRENT_SHAPE_TABLE.with(|cell| {
*cell.borrow_mut() = self.prev.take();
});
}
}
pub fn try_current_shape_table() -> Option<Arc<ShapeTableHandle>> {
if let Ok(h) = CURRENT_SHAPE_TABLE.try_with(|h| h.clone()) {
return Some(h);
}
if let Some(h) = SYNC_CURRENT_SHAPE_TABLE.with(|cell| cell.borrow().clone()) {
return Some(h);
}
Some(DEFAULT_SHAPE_TABLE.clone())
}
pub fn current_shape_table() -> Arc<ShapeTableHandle> {
try_current_shape_table().expect(
"no current ShapeTransitionTable is active; wrap execution in \
shape_graph_current::with_async_shape_table_scope or hold a \
SyncShapeTableScope",
)
}
pub async fn with_async_shape_table_scope<R>(
handle: Arc<ShapeTableHandle>,
fut: impl Future<Output = R>,
) -> R {
CURRENT_SHAPE_TABLE.scope(handle, fut).await
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn sync_scope_push_pop_restores_previous() {
let h1 = ShapeTableHandle::new();
let h2 = ShapeTableHandle::new();
let baseline = try_current_shape_table().expect("default fallback is always present");
let outer = SyncShapeTableScope::enter(h1.clone());
assert!(Arc::ptr_eq(¤t_shape_table(), &h1));
{
let inner = SyncShapeTableScope::enter(h2.clone());
assert!(Arc::ptr_eq(¤t_shape_table(), &h2));
drop(inner);
}
assert!(Arc::ptr_eq(¤t_shape_table(), &h1));
drop(outer);
assert!(Arc::ptr_eq(¤t_shape_table(), &baseline));
}
#[test]
fn try_current_falls_back_to_process_default_without_scope() {
let first = try_current_shape_table().expect("default fallback");
let second = try_current_shape_table().expect("default fallback stable");
assert!(Arc::ptr_eq(&first, &second));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_scope_survives_task_migration() {
let handle = ShapeTableHandle::new();
let expected = handle.clone();
with_async_shape_table_scope(handle, async move {
tokio::task::yield_now().await;
let observed = current_shape_table();
assert!(Arc::ptr_eq(&observed, &expected));
let inner = ShapeTableHandle::new();
with_async_shape_table_scope(inner.clone(), async {
tokio::task::yield_now().await;
assert!(Arc::ptr_eq(¤t_shape_table(), &inner));
})
.await;
assert!(Arc::ptr_eq(¤t_shape_table(), &expected));
})
.await;
}
#[test]
fn task_local_takes_precedence_over_thread_local() {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.expect("current-thread runtime");
let sync_handle = ShapeTableHandle::new();
let async_handle = ShapeTableHandle::new();
let _guard = SyncShapeTableScope::enter(sync_handle.clone());
assert!(Arc::ptr_eq(¤t_shape_table(), &sync_handle));
rt.block_on(async {
with_async_shape_table_scope(async_handle.clone(), async {
assert!(Arc::ptr_eq(¤t_shape_table(), &async_handle));
})
.await;
});
assert!(Arc::ptr_eq(¤t_shape_table(), &sync_handle));
}
}