dynamo_runtime/
service.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4// TODO - refactor this entire module
5//
6// we want to carry forward the concept of live vs ready for the components
7// we will want to associate the components cancellation token with the
8// component's "service state"
9
10use crate::{
11    DistributedRuntime, Result,
12    component::Component,
13    error,
14    metrics::{MetricsRegistry, prometheus_names, prometheus_names::nats_service},
15    traits::*,
16    transports::nats,
17    utils::stream,
18};
19
20use async_nats::Message;
21use async_stream::try_stream;
22use bytes::Bytes;
23use derive_getters::Dissolve;
24use futures::stream::{StreamExt, TryStreamExt};
25use prometheus;
26use serde::{Deserialize, Serialize, de::DeserializeOwned};
27use std::time::Duration;
28
29pub struct ServiceClient {
30    nats_client: nats::Client,
31}
32
33impl ServiceClient {
34    pub fn new(nats_client: nats::Client) -> Self {
35        ServiceClient { nats_client }
36    }
37}
38
39/// ServiceSet contains a collection of services with their endpoints and metrics
40///
41/// Tree structure:
42/// Structure:
43/// - ServiceSet
44///   - services: Vec<ServiceInfo>
45///     - name: String
46///     - id: String
47///     - version: String
48///     - started: String
49///     - endpoints: Vec<EndpointInfo>
50///       - name: String
51///       - subject: String
52///       - data: Option<NatsStatsMetrics>
53///         - average_processing_time: f64
54///         - last_error: String
55///         - num_errors: u64
56///         - num_requests: u64
57///         - processing_time: u64
58///         - queue_group: String
59///         - data: serde_json::Value (custom stats)
60#[derive(Debug, Clone, Serialize, Deserialize)]
61pub struct ServiceSet {
62    services: Vec<ServiceInfo>,
63}
64
65/// This is a example JSON from `nats req '$SRV.STATS.dynamo_backend'`:
66/// {
67///   "type": "io.nats.micro.v1.stats_response",
68///   "name": "dynamo_backend",
69///   "id": "bdu7nA8tbhy9mEkxIWlkBA",
70///   "version": "0.0.1",
71///   "started": "2025-08-08T05:07:17.720783523Z",
72///   "endpoints": [
73///     {
74///       "name": "dynamo_backend-generate-694d988806b92e39",
75///       "subject": "dynamo_backend.generate-694d988806b92e39",
76///       "num_requests": 0,
77///       "num_errors": 0,
78///       "processing_time": 0,
79///       "average_processing_time": 0,
80///       "last_error": "",
81///       "data": {
82///         "val": 10
83///       },
84///       "queue_group": "q"
85///     }
86///   ]
87/// }
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct ServiceInfo {
90    pub name: String,
91    pub id: String,
92    pub version: String,
93    pub started: String,
94    pub endpoints: Vec<EndpointInfo>,
95}
96
97/// Each endpoint has name, subject, num_requests, num_errors, processing_time, average_processing_time, last_error, queue_group, and data
98#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
99pub struct EndpointInfo {
100    pub name: String,
101    pub subject: String,
102
103    /// Extra fields that don't fit in EndpointInfo will be flattened into the Metrics struct.
104    #[serde(flatten)]
105    pub data: Option<NatsStatsMetrics>,
106}
107
108impl EndpointInfo {
109    pub fn id(&self) -> Result<i64> {
110        let id = self
111            .subject
112            .split('-')
113            .next_back()
114            .ok_or_else(|| error!("No id found in subject"))?;
115
116        i64::from_str_radix(id, 16).map_err(|e| error!("Invalid id format: {}", e))
117    }
118}
119
120// TODO: This is _really_ close to the async_nats::service::Stats object,
121// but it's missing a few fields like "name", so use a temporary struct
122// for easy deserialization. Ideally, this type already exists or can
123// be exposed in the library somewhere.
124/// Stats structure returned from NATS service API
125/// https://github.com/nats-io/nats.rs/blob/main/async-nats/src/service/endpoint.rs
126#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
127pub struct NatsStatsMetrics {
128    // Standard NATS Stats Service API fields from $SRV.STATS.<service_name> requests
129    pub average_processing_time: u64, // in nanoseconds according to nats-io
130    pub last_error: String,
131    pub num_errors: u64,
132    pub num_requests: u64,
133    pub processing_time: u64, // in nanoseconds according to nats-io
134    pub queue_group: String,
135    // Field containing custom stats handler data
136    pub data: serde_json::Value,
137}
138
139impl NatsStatsMetrics {
140    pub fn decode<T: for<'de> Deserialize<'de>>(self) -> Result<T> {
141        serde_json::from_value(self.data).map_err(Into::into)
142    }
143}
144
145impl ServiceClient {
146    pub async fn unary(
147        &self,
148        subject: impl Into<String>,
149        payload: impl Into<Bytes>,
150    ) -> Result<Message> {
151        let response = self
152            .nats_client
153            .client()
154            .request(subject.into(), payload.into())
155            .await?;
156        Ok(response)
157    }
158
159    pub async fn collect_services(
160        &self,
161        service_name: &str,
162        timeout: Duration,
163    ) -> Result<ServiceSet> {
164        let sub = self.nats_client.scrape_service(service_name).await?;
165        if timeout.is_zero() {
166            tracing::warn!("collect_services: timeout is zero");
167        }
168        if timeout > Duration::from_secs(10) {
169            tracing::warn!("collect_services: timeout is greater than 10 seconds");
170        }
171        let deadline = tokio::time::Instant::now() + timeout;
172
173        let mut services = vec![];
174        let mut s = stream::until_deadline(sub, deadline);
175        while let Some(message) = s.next().await {
176            if message.payload.is_empty() {
177                // Expected while we wait for KV metrics in worker to start
178                tracing::trace!(service_name, "collect_services: empty payload from nats");
179                continue;
180            }
181            let info = serde_json::from_slice::<ServiceInfo>(&message.payload);
182            match info {
183                Ok(info) => services.push(info),
184                Err(err) => {
185                    let payload = String::from_utf8_lossy(&message.payload);
186                    tracing::debug!(%err, service_name, %payload, "error decoding service info");
187                }
188            }
189        }
190
191        Ok(ServiceSet { services })
192    }
193}
194
195impl ServiceSet {
196    pub fn into_endpoints(self) -> impl Iterator<Item = EndpointInfo> {
197        self.services
198            .into_iter()
199            .flat_map(|s| s.endpoints.into_iter())
200    }
201
202    /// Get a reference to the services in this ServiceSet
203    pub fn services(&self) -> &[ServiceInfo] {
204        &self.services
205    }
206}
207
208#[cfg(test)]
209mod tests {
210
211    use super::*;
212
213    #[test]
214    fn test_service_set() {
215        let services = vec![
216            ServiceInfo {
217                name: "service1".to_string(),
218                id: "1".to_string(),
219                version: "1.0".to_string(),
220                started: "2021-01-01".to_string(),
221                endpoints: vec![
222                    EndpointInfo {
223                        name: "endpoint1".to_string(),
224                        subject: "subject1".to_string(),
225                        data: Some(NatsStatsMetrics {
226                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
227                            last_error: "none".to_string(),
228                            num_errors: 0,
229                            num_requests: 10,
230                            processing_time: 100,
231                            queue_group: "group1".to_string(),
232                            data: serde_json::json!({"key": "value1"}),
233                        }),
234                    },
235                    EndpointInfo {
236                        name: "endpoint2-foo".to_string(),
237                        subject: "subject2".to_string(),
238                        data: Some(NatsStatsMetrics {
239                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
240                            last_error: "none".to_string(),
241                            num_errors: 0,
242                            num_requests: 10,
243                            processing_time: 100,
244                            queue_group: "group1".to_string(),
245                            data: serde_json::json!({"key": "value1"}),
246                        }),
247                    },
248                ],
249            },
250            ServiceInfo {
251                name: "service1".to_string(),
252                id: "2".to_string(),
253                version: "1.0".to_string(),
254                started: "2021-01-01".to_string(),
255                endpoints: vec![
256                    EndpointInfo {
257                        name: "endpoint1".to_string(),
258                        subject: "subject1".to_string(),
259                        data: Some(NatsStatsMetrics {
260                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
261                            last_error: "none".to_string(),
262                            num_errors: 0,
263                            num_requests: 10,
264                            processing_time: 100,
265                            queue_group: "group1".to_string(),
266                            data: serde_json::json!({"key": "value1"}),
267                        }),
268                    },
269                    EndpointInfo {
270                        name: "endpoint2-bar".to_string(),
271                        subject: "subject2".to_string(),
272                        data: Some(NatsStatsMetrics {
273                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
274                            last_error: "none".to_string(),
275                            num_errors: 0,
276                            num_requests: 10,
277                            processing_time: 100,
278                            queue_group: "group1".to_string(),
279                            data: serde_json::json!({"key": "value2"}),
280                        }),
281                    },
282                ],
283            },
284        ];
285
286        let service_set = ServiceSet { services };
287
288        let endpoints: Vec<_> = service_set
289            .into_endpoints()
290            .filter(|e| e.name.starts_with("endpoint2"))
291            .collect();
292
293        assert_eq!(endpoints.len(), 2);
294    }
295}
296
297/// Prometheus metrics for component service statistics (ordered to match NatsStatsMetrics)
298///
299/// ⚠️  IMPORTANT: These Prometheus Gauges are COPIES of NATS data, not live references!
300///
301/// How it works:
302/// 1. NATS provides source data via NatsStatsMetrics
303/// 2. Metrics callbacks read current NATS values and update these Prometheus Gauges
304/// 3. Prometheus scrapes these Gauge values (snapshots, not live data)
305///
306/// Flow: NATS Service → NatsStatsMetrics (Counters) → Metrics Callback → Prometheus Gauge
307/// Note: These are snapshots updated when execute_metrics_callbacks() is called.
308#[derive(Debug, Clone)]
309pub struct ComponentNatsServerPrometheusMetrics {
310    /// Average processing time in milliseconds (maps to: average_processing_time)
311    pub service_avg_processing_ms: prometheus::Gauge,
312    /// Total errors across all endpoints (maps to: num_errors)
313    pub service_total_errors: prometheus::IntGauge,
314    /// Total requests across all endpoints (maps to: num_requests)
315    pub service_total_requests: prometheus::IntGauge,
316    /// Total processing time in milliseconds (maps to: processing_time)
317    pub service_total_processing_ms: prometheus::IntGauge,
318    /// Number of active services (derived from ServiceSet.services)
319    pub service_active_services: prometheus::IntGauge,
320    /// Number of active endpoints (derived from ServiceInfo.endpoints)
321    pub service_active_endpoints: prometheus::IntGauge,
322}
323
324impl ComponentNatsServerPrometheusMetrics {
325    /// Create new ComponentServiceMetrics using Component's DistributedRuntime's Prometheus constructors
326    pub fn new(component: &Component) -> Result<Self> {
327        let service_name = component.service_name();
328
329        // Build labels: service_name first, then component's labels
330        let mut labels_vec = vec![("service_name", service_name.as_str())];
331
332        // Add component's labels (convert from (String, String) to (&str, &str))
333        for (key, value) in component.labels() {
334            labels_vec.push((key.as_str(), value.as_str()));
335        }
336
337        let labels: &[(&str, &str)] = &labels_vec;
338
339        let service_avg_processing_ms = component.create_gauge(
340            nats_service::AVG_PROCESSING_MS,
341            "Average processing time across all component endpoints in milliseconds",
342            labels,
343        )?;
344
345        let service_total_errors = component.create_intgauge(
346            nats_service::TOTAL_ERRORS,
347            "Total number of errors across all component endpoints",
348            labels,
349        )?;
350
351        let service_total_requests = component.create_intgauge(
352            nats_service::TOTAL_REQUESTS,
353            "Total number of requests across all component endpoints",
354            labels,
355        )?;
356
357        let service_total_processing_ms = component.create_intgauge(
358            nats_service::TOTAL_PROCESSING_MS,
359            "Total processing time across all component endpoints in milliseconds",
360            labels,
361        )?;
362
363        let service_active_services = component.create_intgauge(
364            nats_service::ACTIVE_SERVICES,
365            "Number of active services in this component",
366            labels,
367        )?;
368
369        let service_active_endpoints = component.create_intgauge(
370            nats_service::ACTIVE_ENDPOINTS,
371            "Number of active endpoints across all services",
372            labels,
373        )?;
374
375        Ok(Self {
376            service_avg_processing_ms,
377            service_total_errors,
378            service_total_requests,
379            service_total_processing_ms,
380            service_active_services,
381            service_active_endpoints,
382        })
383    }
384
385    /// Update metrics from scraped ServiceSet data
386    pub fn update_from_service_set(&self, service_set: &ServiceSet) {
387        // Variables ordered to match NatsStatsMetrics fields
388        let mut processing_time_samples = 0u64; // for average_processing_time calculation
389        let mut total_errors = 0u64; // maps to: num_errors
390        let mut total_requests = 0u64; // maps to: num_requests
391        let mut total_processing_time_nanos = 0u64; // maps to: processing_time (nanoseconds from NATS)
392        let mut endpoint_count = 0u64; // for derived metrics
393
394        let service_count = service_set.services().len() as i64;
395
396        for service in service_set.services() {
397            for endpoint in &service.endpoints {
398                endpoint_count += 1;
399
400                if let Some(ref stats) = endpoint.data {
401                    total_errors += stats.num_errors;
402                    total_requests += stats.num_requests;
403                    total_processing_time_nanos += stats.processing_time;
404
405                    if stats.num_requests > 0 {
406                        processing_time_samples += 1;
407                    }
408                }
409            }
410        }
411
412        // Update metrics (ordered to match NatsStatsMetrics fields)
413        // Calculate average processing time in milliseconds (maps to: average_processing_time)
414        if processing_time_samples > 0 && total_requests > 0 {
415            let avg_time_nanos = total_processing_time_nanos as f64 / total_requests as f64;
416            let avg_time_ms = avg_time_nanos / 1_000_000.0; // Convert nanoseconds to milliseconds
417            self.service_avg_processing_ms.set(avg_time_ms);
418        } else {
419            self.service_avg_processing_ms.set(0.0);
420        }
421
422        self.service_total_errors.set(total_errors as i64); // maps to: num_errors
423        self.service_total_requests.set(total_requests as i64); // maps to: num_requests
424        self.service_total_processing_ms
425            .set((total_processing_time_nanos / 1_000_000) as i64); // maps to: processing_time (converted to milliseconds)
426        self.service_active_services.set(service_count); // derived from ServiceSet.services
427        self.service_active_endpoints.set(endpoint_count as i64); // derived from ServiceInfo.endpoints
428    }
429
430    /// Reset all metrics to zero. Useful when no data is available or to clear stale values.
431    pub fn reset_to_zeros(&self) {
432        self.service_avg_processing_ms.set(0.0);
433        self.service_total_errors.set(0);
434        self.service_total_requests.set(0);
435        self.service_total_processing_ms.set(0);
436        self.service_active_services.set(0);
437        self.service_active_endpoints.set(0);
438    }
439}