dynamo_runtime/
runtime.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The [Runtime] module is the interface for [crate::component::Component]
5//! to access shared resources. These include thread pool, memory allocators and other shared resources.
6//!
7//! The [Runtime] holds the primary [`CancellationToken`] which can be used to terminate all attached
8//! [`crate::component::Component`].
9//!
10//! We expect in the future to offer topologically aware thread and memory resources, but for now the
11//! set of resources is limited to the thread pool and cancellation token.
12//!
13//! Notes: We will need to do an evaluation on what is fully public, what is pub(crate) and what is
14//! private; however, for now we are exposing most objects as fully public while the API is maturing.
15
16use super::utils::GracefulShutdownTracker;
17use super::{Result, Runtime, RuntimeType, error};
18use crate::config::{self, RuntimeConfig};
19
20use futures::Future;
21use once_cell::sync::OnceCell;
22use std::sync::{Arc, atomic::Ordering};
23use tokio::{signal, sync::Mutex, task::JoinHandle};
24
25pub use tokio_util::sync::CancellationToken;
26
27impl Runtime {
28    fn new(runtime: RuntimeType, secondary: Option<RuntimeType>) -> Result<Runtime> {
29        // worker id
30        let id = Arc::new(uuid::Uuid::new_v4().to_string());
31
32        // create a cancellation token
33        let cancellation_token = CancellationToken::new();
34
35        // create endpoint shutdown token as a child of the main token
36        let endpoint_shutdown_token = cancellation_token.child_token();
37
38        // secondary runtime for background ectd/nats tasks
39        let secondary = match secondary {
40            Some(secondary) => secondary,
41            None => {
42                tracing::debug!("Created secondary runtime with single thread");
43                RuntimeType::Shared(Arc::new(RuntimeConfig::single_threaded().create_runtime()?))
44            }
45        };
46
47        // Initialize compute pool with default config
48        // This will be properly configured when created from RuntimeConfig
49        let compute_pool = None;
50        let block_in_place_permits = None;
51
52        Ok(Runtime {
53            id,
54            primary: runtime,
55            secondary,
56            cancellation_token,
57            endpoint_shutdown_token,
58            graceful_shutdown_tracker: Arc::new(GracefulShutdownTracker::new()),
59            compute_pool,
60            block_in_place_permits,
61        })
62    }
63
64    fn new_with_config(
65        runtime: RuntimeType,
66        secondary: Option<RuntimeType>,
67        config: &RuntimeConfig,
68    ) -> Result<Runtime> {
69        let mut rt = Self::new(runtime, secondary)?;
70
71        // Create compute pool from configuration
72        let compute_config = crate::compute::ComputeConfig {
73            num_threads: config.compute_threads,
74            stack_size: config.compute_stack_size,
75            thread_prefix: config.compute_thread_prefix.clone(),
76            pin_threads: false,
77        };
78
79        // Check if compute pool is explicitly disabled
80        if config.compute_threads == Some(0) {
81            tracing::info!("Compute pool disabled (compute_threads = 0)");
82        } else {
83            match crate::compute::ComputePool::new(compute_config) {
84                Ok(pool) => {
85                    rt.compute_pool = Some(Arc::new(pool));
86                    tracing::debug!(
87                        "Initialized compute pool with {} threads",
88                        rt.compute_pool.as_ref().unwrap().num_threads()
89                    );
90                }
91                Err(e) => {
92                    tracing::warn!(
93                        "Failed to create compute pool: {}. CPU-intensive operations will use spawn_blocking",
94                        e
95                    );
96                }
97            }
98        }
99
100        // Initialize block_in_place semaphore based on actual worker threads
101        let num_workers = config
102            .num_worker_threads
103            .unwrap_or_else(|| std::thread::available_parallelism().unwrap().get());
104        // Reserve at least one thread for async work
105        let permits = num_workers.saturating_sub(1).max(1);
106        rt.block_in_place_permits = Some(Arc::new(tokio::sync::Semaphore::new(permits)));
107        tracing::debug!(
108            "Initialized block_in_place permits: {} (from {} worker threads)",
109            permits,
110            num_workers
111        );
112
113        Ok(rt)
114    }
115
116    /// Initialize thread-local compute context on the current thread
117    /// This should be called on each Tokio worker thread
118    pub fn initialize_thread_local(&self) {
119        if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
120            crate::compute::thread_local::initialize_context(Arc::clone(pool), Arc::clone(permits));
121        }
122    }
123
124    /// Initialize thread-local compute context on all worker threads using a barrier
125    /// This ensures every worker thread has its thread-local context initialized
126    pub async fn initialize_all_thread_locals(&self) -> Result<()> {
127        if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
128            // First, detect how many worker threads we actually have
129            let num_workers = self.detect_worker_thread_count().await;
130
131            if num_workers == 0 {
132                return Err(anyhow::anyhow!("No worker threads detected"));
133            }
134
135            // Create a barrier that all threads must reach
136            let barrier = Arc::new(std::sync::Barrier::new(num_workers));
137            let init_pool = Arc::clone(pool);
138            let init_permits = Arc::clone(permits);
139
140            // Spawn exactly one blocking task per worker thread
141            let mut handles = Vec::new();
142            for i in 0..num_workers {
143                let barrier_clone = Arc::clone(&barrier);
144                let pool_clone = Arc::clone(&init_pool);
145                let permits_clone = Arc::clone(&init_permits);
146
147                let handle = tokio::task::spawn_blocking(move || {
148                    // Wait at barrier - ensures all threads are participating
149                    barrier_clone.wait();
150
151                    // Now initialize thread-local storage
152                    crate::compute::thread_local::initialize_context(pool_clone, permits_clone);
153
154                    // Get thread ID for logging
155                    let thread_id = std::thread::current().id();
156                    tracing::trace!(
157                        "Initialized thread-local compute context on thread {:?} (worker {})",
158                        thread_id,
159                        i
160                    );
161                });
162                handles.push(handle);
163            }
164
165            // Wait for all tasks to complete
166            for handle in handles {
167                handle.await?;
168            }
169
170            tracing::info!(
171                "Successfully initialized thread-local compute context on {} worker threads",
172                num_workers
173            );
174        } else {
175            tracing::debug!("No compute pool configured, skipping thread-local initialization");
176        }
177        Ok(())
178    }
179
180    /// Detect the number of worker threads in the runtime
181    async fn detect_worker_thread_count(&self) -> usize {
182        use parking_lot::Mutex;
183        use std::collections::HashSet;
184
185        let thread_ids = Arc::new(Mutex::new(HashSet::new()));
186        let mut handles = Vec::new();
187
188        // Spawn many blocking tasks to ensure we hit all threads
189        // We use spawn_blocking because it runs on worker threads
190        let num_probes = 100;
191        for _ in 0..num_probes {
192            let ids = Arc::clone(&thread_ids);
193            let handle = tokio::task::spawn_blocking(move || {
194                let thread_id = std::thread::current().id();
195                ids.lock().insert(thread_id);
196            });
197            handles.push(handle);
198        }
199
200        // Wait for all probes to complete
201        for handle in handles {
202            let _ = handle.await;
203        }
204
205        let count = thread_ids.lock().len();
206        tracing::debug!("Detected {} worker threads in runtime", count);
207        count
208    }
209
210    pub fn from_current() -> Result<Runtime> {
211        Runtime::from_handle(tokio::runtime::Handle::current())
212    }
213
214    pub fn from_handle(handle: tokio::runtime::Handle) -> Result<Runtime> {
215        let primary = RuntimeType::External(handle.clone());
216        let secondary = RuntimeType::External(handle);
217        Runtime::new(primary, Some(secondary))
218    }
219
220    /// Create a [`Runtime`] instance from the settings
221    /// See [`config::RuntimeConfig::from_settings`]
222    pub fn from_settings() -> Result<Runtime> {
223        let config = config::RuntimeConfig::from_settings()?;
224        let runtime = Arc::new(config.create_runtime()?);
225        let primary = RuntimeType::Shared(runtime.clone());
226        let secondary = RuntimeType::External(runtime.handle().clone());
227        Runtime::new_with_config(primary, Some(secondary), &config)
228    }
229
230    /// Create a [`Runtime`] with two single-threaded async tokio runtime
231    pub fn single_threaded() -> Result<Runtime> {
232        let config = config::RuntimeConfig::single_threaded();
233        let owned = RuntimeType::Shared(Arc::new(config.create_runtime()?));
234        Runtime::new(owned, None)
235    }
236
237    /// Returns the unique identifier for the [`Runtime`]
238    pub fn id(&self) -> &str {
239        &self.id
240    }
241
242    /// Returns a [`tokio::runtime::Handle`] for the primary/application thread pool
243    pub fn primary(&self) -> tokio::runtime::Handle {
244        self.primary.handle()
245    }
246
247    /// Returns a [`tokio::runtime::Handle`] for the secondary/background thread pool
248    pub fn secondary(&self) -> tokio::runtime::Handle {
249        self.secondary.handle()
250    }
251
252    /// Access the primary [`CancellationToken`] for the [`Runtime`]
253    pub fn primary_token(&self) -> CancellationToken {
254        self.cancellation_token.clone()
255    }
256
257    /// Creates a child [`CancellationToken`] tied to the life-cycle of the [`Runtime`]'s endpoint shutdown token.
258    pub fn child_token(&self) -> CancellationToken {
259        self.endpoint_shutdown_token.child_token()
260    }
261
262    /// Get access to the graceful shutdown tracker
263    pub(crate) fn graceful_shutdown_tracker(&self) -> Arc<GracefulShutdownTracker> {
264        self.graceful_shutdown_tracker.clone()
265    }
266
267    /// Get access to the compute pool for CPU-intensive operations
268    ///
269    /// Returns None if the compute pool was not initialized (e.g., due to configuration error)
270    pub fn compute_pool(&self) -> Option<&Arc<crate::compute::ComputePool>> {
271        self.compute_pool.as_ref()
272    }
273
274    /// Shuts down the [`Runtime`] instance
275    pub fn shutdown(&self) {
276        tracing::info!("Runtime shutdown initiated");
277
278        // Spawn the shutdown coordination task BEFORE cancelling tokens
279        let tracker = self.graceful_shutdown_tracker.clone();
280        let main_token = self.cancellation_token.clone();
281        let endpoint_token = self.endpoint_shutdown_token.clone();
282
283        // Use the runtime handle to spawn the task
284        let handle = self.primary();
285        handle.spawn(async move {
286            // Phase 1: Cancel endpoint shutdown token to stop accepting new requests
287            tracing::info!("Phase 1: Cancelling endpoint shutdown token");
288            endpoint_token.cancel();
289
290            // Phase 2: Wait for all graceful endpoints to complete
291            tracing::info!("Phase 2: Waiting for graceful endpoints to complete");
292
293            let count = tracker.get_count();
294            tracing::info!("Active graceful endpoints: {}", count);
295
296            if count != 0 {
297                tracker.wait_for_completion().await;
298            }
299
300            // Phase 3: Now connections will be disconnected to NATS/ETCD by cancelling the main token
301            tracing::info!(
302                "Phase 3: All endpoints ended gracefully. Connections to NATS/ETCD will now be disconnected"
303            );
304            main_token.cancel();
305        });
306    }
307}
308
309impl RuntimeType {
310    /// Get [`tokio::runtime::Handle`] to runtime
311    pub fn handle(&self) -> tokio::runtime::Handle {
312        match self {
313            RuntimeType::External(rt) => rt.clone(),
314            RuntimeType::Shared(rt) => rt.handle().clone(),
315        }
316    }
317}
318
319impl std::fmt::Debug for RuntimeType {
320    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
321        match self {
322            RuntimeType::External(_) => write!(f, "RuntimeType::External"),
323            RuntimeType::Shared(_) => write!(f, "RuntimeType::Shared"),
324        }
325    }
326}