dynamo_runtime/component/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::pipeline::{
5    AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
6    SingleIn,
7};
8use arc_swap::ArcSwap;
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::net::unix::pipe::Receiver;
12
13use crate::{
14    pipeline::async_trait,
15    transports::etcd::{Client as EtcdClient, WatchEvent},
16};
17
18use super::*;
19
20/// Each state will be have a nonce associated with it
21/// The state will be emitted in a watch channel, so we can observe the
22/// critical state transitions.
23enum MapState {
24    /// The map is empty; value = nonce
25    Empty(u64),
26
27    /// The map is not-empty; values are (nonce, count)
28    NonEmpty(u64, u64),
29
30    /// The watcher has finished, no more events will be emitted
31    Finished,
32}
33
34enum EndpointEvent {
35    Put(String, i64),
36    Delete(String),
37}
38
39#[derive(Clone, Debug)]
40pub struct Client {
41    // This is me
42    pub endpoint: Endpoint,
43    // These are the remotes I know about from watching etcd
44    pub instance_source: Arc<InstanceSource>,
45    // These are the instance source ids less those reported as down from sending rpc
46    instance_avail: Arc<ArcSwap<Vec<i64>>>,
47    // These are the instance source ids less those reported as busy (above threshold)
48    instance_free: Arc<ArcSwap<Vec<i64>>>,
49}
50
51#[derive(Clone, Debug)]
52pub enum InstanceSource {
53    Static,
54    Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
55}
56
57impl Client {
58    // Client will only talk to a single static endpoint
59    pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
60        Ok(Client {
61            endpoint,
62            instance_source: Arc::new(InstanceSource::Static),
63            instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
64            instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
65        })
66    }
67
68    // Client with auto-discover instances using etcd
69    pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
70        const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);
71
72        // create live endpoint watcher
73        let Some(etcd_client) = &endpoint.component.drt.etcd_client else {
74            anyhow::bail!("Attempt to create a dynamic client on a static endpoint");
75        };
76
77        let instance_source =
78            Self::get_or_create_dynamic_instance_source(etcd_client, &endpoint).await?;
79
80        let client = Client {
81            endpoint,
82            instance_source: instance_source.clone(),
83            instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
84            instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
85        };
86        client.monitor_instance_source();
87        Ok(client)
88    }
89
90    pub fn path(&self) -> String {
91        self.endpoint.path()
92    }
93
94    /// The root etcd path we watch in etcd to discover new instances to route to.
95    pub fn etcd_root(&self) -> String {
96        self.endpoint.etcd_root()
97    }
98
99    /// Instances available from watching etcd
100    pub fn instances(&self) -> Vec<Instance> {
101        match self.instance_source.as_ref() {
102            InstanceSource::Static => vec![],
103            InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
104        }
105    }
106
107    pub fn instance_ids(&self) -> Vec<i64> {
108        self.instances().into_iter().map(|ep| ep.id()).collect()
109    }
110
111    pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
112        self.instance_avail.load()
113    }
114
115    pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<i64>>> {
116        self.instance_free.load()
117    }
118
119    /// Wait for at least one Instance to be available for this Endpoint
120    pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
121        let mut instances: Vec<Instance> = vec![];
122        if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() {
123            // wait for there to be 1 or more endpoints
124            loop {
125                instances = rx.borrow_and_update().to_vec();
126                if instances.is_empty() {
127                    rx.changed().await?;
128                } else {
129                    break;
130                }
131            }
132        }
133        Ok(instances)
134    }
135
136    /// Is this component know at startup and not discovered via etcd?
137    pub fn is_static(&self) -> bool {
138        matches!(self.instance_source.as_ref(), InstanceSource::Static)
139    }
140
141    /// Mark an instance as down/unavailable
142    pub fn report_instance_down(&self, instance_id: i64) {
143        let filtered = self
144            .instance_ids_avail()
145            .iter()
146            .filter_map(|&id| if id == instance_id { None } else { Some(id) })
147            .collect::<Vec<_>>();
148        self.instance_avail.store(Arc::new(filtered));
149
150        tracing::debug!("inhibiting instance {instance_id}");
151    }
152
153    /// Update the set of free instances based on busy instance IDs
154    pub fn update_free_instances(&self, busy_instance_ids: &[i64]) {
155        let all_instance_ids = self.instance_ids();
156        let free_ids: Vec<i64> = all_instance_ids
157            .into_iter()
158            .filter(|id| !busy_instance_ids.contains(id))
159            .collect();
160        self.instance_free.store(Arc::new(free_ids));
161    }
162
163    /// Monitor the ETCD instance source and update instance_avail.
164    fn monitor_instance_source(&self) {
165        let cancel_token = self.endpoint.drt().primary_token();
166        let client = self.clone();
167        tokio::task::spawn(async move {
168            let mut rx = match client.instance_source.as_ref() {
169                InstanceSource::Static => {
170                    tracing::error!("Static instance source is not watchable");
171                    return;
172                }
173                InstanceSource::Dynamic(rx) => rx.clone(),
174            };
175            while !cancel_token.is_cancelled() {
176                let instance_ids: Vec<i64> = rx
177                    .borrow_and_update()
178                    .iter()
179                    .map(|instance| instance.id())
180                    .collect();
181
182                // TODO: this resets both tracked available and free instances
183                client.instance_avail.store(Arc::new(instance_ids.clone()));
184                client.instance_free.store(Arc::new(instance_ids));
185
186                tracing::debug!("instance source updated");
187
188                if let Err(err) = rx.changed().await {
189                    tracing::error!("The Sender is dropped: {}", err);
190                    cancel_token.cancel();
191                }
192            }
193        });
194    }
195
196    async fn get_or_create_dynamic_instance_source(
197        etcd_client: &EtcdClient,
198        endpoint: &Endpoint,
199    ) -> Result<Arc<InstanceSource>> {
200        let drt = endpoint.drt();
201        let instance_sources = drt.instance_sources();
202        let mut instance_sources = instance_sources.lock().await;
203
204        if let Some(instance_source) = instance_sources.get(endpoint) {
205            if let Some(instance_source) = instance_source.upgrade() {
206                return Ok(instance_source);
207            } else {
208                instance_sources.remove(endpoint);
209            }
210        }
211
212        let prefix_watcher = etcd_client
213            .kv_get_and_watch_prefix(endpoint.etcd_root())
214            .await?;
215
216        let (prefix, _watcher, mut kv_event_rx) = prefix_watcher.dissolve();
217
218        let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
219
220        let secondary = endpoint.component.drt.runtime.secondary().clone();
221
222        // this task should be included in the registry
223        // currently this is created once per client, but this object/task should only be instantiated
224        // once per worker/instance
225        secondary.spawn(async move {
226            tracing::debug!("Starting endpoint watcher for prefix: {}", prefix);
227            let mut map = HashMap::new();
228
229            loop {
230                let kv_event = tokio::select! {
231                    _ = watch_tx.closed() => {
232                        tracing::debug!("all watchers have closed; shutting down endpoint watcher for prefix: {prefix}");
233                        break;
234                    }
235                    kv_event = kv_event_rx.recv() => {
236                        match kv_event {
237                            Some(kv_event) => kv_event,
238                            None => {
239                                tracing::debug!("watch stream has closed; shutting down endpoint watcher for prefix: {prefix}");
240                                break;
241                            }
242                        }
243                    }
244                };
245
246                match kv_event {
247                    WatchEvent::Put(kv) => {
248                        let key = String::from_utf8(kv.key().to_vec());
249                        let val = serde_json::from_slice::<Instance>(kv.value());
250                        if let (Ok(key), Ok(val)) = (key, val) {
251                            map.insert(key.clone(), val);
252                        } else {
253                            tracing::error!("Unable to parse put endpoint event; shutting down endpoint watcher for prefix: {prefix}");
254                            break;
255                        }
256                    }
257                    WatchEvent::Delete(kv) => {
258                        match String::from_utf8(kv.key().to_vec()) {
259                            Ok(key) => { map.remove(&key); }
260                            Err(_) => {
261                                tracing::error!("Unable to parse delete endpoint event; shutting down endpoint watcher for prefix: {}", prefix);
262                                break;
263                            }
264                        }
265                    }
266                }
267
268                let instances: Vec<Instance> = map.values().cloned().collect();
269
270                if watch_tx.send(instances).is_err() {
271                    tracing::debug!("Unable to send watch updates; shutting down endpoint watcher for prefix: {}", prefix);
272                    break;
273                }
274
275            }
276
277            tracing::debug!("Completed endpoint watcher for prefix: {prefix}");
278            let _ = watch_tx.send(vec![]);
279        });
280
281        let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx));
282        instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
283        Ok(instance_source)
284    }
285}