use super::utils::GracefulShutdownTracker;
use crate::{
compute,
config::{self, RuntimeConfig},
};
use futures::Future;
use once_cell::sync::OnceCell;
use std::{
mem::ManuallyDrop,
sync::{Arc, atomic::Ordering},
};
use tokio::{signal, sync::Mutex, task::JoinHandle};
pub use tokio_util::sync::CancellationToken;
#[derive(Clone, Debug)]
enum RuntimeType {
Shared(Arc<ManuallyDrop<tokio::runtime::Runtime>>),
External(tokio::runtime::Handle),
}
#[derive(Debug, Clone)]
pub struct Runtime {
id: Arc<String>,
primary: RuntimeType,
secondary: RuntimeType,
cancellation_token: CancellationToken,
endpoint_shutdown_token: CancellationToken,
graceful_shutdown_tracker: Arc<GracefulShutdownTracker>,
compute_pool: Option<Arc<compute::ComputePool>>,
block_in_place_permits: Option<Arc<tokio::sync::Semaphore>>,
}
impl Runtime {
fn new(runtime: RuntimeType, secondary: Option<RuntimeType>) -> anyhow::Result<Runtime> {
let id = Arc::new(uuid::Uuid::new_v4().to_string());
let cancellation_token = CancellationToken::new();
let endpoint_shutdown_token = cancellation_token.child_token();
let secondary = match secondary {
Some(secondary) => secondary,
None => {
tracing::debug!("Created secondary runtime with single thread");
RuntimeType::Shared(Arc::new(ManuallyDrop::new(
RuntimeConfig::single_threaded().create_runtime()?,
)))
}
};
let compute_pool = None;
let block_in_place_permits = None;
Ok(Runtime {
id,
primary: runtime,
secondary,
cancellation_token,
endpoint_shutdown_token,
graceful_shutdown_tracker: Arc::new(GracefulShutdownTracker::new()),
compute_pool,
block_in_place_permits,
})
}
fn new_with_config(
runtime: RuntimeType,
secondary: Option<RuntimeType>,
config: &RuntimeConfig,
) -> anyhow::Result<Runtime> {
let mut rt = Self::new(runtime, secondary)?;
let compute_config = crate::compute::ComputeConfig {
num_threads: config.compute_threads,
stack_size: config.compute_stack_size,
thread_prefix: config.compute_thread_prefix.clone(),
pin_threads: false,
};
if config.compute_threads == Some(0) {
tracing::info!("Compute pool disabled (compute_threads = 0)");
} else {
match crate::compute::ComputePool::new(compute_config) {
Ok(pool) => {
rt.compute_pool = Some(Arc::new(pool));
tracing::debug!(
"Initialized compute pool with {} threads",
rt.compute_pool.as_ref().unwrap().num_threads()
);
}
Err(e) => {
tracing::warn!(
"Failed to create compute pool: {}. CPU-intensive operations will use spawn_blocking",
e
);
}
}
}
let num_workers = config
.num_worker_threads
.unwrap_or_else(|| std::thread::available_parallelism().unwrap().get());
let permits = num_workers.saturating_sub(1).max(1);
rt.block_in_place_permits = Some(Arc::new(tokio::sync::Semaphore::new(permits)));
tracing::debug!(
"Initialized block_in_place permits: {} (from {} worker threads)",
permits,
num_workers
);
Ok(rt)
}
pub fn initialize_thread_local(&self) {
if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
crate::compute::thread_local::initialize_context(Arc::clone(pool), Arc::clone(permits));
}
}
pub async fn initialize_all_thread_locals(&self) -> anyhow::Result<()> {
if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
let num_workers = self.detect_worker_thread_count().await;
if num_workers == 0 {
return Err(anyhow::anyhow!("No worker threads detected"));
}
let barrier = Arc::new(std::sync::Barrier::new(num_workers));
let init_pool = Arc::clone(pool);
let init_permits = Arc::clone(permits);
let mut handles = Vec::new();
for i in 0..num_workers {
let barrier_clone = Arc::clone(&barrier);
let pool_clone = Arc::clone(&init_pool);
let permits_clone = Arc::clone(&init_permits);
let handle = tokio::task::spawn_blocking(move || {
barrier_clone.wait();
crate::compute::thread_local::initialize_context(pool_clone, permits_clone);
let thread_id = std::thread::current().id();
tracing::trace!(
"Initialized thread-local compute context on thread {:?} (worker {})",
thread_id,
i
);
});
handles.push(handle);
}
for handle in handles {
handle.await?;
}
tracing::info!(
"Successfully initialized thread-local compute context on {} worker threads",
num_workers
);
} else {
tracing::debug!("No compute pool configured, skipping thread-local initialization");
}
Ok(())
}
async fn detect_worker_thread_count(&self) -> usize {
use parking_lot::Mutex;
use std::collections::HashSet;
let thread_ids = Arc::new(Mutex::new(HashSet::new()));
let mut handles = Vec::new();
let num_probes = 100;
for _ in 0..num_probes {
let ids = Arc::clone(&thread_ids);
let handle = tokio::task::spawn_blocking(move || {
let thread_id = std::thread::current().id();
ids.lock().insert(thread_id);
});
handles.push(handle);
}
for handle in handles {
let _ = handle.await;
}
let count = thread_ids.lock().len();
tracing::debug!("Detected {} worker threads in runtime", count);
count
}
pub fn from_current() -> anyhow::Result<Runtime> {
Runtime::from_handle(tokio::runtime::Handle::current())
}
pub fn from_handle(handle: tokio::runtime::Handle) -> anyhow::Result<Runtime> {
let primary = RuntimeType::External(handle.clone());
let secondary = RuntimeType::External(handle);
Runtime::new(primary, Some(secondary))
}
pub fn from_settings() -> anyhow::Result<Runtime> {
let config = config::RuntimeConfig::from_settings()?;
let runtime = Arc::new(ManuallyDrop::new(config.create_runtime()?));
let primary = RuntimeType::Shared(runtime.clone());
let secondary = RuntimeType::External(runtime.handle().clone());
Runtime::new_with_config(primary, Some(secondary), &config)
}
pub fn single_threaded() -> anyhow::Result<Runtime> {
let config = config::RuntimeConfig::single_threaded();
let owned = RuntimeType::Shared(Arc::new(ManuallyDrop::new(config.create_runtime()?)));
Runtime::new(owned, None)
}
pub fn id(&self) -> &str {
&self.id
}
pub fn primary(&self) -> tokio::runtime::Handle {
self.primary.handle()
}
pub fn secondary(&self) -> tokio::runtime::Handle {
self.secondary.handle()
}
pub fn primary_token(&self) -> CancellationToken {
self.cancellation_token.clone()
}
pub fn child_token(&self) -> CancellationToken {
self.endpoint_shutdown_token.child_token()
}
pub(crate) fn graceful_shutdown_tracker(&self) -> Arc<GracefulShutdownTracker> {
self.graceful_shutdown_tracker.clone()
}
pub fn compute_pool(&self) -> Option<&Arc<crate::compute::ComputePool>> {
self.compute_pool.as_ref()
}
pub fn shutdown(&self) {
tracing::info!("Runtime shutdown initiated");
let tracker = self.graceful_shutdown_tracker.clone();
let main_token = self.cancellation_token.clone();
let endpoint_token = self.endpoint_shutdown_token.clone();
let handle = self.primary();
handle.spawn(async move {
tracing::info!("Phase 1: Cancelling endpoint shutdown token");
endpoint_token.cancel();
tracing::info!("Phase 2: Waiting for graceful endpoints to complete");
let count = tracker.get_count();
tracing::info!("Active graceful endpoints: {}", count);
if count != 0 {
tracker.wait_for_completion().await;
}
tracing::info!(
"Phase 3: All endpoints ended gracefully. Connections to backend services will now be disconnected"
);
main_token.cancel();
});
}
}
impl RuntimeType {
pub fn handle(&self) -> tokio::runtime::Handle {
match self {
RuntimeType::External(rt) => rt.clone(),
RuntimeType::Shared(rt) => rt.handle().clone(),
}
}
}
impl Drop for RuntimeType {
fn drop(&mut self) {
match self {
RuntimeType::External(_) => {}
RuntimeType::Shared(arc) => {
let Some(md_runtime) = Arc::get_mut(arc) else {
return;
};
if tokio::runtime::Handle::try_current().is_ok() {
let tokio_runtime = unsafe { ManuallyDrop::take(md_runtime) };
tokio_runtime.shutdown_background();
} else {
unsafe { ManuallyDrop::drop(md_runtime) };
}
}
}
}
}