use std::cell::RefCell;
use std::collections::HashMap;
use std::fmt::Display;
use std::future::Future;
use std::panic::AssertUnwindSafe;
#[cfg(not(target_family = "wasm"))]
use std::pin::pin;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, LazyLock, OnceLock, Weak};
#[cfg(not(target_family = "wasm"))]
use std::task::{Context, Waker};
use futures::FutureExt;
use reqwest::Client;
use tokio::runtime::{Builder as TokioRuntimeBuilder, Handle as TokioRuntimeHandle, Runtime as TokioRuntime};
use tokio::sync::oneshot;
use tokio::task::JoinHandle;
use tracing::debug;
#[cfg(not(target_family = "wasm"))]
use tracing::info;
use super::XetCommon;
use crate::config::XetConfig;
use crate::error::RuntimeError;
#[cfg(feature = "fd-track")]
use crate::fd_diagnostics::{report_fd_count, track_fd_scope};
#[cfg(not(target_family = "wasm"))]
use crate::logging::SystemMonitor;
#[cfg(not(target_family = "wasm"))]
use crate::utils::ClosureGuard as CallbackGuard;
const THREADPOOL_THREAD_ID_PREFIX: &str = "hf-xet"; const THREADPOOL_STACK_SIZE: usize = 8_000_000;
#[cfg(not(target_family = "wasm"))]
const THREADPOOL_MAX_ASYNC_THREADS: usize = 32;
#[cfg(not(target_family = "wasm"))]
fn get_num_tokio_worker_threads() -> usize {
use std::num::NonZeroUsize;
if let Ok(val) = std::env::var("TOKIO_WORKER_THREADS") {
match val.parse::<usize>() {
Ok(n) if n > 0 => {
info!("Using {n} async threads from TOKIO_WORKER_THREADS");
return n;
},
_ => {
use tracing::warn;
warn!(
value = %val,
"Invalid TOKIO_WORKER_THREADS; must be a positive integer. Falling back to auto."
);
},
}
}
let cores = std::thread::available_parallelism().map(NonZeroUsize::get).unwrap_or(1);
let n = cores.clamp(2, THREADPOOL_MAX_ASYNC_THREADS);
info!("Using {n} async threads for tokio runtime");
n
}
#[inline]
pub fn check_sigint_shutdown() -> Result<(), RuntimeError> {
if XetRuntime::current_if_exists()
.map(|rt| rt.in_sigint_shutdown())
.unwrap_or(false)
{
Err(RuntimeError::KeyboardInterrupt)
} else {
Ok(())
}
}
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum RuntimeMode {
Owned,
External,
}
type OwnedRuntimeCell = Arc<std::sync::RwLock<Option<Arc<TokioRuntime>>>>;
#[derive(Debug)]
#[cfg_attr(target_family = "wasm", allow(dead_code))]
enum RuntimeBackend {
External { handle_id: Option<tokio::runtime::Id> },
OwnedThreadPool { runtime: OwnedRuntimeCell },
}
#[cfg(target_family = "wasm")]
struct CallbackGuard<F: FnOnce()> {
callback: Option<F>,
}
#[cfg(target_family = "wasm")]
impl<F: FnOnce()> CallbackGuard<F> {
fn new(callback: F) -> Self {
Self {
callback: Some(callback),
}
}
}
#[cfg(target_family = "wasm")]
impl<F: FnOnce()> Drop for CallbackGuard<F> {
fn drop(&mut self) {
if let Some(callback) = self.callback.take() {
callback();
}
}
}
#[derive(Debug)]
pub struct XetRuntime {
backend: RuntimeBackend,
handle_ref: OnceLock<TokioRuntimeHandle>,
external_executor_count: AtomicUsize,
sigint_shutdown: AtomicBool,
common: XetCommon,
config: Arc<XetConfig>,
#[cfg(not(target_family = "wasm"))]
system_monitor: Option<SystemMonitor>,
}
thread_local! {
static THREAD_RUNTIME_REF: RefCell<Option<(u32, Weak<XetRuntime>)>> = const { RefCell::new(None) };
}
static EXTERNAL_RUNTIME_REGISTRY: LazyLock<std::sync::RwLock<HashMap<tokio::runtime::Id, Weak<XetRuntime>>>> =
LazyLock::new(|| std::sync::RwLock::new(HashMap::new()));
impl XetRuntime {
#[inline]
pub fn current() -> Arc<Self> {
if let Some(rt) = Self::current_if_exists() {
return rt;
}
let Ok(tokio_rt) = TokioRuntimeHandle::try_current() else {
panic!("ThreadPool::current() called before ThreadPool::new() or on thread outside of current runtime.");
};
Self::from_external(tokio_rt)
}
#[inline]
pub fn current_if_exists() -> Option<Arc<Self>> {
let maybe_rt = THREAD_RUNTIME_REF.with_borrow(|rt| {
rt.as_ref().and_then(|(pid, weak)| {
if *pid == std::process::id() {
weak.upgrade()
} else {
None
}
})
});
if let Some(rt) = maybe_rt {
return Some(rt);
}
if let Ok(handle) = TokioRuntimeHandle::try_current() {
if let Ok(reg) = EXTERNAL_RUNTIME_REGISTRY.read()
&& let Some(weak) = reg.get(&handle.id())
&& let Some(rt) = weak.upgrade()
{
return Some(rt);
}
Some(Self::from_external(handle))
} else {
None
}
}
pub fn new() -> Result<Arc<Self>, RuntimeError> {
Self::new_with_config(XetConfig::new())
}
pub fn new_with_config(config: XetConfig) -> Result<Arc<Self>, RuntimeError> {
#[cfg(feature = "fd-track")]
let _fd_scope = track_fd_scope("XetRuntime::new_with_config");
let runtime = Arc::new(std::sync::RwLock::new(None));
let rt = Arc::new(Self {
backend: RuntimeBackend::OwnedThreadPool {
runtime: runtime.clone(),
},
handle_ref: OnceLock::new(),
external_executor_count: 0.into(),
sigint_shutdown: false.into(),
common: XetCommon::new(&config),
#[cfg(not(target_family = "wasm"))]
system_monitor: config
.system_monitor
.enabled
.then(|| {
SystemMonitor::follow_process(
config.system_monitor.sample_interval,
config.system_monitor.log_path.clone(),
)
.ok()
})
.flatten(),
config: Arc::new(config),
});
let rt_weak = Arc::downgrade(&rt);
let pid = std::process::id();
let set_threadlocal_reference = move || {
THREAD_RUNTIME_REF.set(Some((pid, rt_weak.clone())));
};
let thread_id = AtomicUsize::new(0);
let get_thread_name = move || {
let id = thread_id.fetch_add(1, Ordering::Relaxed);
format!("{THREADPOOL_THREAD_ID_PREFIX}-{id}")
};
let mut tokio_rt_builder = {
#[cfg(not(target_family = "wasm"))]
{
TokioRuntimeBuilder::new_multi_thread()
}
#[cfg(target_family = "wasm")]
{
TokioRuntimeBuilder::new_current_thread()
}
};
#[cfg(not(target_family = "wasm"))]
{
tokio_rt_builder.worker_threads(get_num_tokio_worker_threads());
}
let tokio_rt = tokio_rt_builder
.thread_name_fn(get_thread_name) .on_thread_start(set_threadlocal_reference) .thread_stack_size(THREADPOOL_STACK_SIZE) .thread_keep_alive(std::time::Duration::from_millis(100)) .enable_all() .build()
.map_err(RuntimeError::RuntimeInit)?;
let handle = tokio_rt.handle().clone();
let tokio_rt = Arc::new(tokio_rt);
*runtime.write().unwrap() = Some(tokio_rt); rt.handle_ref.set(handle).unwrap();
#[cfg(feature = "fd-track")]
report_fd_count("XetRuntime::new_with_config complete");
Ok(rt)
}
#[cfg(not(target_family = "wasm"))]
pub fn from_validated_external(
rt_handle: TokioRuntimeHandle,
config: XetConfig,
) -> Result<Arc<Self>, RuntimeError> {
if !Self::handle_meets_requirements(&rt_handle) {
return Err(RuntimeError::InvalidRuntime(
"supplied tokio handle does not meet requirements \
(missing drivers or wrong flavor)"
.into(),
));
}
Self::from_external_with_config(rt_handle, config)
}
pub fn from_external_with_config(
rt_handle: TokioRuntimeHandle,
config: XetConfig,
) -> Result<Arc<Self>, RuntimeError> {
#[cfg(feature = "fd-track")]
let _fd_scope = track_fd_scope("XetRuntime::from_external_with_config");
let id = rt_handle.id();
let mut reg = EXTERNAL_RUNTIME_REGISTRY.write()?;
if let Some(existing) = reg.get(&id)
&& existing.upgrade().is_some()
{
return Err(RuntimeError::ExternalAlreadyAttached(id));
}
let rt = Arc::new(Self {
backend: RuntimeBackend::External { handle_id: Some(id) },
handle_ref: rt_handle.into(),
external_executor_count: 0.into(),
sigint_shutdown: false.into(),
common: XetCommon::new(&config),
#[cfg(not(target_family = "wasm"))]
system_monitor: config
.system_monitor
.enabled
.then(|| {
SystemMonitor::follow_process(
config.system_monitor.sample_interval,
config.system_monitor.log_path.clone(),
)
.ok()
})
.flatten(),
config: Arc::new(config),
});
reg.insert(id, Arc::downgrade(&rt));
#[cfg(feature = "fd-track")]
report_fd_count("XetRuntime::from_external_with_config complete");
Ok(rt)
}
pub fn from_external(rt_handle: TokioRuntimeHandle) -> Arc<Self> {
let config = XetConfig::new();
Arc::new(Self {
backend: RuntimeBackend::External { handle_id: None },
handle_ref: rt_handle.into(),
external_executor_count: 0.into(),
sigint_shutdown: false.into(),
common: XetCommon::new(&config),
#[cfg(not(target_family = "wasm"))]
system_monitor: config
.system_monitor
.enabled
.then(|| {
SystemMonitor::follow_process(
config.system_monitor.sample_interval,
config.system_monitor.log_path.clone(),
)
.ok()
})
.flatten(),
config: Arc::new(config),
})
}
#[inline]
pub fn handle(&self) -> TokioRuntimeHandle {
self.handle_ref.get().expect("Not initialized with handle set.").clone()
}
#[inline]
pub fn common(&self) -> &XetCommon {
&self.common
}
pub fn get_or_create_reqwest_client<F>(tag: String, f: F) -> crate::error::Result<Client>
where
F: FnOnce() -> std::result::Result<Client, reqwest::Error>,
{
if let Some(rt) = Self::current_if_exists() {
rt.common().get_or_create_reqwest_client(tag, f)
} else {
Ok(f()?)
}
}
#[inline]
pub fn num_worker_threads(&self) -> usize {
self.handle().metrics().num_workers()
}
#[inline]
pub fn external_executor_count(&self) -> usize {
self.external_executor_count.load(Ordering::SeqCst)
}
pub fn perform_sigint_shutdown(&self) {
#[cfg(feature = "fd-track")]
let _fd_scope = track_fd_scope("XetRuntime::perform_sigint_shutdown");
self.sigint_shutdown.store(true, Ordering::SeqCst);
if cfg!(debug_assertions) {
eprintln!("SIGINT detected, shutting down.");
}
let Some(runtime_cell) = self.runtime_cell_if_owned() else {
#[cfg(not(target_family = "wasm"))]
if let Some(monitor) = &self.system_monitor {
let _ = monitor.stop();
}
return;
};
let maybe_runtime = runtime_cell.write().expect("cancel_all called recursively.").take();
let Some(runtime) = maybe_runtime else {
eprintln!("WARNING: perform_sigint_shutdown called on runtime that has already been shut down.");
#[cfg(not(target_family = "wasm"))]
if let Some(monitor) = &self.system_monitor {
let _ = monitor.stop();
}
return;
};
drop(runtime);
#[cfg(not(target_family = "wasm"))]
if let Some(monitor) = &self.system_monitor {
let _ = monitor.stop();
}
}
pub fn discard_runtime(&self) {
let Some(runtime_cell) = self.runtime_cell_if_owned() else {
return;
};
let Ok(mut rt_lock) = runtime_cell.write() else {
return;
};
let Some(runtime) = rt_lock.take() else {
return;
};
std::mem::forget(runtime);
}
pub fn in_sigint_shutdown(&self) -> bool {
self.sigint_shutdown.load(Ordering::SeqCst)
}
fn check_sigint(&self) -> Result<(), RuntimeError> {
if self.in_sigint_shutdown() {
Err(RuntimeError::KeyboardInterrupt)
} else {
Ok(())
}
}
pub fn external_run_async_task<F>(&self, future: F) -> Result<F::Output, RuntimeError>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.external_executor_count.fetch_add(1, Ordering::SeqCst);
let _executor_count_guard = CallbackGuard::new(|| {
self.external_executor_count.fetch_sub(1, Ordering::SeqCst);
});
self.handle().block_on(async move {
self.handle().spawn(future).await.map_err(RuntimeError::from)
})
}
pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
debug!("threadpool: spawn called, {}", self);
self.handle().spawn(future)
}
pub async fn bridge_async<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
self.check_sigint()?;
match &self.backend {
RuntimeBackend::External { .. } => Ok(fut.await),
RuntimeBackend::OwnedThreadPool { .. } => self.bridge_to_owned(task_name, fut).await,
}
}
pub fn bridge_sync<F>(&self, future: F) -> Result<F::Output, RuntimeError>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.check_sigint()?;
if matches!(self.backend, RuntimeBackend::External { .. }) {
return Err(RuntimeError::InvalidRuntime(
"bridge_sync() cannot be called on an External-mode runtime; \
use the async API instead"
.into(),
));
}
self.external_executor_count.fetch_add(1, Ordering::SeqCst);
let _executor_count_guard = CallbackGuard::new(|| {
self.external_executor_count.fetch_sub(1, Ordering::SeqCst);
});
let spawn_handle = self.handle();
self.handle()
.block_on(async move { spawn_handle.spawn(future).await.map_err(RuntimeError::from) })
}
async fn bridge_to_owned<T, F>(&self, task_name: &'static str, fut: F) -> Result<T, RuntimeError>
where
F: Future<Output = T> + Send + 'static,
T: Send + 'static,
{
let (tx, rx) = oneshot::channel();
self.spawn(async move {
let result = AssertUnwindSafe(fut).catch_unwind().await;
let _ = tx.send(result);
});
match rx.await {
Ok(Ok(value)) => Ok(value),
Ok(Err(panic_payload)) => {
let msg = if let Some(s) = panic_payload.downcast_ref::<&str>() {
format!("{task_name}: {s}")
} else if let Some(s) = panic_payload.downcast_ref::<String>() {
format!("{task_name}: {s}")
} else {
format!("{task_name}: <unknown panic>")
};
Err(RuntimeError::TaskPanic(msg))
},
Err(_) => Err(RuntimeError::TaskCanceled(task_name.to_string())),
}
}
#[inline]
fn runtime_cell_if_owned(&self) -> Option<&OwnedRuntimeCell> {
match &self.backend {
RuntimeBackend::OwnedThreadPool { runtime } => Some(runtime),
RuntimeBackend::External { .. } => None,
}
}
pub fn spawn_blocking<F, R>(self: &Arc<Self>, f: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let rt_weak = Arc::downgrade(self);
self.handle().spawn_blocking(move || {
let pid = std::process::id();
THREAD_RUNTIME_REF.set(Some((pid, rt_weak)));
f()
})
}
#[inline]
pub fn config(&self) -> &Arc<XetConfig> {
&self.config
}
#[inline]
pub fn mode(&self) -> RuntimeMode {
match &self.backend {
RuntimeBackend::External { .. } => RuntimeMode::External,
RuntimeBackend::OwnedThreadPool { .. } => RuntimeMode::Owned,
}
}
#[cfg(target_family = "wasm")]
pub fn handle_meets_requirements(_handle: &TokioRuntimeHandle) -> bool {
true
}
#[cfg(not(target_family = "wasm"))]
pub fn handle_meets_requirements(handle: &TokioRuntimeHandle) -> bool {
if matches!(handle.runtime_flavor(), tokio::runtime::RuntimeFlavor::CurrentThread) {
return false;
}
let _guard = handle.enter();
let waker = Waker::noop();
let mut cx = Context::from_waker(waker);
let has_time = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut sleep = pin!(tokio::time::sleep(std::time::Duration::ZERO));
let _ = sleep.as_mut().poll(&mut cx);
}))
.is_ok();
let has_io = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut bind = pin!(tokio::net::TcpListener::bind("127.0.0.1:0"));
let _ = bind.as_mut().poll(&mut cx);
}))
.is_ok();
has_time && has_io
}
}
impl Drop for XetRuntime {
fn drop(&mut self) {
#[cfg(feature = "fd-track")]
let _fd_scope = track_fd_scope("XetRuntime::drop");
self.handle_ref.take();
if let RuntimeBackend::External { handle_id: Some(id) } = &self.backend {
if let Ok(mut reg) = EXTERNAL_RUNTIME_REGISTRY.write() {
reg.remove(id);
}
return;
}
let in_async_context = TokioRuntimeHandle::try_current().is_ok();
if let RuntimeBackend::OwnedThreadPool { runtime } = &self.backend
&& let Ok(mut guard) = runtime.write()
&& let Some(rt_arc) = guard.take()
&& let Ok(rt) = Arc::try_unwrap(rt_arc)
{
if in_async_context {
rt.shutdown_background();
} else {
rt.shutdown_timeout(std::time::Duration::from_secs(5));
}
}
}
}
impl Display for XetRuntime {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let metrics = match &self.backend {
RuntimeBackend::External { .. } => self.handle().metrics(),
RuntimeBackend::OwnedThreadPool { runtime } => {
let Ok(runtime_rlg) = runtime.try_read() else {
return write!(f, "Locked Tokio Runtime.");
};
let Some(ref runtime) = *runtime_rlg else {
return write!(f, "Terminated Tokio Runtime Handle; cancel_all_and_shutdown called.");
};
runtime.metrics()
},
};
write!(
f,
"pool: num_workers: {:?}, num_alive_tasks: {:?}, global_queue_depth: {:?}",
metrics.num_workers(),
metrics.num_alive_tasks(),
metrics.global_queue_depth()
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_get_or_create_reqwest_client_returns_client() {
let result =
XetRuntime::get_or_create_reqwest_client("test".to_string(), || reqwest::Client::builder().build());
assert!(result.is_ok());
}
#[test]
fn test_spawn_blocking_sets_current_runtime() {
let rt = XetRuntime::new().expect("Failed to create runtime");
let rt_clone = rt.clone();
let jh = rt.spawn_blocking(move || {
let current = XetRuntime::current();
Arc::ptr_eq(¤t, &rt_clone)
});
let same = rt.bridge_sync(async { jh.await.unwrap() }).unwrap();
assert!(same);
}
#[test]
fn test_current_if_exists_sees_external_runtime_config() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let mut config = XetConfig::new();
config.data.default_cas_endpoint = "https://test-endpoint.example.com".into();
let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), config).unwrap();
tokio_rt.block_on(async {
let found = XetRuntime::current_if_exists().expect("should find a runtime");
assert!(Arc::ptr_eq(&found, &xet_rt), "must be the same XetRuntime instance");
assert_eq!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
});
drop(xet_rt);
tokio_rt.block_on(async {
let found = XetRuntime::current_if_exists().expect("should still find a runtime");
assert_ne!(found.config().data.default_cas_endpoint, "https://test-endpoint.example.com");
});
}
#[test]
fn test_bridge_async_owned_mode_runs_on_pool() {
let rt = XetRuntime::new().unwrap();
assert_eq!(rt.mode(), RuntimeMode::Owned);
let result = rt.bridge_sync(async {
let inner_rt = XetRuntime::new().unwrap();
inner_rt.bridge_async("test", async { 42 }).await.unwrap()
});
assert_eq!(result.unwrap(), 42);
}
#[test]
fn test_bridge_async_external_mode_runs_directly() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
assert_eq!(xet_rt.mode(), RuntimeMode::External);
let result = tokio_rt.block_on(async { xet_rt.bridge_async("test", async { 99 }).await.unwrap() });
assert_eq!(result, 99);
}
#[test]
fn test_bridge_sync_owned_mode() {
let rt = XetRuntime::new().unwrap();
assert_eq!(rt.mode(), RuntimeMode::Owned);
let result = rt.bridge_sync(async { 123 }).unwrap();
assert_eq!(result, 123);
}
#[test]
fn test_bridge_sync_from_spawn_blocking_owned_mode() {
let rt = XetRuntime::new().unwrap();
let rt_clone = rt.clone();
let jh = rt.spawn_blocking(move || rt_clone.bridge_sync(async { 456 }).unwrap());
let result = rt.bridge_sync(async { jh.await.unwrap() }).unwrap();
assert_eq!(result, 456);
}
#[test]
fn test_bridge_sync_external_mode_returns_error() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let xet_rt = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
assert_eq!(xet_rt.mode(), RuntimeMode::External);
let result = xet_rt.bridge_sync(async { 789 });
assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
}
#[cfg(not(target_family = "wasm"))]
#[test]
fn test_handle_meets_requirements_multi_thread_all() {
let rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
assert!(XetRuntime::handle_meets_requirements(rt.handle()));
}
#[cfg(not(target_family = "wasm"))]
#[test]
fn test_handle_meets_requirements_current_thread_rejected() {
let rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
}
#[cfg(not(target_family = "wasm"))]
#[test]
fn test_handle_meets_requirements_no_drivers_rejected() {
let rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
assert!(!XetRuntime::handle_meets_requirements(rt.handle()));
}
#[cfg(not(target_family = "wasm"))]
#[test]
fn test_from_validated_external_accepts_valid_handle() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let xet_rt = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
assert_eq!(xet_rt.mode(), RuntimeMode::External);
}
#[cfg(not(target_family = "wasm"))]
#[test]
fn test_from_validated_external_rejects_current_thread_runtime() {
let tokio_rt = tokio::runtime::Builder::new_current_thread().enable_all().build().unwrap();
let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new());
assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
}
#[cfg(not(target_family = "wasm"))]
#[test]
fn test_from_validated_external_rejects_runtime_without_drivers() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().build().unwrap();
let result = XetRuntime::from_validated_external(tokio_rt.handle().clone(), XetConfig::new());
assert!(matches!(result, Err(RuntimeError::InvalidRuntime(_))));
}
#[test]
fn test_bridge_async_owned_mode_catches_panic() {
let rt = XetRuntime::new().unwrap();
let rt2 = rt.clone();
let result = rt.bridge_sync(async move {
rt2.bridge_async("panic_test", async {
panic!("intentional test panic");
})
.await
});
let err = result.unwrap().unwrap_err();
assert!(matches!(err, RuntimeError::TaskPanic(_)));
}
#[test]
fn test_from_external_with_config_duplicate_handle_fails() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let _first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new());
assert!(
matches!(second, Err(RuntimeError::ExternalAlreadyAttached(_))),
"expected ExternalAlreadyAttached for duplicate handle, got: {second:?}"
);
}
#[test]
fn test_from_external_with_config_reuse_handle_after_drop() {
let tokio_rt = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let first = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new()).unwrap();
drop(first);
let second = XetRuntime::from_external_with_config(tokio_rt.handle().clone(), XetConfig::new());
assert!(second.is_ok(), "expected Ok after previous XetRuntime was dropped, got: {second:?}");
}
#[test]
fn test_from_external_with_config_distinct_handles_both_succeed() {
let rt_a = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let rt_b = tokio::runtime::Builder::new_multi_thread().enable_all().build().unwrap();
let xet_a = XetRuntime::from_external_with_config(rt_a.handle().clone(), XetConfig::new());
let xet_b = XetRuntime::from_external_with_config(rt_b.handle().clone(), XetConfig::new());
assert!(xet_a.is_ok());
assert!(xet_b.is_ok());
}
}