dynamo_runtime/
runtime.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! The [Runtime] module is the interface for [crate::component::Component]
5//! to access shared resources. These include thread pool, memory allocators and other shared resources.
6//!
7//! The [Runtime] holds the primary [`CancellationToken`] which can be used to terminate all attached
8//! [`crate::component::Component`].
9//!
10//! We expect in the future to offer topologically aware thread and memory resources, but for now the
11//! set of resources is limited to the thread pool and cancellation token.
12//!
13//! Notes: We will need to do an evaluation on what is fully public, what is pub(crate) and what is
14//! private; however, for now we are exposing most objects as fully public while the API is maturing.
15
16use super::utils::GracefulShutdownTracker;
17use crate::{
18    compute,
19    config::{self, RuntimeConfig},
20};
21
22use futures::Future;
23use once_cell::sync::OnceCell;
24use std::sync::{Arc, atomic::Ordering};
25use tokio::{signal, sync::Mutex, task::JoinHandle};
26
27pub use tokio_util::sync::CancellationToken;
28
29/// Types of Tokio runtimes that can be used to construct a Dynamo [Runtime].
30#[derive(Clone, Debug)]
31enum RuntimeType {
32    Shared(Arc<tokio::runtime::Runtime>),
33    External(tokio::runtime::Handle),
34}
35
36/// Local [Runtime] which provides access to shared resources local to the physical node/machine.
37#[derive(Debug, Clone)]
38pub struct Runtime {
39    id: Arc<String>,
40    primary: RuntimeType,
41    secondary: RuntimeType,
42    cancellation_token: CancellationToken,
43    endpoint_shutdown_token: CancellationToken,
44    graceful_shutdown_tracker: Arc<GracefulShutdownTracker>,
45    compute_pool: Option<Arc<compute::ComputePool>>,
46    block_in_place_permits: Option<Arc<tokio::sync::Semaphore>>,
47}
48
49impl Runtime {
50    fn new(runtime: RuntimeType, secondary: Option<RuntimeType>) -> anyhow::Result<Runtime> {
51        // worker id
52        let id = Arc::new(uuid::Uuid::new_v4().to_string());
53
54        // create a cancellation token
55        let cancellation_token = CancellationToken::new();
56
57        // create endpoint shutdown token as a child of the main token
58        let endpoint_shutdown_token = cancellation_token.child_token();
59
60        // secondary runtime for background ectd/nats tasks
61        let secondary = match secondary {
62            Some(secondary) => secondary,
63            None => {
64                tracing::debug!("Created secondary runtime with single thread");
65                RuntimeType::Shared(Arc::new(RuntimeConfig::single_threaded().create_runtime()?))
66            }
67        };
68
69        // Initialize compute pool with default config
70        // This will be properly configured when created from RuntimeConfig
71        let compute_pool = None;
72        let block_in_place_permits = None;
73
74        Ok(Runtime {
75            id,
76            primary: runtime,
77            secondary,
78            cancellation_token,
79            endpoint_shutdown_token,
80            graceful_shutdown_tracker: Arc::new(GracefulShutdownTracker::new()),
81            compute_pool,
82            block_in_place_permits,
83        })
84    }
85
86    fn new_with_config(
87        runtime: RuntimeType,
88        secondary: Option<RuntimeType>,
89        config: &RuntimeConfig,
90    ) -> anyhow::Result<Runtime> {
91        let mut rt = Self::new(runtime, secondary)?;
92
93        // Create compute pool from configuration
94        let compute_config = crate::compute::ComputeConfig {
95            num_threads: config.compute_threads,
96            stack_size: config.compute_stack_size,
97            thread_prefix: config.compute_thread_prefix.clone(),
98            pin_threads: false,
99        };
100
101        // Check if compute pool is explicitly disabled
102        if config.compute_threads == Some(0) {
103            tracing::info!("Compute pool disabled (compute_threads = 0)");
104        } else {
105            match crate::compute::ComputePool::new(compute_config) {
106                Ok(pool) => {
107                    rt.compute_pool = Some(Arc::new(pool));
108                    tracing::debug!(
109                        "Initialized compute pool with {} threads",
110                        rt.compute_pool.as_ref().unwrap().num_threads()
111                    );
112                }
113                Err(e) => {
114                    tracing::warn!(
115                        "Failed to create compute pool: {}. CPU-intensive operations will use spawn_blocking",
116                        e
117                    );
118                }
119            }
120        }
121
122        // Initialize block_in_place semaphore based on actual worker threads
123        let num_workers = config
124            .num_worker_threads
125            .unwrap_or_else(|| std::thread::available_parallelism().unwrap().get());
126        // Reserve at least one thread for async work
127        let permits = num_workers.saturating_sub(1).max(1);
128        rt.block_in_place_permits = Some(Arc::new(tokio::sync::Semaphore::new(permits)));
129        tracing::debug!(
130            "Initialized block_in_place permits: {} (from {} worker threads)",
131            permits,
132            num_workers
133        );
134
135        Ok(rt)
136    }
137
138    /// Initialize thread-local compute context on the current thread
139    /// This should be called on each Tokio worker thread
140    pub fn initialize_thread_local(&self) {
141        if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
142            crate::compute::thread_local::initialize_context(Arc::clone(pool), Arc::clone(permits));
143        }
144    }
145
146    /// Initialize thread-local compute context on all worker threads using a barrier
147    /// This ensures every worker thread has its thread-local context initialized
148    pub async fn initialize_all_thread_locals(&self) -> anyhow::Result<()> {
149        if let (Some(pool), Some(permits)) = (&self.compute_pool, &self.block_in_place_permits) {
150            // First, detect how many worker threads we actually have
151            let num_workers = self.detect_worker_thread_count().await;
152
153            if num_workers == 0 {
154                return Err(anyhow::anyhow!("No worker threads detected"));
155            }
156
157            // Create a barrier that all threads must reach
158            let barrier = Arc::new(std::sync::Barrier::new(num_workers));
159            let init_pool = Arc::clone(pool);
160            let init_permits = Arc::clone(permits);
161
162            // Spawn exactly one blocking task per worker thread
163            let mut handles = Vec::new();
164            for i in 0..num_workers {
165                let barrier_clone = Arc::clone(&barrier);
166                let pool_clone = Arc::clone(&init_pool);
167                let permits_clone = Arc::clone(&init_permits);
168
169                let handle = tokio::task::spawn_blocking(move || {
170                    // Wait at barrier - ensures all threads are participating
171                    barrier_clone.wait();
172
173                    // Now initialize thread-local storage
174                    crate::compute::thread_local::initialize_context(pool_clone, permits_clone);
175
176                    // Get thread ID for logging
177                    let thread_id = std::thread::current().id();
178                    tracing::trace!(
179                        "Initialized thread-local compute context on thread {:?} (worker {})",
180                        thread_id,
181                        i
182                    );
183                });
184                handles.push(handle);
185            }
186
187            // Wait for all tasks to complete
188            for handle in handles {
189                handle.await?;
190            }
191
192            tracing::info!(
193                "Successfully initialized thread-local compute context on {} worker threads",
194                num_workers
195            );
196        } else {
197            tracing::debug!("No compute pool configured, skipping thread-local initialization");
198        }
199        Ok(())
200    }
201
202    /// Detect the number of worker threads in the runtime
203    async fn detect_worker_thread_count(&self) -> usize {
204        use parking_lot::Mutex;
205        use std::collections::HashSet;
206
207        let thread_ids = Arc::new(Mutex::new(HashSet::new()));
208        let mut handles = Vec::new();
209
210        // Spawn many blocking tasks to ensure we hit all threads
211        // We use spawn_blocking because it runs on worker threads
212        let num_probes = 100;
213        for _ in 0..num_probes {
214            let ids = Arc::clone(&thread_ids);
215            let handle = tokio::task::spawn_blocking(move || {
216                let thread_id = std::thread::current().id();
217                ids.lock().insert(thread_id);
218            });
219            handles.push(handle);
220        }
221
222        // Wait for all probes to complete
223        for handle in handles {
224            let _ = handle.await;
225        }
226
227        let count = thread_ids.lock().len();
228        tracing::debug!("Detected {} worker threads in runtime", count);
229        count
230    }
231
232    pub fn from_current() -> anyhow::Result<Runtime> {
233        Runtime::from_handle(tokio::runtime::Handle::current())
234    }
235
236    pub fn from_handle(handle: tokio::runtime::Handle) -> anyhow::Result<Runtime> {
237        let primary = RuntimeType::External(handle.clone());
238        let secondary = RuntimeType::External(handle);
239        Runtime::new(primary, Some(secondary))
240    }
241
242    /// Create a [`Runtime`] instance from the settings
243    /// See [`config::RuntimeConfig::from_settings`]
244    pub fn from_settings() -> anyhow::Result<Runtime> {
245        let config = config::RuntimeConfig::from_settings()?;
246        let runtime = Arc::new(config.create_runtime()?);
247        let primary = RuntimeType::Shared(runtime.clone());
248        let secondary = RuntimeType::External(runtime.handle().clone());
249        Runtime::new_with_config(primary, Some(secondary), &config)
250    }
251
252    /// Create a [`Runtime`] with two single-threaded async tokio runtime
253    pub fn single_threaded() -> anyhow::Result<Runtime> {
254        let config = config::RuntimeConfig::single_threaded();
255        let owned = RuntimeType::Shared(Arc::new(config.create_runtime()?));
256        Runtime::new(owned, None)
257    }
258
259    /// Returns the unique identifier for the [`Runtime`]
260    pub fn id(&self) -> &str {
261        &self.id
262    }
263
264    /// Returns a [`tokio::runtime::Handle`] for the primary/application thread pool
265    pub fn primary(&self) -> tokio::runtime::Handle {
266        self.primary.handle()
267    }
268
269    /// Returns a [`tokio::runtime::Handle`] for the secondary/background thread pool
270    pub fn secondary(&self) -> tokio::runtime::Handle {
271        self.secondary.handle()
272    }
273
274    /// Access the primary [`CancellationToken`] for the [`Runtime`]
275    pub fn primary_token(&self) -> CancellationToken {
276        self.cancellation_token.clone()
277    }
278
279    /// Creates a child [`CancellationToken`] tied to the life-cycle of the [`Runtime`]'s endpoint shutdown token.
280    pub fn child_token(&self) -> CancellationToken {
281        self.endpoint_shutdown_token.child_token()
282    }
283
284    /// Get access to the graceful shutdown tracker
285    pub(crate) fn graceful_shutdown_tracker(&self) -> Arc<GracefulShutdownTracker> {
286        self.graceful_shutdown_tracker.clone()
287    }
288
289    /// Get access to the compute pool for CPU-intensive operations
290    ///
291    /// Returns None if the compute pool was not initialized (e.g., due to configuration error)
292    pub fn compute_pool(&self) -> Option<&Arc<crate::compute::ComputePool>> {
293        self.compute_pool.as_ref()
294    }
295
296    /// Shuts down the [`Runtime`] instance
297    pub fn shutdown(&self) {
298        tracing::info!("Runtime shutdown initiated");
299
300        // Spawn the shutdown coordination task BEFORE cancelling tokens
301        let tracker = self.graceful_shutdown_tracker.clone();
302        let main_token = self.cancellation_token.clone();
303        let endpoint_token = self.endpoint_shutdown_token.clone();
304
305        // Use the runtime handle to spawn the task
306        let handle = self.primary();
307        handle.spawn(async move {
308            // Phase 1: Cancel endpoint shutdown token to stop accepting new requests
309            tracing::info!("Phase 1: Cancelling endpoint shutdown token");
310            endpoint_token.cancel();
311
312            // Phase 2: Wait for all graceful endpoints to complete
313            tracing::info!("Phase 2: Waiting for graceful endpoints to complete");
314
315            let count = tracker.get_count();
316            tracing::info!("Active graceful endpoints: {}", count);
317
318            if count != 0 {
319                tracker.wait_for_completion().await;
320            }
321
322            // Phase 3: Now connections will be disconnected to NATS/ETCD by cancelling the main token
323            tracing::info!(
324                "Phase 3: All endpoints ended gracefully. Connections to NATS/ETCD will now be disconnected"
325            );
326            main_token.cancel();
327        });
328    }
329}
330
331impl RuntimeType {
332    /// Get [`tokio::runtime::Handle`] to runtime
333    pub fn handle(&self) -> tokio::runtime::Handle {
334        match self {
335            RuntimeType::External(rt) => rt.clone(),
336            RuntimeType::Shared(rt) => rt.handle().clone(),
337        }
338    }
339}