Skip to main content

dynamo_runtime/
service.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4// TODO - refactor this entire module
5//
6// we want to carry forward the concept of live vs ready for the components
7// we will want to associate the components cancellation token with the
8// component's "service state"
9
10use crate::{
11    DistributedRuntime,
12    component::Component,
13    metrics::{MetricsHierarchy, prometheus_names},
14    traits::*,
15    transports::nats,
16    utils::stream,
17};
18
19use anyhow::Result;
20use anyhow::anyhow as error;
21use async_nats::Message;
22use async_stream::try_stream;
23use bytes::Bytes;
24use derive_getters::Dissolve;
25use futures::stream::{StreamExt, TryStreamExt};
26use prometheus;
27use serde::{Deserialize, Serialize, de::DeserializeOwned};
28use std::time::Duration;
29
30pub struct ServiceClient {
31    nats_client: nats::Client,
32}
33
34impl ServiceClient {
35    pub fn new(nats_client: nats::Client) -> Self {
36        ServiceClient { nats_client }
37    }
38}
39
40/// ServiceSet contains a collection of services with their endpoints and metrics
41///
42/// Tree structure:
43/// Structure:
44/// - ServiceSet
45///   - services: Vec<ServiceInfo>
46///     - name: String
47///     - id: String
48///     - version: String
49///     - started: String
50///     - endpoints: Vec<EndpointInfo>
51///       - name: String
52///       - subject: String
53///       - data: Option<NatsStatsMetrics>
54///         - average_processing_time: f64
55///         - last_error: String
56///         - num_errors: u64
57///         - num_requests: u64
58///         - processing_time: u64
59///         - queue_group: String
60///         - data: serde_json::Value (custom stats)
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct ServiceSet {
63    services: Vec<ServiceInfo>,
64}
65
66/// This is a example JSON from `nats req '$SRV.STATS.dynamo_backend'`:
67/// {
68///   "type": "io.nats.micro.v1.stats_response",
69///   "name": "dynamo_backend",
70///   "id": "bdu7nA8tbhy9mEkxIWlkBA",
71///   "version": "0.0.1",
72///   "started": "2025-08-08T05:07:17.720783523Z",
73///   "endpoints": [
74///     {
75///       "name": "dynamo_backend-generate-694d988806b92e39",
76///       "subject": "dynamo_backend.generate-694d988806b92e39",
77///       "num_requests": 0,
78///       "num_errors": 0,
79///       "processing_time": 0,
80///       "average_processing_time": 0,
81///       "last_error": "",
82///       "data": {
83///         "val": 10
84///       },
85///       "queue_group": "q"
86///     }
87///   ]
88/// }
89#[derive(Debug, Clone, Serialize, Deserialize)]
90pub struct ServiceInfo {
91    pub name: String,
92    pub id: String,
93    pub version: String,
94    pub started: String,
95    pub endpoints: Vec<EndpointInfo>,
96}
97
98/// Each endpoint has name, subject, num_requests, num_errors, processing_time, average_processing_time, last_error, queue_group, and data
99#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
100pub struct EndpointInfo {
101    pub name: String,
102    pub subject: String,
103
104    /// Extra fields that don't fit in EndpointInfo will be flattened into the Metrics struct.
105    #[serde(flatten)]
106    pub data: Option<NatsStatsMetrics>,
107}
108
109impl EndpointInfo {
110    pub fn id(&self) -> Result<i64> {
111        let id = self
112            .subject
113            .split('-')
114            .next_back()
115            .ok_or_else(|| error!("No id found in subject"))?;
116
117        i64::from_str_radix(id, 16).map_err(|e| error!("Invalid id format: {}", e))
118    }
119}
120
121// TODO: This is _really_ close to the async_nats::service::Stats object,
122// but it's missing a few fields like "name", so use a temporary struct
123// for easy deserialization. Ideally, this type already exists or can
124// be exposed in the library somewhere.
125/// Stats structure returned from NATS service API
126/// https://github.com/nats-io/nats.rs/blob/main/async-nats/src/service/endpoint.rs
127#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
128pub struct NatsStatsMetrics {
129    // Standard NATS Stats Service API fields from $SRV.STATS.<service_name> requests
130    pub average_processing_time: u64, // in nanoseconds according to nats-io
131    pub last_error: String,
132    pub num_errors: u64,
133    pub num_requests: u64,
134    pub processing_time: u64, // in nanoseconds according to nats-io
135    pub queue_group: String,
136    // Field containing custom stats handler data
137    pub data: serde_json::Value,
138}
139
140impl NatsStatsMetrics {
141    pub fn decode<T: for<'de> Deserialize<'de>>(self) -> Result<T> {
142        serde_json::from_value(self.data).map_err(Into::into)
143    }
144}
145
146impl ServiceClient {
147    pub async fn unary(
148        &self,
149        subject: impl Into<String>,
150        payload: impl Into<Bytes>,
151    ) -> Result<Message> {
152        let response = self
153            .nats_client
154            .client()
155            .request(subject.into(), payload.into())
156            .await?;
157        Ok(response)
158    }
159
160    pub async fn collect_services(
161        &self,
162        service_name: &str,
163        timeout: Duration,
164    ) -> Result<ServiceSet> {
165        let sub = self.nats_client.scrape_service(service_name).await?;
166        if timeout.is_zero() {
167            tracing::warn!("collect_services: timeout is zero");
168        }
169        if timeout > Duration::from_secs(10) {
170            tracing::warn!("collect_services: timeout is greater than 10 seconds");
171        }
172        let deadline = tokio::time::Instant::now() + timeout;
173
174        let mut services = vec![];
175        let mut s = stream::until_deadline(sub, deadline);
176        while let Some(message) = s.next().await {
177            if message.payload.is_empty() {
178                // Expected while we wait for KV metrics in worker to start
179                tracing::trace!(service_name, "collect_services: empty payload from nats");
180                continue;
181            }
182            let info = serde_json::from_slice::<ServiceInfo>(&message.payload);
183            match info {
184                Ok(info) => services.push(info),
185                Err(err) => {
186                    let payload = String::from_utf8_lossy(&message.payload);
187                    tracing::debug!(%err, service_name, %payload, "error decoding service info");
188                }
189            }
190        }
191
192        Ok(ServiceSet { services })
193    }
194}
195
196impl ServiceSet {
197    pub fn into_endpoints(self) -> impl Iterator<Item = EndpointInfo> {
198        self.services
199            .into_iter()
200            .flat_map(|s| s.endpoints.into_iter())
201    }
202
203    /// Get a reference to the services in this ServiceSet
204    pub fn services(&self) -> &[ServiceInfo] {
205        &self.services
206    }
207}
208
209#[cfg(test)]
210mod tests {
211
212    use super::*;
213
214    #[test]
215    fn test_service_set() {
216        let services = vec![
217            ServiceInfo {
218                name: "service1".to_string(),
219                id: "1".to_string(),
220                version: "1.0".to_string(),
221                started: "2021-01-01".to_string(),
222                endpoints: vec![
223                    EndpointInfo {
224                        name: "endpoint1".to_string(),
225                        subject: "subject1".to_string(),
226                        data: Some(NatsStatsMetrics {
227                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
228                            last_error: "none".to_string(),
229                            num_errors: 0,
230                            num_requests: 10,
231                            processing_time: 100,
232                            queue_group: "group1".to_string(),
233                            data: serde_json::json!({"key": "value1"}),
234                        }),
235                    },
236                    EndpointInfo {
237                        name: "endpoint2-foo".to_string(),
238                        subject: "subject2".to_string(),
239                        data: Some(NatsStatsMetrics {
240                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
241                            last_error: "none".to_string(),
242                            num_errors: 0,
243                            num_requests: 10,
244                            processing_time: 100,
245                            queue_group: "group1".to_string(),
246                            data: serde_json::json!({"key": "value1"}),
247                        }),
248                    },
249                ],
250            },
251            ServiceInfo {
252                name: "service1".to_string(),
253                id: "2".to_string(),
254                version: "1.0".to_string(),
255                started: "2021-01-01".to_string(),
256                endpoints: vec![
257                    EndpointInfo {
258                        name: "endpoint1".to_string(),
259                        subject: "subject1".to_string(),
260                        data: Some(NatsStatsMetrics {
261                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
262                            last_error: "none".to_string(),
263                            num_errors: 0,
264                            num_requests: 10,
265                            processing_time: 100,
266                            queue_group: "group1".to_string(),
267                            data: serde_json::json!({"key": "value1"}),
268                        }),
269                    },
270                    EndpointInfo {
271                        name: "endpoint2-bar".to_string(),
272                        subject: "subject2".to_string(),
273                        data: Some(NatsStatsMetrics {
274                            average_processing_time: 100_000, // 0.1ms = 100,000 nanoseconds
275                            last_error: "none".to_string(),
276                            num_errors: 0,
277                            num_requests: 10,
278                            processing_time: 100,
279                            queue_group: "group1".to_string(),
280                            data: serde_json::json!({"key": "value2"}),
281                        }),
282                    },
283                ],
284            },
285        ];
286
287        let service_set = ServiceSet { services };
288
289        let endpoints: Vec<_> = service_set
290            .into_endpoints()
291            .filter(|e| e.name.starts_with("endpoint2"))
292            .collect();
293
294        assert_eq!(endpoints.len(), 2);
295    }
296}