dynamo_runtime/component/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::pipeline::{
17    network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter},
18    AsyncEngine, Data, ManyOut, SingleIn,
19};
20use rand::Rng;
21use std::collections::HashMap;
22use std::sync::{
23    atomic::{AtomicU64, Ordering},
24    Arc,
25};
26use tokio::{net::unix::pipe::Receiver, sync::Mutex};
27
28use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error};
29
30use super::*;
31
32/// Each state will be have a nonce associated with it
33/// The state will be emitted in a watch channel, so we can observe the
34/// critical state transitions.
35enum MapState {
36    /// The map is empty; value = nonce
37    Empty(u64),
38
39    /// The map is not-empty; values are (nonce, count)
40    NonEmpty(u64, u64),
41
42    /// The watcher has finished, no more events will be emitted
43    Finished,
44}
45
46enum EndpointEvent {
47    Put(String, i64),
48    Delete(String),
49}
50
51#[derive(Clone)]
52pub struct Client<T: Data, U: Data> {
53    endpoint: Endpoint,
54    router: PushRouter<T, U>,
55    watch_rx: tokio::sync::watch::Receiver<Vec<i64>>,
56    counter: Arc<AtomicU64>,
57}
58
59impl<T, U> Client<T, U>
60where
61    T: Data + Serialize,
62    U: Data + for<'de> Deserialize<'de>,
63{
64    pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
65        let router = AddressedPushRouter::new(
66            endpoint.component.drt.nats_client.client().clone(),
67            endpoint.component.drt.tcp_server().await?,
68        )?;
69
70        // create live endpoint watcher
71        let prefix_watcher = endpoint
72            .component
73            .drt
74            .etcd_client
75            .kv_get_and_watch_prefix(endpoint.etcd_path())
76            .await?;
77
78        let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
79
80        let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
81
82        let secondary = endpoint.component.drt.runtime.secondary().clone();
83
84        // this task should be included in the registry
85        // currently this is created once per client, but this object/task should only be instantiated
86        // once per worker/instance
87        secondary.spawn(async move {
88            tracing::debug!("Starting endpoint watcher for prefix: {}", prefix);
89            let mut map = HashMap::new();
90
91            loop {
92                let kv_event = tokio::select! {
93                    _ = watch_tx.closed() => {
94                        tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {}", prefix);
95                        break;
96                    }
97                    kv_event = kv_event_rx.recv() => {
98                        match kv_event {
99                            Some(kv_event) => kv_event,
100                            None => {
101                                tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {}", prefix);
102                                break;
103                            }
104                        }
105                    }
106                };
107
108                match kv_event {
109                    WatchEvent::Put(kv) => {
110                        let key = String::from_utf8(kv.key().to_vec());
111                        let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
112                        if let (Ok(key), Ok(val)) = (key, val) {
113                            map.insert(key.clone(), val.lease_id);
114                        } else {
115                            tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
116                            break;
117                        }
118                    }
119                    WatchEvent::Delete(kv) => {
120                        match String::from_utf8(kv.key().to_vec()) {
121                            Ok(key) => { map.remove(&key); }
122                            Err(_) => {
123                                tracing::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
124                                break;
125                            }
126                        }
127                    }
128                }
129
130                let endpoint_ids: Vec<i64> = map.values().cloned().collect();
131
132                if watch_tx.send(endpoint_ids).is_err() {
133                    tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
134                    break;
135                }
136
137            }
138
139            tracing::debug!("Completed endpoint watcher for prefix: {}", prefix);
140            let _ = watch_tx.send(vec![]);
141        });
142
143        Ok(Client {
144            endpoint,
145            router,
146            watch_rx,
147            counter: Arc::new(AtomicU64::new(0)),
148        })
149    }
150
151    /// String identifying `<namespace>/<component>/<endpoint>`
152    pub fn path(&self) -> String {
153        self.endpoint.path()
154    }
155
156    /// String identifying `<namespace>/component/<component>/<endpoint>`
157    pub fn etcd_path(&self) -> String {
158        self.endpoint.etcd_path()
159    }
160
161    pub fn endpoint_ids(&self) -> &tokio::sync::watch::Receiver<Vec<i64>> {
162        &self.watch_rx
163    }
164
165    /// Wait for at least one [`Endpoint`] to be available
166    pub async fn wait_for_endpoints(&self) -> Result<()> {
167        let mut rx = self.watch_rx.clone();
168        // wait for there to be 1 or more endpoints
169        loop {
170            if rx.borrow_and_update().is_empty() {
171                rx.changed().await?;
172            } else {
173                break;
174            }
175        }
176
177        Ok(())
178    }
179
180    /// Issue a request to the next available endpoint in a round-robin fashion
181    pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
182        let counter = self.counter.fetch_add(1, Ordering::Relaxed);
183
184        let endpoint_id = {
185            let endpoints = self.watch_rx.borrow();
186            let count = endpoints.len();
187            if count == 0 {
188                return Err(error!(
189                    "no endpoints found for endpoint {:?}",
190                    self.endpoint.etcd_path()
191                ));
192            }
193            let offset = counter % count as u64;
194            endpoints[offset as usize]
195        };
196
197        let subject = self.endpoint.subject_to(endpoint_id);
198        let request = request.map(|req| AddressedRequest::new(req, subject));
199
200        self.router.generate(request).await
201    }
202
203    /// Issue a request to a random endpoint
204    pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
205        let endpoint_id = {
206            let endpoints = self.watch_rx.borrow();
207            let count = endpoints.len();
208            if count == 0 {
209                return Err(error!(
210                    "no endpoints found for endpoint {:?}",
211                    self.endpoint.etcd_path()
212                ));
213            }
214            let counter = rand::thread_rng().gen::<u64>();
215            let offset = counter % count as u64;
216            endpoints[offset as usize]
217        };
218
219        let subject = self.endpoint.subject_to(endpoint_id);
220        let request = request.map(|req| AddressedRequest::new(req, subject));
221
222        self.router.generate(request).await
223    }
224
225    /// Issue a request to a specific endpoint
226    pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
227        let found = {
228            let endpoints = self.watch_rx.borrow();
229            endpoints.contains(&endpoint_id)
230        };
231
232        if !found {
233            return Err(error!(
234                "endpoint_id={} not found for endpoint {:?}",
235                endpoint_id,
236                self.endpoint.etcd_path()
237            ));
238        }
239
240        let subject = self.endpoint.subject_to(endpoint_id);
241        let request = request.map(|req| AddressedRequest::new(req, subject));
242
243        self.router.generate(request).await
244    }
245}
246
247#[async_trait]
248impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for Client<T, U>
249where
250    T: Data + Serialize,
251    U: Data + for<'de> Deserialize<'de>,
252{
253    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
254        self.random(request).await
255    }
256}