Skip to main content

dynamo_runtime/
runtime.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The [Runtime] module is the interface for [crate::component::Component]
5//! to access shared resources. These include thread pool, memory allocators and other shared resources.
6//!
7//! The [Runtime] holds the primary [`CancellationToken`] which can be used to terminate all attached
8//! [`crate::component::Component`].
9//!
10//! We expect in the future to offer topologically aware thread and memory resources, but for now the
11//! set of resources is limited to the thread pool and cancellation token.
12//!
13//! Notes: We will need to do an evaluation on what is fully public, what is pub(crate) and what is
14//! private; however, for now we are exposing most objects as fully public while the API is maturing.
15
16use super::utils::GracefulShutdownTracker;
17use crate::{
18    compute,
19    config::{self, RuntimeConfig},
20};
21
22use futures::Future;
23use once_cell::sync::OnceCell;
24use std::{
25    mem::ManuallyDrop,
26    sync::{Arc, atomic::Ordering},
27};
28use tokio::{signal, sync::Mutex, task::JoinHandle};
29
30pub use tokio_util::sync::CancellationToken;
31
32/// Types of Tokio runtimes that can be used to construct a Dynamo [Runtime].
33#[derive(Clone, Debug)]
34enum RuntimeType {
35    Shared(Arc<ManuallyDrop<tokio::runtime::Runtime>>),
36    External(tokio::runtime::Handle),
37}
38
39/// Local [Runtime] which provides access to shared resources local to the physical node/machine.
40#[derive(Debug, Clone)]
41pub struct Runtime {
42    id: Arc<String>,
43    primary: RuntimeType,
44    secondary: RuntimeType,
45    cancellation_token: CancellationToken,
46    endpoint_shutdown_token: CancellationToken,
47    graceful_shutdown_tracker: Arc<GracefulShutdownTracker>,
48    compute_pool: Option<Arc<compute::ComputePool>>,
49    block_in_place_permits: Option<Arc<tokio::sync::Semaphore>>,
50}
51
52impl Runtime {
53    fn new(runtime: RuntimeType, secondary: Option<RuntimeType>) -> anyhow::Result<Runtime> {
54        // Initialise NVTX toggle once from environment (no-op when feature is off)
55        crate::nvtx::init();
56
57        // worker id
58        let id = Arc::new(uuid::Uuid::new_v4().to_string());
59
60        // create a cancellation token
61        let cancellation_token = CancellationToken::new();
62
63        // create endpoint shutdown token as a child of the main token
64        let endpoint_shutdown_token = cancellation_token.child_token();
65
66        // secondary runtime for background ectd/nats tasks
67        let secondary = match secondary {
68            Some(secondary) => secondary,
69            None => {
70                tracing::debug!("Created secondary runtime with single thread");
71                RuntimeType::Shared(Arc::new(ManuallyDrop::new(
72                    RuntimeConfig::single_threaded().create_runtime()?,
73                )))
74            }
75        };
76
77        // Initialize compute pool with default config
78        // This will be properly configured when created from RuntimeConfig
79        let compute_pool = None;
80        let block_in_place_permits = None;
81
82        Ok(Runtime {
83            id,
84            primary: runtime,
85            secondary,
86            cancellation_token,
87            endpoint_shutdown_token,
88            graceful_shutdown_tracker: Arc::new(GracefulShutdownTracker::new()),
89            compute_pool,
90            block_in_place_permits,
91        })
92    }
93
94    fn new_with_config(
95        runtime: RuntimeType,
96        secondary: Option<RuntimeType>,
97        config: &RuntimeConfig,
98    ) -> anyhow::Result<Runtime> {
99        let mut rt = Self::new(runtime, secondary)?;
100
101        // Create compute pool from configuration
102        let compute_config = crate::compute::ComputeConfig {
103            num_threads: config.compute_threads,
104            stack_size: config.compute_stack_size,
105            thread_prefix: config.compute_thread_prefix.clone(),
106            pin_threads: false,
107        };
108
109        // Check if compute pool is explicitly disabled
110        if config.compute_threads == Some(0) {
111            tracing::info!("Compute pool disabled (compute_threads = 0)");
112        } else {
113            match crate::compute::ComputePool::new(compute_config) {
114                Ok(pool) => {
115                    rt.compute_pool = Some(Arc::new(pool));
116                    tracing::debug!(
117                        "Initialized compute pool with {} threads",
118                        rt.compute_pool.as_ref().unwrap().num_threads()
119                    );
120                }
121                Err(e) => {
122                    tracing::warn!(
123                        "Failed to create compute pool: {}. CPU-intensive operations will use spawn_blocking",
124                        e
125                    );
126                }
127            }
128        }
129
130        // Initialize block_in_place semaphore based on actual worker threads
131        let num_workers = config
132            .num_worker_threads
133            .unwrap_or_else(|| std::thread::available_parallelism().unwrap().get());
134        // Reserve at least one thread for async work
135        let permits = num_workers.saturating_sub(1).max(1);
136        rt.block_in_place_permits = Some(Arc::new(tokio::sync::Semaphore::new(permits)));
137        tracing::debug!(
138            "Initialized block_in_place permits: {} (from {} worker threads)",
139            permits,
140            num_workers
141        );
142
143        Ok(rt)
144    }
145
146    /// Initialize thread-local compute context on the current thread
147    /// This should be called on each Tokio worker thread
148    pub fn initialize_thread_local(&self) {
149        if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
150            crate::compute::thread_local::initialize_context(Arc::clone(pool), Arc::clone(permits));
151        }
152        // Name this worker thread in the Nsight Systems timeline (no-op when nvtx feature is off)
153        let thread_name = std::thread::current()
154            .name()
155            .map(|n| n.to_string())
156            .unwrap_or_else(|| format!("tokio-worker-{:?}", std::thread::current().id()));
157        crate::nvtx::name_current_thread_impl(&thread_name);
158    }
159
160    /// Initialize thread-local compute context on all worker threads using a barrier
161    /// This ensures every worker thread has its thread-local context initialized
162    pub async fn initialize_all_thread_locals(&self) -> anyhow::Result<()> {
163        if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
164            // First, detect how many worker threads we actually have
165            let num_workers = self.detect_worker_thread_count().await;
166
167            if num_workers == 0 {
168                return Err(anyhow::anyhow!("No worker threads detected"));
169            }
170
171            // Create a barrier that all threads must reach
172            let barrier = Arc::new(std::sync::Barrier::new(num_workers));
173            let init_pool = Arc::clone(pool);
174            let init_permits = Arc::clone(permits);
175
176            // Spawn exactly one blocking task per worker thread
177            let mut handles = Vec::new();
178            for i in 0..num_workers {
179                let barrier_clone = Arc::clone(&barrier);
180                let pool_clone = Arc::clone(&init_pool);
181                let permits_clone = Arc::clone(&init_permits);
182
183                let handle = tokio::task::spawn_blocking(move || {
184                    // Wait at barrier - ensures all threads are participating
185                    barrier_clone.wait();
186
187                    // Now initialize thread-local storage
188                    crate::compute::thread_local::initialize_context(pool_clone, permits_clone);
189
190                    // Get thread ID for logging
191                    let thread_id = std::thread::current().id();
192                    tracing::trace!(
193                        "Initialized thread-local compute context on thread {:?} (worker {})",
194                        thread_id,
195                        i
196                    );
197                });
198                handles.push(handle);
199            }
200
201            // Wait for all tasks to complete
202            for handle in handles {
203                handle.await?;
204            }
205
206            tracing::info!(
207                "Successfully initialized thread-local compute context on {} worker threads",
208                num_workers
209            );
210        } else {
211            tracing::debug!("No compute pool configured, skipping thread-local initialization");
212        }
213        Ok(())
214    }
215
216    /// Detect the number of worker threads in the runtime
217    async fn detect_worker_thread_count(&self) -> usize {
218        use parking_lot::Mutex;
219        use std::collections::HashSet;
220
221        let thread_ids = Arc::new(Mutex::new(HashSet::new()));
222        let mut handles = Vec::new();
223
224        // Spawn many blocking tasks to ensure we hit all threads
225        // We use spawn_blocking because it runs on worker threads
226        let num_probes = 100;
227        for _ in 0..num_probes {
228            let ids = Arc::clone(&thread_ids);
229            let handle = tokio::task::spawn_blocking(move || {
230                let thread_id = std::thread::current().id();
231                ids.lock().insert(thread_id);
232            });
233            handles.push(handle);
234        }
235
236        // Wait for all probes to complete
237        for handle in handles {
238            let _ = handle.await;
239        }
240
241        let count = thread_ids.lock().len();
242        tracing::debug!("Detected {count} worker threads in runtime");
243        count
244    }
245
246    pub fn from_current() -> anyhow::Result<Runtime> {
247        Runtime::from_handle(tokio::runtime::Handle::current())
248    }
249
250    pub fn from_handle(handle: tokio::runtime::Handle) -> anyhow::Result<Runtime> {
251        let primary = RuntimeType::External(handle.clone());
252        let secondary = RuntimeType::External(handle);
253        Runtime::new(primary, Some(secondary))
254    }
255
256    /// Create a [`Runtime`] instance from the settings
257    /// See [`config::RuntimeConfig::from_settings`]
258    pub fn from_settings() -> anyhow::Result<Runtime> {
259        let config = config::RuntimeConfig::from_settings()?;
260        let runtime = Arc::new(ManuallyDrop::new(config.create_runtime()?));
261        let primary = RuntimeType::Shared(runtime.clone());
262        let secondary = RuntimeType::External(runtime.handle().clone());
263        Runtime::new_with_config(primary, Some(secondary), &config)
264    }
265
266    /// Create a [`Runtime`] with two single-threaded async tokio runtime
267    pub fn single_threaded() -> anyhow::Result<Runtime> {
268        let config = config::RuntimeConfig::single_threaded();
269        let owned = RuntimeType::Shared(Arc::new(ManuallyDrop::new(config.create_runtime()?)));
270        Runtime::new(owned, None)
271    }
272
273    /// Returns the unique identifier for the [`Runtime`]
274    pub fn id(&self) -> &str {
275        &self.id
276    }
277
278    /// Returns a [`tokio::runtime::Handle`] for the primary/application thread pool
279    pub fn primary(&self) -> tokio::runtime::Handle {
280        self.primary.handle()
281    }
282
283    /// Returns a [`tokio::runtime::Handle`] for the secondary/background thread pool
284    pub fn secondary(&self) -> tokio::runtime::Handle {
285        self.secondary.handle()
286    }
287
288    /// Access the primary [`CancellationToken`] for the [`Runtime`]
289    pub fn primary_token(&self) -> CancellationToken {
290        self.cancellation_token.clone()
291    }
292
293    /// Creates a child [`CancellationToken`] tied to the life-cycle of the [`Runtime`]'s endpoint shutdown token.
294    pub fn child_token(&self) -> CancellationToken {
295        self.endpoint_shutdown_token.child_token()
296    }
297
298    /// Get access to the graceful shutdown tracker
299    pub(crate) fn graceful_shutdown_tracker(&self) -> Arc<GracefulShutdownTracker> {
300        self.graceful_shutdown_tracker.clone()
301    }
302
303    /// Get access to the compute pool for CPU-intensive operations
304    ///
305    /// Returns None if the compute pool was not initialized (e.g., due to configuration error)
306    pub fn compute_pool(&self) -> Option<&Arc<crate::compute::ComputePool>> {
307        self.compute_pool.as_ref()
308    }
309
310    /// Shuts down the [`Runtime`] instance
311    pub fn shutdown(&self) {
312        tracing::info!("Runtime shutdown initiated");
313
314        // Spawn the shutdown coordination task BEFORE cancelling tokens
315        let tracker = self.graceful_shutdown_tracker.clone();
316        let main_token = self.cancellation_token.clone();
317        let endpoint_token = self.endpoint_shutdown_token.clone();
318
319        // Use the runtime handle to spawn the task
320        let handle = self.primary();
321        handle.spawn(async move {
322            // Phase 1: Cancel endpoint shutdown token to stop accepting new requests
323            tracing::info!("Phase 1: Cancelling endpoint shutdown token");
324            endpoint_token.cancel();
325
326            // Phase 2: Wait for all graceful endpoints to complete
327            tracing::info!("Phase 2: Waiting for graceful endpoints to complete");
328
329            let count = tracker.get_count();
330            tracing::info!("Active graceful endpoints: {count}");
331
332            if count != 0 {
333                tracker.wait_for_completion().await;
334            }
335
336            // Phase 3: Now connections will be disconnected to backend services (e.g. NATS/ETCD) by cancelling the main token
337            tracing::info!(
338                "Phase 3: All endpoints ended gracefully. Connections to backend services will now be disconnected"
339            );
340            main_token.cancel();
341        });
342    }
343}
344
345impl RuntimeType {
346    /// Get [`tokio::runtime::Handle`] to runtime
347    pub fn handle(&self) -> tokio::runtime::Handle {
348        match self {
349            RuntimeType::External(rt) => rt.clone(),
350            RuntimeType::Shared(rt) => rt.handle().clone(),
351        }
352    }
353}
354
355/// Handle dropping a tokio runtime from an async context.
356///
357/// When used from the Python bindings the runtime will be dropped from (I think) Python's asyncio.
358/// Tokio does not allow this and will panic. That panic prevents logging from printing it's last
359/// messages, which makes knowing what went wrong very difficult.
360///
361/// This is the panic:
362/// > pyo3_runtime.PanicException: Cannot drop a runtime in a context where blocking is not allowed.
363/// > This happens when a runtime is dropped from within an asynchronous context.
364///
365/// Hence we wrap the runtime in a ManuallyDrop and use tokio's alternative shutdown if we detect
366/// that we are inside an async runtime.
367impl Drop for RuntimeType {
368    fn drop(&mut self) {
369        match self {
370            RuntimeType::External(_) => {}
371            RuntimeType::Shared(arc) => {
372                let Some(md_runtime) = Arc::get_mut(arc) else {
373                    // Only drop if we are the only owner of the shared pointer, meaning
374                    // one strong count and no weak count.
375                    return;
376                };
377                if tokio::runtime::Handle::try_current().is_ok() {
378                    // We are inside an async runtime.
379                    let tokio_runtime = unsafe { ManuallyDrop::take(md_runtime) };
380                    tokio_runtime.shutdown_background();
381                } else {
382                    // We are not inside an async context, dropping the runtime is safe.
383                    //
384                    // We never reach this case. I'm not sure why, something about the interaction
385                    // with pyo3 and Python lifetimes.
386                    //
387                    // Process is gone so doesn't really matter, but TODO now that we realize it.
388                    unsafe { ManuallyDrop::drop(md_runtime) };
389                }
390            }
391        }
392    }
393}