dynamo_runtime/
distributed.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4pub use crate::component::Component;
5use crate::storage::key_value_store::{EtcdStore, KeyValueStore, MemoryStore};
6use crate::transports::nats::DRTNatsClientPrometheusMetrics;
7use crate::{
8    ErrorContext, PrometheusUpdateCallback,
9    component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
10    discovery::DiscoveryClient,
11    metrics::MetricsRegistry,
12    service::ServiceClient,
13    transports::{etcd, nats, tcp},
14};
15
16use super::utils::GracefulShutdownTracker;
17use super::{Arc, DistributedRuntime, OK, OnceCell, Result, Runtime, SystemHealth, Weak, error};
18use std::sync::OnceLock;
19
20use derive_getters::Dissolve;
21use figment::error;
22use std::collections::HashMap;
23use tokio::sync::Mutex;
24use tokio_util::sync::CancellationToken;
25
26impl MetricsRegistry for DistributedRuntime {
27    fn basename(&self) -> String {
28        "".to_string() // drt has no basename. Basename only begins with the Namespace.
29    }
30
31    fn parent_hierarchy(&self) -> Vec<String> {
32        vec![] // drt is the root, so no parent hierarchy
33    }
34}
35
36impl std::fmt::Debug for DistributedRuntime {
37    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
38        write!(f, "DistributedRuntime")
39    }
40}
41
42impl DistributedRuntime {
43    pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> {
44        let (etcd_config, nats_config, is_static) = config.dissolve();
45
46        let runtime_clone = runtime.clone();
47
48        let (etcd_client, store) = if is_static {
49            let store: Arc<dyn KeyValueStore> = Arc::new(MemoryStore::new());
50            (None, store)
51        } else {
52            let etcd_client = etcd::Client::new(etcd_config.clone(), runtime_clone).await?;
53            let store: Arc<dyn KeyValueStore> = Arc::new(EtcdStore::new(etcd_client.clone()));
54
55            (Some(etcd_client), store)
56        };
57
58        let nats_client = nats_config.clone().connect().await?;
59
60        // Start system status server for health and metrics if enabled in configuration
61        let config = crate::config::RuntimeConfig::from_settings().unwrap_or_default();
62        // IMPORTANT: We must extract cancel_token from runtime BEFORE moving runtime into the struct below.
63        // This is because after moving, runtime is no longer accessible in this scope (ownership rules).
64        let cancel_token = if config.system_server_enabled() {
65            Some(runtime.clone().child_token())
66        } else {
67            None
68        };
69        let starting_health_status = config.starting_health_status.clone();
70        let use_endpoint_health_status = config.use_endpoint_health_status.clone();
71        let health_endpoint_path = config.system_health_path.clone();
72        let live_endpoint_path = config.system_live_path.clone();
73        let system_health = Arc::new(std::sync::Mutex::new(SystemHealth::new(
74            starting_health_status,
75            use_endpoint_health_status,
76            health_endpoint_path,
77            live_endpoint_path,
78        )));
79
80        let nats_client_for_metrics = nats_client.clone();
81
82        let distributed_runtime = Self {
83            runtime,
84            etcd_client,
85            store,
86            nats_client,
87            tcp_server: Arc::new(OnceCell::new()),
88            system_status_server: Arc::new(OnceLock::new()),
89            component_registry: component::Registry::new(),
90            is_static,
91            instance_sources: Arc::new(Mutex::new(HashMap::new())),
92            hierarchy_to_metricsregistry: Arc::new(std::sync::RwLock::new(HashMap::<
93                String,
94                crate::MetricsRegistryEntry,
95            >::new())),
96            system_health,
97        };
98
99        let nats_client_metrics = DRTNatsClientPrometheusMetrics::new(
100            &distributed_runtime,
101            nats_client_for_metrics.client().clone(),
102        )?;
103        let mut drt_hierarchies = distributed_runtime.parent_hierarchy();
104        drt_hierarchies.push(distributed_runtime.hierarchy());
105        // Register a callback to update NATS client metrics
106        let nats_client_callback = Arc::new({
107            let nats_client_clone = nats_client_metrics.clone();
108            move || {
109                nats_client_clone.set_from_client_stats();
110                Ok(())
111            }
112        });
113        distributed_runtime
114            .register_prometheus_update_callback(drt_hierarchies, nats_client_callback);
115
116        // Initialize the uptime gauge in SystemHealth
117        distributed_runtime
118            .system_health
119            .lock()
120            .unwrap()
121            .initialize_uptime_gauge(&distributed_runtime)?;
122
123        // Handle system status server initialization
124        if let Some(cancel_token) = cancel_token {
125            // System server is enabled - start both the state and HTTP server
126            let host = config.system_host.clone();
127            let port = config.system_port;
128
129            // Start system status server (it creates SystemStatusState internally)
130            match crate::system_status_server::spawn_system_status_server(
131                &host,
132                port,
133                cancel_token,
134                Arc::new(distributed_runtime.clone()),
135            )
136            .await
137            {
138                Ok((addr, handle)) => {
139                    tracing::info!("System status server started successfully on {}", addr);
140
141                    // Store system status server information
142                    let system_status_server_info =
143                        crate::system_status_server::SystemStatusServerInfo::new(
144                            addr,
145                            Some(handle),
146                        );
147
148                    // Initialize the system_status_server field
149                    distributed_runtime
150                        .system_status_server
151                        .set(Arc::new(system_status_server_info))
152                        .expect("System status server info should only be set once");
153                }
154                Err(e) => {
155                    tracing::error!("System status server startup failed: {}", e);
156                }
157            }
158        } else {
159            // System server HTTP is disabled, but uptime metrics are still being tracked via SystemHealth
160            tracing::debug!(
161                "System status server HTTP endpoints disabled, but uptime metrics are being tracked"
162            );
163        }
164
165        // Start health check manager if enabled
166        if config.health_check_enabled {
167            let health_check_config = crate::health_check::HealthCheckConfig {
168                canary_wait_time: std::time::Duration::from_secs(config.canary_wait_time_secs),
169                request_timeout: std::time::Duration::from_secs(
170                    config.health_check_request_timeout_secs,
171                ),
172            };
173
174            // Start the health check manager (spawns per-endpoint monitoring tasks)
175            match crate::health_check::start_health_check_manager(
176                distributed_runtime.clone(),
177                Some(health_check_config),
178            )
179            .await
180            {
181                Ok(()) => tracing::info!(
182                    "Health check manager started (canary_wait_time: {}s, request_timeout: {}s)",
183                    config.canary_wait_time_secs,
184                    config.health_check_request_timeout_secs
185                ),
186                Err(e) => tracing::error!("Health check manager failed to start: {}", e),
187            }
188        }
189
190        Ok(distributed_runtime)
191    }
192
193    pub async fn from_settings(runtime: Runtime) -> Result<Self> {
194        let config = DistributedConfig::from_settings(false);
195        Self::new(runtime, config).await
196    }
197
198    // Call this if you are using static workers that do not need etcd-based discovery.
199    pub async fn from_settings_without_discovery(runtime: Runtime) -> Result<Self> {
200        let config = DistributedConfig::from_settings(true);
201        Self::new(runtime, config).await
202    }
203
204    pub fn runtime(&self) -> &Runtime {
205        &self.runtime
206    }
207
208    pub fn primary_token(&self) -> CancellationToken {
209        self.runtime.primary_token()
210    }
211
212    /// The etcd lease all our components will be attached to.
213    /// Not available for static workers.
214    pub fn primary_lease(&self) -> Option<etcd::Lease> {
215        self.etcd_client.as_ref().map(|c| c.primary_lease())
216    }
217
218    pub fn shutdown(&self) {
219        self.runtime.shutdown();
220    }
221
222    /// Create a [`Namespace`]
223    pub fn namespace(&self, name: impl Into<String>) -> Result<Namespace> {
224        Namespace::new(self.clone(), name.into(), self.is_static)
225    }
226
227    // /// Create a [`Component`]
228    // pub fn component(
229    //     &self,
230    //     name: impl Into<String>,
231    //     namespace: impl Into<String>,
232    // ) -> Result<Component> {
233    //     Ok(ComponentBuilder::from_runtime(self.clone())
234    //         .name(name.into())
235    //         .namespace(namespace.into())
236    //         .build()?)
237    // }
238
239    pub(crate) fn discovery_client(&self, namespace: impl Into<String>) -> DiscoveryClient {
240        DiscoveryClient::new(
241            namespace.into(),
242            self.etcd_client
243                .clone()
244                .expect("Attempt to get discovery_client on static DistributedRuntime"),
245        )
246    }
247
248    pub(crate) fn service_client(&self) -> ServiceClient {
249        ServiceClient::new(self.nats_client.clone())
250    }
251
252    pub async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> {
253        Ok(self
254            .tcp_server
255            .get_or_try_init(async move {
256                let options = tcp::server::ServerOptions::default();
257                let server = tcp::server::TcpStreamServer::new(options).await?;
258                OK(server)
259            })
260            .await?
261            .clone())
262    }
263
264    pub fn nats_client(&self) -> nats::Client {
265        self.nats_client.clone()
266    }
267
268    /// Get system status server information if available
269    pub fn system_status_server_info(
270        &self,
271    ) -> Option<Arc<crate::system_status_server::SystemStatusServerInfo>> {
272        self.system_status_server.get().cloned()
273    }
274
275    // todo(ryan): deprecate this as we move to Discovery traits and Component Identifiers
276    pub fn etcd_client(&self) -> Option<etcd::Client> {
277        self.etcd_client.clone()
278    }
279
280    /// An interface to store things. Will eventually replace `etcd_client`.
281    /// Currently does key-value, but will grow to include whatever we need to store.
282    pub fn store(&self) -> Arc<dyn KeyValueStore> {
283        self.store.clone()
284    }
285
286    pub fn child_token(&self) -> CancellationToken {
287        self.runtime.child_token()
288    }
289
290    pub(crate) fn graceful_shutdown_tracker(&self) -> Arc<GracefulShutdownTracker> {
291        self.runtime.graceful_shutdown_tracker()
292    }
293
294    pub fn instance_sources(&self) -> Arc<Mutex<HashMap<Endpoint, Weak<InstanceSource>>>> {
295        self.instance_sources.clone()
296    }
297
298    /// Add a Prometheus metric to a specific hierarchy's registry. Note that it is possible
299    /// to register the same metric name multiple times, as long as the labels are different.
300    pub fn add_prometheus_metric(
301        &self,
302        hierarchy: &str,
303        prometheus_metric: Box<dyn prometheus::core::Collector>,
304    ) -> anyhow::Result<()> {
305        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
306        let entry = registries.entry(hierarchy.to_string()).or_default();
307
308        // Try to register the metric
309        entry
310            .prometheus_registry
311            .register(prometheus_metric)
312            .map_err(|e| e.into())
313    }
314
315    /// Add a Prometheus update callback to the given hierarchies
316    /// TODO: rename this to register_callback, once we move the the MetricsRegistry trait
317    ///       out of the runtime, and make it into a composed module.
318    pub fn register_prometheus_update_callback(
319        &self,
320        hierarchies: Vec<String>,
321        callback: PrometheusUpdateCallback,
322    ) {
323        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
324        for hierarchy in &hierarchies {
325            registries
326                .entry(hierarchy.clone())
327                .or_default()
328                .add_prometheus_update_callback(callback.clone());
329        }
330    }
331
332    /// Execute all Prometheus update callbacks for a given hierarchy and return their results
333    pub fn execute_prometheus_update_callbacks(&self, hierarchy: &str) -> Vec<anyhow::Result<()>> {
334        // Clone callbacks while holding read lock (fast operation)
335        let callbacks = {
336            let registries = self.hierarchy_to_metricsregistry.read().unwrap();
337            registries
338                .get(hierarchy)
339                .map(|entry| entry.prometheus_update_callbacks.clone())
340        }; // Read lock released here
341
342        // Execute callbacks without holding the lock
343        match callbacks {
344            Some(callbacks) => callbacks.iter().map(|callback| callback()).collect(),
345            None => Vec::new(),
346        }
347    }
348
349    /// Add a Prometheus exposition text callback that returns Prometheus text for the given hierarchies
350    pub fn register_prometheus_expfmt_callback(
351        &self,
352        hierarchies: Vec<String>,
353        callback: crate::PrometheusExpositionFormatCallback,
354    ) {
355        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
356        for hierarchy in &hierarchies {
357            registries
358                .entry(hierarchy.clone())
359                .or_default()
360                .add_prometheus_expfmt_callback(callback.clone());
361        }
362    }
363
364    /// Get all registered hierarchy keys. Private because it is only used for testing.
365    fn get_registered_hierarchies(&self) -> Vec<String> {
366        let registries = self.hierarchy_to_metricsregistry.read().unwrap();
367        registries.keys().cloned().collect()
368    }
369}
370
371#[derive(Dissolve)]
372pub struct DistributedConfig {
373    pub etcd_config: etcd::ClientOptions,
374    pub nats_config: nats::ClientOptions,
375    pub is_static: bool,
376}
377
378impl DistributedConfig {
379    pub fn from_settings(is_static: bool) -> DistributedConfig {
380        DistributedConfig {
381            etcd_config: etcd::ClientOptions::default(),
382            nats_config: nats::ClientOptions::default(),
383            is_static,
384        }
385    }
386
387    pub fn for_cli() -> DistributedConfig {
388        let mut config = DistributedConfig {
389            etcd_config: etcd::ClientOptions::default(),
390            nats_config: nats::ClientOptions::default(),
391            is_static: false,
392        };
393
394        config.etcd_config.attach_lease = false;
395
396        config
397    }
398}
399
400pub mod distributed_test_utils {
401    //! Common test helper functions for DistributedRuntime tests
402    // TODO: Use in-memory DistributedRuntime for tests instead of full runtime when available.
403
404    /// Helper function to create a DRT instance for integration-only tests.
405    /// Uses from_current to leverage existing tokio runtime
406    /// Note: Settings are read from environment variables inside DistributedRuntime::from_settings_without_discovery
407    #[cfg(feature = "integration")]
408    pub async fn create_test_drt_async() -> crate::DistributedRuntime {
409        let rt = crate::Runtime::from_current().unwrap();
410        crate::DistributedRuntime::from_settings_without_discovery(rt)
411            .await
412            .unwrap()
413    }
414}
415
416#[cfg(all(test, feature = "integration"))]
417mod tests {
418    use super::distributed_test_utils::create_test_drt_async;
419
420    #[tokio::test]
421    async fn test_drt_uptime_after_delay_system_disabled() {
422        // Test uptime with system status server disabled
423        temp_env::async_with_vars([("DYN_SYSTEM_ENABLED", Some("false"))], async {
424            // Start a DRT
425            let drt = create_test_drt_async().await;
426
427            // Wait 50ms
428            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
429
430            // Check that uptime is 50+ ms
431            let uptime = drt.system_health.lock().unwrap().uptime();
432            assert!(
433                uptime >= std::time::Duration::from_millis(50),
434                "Expected uptime to be at least 50ms, but got {:?}",
435                uptime
436            );
437
438            println!(
439                "✓ DRT uptime test passed (system disabled): uptime = {:?}",
440                uptime
441            );
442        })
443        .await;
444    }
445
446    #[tokio::test]
447    async fn test_drt_uptime_after_delay_system_enabled() {
448        // Test uptime with system status server enabled
449        temp_env::async_with_vars([("DYN_SYSTEM_ENABLED", Some("true"))], async {
450            // Start a DRT
451            let drt = create_test_drt_async().await;
452
453            // Wait 50ms
454            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
455
456            // Check that uptime is 50+ ms
457            let uptime = drt.system_health.lock().unwrap().uptime();
458            assert!(
459                uptime >= std::time::Duration::from_millis(50),
460                "Expected uptime to be at least 50ms, but got {:?}",
461                uptime
462            );
463
464            println!(
465                "✓ DRT uptime test passed (system enabled): uptime = {:?}",
466                uptime
467            );
468        })
469        .await;
470    }
471}