use super::TypeSchemaRegistry;
use std::cell::RefCell;
use std::future::Future;
use std::sync::{Arc, LazyLock};
tokio::task_local! {
static CURRENT_SCHEMA_REGISTRY: Arc<TypeSchemaRegistry>;
}
thread_local! {
static SYNC_CURRENT_SCHEMA_REGISTRY: RefCell<Option<Arc<TypeSchemaRegistry>>> =
const { RefCell::new(None) };
}
static DEFAULT_SCHEMA_REGISTRY: LazyLock<Arc<TypeSchemaRegistry>> =
LazyLock::new(|| Arc::new(TypeSchemaRegistry::new_with_stdlib()));
pub fn default_registry() -> Arc<TypeSchemaRegistry> {
DEFAULT_SCHEMA_REGISTRY.clone()
}
#[must_use = "the scope only lives as long as the guard is held"]
pub struct SyncRegistryScope {
prev: Option<Arc<TypeSchemaRegistry>>,
}
impl SyncRegistryScope {
pub fn enter(registry: Arc<TypeSchemaRegistry>) -> Self {
let prev = SYNC_CURRENT_SCHEMA_REGISTRY
.with(|cell| cell.borrow_mut().replace(registry));
Self { prev }
}
}
impl Drop for SyncRegistryScope {
fn drop(&mut self) {
SYNC_CURRENT_SCHEMA_REGISTRY.with(|cell| {
*cell.borrow_mut() = self.prev.take();
});
}
}
pub fn current_registry() -> Arc<TypeSchemaRegistry> {
if let Ok(r) = CURRENT_SCHEMA_REGISTRY.try_with(|r| r.clone()) {
return r;
}
if let Some(r) = SYNC_CURRENT_SCHEMA_REGISTRY.with(|cell| cell.borrow().clone()) {
return r;
}
DEFAULT_SCHEMA_REGISTRY.clone()
}
pub fn try_current_registry() -> Option<Arc<TypeSchemaRegistry>> {
Some(current_registry())
}
pub async fn with_async_scope<R>(
registry: Arc<TypeSchemaRegistry>,
fut: impl Future<Output = R>,
) -> R {
CURRENT_SCHEMA_REGISTRY.scope(registry, fut).await
}
#[cfg(test)]
pub(crate) fn test_runtime_scope() -> SyncRegistryScope {
let registry = Arc::new(TypeSchemaRegistry::new_with_stdlib());
SyncRegistryScope::enter(registry)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::type_schema::FieldType;
#[test]
fn sync_scope_push_pop_restores_previous() {
let r1 = Arc::new(TypeSchemaRegistry::new_with_stdlib());
let r2 = Arc::new(TypeSchemaRegistry::new_with_stdlib());
let baseline = current_registry();
assert!(Arc::ptr_eq(&baseline, &DEFAULT_SCHEMA_REGISTRY));
let outer = SyncRegistryScope::enter(r1.clone());
assert!(Arc::ptr_eq(¤t_registry(), &r1));
{
let inner = SyncRegistryScope::enter(r2.clone());
assert!(Arc::ptr_eq(¤t_registry(), &r2));
drop(inner);
}
assert!(Arc::ptr_eq(¤t_registry(), &r1));
drop(outer);
assert!(Arc::ptr_eq(¤t_registry(), &baseline));
}
#[test]
fn current_registry_falls_back_to_process_default_without_scope() {
let first = current_registry();
let second = current_registry();
assert!(Arc::ptr_eq(&first, &second));
assert!(first.has_type("Row"));
assert!(first.has_type("Option"));
assert!(first.has_type("Result"));
}
#[test]
fn test_runtime_scope_installs_stdlib_registry() {
let _guard = test_runtime_scope();
let reg = current_registry();
assert!(reg.has_type("Row"));
assert!(reg.has_type("Option"));
assert!(reg.has_type("Result"));
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn async_scope_survives_task_migration() {
let registry = Arc::new(TypeSchemaRegistry::new_with_stdlib());
let expected_id = registry.clone();
with_async_scope(registry, async move {
tokio::task::yield_now().await;
let observed = current_registry();
assert!(Arc::ptr_eq(&observed, &expected_id));
let mut inner = TypeSchemaRegistry::new();
inner.register_type("Inner", vec![("n".to_string(), FieldType::F64)]);
let inner = Arc::new(inner);
with_async_scope(inner.clone(), async {
tokio::task::yield_now().await;
assert!(Arc::ptr_eq(¤t_registry(), &inner));
})
.await;
assert!(Arc::ptr_eq(¤t_registry(), &expected_id));
})
.await;
}
#[test]
fn task_local_takes_precedence_over_thread_local() {
let rt = tokio::runtime::Builder::new_current_thread()
.build()
.expect("current-thread runtime");
let sync_reg = Arc::new(TypeSchemaRegistry::new_with_stdlib());
let async_reg = Arc::new(TypeSchemaRegistry::new_with_stdlib());
let _guard = SyncRegistryScope::enter(sync_reg.clone());
assert!(Arc::ptr_eq(¤t_registry(), &sync_reg));
rt.block_on(async {
with_async_scope(async_reg.clone(), async {
assert!(Arc::ptr_eq(¤t_registry(), &async_reg));
})
.await;
});
assert!(Arc::ptr_eq(¤t_registry(), &sync_reg));
}
}