dynamo_runtime/
distributed.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16pub use crate::component::Component;
17use crate::transports::nats::DRTNatsClientPrometheusMetrics;
18use crate::{
19    ErrorContext, RuntimeCallback,
20    component::{self, ComponentBuilder, Endpoint, InstanceSource, Namespace},
21    discovery::DiscoveryClient,
22    metrics::MetricsRegistry,
23    service::ServiceClient,
24    transports::{etcd, nats, tcp},
25};
26
27use super::utils::GracefulShutdownTracker;
28use super::{Arc, DistributedRuntime, OK, OnceCell, Result, Runtime, SystemHealth, Weak, error};
29use std::sync::OnceLock;
30
31use derive_getters::Dissolve;
32use figment::error;
33use std::collections::HashMap;
34use tokio::sync::Mutex;
35use tokio_util::sync::CancellationToken;
36
37impl MetricsRegistry for DistributedRuntime {
38    fn basename(&self) -> String {
39        "".to_string() // drt has no basename. Basename only begins with the Namespace.
40    }
41
42    fn parent_hierarchy(&self) -> Vec<String> {
43        vec![] // drt is the root, so no parent hierarchy
44    }
45}
46
47impl std::fmt::Debug for DistributedRuntime {
48    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
49        write!(f, "DistributedRuntime")
50    }
51}
52
53impl DistributedRuntime {
54    pub async fn new(runtime: Runtime, config: DistributedConfig) -> Result<Self> {
55        let (etcd_config, nats_config, is_static) = config.dissolve();
56
57        let runtime_clone = runtime.clone();
58
59        let etcd_client = if is_static {
60            None
61        } else {
62            Some(etcd::Client::new(etcd_config.clone(), runtime_clone).await?)
63        };
64
65        let nats_client = nats_config.clone().connect().await?;
66
67        // Start system status server for health and metrics if enabled in configuration
68        let config = crate::config::RuntimeConfig::from_settings().unwrap_or_default();
69        // IMPORTANT: We must extract cancel_token from runtime BEFORE moving runtime into the struct below.
70        // This is because after moving, runtime is no longer accessible in this scope (ownership rules).
71        let cancel_token = if config.system_server_enabled() {
72            Some(runtime.clone().child_token())
73        } else {
74            None
75        };
76        let starting_health_status = config.starting_health_status.clone();
77        let use_endpoint_health_status = config.use_endpoint_health_status.clone();
78        let health_endpoint_path = config.system_health_path.clone();
79        let live_endpoint_path = config.system_live_path.clone();
80        let system_health = Arc::new(std::sync::Mutex::new(SystemHealth::new(
81            starting_health_status,
82            use_endpoint_health_status,
83            health_endpoint_path,
84            live_endpoint_path,
85        )));
86
87        let nats_client_for_metrics = nats_client.clone();
88
89        let distributed_runtime = Self {
90            runtime,
91            etcd_client,
92            nats_client,
93            tcp_server: Arc::new(OnceCell::new()),
94            system_status_server: Arc::new(OnceLock::new()),
95            component_registry: component::Registry::new(),
96            is_static,
97            instance_sources: Arc::new(Mutex::new(HashMap::new())),
98            hierarchy_to_metricsregistry: Arc::new(std::sync::RwLock::new(HashMap::<
99                String,
100                crate::MetricsRegistryEntry,
101            >::new())),
102            system_health,
103        };
104
105        let nats_client_metrics = DRTNatsClientPrometheusMetrics::new(
106            &distributed_runtime,
107            nats_client_for_metrics.client().clone(),
108        )?;
109        let mut drt_hierarchies = distributed_runtime.parent_hierarchy();
110        drt_hierarchies.push(distributed_runtime.hierarchy());
111        // Register a callback to update NATS client metrics
112        let nats_client_callback = Arc::new({
113            let nats_client_clone = nats_client_metrics.clone();
114            move || {
115                nats_client_clone.set_from_client_stats();
116                Ok(())
117            }
118        });
119        distributed_runtime.register_metrics_callback(drt_hierarchies, nats_client_callback);
120
121        // Initialize the uptime gauge in SystemHealth
122        distributed_runtime
123            .system_health
124            .lock()
125            .unwrap()
126            .initialize_uptime_gauge(&distributed_runtime)?;
127
128        // Handle system status server initialization
129        if let Some(cancel_token) = cancel_token {
130            // System server is enabled - start both the state and HTTP server
131            let host = config.system_host.clone();
132            let port = config.system_port;
133
134            // Start system status server (it creates SystemStatusState internally)
135            match crate::system_status_server::spawn_system_status_server(
136                &host,
137                port,
138                cancel_token,
139                Arc::new(distributed_runtime.clone()),
140            )
141            .await
142            {
143                Ok((addr, handle)) => {
144                    tracing::info!("System status server started successfully on {}", addr);
145
146                    // Store system status server information
147                    let system_status_server_info =
148                        crate::system_status_server::SystemStatusServerInfo::new(
149                            addr,
150                            Some(handle),
151                        );
152
153                    // Initialize the system_status_server field
154                    distributed_runtime
155                        .system_status_server
156                        .set(Arc::new(system_status_server_info))
157                        .expect("System status server info should only be set once");
158                }
159                Err(e) => {
160                    tracing::error!("System status server startup failed: {}", e);
161                }
162            }
163        } else {
164            // System server HTTP is disabled, but uptime metrics are still being tracked via SystemHealth
165            tracing::debug!(
166                "System status server HTTP endpoints disabled, but uptime metrics are being tracked"
167            );
168        }
169
170        Ok(distributed_runtime)
171    }
172
173    pub async fn from_settings(runtime: Runtime) -> Result<Self> {
174        let config = DistributedConfig::from_settings(false);
175        Self::new(runtime, config).await
176    }
177
178    // Call this if you are using static workers that do not need etcd-based discovery.
179    pub async fn from_settings_without_discovery(runtime: Runtime) -> Result<Self> {
180        let config = DistributedConfig::from_settings(true);
181        Self::new(runtime, config).await
182    }
183
184    pub fn runtime(&self) -> &Runtime {
185        &self.runtime
186    }
187
188    pub fn primary_token(&self) -> CancellationToken {
189        self.runtime.primary_token()
190    }
191
192    /// The etcd lease all our components will be attached to.
193    /// Not available for static workers.
194    pub fn primary_lease(&self) -> Option<etcd::Lease> {
195        self.etcd_client.as_ref().map(|c| c.primary_lease())
196    }
197
198    pub fn shutdown(&self) {
199        self.runtime.shutdown();
200    }
201
202    /// Create a [`Namespace`]
203    pub fn namespace(&self, name: impl Into<String>) -> Result<Namespace> {
204        Namespace::new(self.clone(), name.into(), self.is_static)
205    }
206
207    // /// Create a [`Component`]
208    // pub fn component(
209    //     &self,
210    //     name: impl Into<String>,
211    //     namespace: impl Into<String>,
212    // ) -> Result<Component> {
213    //     Ok(ComponentBuilder::from_runtime(self.clone())
214    //         .name(name.into())
215    //         .namespace(namespace.into())
216    //         .build()?)
217    // }
218
219    pub(crate) fn discovery_client(&self, namespace: impl Into<String>) -> DiscoveryClient {
220        DiscoveryClient::new(
221            namespace.into(),
222            self.etcd_client
223                .clone()
224                .expect("Attempt to get discovery_client on static DistributedRuntime"),
225        )
226    }
227
228    pub(crate) fn service_client(&self) -> ServiceClient {
229        ServiceClient::new(self.nats_client.clone())
230    }
231
232    pub async fn tcp_server(&self) -> Result<Arc<tcp::server::TcpStreamServer>> {
233        Ok(self
234            .tcp_server
235            .get_or_try_init(async move {
236                let options = tcp::server::ServerOptions::default();
237                let server = tcp::server::TcpStreamServer::new(options).await?;
238                OK(server)
239            })
240            .await?
241            .clone())
242    }
243
244    pub fn nats_client(&self) -> nats::Client {
245        self.nats_client.clone()
246    }
247
248    /// Get system status server information if available
249    pub fn system_status_server_info(
250        &self,
251    ) -> Option<Arc<crate::system_status_server::SystemStatusServerInfo>> {
252        self.system_status_server.get().cloned()
253    }
254
255    // todo(ryan): deprecate this as we move to Discovery traits and Component Identifiers
256    pub fn etcd_client(&self) -> Option<etcd::Client> {
257        self.etcd_client.clone()
258    }
259
260    pub fn child_token(&self) -> CancellationToken {
261        self.runtime.child_token()
262    }
263
264    pub(crate) fn graceful_shutdown_tracker(&self) -> Arc<GracefulShutdownTracker> {
265        self.runtime.graceful_shutdown_tracker()
266    }
267
268    pub fn instance_sources(&self) -> Arc<Mutex<HashMap<Endpoint, Weak<InstanceSource>>>> {
269        self.instance_sources.clone()
270    }
271
272    /// Add a Prometheus metric to a specific hierarchy's registry. Note that it is possible
273    /// to register the same metric name multiple times, as long as the labels are different.
274    pub fn add_prometheus_metric(
275        &self,
276        hierarchy: &str,
277        prometheus_metric: Box<dyn prometheus::core::Collector>,
278    ) -> anyhow::Result<()> {
279        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
280        let entry = registries.entry(hierarchy.to_string()).or_default();
281
282        // Try to register the metric
283        entry
284            .prometheus_registry
285            .register(prometheus_metric)
286            .map_err(|e| e.into())
287    }
288
289    /// Add a callback function to metrics registries for the given hierarchies
290    pub fn register_metrics_callback(&self, hierarchies: Vec<String>, callback: RuntimeCallback) {
291        let mut registries = self.hierarchy_to_metricsregistry.write().unwrap();
292        for hierarchy in hierarchies {
293            registries
294                .entry(hierarchy)
295                .or_default()
296                .add_callback(callback.clone());
297        }
298    }
299
300    /// Execute all callbacks for a given hierarchy key and return their results
301    pub fn execute_metrics_callbacks(&self, hierarchy: &str) -> Vec<anyhow::Result<()>> {
302        // Clone callbacks while holding read lock (fast operation)
303        let callbacks = {
304            let registries = self.hierarchy_to_metricsregistry.read().unwrap();
305            registries
306                .get(hierarchy)
307                .map(|entry| entry.runtime_callbacks.clone())
308        }; // Read lock released here
309
310        // Execute callbacks without holding the lock
311        match callbacks {
312            Some(callbacks) => callbacks.iter().map(|callback| callback()).collect(),
313            None => Vec::new(),
314        }
315    }
316
317    /// Get all registered hierarchy keys. Private because it is only used for testing.
318    fn get_registered_hierarchies(&self) -> Vec<String> {
319        let registries = self.hierarchy_to_metricsregistry.read().unwrap();
320        registries.keys().cloned().collect()
321    }
322}
323
324#[derive(Dissolve)]
325pub struct DistributedConfig {
326    pub etcd_config: etcd::ClientOptions,
327    pub nats_config: nats::ClientOptions,
328    pub is_static: bool,
329}
330
331impl DistributedConfig {
332    pub fn from_settings(is_static: bool) -> DistributedConfig {
333        DistributedConfig {
334            etcd_config: etcd::ClientOptions::default(),
335            nats_config: nats::ClientOptions::default(),
336            is_static,
337        }
338    }
339
340    pub fn for_cli() -> DistributedConfig {
341        let mut config = DistributedConfig {
342            etcd_config: etcd::ClientOptions::default(),
343            nats_config: nats::ClientOptions::default(),
344            is_static: false,
345        };
346
347        config.etcd_config.attach_lease = false;
348
349        config
350    }
351}
352
353pub mod distributed_test_utils {
354    //! Common test helper functions for DistributedRuntime tests
355    // TODO: Use in-memory DistributedRuntime for tests instead of full runtime when available.
356
357    /// Helper function to create a DRT instance for integration-only tests.
358    /// Uses from_current to leverage existing tokio runtime
359    /// Note: Settings are read from environment variables inside DistributedRuntime::from_settings_without_discovery
360    #[cfg(feature = "integration")]
361    pub async fn create_test_drt_async() -> crate::DistributedRuntime {
362        let rt = crate::Runtime::from_current().unwrap();
363        crate::DistributedRuntime::from_settings_without_discovery(rt)
364            .await
365            .unwrap()
366    }
367}
368
369#[cfg(all(test, feature = "integration"))]
370mod tests {
371    use super::distributed_test_utils::create_test_drt_async;
372
373    #[tokio::test]
374    async fn test_drt_uptime_after_delay_system_disabled() {
375        // Test uptime with system status server disabled
376        temp_env::async_with_vars([("DYN_SYSTEM_ENABLED", Some("false"))], async {
377            // Start a DRT
378            let drt = create_test_drt_async().await;
379
380            // Wait 50ms
381            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
382
383            // Check that uptime is 50+ ms
384            let uptime = drt.system_health.lock().unwrap().uptime();
385            assert!(
386                uptime >= std::time::Duration::from_millis(50),
387                "Expected uptime to be at least 50ms, but got {:?}",
388                uptime
389            );
390
391            println!(
392                "✓ DRT uptime test passed (system disabled): uptime = {:?}",
393                uptime
394            );
395        })
396        .await;
397    }
398
399    #[tokio::test]
400    async fn test_drt_uptime_after_delay_system_enabled() {
401        // Test uptime with system status server enabled
402        temp_env::async_with_vars([("DYN_SYSTEM_ENABLED", Some("true"))], async {
403            // Start a DRT
404            let drt = create_test_drt_async().await;
405
406            // Wait 50ms
407            tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
408
409            // Check that uptime is 50+ ms
410            let uptime = drt.system_health.lock().unwrap().uptime();
411            assert!(
412                uptime >= std::time::Duration::from_millis(50),
413                "Expected uptime to be at least 50ms, but got {:?}",
414                uptime
415            );
416
417            println!(
418                "✓ DRT uptime test passed (system enabled): uptime = {:?}",
419                uptime
420            );
421        })
422        .await;
423    }
424}