dynamo_runtime/component/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5use std::{collections::HashMap, time::Duration};
6
7use anyhow::Result;
8use arc_swap::ArcSwap;
9use futures::StreamExt;
10use tokio::net::unix::pipe::Receiver;
11
12use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
13use crate::{
14    component::{Endpoint, Instance},
15    pipeline::async_trait,
16    pipeline::{
17        AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
18        SingleIn,
19    },
20    traits::DistributedRuntimeProvider,
21    transports::etcd::Client as EtcdClient,
22};
23
24#[derive(Clone, Debug)]
25pub struct Client {
26    // This is me
27    pub endpoint: Endpoint,
28    // These are the remotes I know about from watching key-value store
29    pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
30    // These are the instance source ids less those reported as down from sending rpc
31    instance_avail: Arc<ArcSwap<Vec<u64>>>,
32    // These are the instance source ids less those reported as busy (above threshold)
33    instance_free: Arc<ArcSwap<Vec<u64>>>,
34    // Watch sender for available instance IDs (for sending updates)
35    instance_avail_tx: Arc<tokio::sync::watch::Sender<Vec<u64>>>,
36    // Watch receiver for available instance IDs (for cloning to external subscribers)
37    instance_avail_rx: tokio::sync::watch::Receiver<Vec<u64>>,
38}
39
40impl Client {
41    // Client with auto-discover instances using key-value store
42    pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
43        tracing::trace!(
44            "Client::new_dynamic: Creating dynamic client for endpoint: {}",
45            endpoint.id()
46        );
47        let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
48
49        let (avail_tx, avail_rx) = tokio::sync::watch::channel(vec![]);
50        let client = Client {
51            endpoint: endpoint.clone(),
52            instance_source: instance_source.clone(),
53            instance_avail: Arc::new(ArcSwap::from(Arc::new(vec![]))),
54            instance_free: Arc::new(ArcSwap::from(Arc::new(vec![]))),
55            instance_avail_tx: Arc::new(avail_tx),
56            instance_avail_rx: avail_rx,
57        };
58        client.monitor_instance_source();
59        Ok(client)
60    }
61
62    /// Instances available from watching key-value store
63    pub fn instances(&self) -> Vec<Instance> {
64        self.instance_source.borrow().clone()
65    }
66
67    pub fn instance_ids(&self) -> Vec<u64> {
68        self.instances().into_iter().map(|ep| ep.id()).collect()
69    }
70
71    pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
72        self.instance_avail.load()
73    }
74
75    pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
76        self.instance_free.load()
77    }
78
79    /// Get a watcher for available instance IDs
80    pub fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
81        self.instance_avail_rx.clone()
82    }
83
84    /// Wait for at least one Instance to be available for this Endpoint
85    pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
86        tracing::trace!(
87            "wait_for_instances: Starting wait for endpoint: {}",
88            self.endpoint.id()
89        );
90        let mut rx = self.instance_source.as_ref().clone();
91        // wait for there to be 1 or more endpoints
92        let mut instances: Vec<Instance>;
93        loop {
94            instances = rx.borrow_and_update().to_vec();
95            if instances.is_empty() {
96                rx.changed().await?;
97            } else {
98                tracing::info!(
99                    "wait_for_instances: Found {} instance(s) for endpoint: {}",
100                    instances.len(),
101                    self.endpoint.id()
102                );
103                break;
104            }
105        }
106        Ok(instances)
107    }
108
109    /// Mark an instance as down/unavailable
110    pub fn report_instance_down(&self, instance_id: u64) {
111        let filtered = self
112            .instance_ids_avail()
113            .iter()
114            .filter_map(|&id| if id == instance_id { None } else { Some(id) })
115            .collect::<Vec<_>>();
116        self.instance_avail.store(Arc::new(filtered.clone()));
117
118        // Notify watch channel subscribers about the change
119        let _ = self.instance_avail_tx.send(filtered);
120
121        tracing::debug!("inhibiting instance {instance_id}");
122    }
123
124    /// Update the set of free instances based on busy instance IDs
125    pub fn update_free_instances(&self, busy_instance_ids: &[u64]) {
126        let all_instance_ids = self.instance_ids();
127        let free_ids: Vec<u64> = all_instance_ids
128            .into_iter()
129            .filter(|id| !busy_instance_ids.contains(id))
130            .collect();
131        self.instance_free.store(Arc::new(free_ids));
132    }
133
134    /// Monitor the key-value instance source and update instance_avail.
135    fn monitor_instance_source(&self) {
136        let cancel_token = self.endpoint.drt().primary_token();
137        let client = self.clone();
138        let endpoint_id = self.endpoint.id();
139        tokio::task::spawn(async move {
140            let mut rx = client.instance_source.as_ref().clone();
141            while !cancel_token.is_cancelled() {
142                let instance_ids: Vec<u64> = rx
143                    .borrow_and_update()
144                    .iter()
145                    .map(|instance| instance.id())
146                    .collect();
147
148                // TODO: this resets both tracked available and free instances
149                client.instance_avail.store(Arc::new(instance_ids.clone()));
150                client.instance_free.store(Arc::new(instance_ids.clone()));
151
152                // Send update to watch channel subscribers
153                let _ = client.instance_avail_tx.send(instance_ids);
154
155                if let Err(err) = rx.changed().await {
156                    tracing::error!(
157                        "monitor_instance_source: The Sender is dropped: {err}, endpoint={endpoint_id}",
158                    );
159                    cancel_token.cancel();
160                }
161            }
162        });
163    }
164
165    async fn get_or_create_dynamic_instance_source(
166        endpoint: &Endpoint,
167    ) -> Result<Arc<tokio::sync::watch::Receiver<Vec<Instance>>>> {
168        let drt = endpoint.drt();
169        let instance_sources = drt.instance_sources();
170        let mut instance_sources = instance_sources.lock().await;
171
172        if let Some(instance_source) = instance_sources.get(endpoint) {
173            if let Some(instance_source) = instance_source.upgrade() {
174                return Ok(instance_source);
175            } else {
176                instance_sources.remove(endpoint);
177            }
178        }
179
180        let discovery = drt.discovery();
181        let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
182            namespace: endpoint.component.namespace.name.clone(),
183            component: endpoint.component.name.clone(),
184            endpoint: endpoint.name.clone(),
185        };
186
187        let mut discovery_stream = discovery
188            .list_and_watch(discovery_query.clone(), None)
189            .await?;
190        let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
191
192        let secondary = endpoint.component.drt.runtime().secondary().clone();
193
194        secondary.spawn(async move {
195            tracing::trace!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
196            let mut map: HashMap<u64, Instance> = HashMap::new();
197
198            loop {
199                let discovery_event = tokio::select! {
200                    _ = watch_tx.closed() => {
201                        break;
202                    }
203                    discovery_event = discovery_stream.next() => {
204                        match discovery_event {
205                            Some(Ok(event)) => {
206                                event
207                            },
208                            Some(Err(e)) => {
209                                tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
210                                break;
211                            }
212                            None => {
213                                break;
214                            }
215                        }
216                    }
217                };
218
219                match discovery_event {
220                    DiscoveryEvent::Added(discovery_instance) => {
221                        if let DiscoveryInstance::Endpoint(instance) = discovery_instance {
222
223                                map.insert(instance.instance_id, instance);
224                        }
225                    }
226                    DiscoveryEvent::Removed(id) => {
227                        map.remove(&id.instance_id());
228                    }
229                }
230
231                let instances: Vec<Instance> = map.values().cloned().collect();
232                if watch_tx.send(instances).is_err() {
233                    break;
234                }
235            }
236            let _ = watch_tx.send(vec![]);
237        });
238
239        let instance_source = Arc::new(watch_rx);
240        instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
241        Ok(instance_source)
242    }
243}