dynamo_runtime/component/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5use std::{collections::HashMap, time::Duration};
6
7use anyhow::Result;
8use arc_swap::ArcSwap;
9use futures::StreamExt;
10use tokio::net::unix::pipe::Receiver;
11
12use crate::{
13    component::{Endpoint, Instance},
14    pipeline::async_trait,
15    pipeline::{
16        AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
17        SingleIn,
18    },
19    storage::key_value_store::{KeyValueStoreManager, WatchEvent},
20    traits::DistributedRuntimeProvider,
21    transports::etcd::Client as EtcdClient,
22};
23
24/// Each state will be have a nonce associated with it
25/// The state will be emitted in a watch channel, so we can observe the
26/// critical state transitions.
27enum MapState {
28    /// The map is empty; value = nonce
29    Empty(u64),
30
31    /// The map is not-empty; values are (nonce, count)
32    NonEmpty(u64, u64),
33
34    /// The watcher has finished, no more events will be emitted
35    Finished,
36}
37
38enum EndpointEvent {
39    Put(String, u64),
40    Delete(String),
41}
42
43#[derive(Clone, Debug)]
44pub struct Client {
45    // This is me
46    pub endpoint: Endpoint,
47    // These are the remotes I know about from watching etcd
48    pub instance_source: Arc<InstanceSource>,
49    // These are the instance source ids less those reported as down from sending rpc
50    instance_avail: Arc<ArcSwap<Vec<u64>>>,
51    // These are the instance source ids less those reported as busy (above threshold)
52    instance_free: Arc<ArcSwap<Vec<u64>>>,
53}
54
55#[derive(Clone, Debug)]
56pub enum InstanceSource {
57    Static,
58    Dynamic(tokio::sync::watch::Receiver<Vec<Instance>>),
59}
60
61impl Client {
62    // Client will only talk to a single static endpoint
63    pub(crate) async fn new_static(endpoint: Endpoint) -> Result<Self> {
64        Ok(Client {
65            endpoint,
66            instance_source: Arc::new(InstanceSource::Static),
67            instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
68            instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
69        })
70    }
71
72    // Client with auto-discover instances using etcd
73    pub(crate) async fn new_dynamic(endpoint: Endpoint) -> Result<Self> {
74        tracing::debug!(
75            "Client::new_dynamic: Creating dynamic client for endpoint: {}",
76            endpoint.path()
77        );
78        const INSTANCE_REFRESH_PERIOD: Duration = Duration::from_secs(1);
79
80        let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
81        tracing::debug!(
82            "Client::new_dynamic: Got instance source for endpoint: {}",
83            endpoint.path()
84        );
85
86        let client = Client {
87            endpoint: endpoint.clone(),
88            instance_source: instance_source.clone(),
89            instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
90            instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
91        };
92        tracing::debug!(
93            "Client::new_dynamic: Starting instance source monitor for endpoint: {}",
94            endpoint.path()
95        );
96        client.monitor_instance_source();
97        tracing::debug!(
98            "Client::new_dynamic: Successfully created dynamic client for endpoint: {}",
99            endpoint.path()
100        );
101        Ok(client)
102    }
103
104    pub fn path(&self) -> String {
105        self.endpoint.path()
106    }
107
108    /// The root etcd path we watch in etcd to discover new instances to route to.
109    pub fn etcd_root(&self) -> String {
110        self.endpoint.etcd_root()
111    }
112
113    /// Instances available from watching etcd
114    pub fn instances(&self) -> Vec<Instance> {
115        match self.instance_source.as_ref() {
116            InstanceSource::Static => vec![],
117            InstanceSource::Dynamic(watch_rx) => watch_rx.borrow().clone(),
118        }
119    }
120
121    pub fn instance_ids(&self) -> Vec<u64> {
122        self.instances().into_iter().map(|ep| ep.id()).collect()
123    }
124
125    pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
126        self.instance_avail.load()
127    }
128
129    pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
130        self.instance_free.load()
131    }
132
133    /// Wait for at least one Instance to be available for this Endpoint
134    pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
135        tracing::debug!(
136            "wait_for_instances: Starting wait for endpoint: {}",
137            self.endpoint.path()
138        );
139        let mut instances: Vec<Instance> = vec![];
140        if let InstanceSource::Dynamic(mut rx) = self.instance_source.as_ref().clone() {
141            // wait for there to be 1 or more endpoints
142            let mut iteration = 0;
143            loop {
144                instances = rx.borrow_and_update().to_vec();
145                tracing::debug!(
146                    "wait_for_instances: iteration={}, current_instance_count={}, endpoint={}",
147                    iteration,
148                    instances.len(),
149                    self.endpoint.path()
150                );
151                if instances.is_empty() {
152                    tracing::debug!(
153                        "wait_for_instances: No instances yet, waiting for change notification for endpoint: {}",
154                        self.endpoint.path()
155                    );
156                    rx.changed().await?;
157                    tracing::debug!(
158                        "wait_for_instances: Change notification received for endpoint: {}",
159                        self.endpoint.path()
160                    );
161                } else {
162                    tracing::info!(
163                        "wait_for_instances: Found {} instance(s) for endpoint: {}",
164                        instances.len(),
165                        self.endpoint.path()
166                    );
167                    break;
168                }
169                iteration += 1;
170            }
171        } else {
172            tracing::debug!(
173                "wait_for_instances: Static instance source, no dynamic discovery for endpoint: {}",
174                self.endpoint.path()
175            );
176        }
177        Ok(instances)
178    }
179
180    /// Is this component know at startup and not discovered via etcd?
181    pub fn is_static(&self) -> bool {
182        matches!(self.instance_source.as_ref(), InstanceSource::Static)
183    }
184
185    /// Mark an instance as down/unavailable
186    pub fn report_instance_down(&self, instance_id: u64) {
187        let filtered = self
188            .instance_ids_avail()
189            .iter()
190            .filter_map(|&id| if id == instance_id { None } else { Some(id) })
191            .collect::<Vec<_>>();
192        self.instance_avail.store(Arc::new(filtered));
193
194        tracing::debug!("inhibiting instance {instance_id}");
195    }
196
197    /// Update the set of free instances based on busy instance IDs
198    pub fn update_free_instances(&self, busy_instance_ids: &[u64]) {
199        let all_instance_ids = self.instance_ids();
200        let free_ids: Vec<u64> = all_instance_ids
201            .into_iter()
202            .filter(|id| !busy_instance_ids.contains(id))
203            .collect();
204        self.instance_free.store(Arc::new(free_ids));
205    }
206
207    /// Monitor the ETCD instance source and update instance_avail.
208    fn monitor_instance_source(&self) {
209        let cancel_token = self.endpoint.drt().primary_token();
210        let client = self.clone();
211        let endpoint_path = self.endpoint.path();
212        tracing::debug!(
213            "monitor_instance_source: Starting monitor for endpoint: {}",
214            endpoint_path
215        );
216        tokio::task::spawn(async move {
217            let mut rx = match client.instance_source.as_ref() {
218                InstanceSource::Static => {
219                    tracing::error!(
220                        "monitor_instance_source: Static instance source is not watchable"
221                    );
222                    return;
223                }
224                InstanceSource::Dynamic(rx) => rx.clone(),
225            };
226            let mut iteration = 0;
227            while !cancel_token.is_cancelled() {
228                let instance_ids: Vec<u64> = rx
229                    .borrow_and_update()
230                    .iter()
231                    .map(|instance| instance.id())
232                    .collect();
233
234                tracing::debug!(
235                    "monitor_instance_source: iteration={}, instance_count={}, instance_ids={:?}, endpoint={}",
236                    iteration,
237                    instance_ids.len(),
238                    instance_ids,
239                    endpoint_path
240                );
241
242                // TODO: this resets both tracked available and free instances
243                client.instance_avail.store(Arc::new(instance_ids.clone()));
244                client.instance_free.store(Arc::new(instance_ids.clone()));
245
246                tracing::debug!(
247                    "monitor_instance_source: instance source updated, endpoint={}",
248                    endpoint_path
249                );
250
251                if let Err(err) = rx.changed().await {
252                    tracing::error!(
253                        "monitor_instance_source: The Sender is dropped: {}, endpoint={}",
254                        err,
255                        endpoint_path
256                    );
257                    cancel_token.cancel();
258                }
259                iteration += 1;
260            }
261            tracing::debug!(
262                "monitor_instance_source: Monitor loop exiting for endpoint: {}",
263                endpoint_path
264            );
265        });
266    }
267
268    async fn get_or_create_dynamic_instance_source(
269        endpoint: &Endpoint,
270    ) -> Result<Arc<InstanceSource>> {
271        let drt = endpoint.drt();
272        let instance_sources = drt.instance_sources();
273        let mut instance_sources = instance_sources.lock().await;
274
275        tracing::debug!(
276            "get_or_create_dynamic_instance_source: Checking cache for endpoint: {}",
277            endpoint.path()
278        );
279
280        if let Some(instance_source) = instance_sources.get(endpoint) {
281            if let Some(instance_source) = instance_source.upgrade() {
282                tracing::debug!(
283                    "get_or_create_dynamic_instance_source: Found cached instance source for endpoint: {}",
284                    endpoint.path()
285                );
286                return Ok(instance_source);
287            } else {
288                tracing::debug!(
289                    "get_or_create_dynamic_instance_source: Cached instance source was dropped, removing for endpoint: {}",
290                    endpoint.path()
291                );
292                instance_sources.remove(endpoint);
293            }
294        }
295
296        tracing::debug!(
297            "get_or_create_dynamic_instance_source: Creating new instance source for endpoint: {}",
298            endpoint.path()
299        );
300
301        let discovery = drt.discovery();
302        let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
303            namespace: endpoint.component.namespace.name.clone(),
304            component: endpoint.component.name.clone(),
305            endpoint: endpoint.name.clone(),
306        };
307
308        tracing::debug!(
309            "get_or_create_dynamic_instance_source: Calling discovery.list_and_watch for query: {:?}",
310            discovery_query
311        );
312
313        let mut discovery_stream = discovery
314            .list_and_watch(discovery_query.clone(), None)
315            .await?;
316
317        tracing::debug!(
318            "get_or_create_dynamic_instance_source: Got discovery stream for query: {:?}",
319            discovery_query
320        );
321
322        let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
323
324        let secondary = endpoint.component.drt.runtime().secondary().clone();
325
326        secondary.spawn(async move {
327            tracing::debug!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
328            let mut map: HashMap<u64, Instance> = HashMap::new();
329            let mut event_count = 0;
330
331            loop {
332                let discovery_event = tokio::select! {
333                    _ = watch_tx.closed() => {
334                        tracing::debug!("endpoint_watcher: all watchers have closed; shutting down for discovery query: {:?}", discovery_query);
335                        break;
336                    }
337                    discovery_event = discovery_stream.next() => {
338                        tracing::debug!("endpoint_watcher: Received stream event for discovery query: {:?}", discovery_query);
339                        match discovery_event {
340                            Some(Ok(event)) => {
341                                tracing::debug!("endpoint_watcher: Got Ok event: {:?}", event);
342                                event
343                            },
344                            Some(Err(e)) => {
345                                tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
346                                break;
347                            }
348                            None => {
349                                tracing::debug!("endpoint_watcher: watch stream has closed; shutting down for discovery query: {:?}", discovery_query);
350                                break;
351                            }
352                        }
353                    }
354                };
355
356                event_count += 1;
357                tracing::debug!("endpoint_watcher: Processing event #{} for discovery query: {:?}", event_count, discovery_query);
358
359                match discovery_event {
360                    crate::discovery::DiscoveryEvent::Added(discovery_instance) => {
361                        match discovery_instance {
362                            crate::discovery::DiscoveryInstance::Endpoint(instance) => {
363                                tracing::debug!(
364                                    "endpoint_watcher: Added endpoint instance_id={}, namespace={}, component={}, endpoint={}",
365                                    instance.instance_id,
366                                    instance.namespace,
367                                    instance.component,
368                                    instance.endpoint
369                                );
370                                map.insert(instance.instance_id, instance);
371                            }
372                            _ => {
373                                tracing::debug!("endpoint_watcher: Ignoring non-endpoint instance (Model, etc.) for discovery query: {:?}", discovery_query);
374                            }
375                        }
376                    }
377                    crate::discovery::DiscoveryEvent::Removed(instance_id) => {
378                        tracing::debug!(
379                            "endpoint_watcher: Removed instance_id={} for discovery query: {:?}",
380                            instance_id,
381                            discovery_query
382                        );
383                        map.remove(&instance_id);
384                    }
385                }
386
387                let instances: Vec<Instance> = map.values().cloned().collect();
388                tracing::debug!(
389                    "endpoint_watcher: Current map size={}, sending update for discovery query: {:?}",
390                    instances.len(),
391                    discovery_query
392                );
393
394                if watch_tx.send(instances).is_err() {
395                    tracing::debug!("endpoint_watcher: Unable to send watch updates; shutting down for discovery query: {:?}", discovery_query);
396                    break;
397                }
398            }
399
400            tracing::debug!("endpoint_watcher: Completed for discovery query: {:?}, total events processed: {}", discovery_query, event_count);
401            let _ = watch_tx.send(vec![]);
402        });
403
404        let instance_source = Arc::new(InstanceSource::Dynamic(watch_rx));
405        instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
406        tracing::debug!(
407            "get_or_create_dynamic_instance_source: Successfully created and cached instance source for endpoint: {}",
408            endpoint.path()
409        );
410        Ok(instance_source)
411    }
412}