dynamo_runtime/
service.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16// TODO - refactor this entire module
17//
18// we want to carry forward the concept of live vs ready for the components
19// we will want to associate the components cancellation token with the
20// component's "service state"
21
22use crate::{error, transports::nats, utils::stream, Result};
23
24use async_nats::Message;
25use async_stream::try_stream;
26use bytes::Bytes;
27use derive_getters::Dissolve;
28use futures::stream::{StreamExt, TryStreamExt};
29use serde::{de::DeserializeOwned, Deserialize, Serialize};
30use std::time::Duration;
31
32pub struct ServiceClient {
33    nats_client: nats::Client,
34}
35
36impl ServiceClient {
37    pub fn new(nats_client: nats::Client) -> Self {
38        ServiceClient { nats_client }
39    }
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
43pub struct ServiceSet {
44    services: Vec<ServiceInfo>,
45}
46
47#[derive(Debug, Clone, Serialize, Deserialize)]
48pub struct ServiceInfo {
49    pub name: String,
50    pub id: String,
51    pub version: String,
52    pub started: String,
53    pub endpoints: Vec<EndpointInfo>,
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
57pub struct EndpointInfo {
58    pub name: String,
59    pub subject: String,
60
61    #[serde(flatten)]
62    pub data: Option<Metrics>,
63}
64
65impl EndpointInfo {
66    pub fn id(&self) -> Result<i64> {
67        let id = self
68            .subject
69            .split('-')
70            .next_back()
71            .ok_or_else(|| error!("No id found in subject"))?;
72
73        i64::from_str_radix(id, 16).map_err(|e| error!("Invalid id format: {}", e))
74    }
75}
76
77// TODO: This is _really_ close to the async_nats::service::Stats object,
78// but it's missing a few fields like "name", so use a temporary struct
79// for easy deserialization. Ideally, this type already exists or can
80// be exposed in the library somewhere.
81/// Stats structure returned from NATS service API
82#[derive(Debug, Clone, Serialize, Deserialize, Dissolve)]
83pub struct Metrics {
84    // Standard NATS Service API fields
85    pub average_processing_time: f64,
86    pub last_error: String,
87    pub num_errors: u64,
88    pub num_requests: u64,
89    pub processing_time: u64,
90    pub queue_group: String,
91    // Field containing custom stats handler data
92    pub data: serde_json::Value,
93}
94
95impl Metrics {
96    pub fn decode<T: for<'de> Deserialize<'de>>(self) -> Result<T> {
97        serde_json::from_value(self.data).map_err(Into::into)
98    }
99}
100
101impl ServiceClient {
102    pub async fn unary(
103        &self,
104        subject: impl Into<String>,
105        payload: impl Into<Bytes>,
106    ) -> Result<Message> {
107        let response = self
108            .nats_client
109            .client()
110            .request(subject.into(), payload.into())
111            .await?;
112        Ok(response)
113    }
114
115    pub async fn collect_services(
116        &self,
117        service_name: &str,
118        timeout: Duration,
119    ) -> Result<ServiceSet> {
120        let sub = self.nats_client.scrape_service(service_name).await?;
121        if timeout.is_zero() {
122            tracing::warn!("collect_services: timeout is zero");
123        }
124        if timeout > Duration::from_secs(10) {
125            tracing::warn!("collect_services: timeout is greater than 10 seconds");
126        }
127        let deadline = tokio::time::Instant::now() + timeout;
128
129        let services: Vec<ServiceInfo> = stream::until_deadline(sub, deadline)
130            .map(|message| serde_json::from_slice::<ServiceInfo>(&message.payload))
131            .filter_map(|info| async move {
132                match info {
133                    Ok(info) => Some(info),
134                    Err(e) => {
135                        log::debug!("error decoding service info: {:?}", e);
136                        None
137                    }
138                }
139            })
140            .collect()
141            .await;
142
143        Ok(ServiceSet { services })
144    }
145}
146
147impl ServiceSet {
148    pub fn into_endpoints(self) -> impl Iterator<Item = EndpointInfo> {
149        self.services
150            .into_iter()
151            .flat_map(|s| s.endpoints.into_iter())
152    }
153}
154
155#[cfg(test)]
156mod tests {
157
158    use super::*;
159
160    #[test]
161    fn test_service_set() {
162        let services = vec![
163            ServiceInfo {
164                name: "service1".to_string(),
165                id: "1".to_string(),
166                version: "1.0".to_string(),
167                started: "2021-01-01".to_string(),
168                endpoints: vec![
169                    EndpointInfo {
170                        name: "endpoint1".to_string(),
171                        subject: "subject1".to_string(),
172                        data: Some(Metrics {
173                            average_processing_time: 0.1,
174                            last_error: "none".to_string(),
175                            num_errors: 0,
176                            num_requests: 10,
177                            processing_time: 100,
178                            queue_group: "group1".to_string(),
179                            data: serde_json::json!({"key": "value1"}),
180                        }),
181                    },
182                    EndpointInfo {
183                        name: "endpoint2-foo".to_string(),
184                        subject: "subject2".to_string(),
185                        data: Some(Metrics {
186                            average_processing_time: 0.1,
187                            last_error: "none".to_string(),
188                            num_errors: 0,
189                            num_requests: 10,
190                            processing_time: 100,
191                            queue_group: "group1".to_string(),
192                            data: serde_json::json!({"key": "value1"}),
193                        }),
194                    },
195                ],
196            },
197            ServiceInfo {
198                name: "service1".to_string(),
199                id: "2".to_string(),
200                version: "1.0".to_string(),
201                started: "2021-01-01".to_string(),
202                endpoints: vec![
203                    EndpointInfo {
204                        name: "endpoint1".to_string(),
205                        subject: "subject1".to_string(),
206                        data: Some(Metrics {
207                            average_processing_time: 0.1,
208                            last_error: "none".to_string(),
209                            num_errors: 0,
210                            num_requests: 10,
211                            processing_time: 100,
212                            queue_group: "group1".to_string(),
213                            data: serde_json::json!({"key": "value1"}),
214                        }),
215                    },
216                    EndpointInfo {
217                        name: "endpoint2-bar".to_string(),
218                        subject: "subject2".to_string(),
219                        data: Some(Metrics {
220                            average_processing_time: 0.1,
221                            last_error: "none".to_string(),
222                            num_errors: 0,
223                            num_requests: 10,
224                            processing_time: 100,
225                            queue_group: "group1".to_string(),
226                            data: serde_json::json!({"key": "value2"}),
227                        }),
228                    },
229                ],
230            },
231        ];
232
233        let service_set = ServiceSet { services };
234
235        let endpoints: Vec<_> = service_set
236            .into_endpoints()
237            .filter(|e| e.name.starts_with("endpoint2"))
238            .collect();
239
240        assert_eq!(endpoints.len(), 2);
241    }
242}