dynamo_runtime/
distributed.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4pub use crate::component::Component;
5use crate::transports::nats::DRTNatsClientPrometheusMetrics;
6use crate::{
7    ErrorContext, RuntimeCallback,
8    component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
9    discovery::DiscoveryClient,
10    metrics::MetricsRegistry,
11    service::ServiceClient,
12    transports::{etcd, nats, tcp},
13};
14
15use super::utils::GracefulShutdownTracker;
16use super::{Arc, DistributedRuntime, OK, OnceCell, Result, Runtime, SystemHealth, Weak, error};
17use std::sync::OnceLock;
18
19use derive_getters::Dissolve;
20use figment::error;
21use std::collections::HashMap;
22use tokio::sync::Mutex;
23use tokio_util::sync::CancellationToken;
24
25impl MetricsRegistry for DistributedRuntime {
26    fn basename(&self) -> String {
27        "".to_string() // drt has no basename. Basename only begins with the Namespace.
28    }
29
30    fn parent_hierarchy(&self) -> Vec<String> {
31        vec![] // drt is the root, so no parent hierarchy
32    }
33}
34
35impl std::fmt::Debug for DistributedRuntime {
36    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
37        write!(f, "DistributedRuntime")
38    }
39}
40
41impl DistributedRuntime {
42    pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> {
43        let (etcd_config, nats_config, is_static) = config.dissolve();
44
45        let runtime_clone = runtime.clone();
46
47        let etcd_client = if is_static {
48            None
49        } else {
50            Some(etcd::Client::new(etcd_config.clone(), runtime_clone).await?)
51        };
52
53        let nats_client = nats_config.clone().connect().await?;
54
55        // Start system status server for health and metrics if enabled in configuration
56        let config = crate::config::RuntimeConfig::from_settings().unwrap_or_default();
57        // IMPORTANT: We must extract cancel_token from runtime BEFORE moving runtime into the struct below.
58        // This is because after moving, runtime is no longer accessible in this scope (ownership rules).
59        let cancel_token = if config.system_server_enabled() {
60            Some(runtime.clone().child_token())
61        } else {
62            None
63        };
64        let starting_health_status = config.starting_health_status.clone();
65        let use_endpoint_health_status = config.use_endpoint_health_status.clone();
66        let health_endpoint_path = config.system_health_path.clone();
67        let live_endpoint_path = config.system_live_path.clone();
68        let system_health = Arc::new(std::sync::Mutex::new(SystemHealth::new(
69            starting_health_status,
70            use_endpoint_health_status,
71            health_endpoint_path,
72            live_endpoint_path,
73        )));
74
75        let nats_client_for_metrics = nats_client.clone();
76
77        let distributed_runtime = Self {
78            runtime,
79            etcd_client,
80            nats_client,
81            tcp_server: Arc::new(OnceCell::new()),
82            system_status_server: Arc::new(OnceLock::new()),
83            component_registry: component::Registry::new(),
84            is_static,
85            instance_sources: Arc::new(Mutex::new(HashMap::new())),
86            hierarchy_to_metricsregistry: Arc::new(std::sync::RwLock::new(HashMap::<
87                String,
88                crate::MetricsRegistryEntry,
89            >::new())),
90            system_health,
91        };
92
93        let nats_client_metrics = DRTNatsClientPrometheusMetrics::new(
94            &distributed_runtime,
95            nats_client_for_metrics.client().clone(),
96        )?;
97        let mut drt_hierarchies = distributed_runtime.parent_hierarchy();
98        drt_hierarchies.push(distributed_runtime.hierarchy());
99        // Register a callback to update NATS client metrics
100        let nats_client_callback = Arc::new({
101            let nats_client_clone = nats_client_metrics.clone();
102            move || {
103                nats_client_clone.set_from_client_stats();
104                Ok(())
105            }
106        });
107        distributed_runtime.register_metrics_callback(drt_hierarchies, nats_client_callback);
108
109        // Initialize the uptime gauge in SystemHealth
110        distributed_runtime
111            .system_health
112            .lock()
113            .unwrap()
114            .initialize_uptime_gauge(&distributed_runtime)?;
115
116        // Handle system status server initialization
117        if let Some(cancel_token) = cancel_token {
118            // System server is enabled - start both the state and HTTP server
119            let host = config.system_host.clone();
120            let port = config.system_port;
121
122            // Start system status server (it creates SystemStatusState internally)
123            match crate::system_status_server::spawn_system_status_server(
124                &host,
125                port,
126                cancel_token,
127                Arc::new(distributed_runtime.clone()),
128            )
129            .await
130            {
131                Ok((addr, handle)) => {
132                    tracing::info!("System status server started successfully on {}", addr);
133
134                    // Store system status server information
135                    let system_status_server_info =
136                        crate::system_status_server::SystemStatusServerInfo::new(
137                            addr,
138                            Some(handle),
139                        );
140
141                    // Initialize the system_status_server field
142                    distributed_runtime
143                        .system_status_server
144                        .set(Arc::new(system_status_server_info))
145                        .expect("System status server info should only be set once");
146                }
147                Err(e) => {
148                    tracing::error!("System status server startup failed: {}", e);
149                }
150            }
151        } else {
152            // System server HTTP is disabled, but uptime metrics are still being tracked via SystemHealth
153            tracing::debug!(
154                "System status server HTTP endpoints disabled, but uptime metrics are being tracked"
155            );
156        }
157
158        // Start health check manager if enabled
159        if config.health_check_enabled {
160            let health_check_config = crate::health_check::HealthCheckConfig {
161                canary_wait_time: std::time::Duration::from_secs(config.canary_wait_time_secs),
162                request_timeout: std::time::Duration::from_secs(
163                    config.health_check_request_timeout_secs,
164                ),
165            };
166
167            // Start the health check manager (spawns per-endpoint monitoring tasks)
168            match crate::health_check::start_health_check_manager(
169                distributed_runtime.clone(),
170                Some(health_check_config),
171            )
172            .await
173            {
174                Ok(()) => tracing::info!(
175                    "Health check manager started (canary_wait_time: {}s, request_timeout: {}s)",
176                    config.canary_wait_time_secs,
177                    config.health_check_request_timeout_secs
178                ),
179                Err(e) => tracing::error!("Health check manager failed to start: {}", e),
180            }
181        }
182
183        Ok(distributed_runtime)
184    }
185
186    pub async fn from_settings(runtime: Runtime) -> Result<Self> {
187        let config = DistributedConfig::from_settings(false);
188        Self::new(runtime, config).await
189    }
190
191    // Call this if you are using static workers that do not need etcd-based discovery.
192    pub async fn from_settings_without_discovery(runtime: Runtime) -> Result<Self> {
193        let config = DistributedConfig::from_settings(true);
194        Self::new(runtime, config).await
195    }
196
197    pub fn runtime(&self) -> &Runtime {
198        &self.runtime
199    }
200
201    pub fn primary_token(&self) -> CancellationToken {
202        self.runtime.primary_token()
203    }
204
205    /// The etcd lease all our components will be attached to.
206    /// Not available for static workers.
207    pub fn primary_lease(&self) -> Option<etcd::Lease> {
208        self.etcd_client.as_ref().map(|c| c.primary_lease())
209    }
210
211    pub fn shutdown(&self) {
212        self.runtime.shutdown();
213    }
214
215    /// Create a [`Namespace`]
216    pub fn namespace(&self, name: impl Into<String>) -> Result<Namespace> {
217        Namespace::new(self.clone(), name.into(), self.is_static)
218    }
219
220    // /// Create a [`Component`]
221    // pub fn component(
222    //     &self,
223    //     name: impl Into<String>,
224    //     namespace: impl Into<String>,
225    // ) -> Result<Component> {
226    //     Ok(ComponentBuilder::from_runtime(self.clone())
227    //         .name(name.into())
228    //         .namespace(namespace.into())
229    //         .build()?)
230    // }
231
232    pub(crate) fn discovery_client(&self, namespace: impl Into<String>) -> DiscoveryClient {
233        DiscoveryClient::new(
234            namespace.into(),
235            self.etcd_client
236                .clone()
237                .expect("Attempt to get discovery_client on static DistributedRuntime"),
238        )
239    }
240
241    pub(crate) fn service_client(&self) -> ServiceClient {
242        ServiceClient::new(self.nats_client.clone())
243    }
244
245    pub async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> {
246        Ok(self
247            .tcp_server
248            .get_or_try_init(async move {
249                let options = tcp::server::ServerOptions::default();
250                let server = tcp::server::TcpStreamServer::new(options).await?;
251                OK(server)
252            })
253            .await?
254            .clone())
255    }
256
257    pub fn nats_client(&self) -> nats::Client {
258        self.nats_client.clone()
259    }
260
261    /// Get system status server information if available
262    pub fn system_status_server_info(
263        &self,
264    ) -> Option<Arc<crate::system_status_server::SystemStatusServerInfo>> {
265        self.system_status_server.get().cloned()
266    }
267
268    // todo(ryan): deprecate this as we move to Discovery traits and Component Identifiers
269    pub fn etcd_client(&self) -> Option<etcd::Client> {
270        self.etcd_client.clone()
271    }
272
273    pub fn child_token(&self) -> CancellationToken {
274        self.runtime.child_token()
275    }
276
277    pub(crate) fn graceful_shutdown_tracker(&self) -> Arc<GracefulShutdownTracker> {
278        self.runtime.graceful_shutdown_tracker()
279    }
280
281    pub fn instance_sources(&self) -> Arc<Mutex<HashMap<Endpoint, Weak<InstanceSource>>>> {
282        self.instance_sources.clone()
283    }
284
285    /// Add a Prometheus metric to a specific hierarchy's registry. Note that it is possible
286    /// to register the same metric name multiple times, as long as the labels are different.
287    pub fn add_prometheus_metric(
288        &self,
289        hierarchy: &str,
290        prometheus_metric: Box<dyn prometheus::core::Collector>,
291    ) -> anyhow::Result<()> {
292        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
293        let entry = registries.entry(hierarchy.to_string()).or_default();
294
295        // Try to register the metric
296        entry
297            .prometheus_registry
298            .register(prometheus_metric)
299            .map_err(|e| e.into())
300    }
301
302    /// Add a callback function to metrics registries for the given hierarchies
303    pub fn register_metrics_callback(&self, hierarchies: Vec<String>, callback: RuntimeCallback) {
304        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
305        for hierarchy in hierarchies {
306            registries
307                .entry(hierarchy)
308                .or_default()
309                .add_callback(callback.clone());
310        }
311    }
312
313    /// Execute all callbacks for a given hierarchy key and return their results
314    pub fn execute_metrics_callbacks(&self, hierarchy: &str) -> Vec<anyhow::Result<()>> {
315        // Clone callbacks while holding read lock (fast operation)
316        let callbacks = {
317            let registries = self.hierarchy_to_metricsregistry.read().unwrap();
318            registries
319                .get(hierarchy)
320                .map(|entry| entry.runtime_callbacks.clone())
321        }; // Read lock released here
322
323        // Execute callbacks without holding the lock
324        match callbacks {
325            Some(callbacks) => callbacks.iter().map(|callback| callback()).collect(),
326            None => Vec::new(),
327        }
328    }
329
330    /// Clear everything in etcd under a key.
331    /// todo: Remove as soon as we auto-delete the MDC.
332    pub async fn temp_clear_namespace(&self, name: &str) -> anyhow::Result<()> {
333        let Some(etcd_client) = self.etcd_client() else {
334            return Ok(()); // no etcd, nothing to clear
335        };
336        let kvs = etcd_client.kv_get_prefix(name).await?;
337        for kv in kvs {
338            etcd_client.kv_delete(kv.key(), None).await?;
339        }
340        Ok(())
341    }
342
343    /// Get all registered hierarchy keys. Private because it is only used for testing.
344    fn get_registered_hierarchies(&self) -> Vec<String> {
345        let registries = self.hierarchy_to_metricsregistry.read().unwrap();
346        registries.keys().cloned().collect()
347    }
348}
349
350#[derive(Dissolve)]
351pub struct DistributedConfig {
352    pub etcd_config: etcd::ClientOptions,
353    pub nats_config: nats::ClientOptions,
354    pub is_static: bool,
355}
356
357impl DistributedConfig {
358    pub fn from_settings(is_static: bool) -> DistributedConfig {
359        DistributedConfig {
360            etcd_config: etcd::ClientOptions::default(),
361            nats_config: nats::ClientOptions::default(),
362            is_static,
363        }
364    }
365
366    pub fn for_cli() -> DistributedConfig {
367        let mut config = DistributedConfig {
368            etcd_config: etcd::ClientOptions::default(),
369            nats_config: nats::ClientOptions::default(),
370            is_static: false,
371        };
372
373        config.etcd_config.attach_lease = false;
374
375        config
376    }
377}
378
379pub mod distributed_test_utils {
380    //! Common test helper functions for DistributedRuntime tests
381    // TODO: Use in-memory DistributedRuntime for tests instead of full runtime when available.
382
383    /// Helper function to create a DRT instance for integration-only tests.
384    /// Uses from_current to leverage existing tokio runtime
385    /// Note: Settings are read from environment variables inside DistributedRuntime::from_settings_without_discovery
386    #[cfg(feature = "integration")]
387    pub async fn create_test_drt_async() -> crate::DistributedRuntime {
388        let rt = crate::Runtime::from_current().unwrap();
389        crate::DistributedRuntime::from_settings_without_discovery(rt)
390            .await
391            .unwrap()
392    }
393}
394
395#[cfg(all(test, feature = "integration"))]
396mod tests {
397    use super::distributed_test_utils::create_test_drt_async;
398
399    #[tokio::test]
400    async fn test_drt_uptime_after_delay_system_disabled() {
401        // Test uptime with system status server disabled
402        temp_env::async_with_vars([("DYN_SYSTEM_ENABLED", Some("false"))], async {
403            // Start a DRT
404            let drt = create_test_drt_async().await;
405
406            // Wait 50ms
407            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
408
409            // Check that uptime is 50+ ms
410            let uptime = drt.system_health.lock().unwrap().uptime();
411            assert!(
412                uptime >= std::time::Duration::from_millis(50),
413                "Expected uptime to be at least 50ms, but got {:?}",
414                uptime
415            );
416
417            println!(
418                "✓ DRT uptime test passed (system disabled): uptime = {:?}",
419                uptime
420            );
421        })
422        .await;
423    }
424
425    #[tokio::test]
426    async fn test_drt_uptime_after_delay_system_enabled() {
427        // Test uptime with system status server enabled
428        temp_env::async_with_vars([("DYN_SYSTEM_ENABLED", Some("true"))], async {
429            // Start a DRT
430            let drt = create_test_drt_async().await;
431
432            // Wait 50ms
433            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
434
435            // Check that uptime is 50+ ms
436            let uptime = drt.system_health.lock().unwrap().uptime();
437            assert!(
438                uptime >= std::time::Duration::from_millis(50),
439                "Expected uptime to be at least 50ms, but got {:?}",
440                uptime
441            );
442
443            println!(
444                "✓ DRT uptime test passed (system enabled): uptime = {:?}",
445                uptime
446            );
447        })
448        .await;
449    }
450}