dynamo_runtime/
service.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4// TODO - refactor this entire module
5//
6// we want to carry forward the concept of live vs ready for the components
7// we will want to associate the components cancellation token with the
8// component's "service state"
9
10use crate::{
11    DistributedRuntime,
12    component::Component,
13    metrics::{MetricsHierarchy, prometheus_names, prometheus_names::nats_service},
14    traits::*,
15    transports::nats,
16    utils::stream,
17};
18
19use anyhow::Result;
20use anyhow::anyhow as error;
21use async_nats::Message;
22use async_stream::try_stream;
23use bytes::Bytes;
24use derive_getters::Dissolve;
25use futures::stream::{StreamExt, TryStreamExt};
26use prometheus;
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use std::time::Duration;
29
30pub struct ServiceClient {
31    nats_client: nats::Client,
32}
33
34impl ServiceClient {
35    pub fn new(nats_client: nats::Client) -> Self {
36        ServiceClient { nats_client }
37    }
38}
39
40/// ServiceSet contains a collection of services with their endpoints and metrics
41///
42/// Tree structure:
43/// Structure:
44/// - ServiceSet
45///   - services: Vec<ServiceInfo>
46///     - name: String
47///     - id: String
48///     - version: String
49///     - started: String
50///     - endpoints: Vec<EndpointInfo>
51///       - name: String
52///       - subject: String
53///       - data: Option<NatsStatsMetrics>
54///         - average_processing_time: f64
55///         - last_error: String
56///         - num_errors: u64
57///         - num_requests: u64
58///         - processing_time: u64
59///         - queue_group: String
60///         - data: serde_json::Value (custom stats)
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ServiceSet {
63    services: Vec<ServiceInfo>,
64}
65
66/// This is a example JSON from `nats req '$SRV.STATS.dynamo_backend'`:
67/// {
68///   "type": "io.nats.micro.v1.stats_response",
69///   "name": "dynamo_backend",
70///   "id": "bdu7nA8tbhy9mEkxIWlkBA",
71///   "version": "0.0.1",
72///   "started": "2025-08-08T05:07:17.720783523Z",
73///   "endpoints": [
74///     {
75///       "name": "dynamo_backend-generate-694d988806b92e39",
76///       "subject": "dynamo_backend.generate-694d988806b92e39",
77///       "num_requests": 0,
78///       "num_errors": 0,
79///       "processing_time": 0,
80///       "average_processing_time": 0,
81///       "last_error": "",
82///       "data": {
83///         "val": 10
84///       },
85///       "queue_group": "q"
86///     }
87///   ]
88/// }
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ServiceInfo {
91    pub name: String,
92    pub id: String,
93    pub version: String,
94    pub started: String,
95    pub endpoints: Vec<EndpointInfo>,
96}
97
98/// Each endpoint has name, subject, num_requests, num_errors, processing_time, average_processing_time, last_error, queue_group, and data
99#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
100pub struct EndpointInfo {
101    pub name: String,
102    pub subject: String,
103
104    /// Extra fields that don't fit in EndpointInfo will be flattened into the Metrics struct.
105    #[serde(flatten)]
106    pub data: Option<NatsStatsMetrics>,
107}
108
109impl EndpointInfo {
110    pub fn id(&self) -> Result<i64> {
111        let id = self
112            .subject
113            .split('-')
114            .next_back()
115            .ok_or_else(|| error!("No id found in subject"))?;
116
117        i64::from_str_radix(id, 16).map_err(|e| error!("Invalid id format: {}", e))
118    }
119}
120
121// TODO: This is _really_ close to the async_nats::service::Stats object,
122// but it's missing a few fields like "name", so use a temporary struct
123// for easy deserialization. Ideally, this type already exists or can
124// be exposed in the library somewhere.
125/// Stats structure returned from NATS service API
126/// https://github.com/nats-io/nats.rs/blob/main/async-nats/src/service/endpoint.rs
127#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
128pub struct NatsStatsMetrics {
129    // Standard NATS Stats Service API fields from $SRV.STATS.<service_name> requests
130    pub average_processing_time: u64, // in nanoseconds according to nats-io
131    pub last_error: String,
132    pub num_errors: u64,
133    pub num_requests: u64,
134    pub processing_time: u64, // in nanoseconds according to nats-io
135    pub queue_group: String,
136    // Field containing custom stats handler data
137    pub data: serde_json::Value,
138}
139
140impl NatsStatsMetrics {
141    pub fn decode<T: for<'de> Deserialize<'de>>(self) -> Result<T> {
142        serde_json::from_value(self.data).map_err(Into::into)
143    }
144}
145
146impl ServiceClient {
147    pub async fn unary(
148        &self,
149        subject: impl Into<String>,
150        payload: impl Into<Bytes>,
151    ) -> Result<Message> {
152        let response = self
153            .nats_client
154            .client()
155            .request(subject.into(), payload.into())
156            .await?;
157        Ok(response)
158    }
159
160    pub async fn collect_services(
161        &self,
162        service_name: &str,
163        timeout: Duration,
164    ) -> Result<ServiceSet> {
165        let sub = self.nats_client.scrape_service(service_name).await?;
166        if timeout.is_zero() {
167            tracing::warn!("collect_services: timeout is zero");
168        }
169        if timeout > Duration::from_secs(10) {
170            tracing::warn!("collect_services: timeout is greater than 10 seconds");
171        }
172        let deadline = tokio::time::Instant::now() + timeout;
173
174        let mut services = vec![];
175        let mut s = stream::until_deadline(sub, deadline);
176        while let Some(message) = s.next().await {
177            if message.payload.is_empty() {
178                // Expected while we wait for KV metrics in worker to start
179                tracing::trace!(service_name, "collect_services: empty payload from nats");
180                continue;
181            }
182            let info = serde_json::from_slice::<ServiceInfo>(&message.payload);
183            match info {
184                Ok(info) => services.push(info),
185                Err(err) => {
186                    let payload = String::from_utf8_lossy(&message.payload);
187                    tracing::debug!(%err, service_name, %payload, "error decoding service info");
188                }
189            }
190        }
191
192        Ok(ServiceSet { services })
193    }
194}
195
196impl ServiceSet {
197    pub fn into_endpoints(self) -> impl Iterator<Item = EndpointInfo> {
198        self.services
199            .into_iter()
200            .flat_map(|s| s.endpoints.into_iter())
201    }
202
203    /// Get a reference to the services in this ServiceSet
204    pub fn services(&self) -> &[ServiceInfo] {
205        &self.services
206    }
207}
208
209#[cfg(test)]
210mod tests {
211
212    use super::*;
213
214    #[test]
215    fn test_service_set() {
216        let services = vec![
217            ServiceInfo {
218                name: "service1".to_string(),
219                id: "1".to_string(),
220                version: "1.0".to_string(),
221                started: "2021-01-01".to_string(),
222                endpoints: vec![
223                    EndpointInfo {
224                        name: "endpoint1".to_string(),
225                        subject: "subject1".to_string(),
226                        data: Some(NatsStatsMetrics {
227                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
228                            last_error: "none".to_string(),
229                            num_errors: 0,
230                            num_requests: 10,
231                            processing_time: 100,
232                            queue_group: "group1".to_string(),
233                            data: serde_json::json!({"key": "value1"}),
234                        }),
235                    },
236                    EndpointInfo {
237                        name: "endpoint2-foo".to_string(),
238                        subject: "subject2".to_string(),
239                        data: Some(NatsStatsMetrics {
240                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
241                            last_error: "none".to_string(),
242                            num_errors: 0,
243                            num_requests: 10,
244                            processing_time: 100,
245                            queue_group: "group1".to_string(),
246                            data: serde_json::json!({"key": "value1"}),
247                        }),
248                    },
249                ],
250            },
251            ServiceInfo {
252                name: "service1".to_string(),
253                id: "2".to_string(),
254                version: "1.0".to_string(),
255                started: "2021-01-01".to_string(),
256                endpoints: vec![
257                    EndpointInfo {
258                        name: "endpoint1".to_string(),
259                        subject: "subject1".to_string(),
260                        data: Some(NatsStatsMetrics {
261                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
262                            last_error: "none".to_string(),
263                            num_errors: 0,
264                            num_requests: 10,
265                            processing_time: 100,
266                            queue_group: "group1".to_string(),
267                            data: serde_json::json!({"key": "value1"}),
268                        }),
269                    },
270                    EndpointInfo {
271                        name: "endpoint2-bar".to_string(),
272                        subject: "subject2".to_string(),
273                        data: Some(NatsStatsMetrics {
274                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
275                            last_error: "none".to_string(),
276                            num_errors: 0,
277                            num_requests: 10,
278                            processing_time: 100,
279                            queue_group: "group1".to_string(),
280                            data: serde_json::json!({"key": "value2"}),
281                        }),
282                    },
283                ],
284            },
285        ];
286
287        let service_set = ServiceSet { services };
288
289        let endpoints: Vec<_> = service_set
290            .into_endpoints()
291            .filter(|e| e.name.starts_with("endpoint2"))
292            .collect();
293
294        assert_eq!(endpoints.len(), 2);
295    }
296}
297
298/// Prometheus metrics for component service statistics (ordered to match NatsStatsMetrics)
299///
300/// ⚠️  IMPORTANT: These Prometheus Gauges are COPIES of NATS data, not live references!
301///
302/// How it works:
303/// 1. NATS provides source data via NatsStatsMetrics
304/// 2. Metrics callbacks read current NATS values and update these Prometheus Gauges
305/// 3. Prometheus scrapes these Gauge values (snapshots, not live data)
306///
307/// Flow: NATS Service → NatsStatsMetrics (Counters) → Metrics Callback → Prometheus Gauge
308/// Note: These are snapshots updated when execute_prometheus_update_callbacks() is called.
309#[derive(Debug, Clone)]
310/// Prometheus metrics for NATS server components.
311/// Note: Metrics with `_total` names use IntGauge because we copy counter values
312/// from underlying services rather than incrementing directly.
313pub struct ComponentNatsServerPrometheusMetrics {
314    /// Average processing time in milliseconds (maps to: average_processing_time)
315    pub service_processing_ms_avg: prometheus::Gauge,
316    /// Total errors across all endpoints (maps to: num_errors)
317    pub service_errors_total: prometheus::IntGauge,
318    /// Total requests across all endpoints (maps to: num_requests)
319    pub service_requests_total: prometheus::IntGauge,
320    /// Total processing time in milliseconds (maps to: processing_time)
321    pub service_processing_ms_total: prometheus::IntGauge,
322    /// Number of active services (derived from ServiceSet.services)
323    pub service_active_services: prometheus::IntGauge,
324    /// Number of active endpoints (derived from ServiceInfo.endpoints)
325    pub service_active_endpoints: prometheus::IntGauge,
326}
327
328impl ComponentNatsServerPrometheusMetrics {
329    /// Create new ComponentServiceMetrics using Component's DistributedRuntime's Prometheus constructors
330    pub fn new(component: &Component) -> Result<Self> {
331        let service_name = component.service_name();
332
333        // Build labels: service_name first, then component's labels
334        let mut labels_vec = vec![("service_name", service_name.as_str())];
335
336        // Add component's labels (convert from (String, String) to (&str, &str))
337        for (key, value) in component.labels() {
338            labels_vec.push((key.as_str(), value.as_str()));
339        }
340
341        let labels: &[(&str, &str)] = &labels_vec;
342
343        let service_processing_ms_avg = component.metrics().create_gauge(
344            nats_service::PROCESSING_MS_AVG,
345            "Average processing time across all component endpoints in milliseconds",
346            labels,
347        )?;
348
349        let service_errors_total = component.metrics().create_intgauge(
350            nats_service::ERRORS_TOTAL,
351            "Total number of errors across all component endpoints",
352            labels,
353        )?;
354
355        let service_requests_total = component.metrics().create_intgauge(
356            nats_service::REQUESTS_TOTAL,
357            "Total number of requests across all component endpoints",
358            labels,
359        )?;
360
361        let service_processing_ms_total = component.metrics().create_intgauge(
362            nats_service::PROCESSING_MS_TOTAL,
363            "Total processing time across all component endpoints in milliseconds",
364            labels,
365        )?;
366
367        let service_active_services = component.metrics().create_intgauge(
368            nats_service::ACTIVE_SERVICES,
369            "Number of active services in this component",
370            labels,
371        )?;
372
373        let service_active_endpoints = component.metrics().create_intgauge(
374            nats_service::ACTIVE_ENDPOINTS,
375            "Number of active endpoints across all services",
376            labels,
377        )?;
378
379        Ok(Self {
380            service_processing_ms_avg,
381            service_errors_total,
382            service_requests_total,
383            service_processing_ms_total,
384            service_active_services,
385            service_active_endpoints,
386        })
387    }
388
389    /// Update metrics from scraped ServiceSet data
390    pub fn update_from_service_set(&self, service_set: &ServiceSet) {
391        // Variables ordered to match NatsStatsMetrics fields
392        let mut processing_time_samples = 0u64; // for average_processing_time calculation
393        let mut total_errors = 0u64; // maps to: num_errors
394        let mut total_requests = 0u64; // maps to: num_requests
395        let mut total_processing_time_nanos = 0u64; // maps to: processing_time (nanoseconds from NATS)
396        let mut endpoint_count = 0u64; // for derived metrics
397
398        let service_count = service_set.services().len() as i64;
399
400        for service in service_set.services() {
401            for endpoint in &service.endpoints {
402                endpoint_count += 1;
403
404                if let Some(ref stats) = endpoint.data {
405                    total_errors += stats.num_errors;
406                    total_requests += stats.num_requests;
407                    total_processing_time_nanos += stats.processing_time;
408
409                    if stats.num_requests > 0 {
410                        processing_time_samples += 1;
411                    }
412                }
413            }
414        }
415
416        // Update metrics (ordered to match NatsStatsMetrics fields)
417        // Calculate average processing time in milliseconds (maps to: average_processing_time)
418        if processing_time_samples > 0 && total_requests > 0 {
419            let avg_time_nanos = total_processing_time_nanos as f64 / total_requests as f64;
420            let avg_time_ms = avg_time_nanos / 1_000_000.0; // Convert nanoseconds to milliseconds
421            self.service_processing_ms_avg.set(avg_time_ms);
422        } else {
423            self.service_processing_ms_avg.set(0.0);
424        }
425
426        self.service_errors_total.set(total_errors as i64); // maps to: num_errors
427        self.service_requests_total.set(total_requests as i64); // maps to: num_requests
428        self.service_processing_ms_total
429            .set((total_processing_time_nanos / 1_000_000) as i64); // maps to: processing_time (converted to milliseconds)
430        self.service_active_services.set(service_count); // derived from ServiceSet.services
431        self.service_active_endpoints.set(endpoint_count as i64); // derived from ServiceInfo.endpoints
432    }
433
434    /// Reset all metrics to zero. Useful when no data is available or to clear stale values.
435    pub fn reset_to_zeros(&self) {
436        self.service_processing_ms_avg.set(0.0);
437        self.service_errors_total.set(0);
438        self.service_requests_total.set(0);
439        self.service_processing_ms_total.set(0);
440        self.service_active_services.set(0);
441        self.service_active_endpoints.set(0);
442    }
443}