dynamo_runtime/component/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use crate::pipeline::{
17    network::egress::push::{AddressedPushRouter, AddressedRequest, PushRouter},
18    AsyncEngine, Data, ManyOut, SingleIn,
19};
20use rand::Rng;
21use std::collections::HashMap;
22use std::sync::{
23    atomic::{AtomicU64, Ordering},
24    Arc,
25};
26use tokio::{net::unix::pipe::Receiver, sync::Mutex};
27
28use crate::{pipeline::async_trait, transports::etcd::WatchEvent, Error};
29
30use super::*;
31
32/// Each state will be have a nonce associated with it
33/// The state will be emitted in a watch channel, so we can observe the
34/// critical state transitions.
35enum MapState {
36    /// The map is empty; value = nonce
37    Empty(u64),
38
39    /// The map is not-empty; values are (nonce, count)
40    NonEmpty(u64, u64),
41
42    /// The watcher has finished, no more events will be emitted
43    Finished,
44}
45
46enum EndpointEvent {
47    Put(String, i64),
48    Delete(String),
49}
50
51#[derive(Default, Debug, Clone, Copy)]
52pub enum RouterMode {
53    #[default]
54    Random,
55    RoundRobin,
56    //KV,
57    //
58    // Always and only go to the given endpoint ID.
59    // TODO: Is this useful?
60    Direct(i64),
61}
62
63#[derive(Clone)]
64pub struct Client<T: Data, U: Data> {
65    endpoint: Endpoint,
66    router: PushRouter<T, U>,
67    counter: Arc<AtomicU64>,
68    endpoints: EndpointSource,
69    router_mode: RouterMode,
70}
71
72#[derive(Clone, Debug)]
73enum EndpointSource {
74    Static,
75    Dynamic(tokio::sync::watch::Receiver<Vec<i64>>),
76}
77
78impl<T, U> Client<T, U>
79where
80    T: Data + Serialize,
81    U: Data + for<'de> Deserialize<'de>,
82{
83    // Client will only talk to a single static endpoint
84    pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
85        Ok(Client {
86            router: router(&endpoint).await?,
87            endpoint,
88            counter: Arc::new(AtomicU64::new(0)),
89            endpoints: EndpointSource::Static,
90            router_mode: Default::default(),
91        })
92    }
93
94    // Client with auto-discover endpoints using etcd
95    pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
96        // create live endpoint watcher
97        let Some(etcd_client) = &endpoint.component.drt.etcd_client else {
98            anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
99        };
100        let prefix_watcher = etcd_client
101            .kv_get_and_watch_prefix(endpoint.etcd_path())
102            .await?;
103
104        let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
105
106        let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
107
108        let secondary = endpoint.component.drt.runtime.secondary().clone();
109
110        // this task should be included in the registry
111        // currently this is created once per client, but this object/task should only be instantiated
112        // once per worker/instance
113        secondary.spawn(async move {
114            tracing::debug!("Starting endpoint watcher for prefix: {}", prefix);
115            let mut map = HashMap::new();
116
117            loop {
118                let kv_event = tokio::select! {
119                    _ = watch_tx.closed() => {
120                        tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {}", prefix);
121                        break;
122                    }
123                    kv_event = kv_event_rx.recv() => {
124                        match kv_event {
125                            Some(kv_event) => kv_event,
126                            None => {
127                                tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {}", prefix);
128                                break;
129                            }
130                        }
131                    }
132                };
133
134                match kv_event {
135                    WatchEvent::Put(kv) => {
136                        let key = String::from_utf8(kv.key().to_vec());
137                        let val = serde_json::from_slice::<ComponentEndpointInfo>(kv.value());
138                        if let (Ok(key), Ok(val)) = (key, val) {
139                            map.insert(key.clone(), val.lease_id);
140                        } else {
141                            tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
142                            break;
143                        }
144                    }
145                    WatchEvent::Delete(kv) => {
146                        match String::from_utf8(kv.key().to_vec()) {
147                            Ok(key) => { map.remove(&key); }
148                            Err(_) => {
149                                tracing::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
150                                break;
151                            }
152                        }
153                    }
154                }
155
156                let endpoint_ids: Vec<i64> = map.values().cloned().collect();
157
158                if watch_tx.send(endpoint_ids).is_err() {
159                    tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
160                    break;
161                }
162
163            }
164
165            tracing::debug!("Completed endpoint watcher for prefix: {}", prefix);
166            let _ = watch_tx.send(vec![]);
167        });
168
169        Ok(Client {
170            router: router(&endpoint).await?,
171            endpoint,
172            counter: Arc::new(AtomicU64::new(0)),
173            endpoints: EndpointSource::Dynamic(watch_rx),
174            router_mode: Default::default(),
175        })
176    }
177
178    /// String identifying `<namespace>/<component>/<endpoint>`
179    pub fn path(&self) -> String {
180        self.endpoint.path()
181    }
182
183    /// String identifying `<namespace>/component/<component>/<endpoint>`
184    pub fn etcd_path(&self) -> String {
185        self.endpoint.etcd_path()
186    }
187
188    pub fn endpoint_ids(&self) -> Vec<i64> {
189        match &self.endpoints {
190            EndpointSource::Static => vec![0],
191            EndpointSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
192        }
193    }
194
195    pub fn set_router_mode(&mut self, mode: RouterMode) {
196        self.router_mode = mode
197    }
198
199    /// Wait for at least one [`Endpoint`] to be available
200    pub async fn wait_for_endpoints(&self) -> Result<()> {
201        if let EndpointSource::Dynamic(mut rx) = self.endpoints.clone() {
202            // wait for there to be 1 or more endpoints
203            loop {
204                if rx.borrow_and_update().is_empty() {
205                    rx.changed().await?;
206                } else {
207                    break;
208                }
209            }
210        }
211        Ok(())
212    }
213
214    /// Is this component know at startup and not discovered via etcd?
215    pub fn is_static(&self) -> bool {
216        matches!(self.endpoints, EndpointSource::Static)
217    }
218
219    /// Issue a request to the next available endpoint in a round-robin fashion
220    pub async fn round_robin(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
221        let counter = self.counter.fetch_add(1, Ordering::Relaxed);
222
223        let endpoint_id = {
224            let endpoints = self.endpoint_ids();
225            let count = endpoints.len();
226            if count == 0 {
227                return Err(error!(
228                    "no endpoints found for endpoint {:?}",
229                    self.endpoint.etcd_path()
230                ));
231            }
232            let offset = counter % count as u64;
233            endpoints[offset as usize]
234        };
235        tracing::trace!("round robin router selected {endpoint_id}");
236
237        let subject = self.endpoint.subject_to(endpoint_id);
238        let request = request.map(|req| AddressedRequest::new(req, subject));
239
240        self.router.generate(request).await
241    }
242
243    /// Issue a request to a random endpoint
244    pub async fn random(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
245        let endpoint_id = {
246            let endpoints = self.endpoint_ids();
247            let count = endpoints.len();
248            if count == 0 {
249                return Err(error!(
250                    "no endpoints found for endpoint {:?}",
251                    self.endpoint.etcd_path()
252                ));
253            }
254            let counter = rand::rng().random::<u64>();
255            let offset = counter % count as u64;
256            endpoints[offset as usize]
257        };
258        tracing::trace!("random router selected {endpoint_id}");
259
260        let subject = self.endpoint.subject_to(endpoint_id);
261        let request = request.map(|req| AddressedRequest::new(req, subject));
262
263        self.router.generate(request).await
264    }
265
266    /// Issue a request to a specific endpoint
267    pub async fn direct(&self, request: SingleIn<T>, endpoint_id: i64) -> Result<ManyOut<U>> {
268        let found = {
269            let endpoints = self.endpoint_ids();
270            endpoints.contains(&endpoint_id)
271        };
272
273        if !found {
274            return Err(error!(
275                "endpoint_id={} not found for endpoint {:?}",
276                endpoint_id,
277                self.endpoint.etcd_path()
278            ));
279        }
280
281        let subject = self.endpoint.subject_to(endpoint_id);
282        let request = request.map(|req| AddressedRequest::new(req, subject));
283
284        self.router.generate(request).await
285    }
286
287    pub async fn r#static(&self, request: SingleIn<T>) -> Result<ManyOut<U>> {
288        let subject = self.endpoint.subject();
289        tracing::debug!("static got subject: {subject}");
290        let request = request.map(|req| AddressedRequest::new(req, subject));
291        tracing::debug!("router generate");
292        self.router.generate(request).await
293    }
294}
295
296async fn router(endpoint: &Endpoint) -> Result<Arc<AddressedPushRouter>> {
297    AddressedPushRouter::new(
298        endpoint.component.drt.nats_client.client().clone(),
299        endpoint.component.drt.tcp_server().await?,
300    )
301}
302
303#[async_trait]
304impl<T, U> AsyncEngine<SingleIn<T>, ManyOut<U>, Error> for Client<T, U>
305where
306    T: Data + Serialize,
307    U: Data + for<'de> Deserialize<'de>,
308{
309    async fn generate(&self, request: SingleIn<T>) -> Result<ManyOut<U>, Error> {
310        match &self.endpoints {
311            EndpointSource::Static => self.r#static(request).await,
312            EndpointSource::Dynamic(_) => match self.router_mode {
313                RouterMode::Random => self.random(request).await,
314                RouterMode::RoundRobin => self.round_robin(request).await,
315                RouterMode::Direct(endpoint_id) => self.direct(request, endpoint_id).await,
316            },
317        }
318    }
319}