Skip to main content

dynamo_runtime/component/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::Arc;
5use std::{collections::HashMap, time::Duration};
6
7use anyhow::Result;
8use arc_swap::ArcSwap;
9use futures::StreamExt;
10use tokio::net::unix::pipe::Receiver;
11
12use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
13use crate::{
14    component::{Endpoint, Instance},
15    pipeline::async_trait,
16    pipeline::{
17        AddressedPushRouter, AddressedRequest, AsyncEngine, Data, ManyOut, PushRouter, RouterMode,
18        SingleIn,
19    },
20    traits::DistributedRuntimeProvider,
21    transports::etcd::Client as EtcdClient,
22};
23
24/// Default interval for periodic reconciliation of instance_avail with instance_source
25const DEFAULT_RECONCILE_INTERVAL: Duration = Duration::from_secs(5);
26
27#[derive(Clone, Debug)]
28pub struct Client {
29    // This is me
30    pub endpoint: Endpoint,
31    // These are the remotes I know about from watching key-value store
32    pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
33    // These are the instance source ids less those reported as down from sending rpc
34    instance_avail: Arc<ArcSwap<Vec<u64>>>,
35    // These are the instance source ids less those reported as busy (above threshold)
36    instance_free: Arc<ArcSwap<Vec<u64>>>,
37    // Watch sender for available instance IDs (for sending updates)
38    instance_avail_tx: Arc<tokio::sync::watch::Sender<Vec<u64>>>,
39    // Watch receiver for available instance IDs (for cloning to external subscribers)
40    instance_avail_rx: tokio::sync::watch::Receiver<Vec<u64>>,
41    /// Interval for periodic reconciliation of instance_avail with instance_source.
42    /// This ensures instances removed via `report_instance_down` are eventually restored.
43    reconcile_interval: Duration,
44}
45
46impl Client {
47    // Client with auto-discover instances using key-value store
48    pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
49        Self::with_reconcile_interval(endpoint, DEFAULT_RECONCILE_INTERVAL).await
50    }
51
52    /// Create a client with a custom reconcile interval.
53    /// The reconcile interval controls how often `instance_avail` is reset to match
54    /// `instance_source`, restoring any instances removed via `report_instance_down`.
55    pub(crate) async fn with_reconcile_interval(
56        endpoint: Endpoint,
57        reconcile_interval: Duration,
58    ) -> Result<Self> {
59        tracing::trace!(
60            "Client::new_dynamic: Creating dynamic client for endpoint: {}",
61            endpoint.id()
62        );
63        let instance_source = Self::get_or_create_dynamic_instance_source(&endpoint).await?;
64
65        // Seed instance_avail from the current instance_source snapshot so that
66        // callers who proceed immediately after wait_for_instances (which reads
67        // instance_source directly) will also find instances in instance_avail
68        // (which is read by the routing methods like random/round_robin).
69        let initial_ids: Vec<u64> = instance_source
70            .borrow()
71            .iter()
72            .map(|instance| instance.id())
73            .collect();
74        let (avail_tx, avail_rx) = tokio::sync::watch::channel(initial_ids.clone());
75        let client = Client {
76            endpoint: endpoint.clone(),
77            instance_source: instance_source.clone(),
78            instance_avail: Arc::new(ArcSwap::from(Arc::new(initial_ids.clone()))),
79            instance_free: Arc::new(ArcSwap::from(Arc::new(initial_ids))),
80            instance_avail_tx: Arc::new(avail_tx),
81            instance_avail_rx: avail_rx,
82            reconcile_interval,
83        };
84        client.monitor_instance_source();
85        Ok(client)
86    }
87
88    /// Instances available from watching key-value store
89    pub fn instances(&self) -> Vec<Instance> {
90        self.instance_source.borrow().clone()
91    }
92
93    pub fn instance_ids(&self) -> Vec<u64> {
94        self.instances().into_iter().map(|ep| ep.id()).collect()
95    }
96
97    pub fn instance_ids_avail(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
98        self.instance_avail.load()
99    }
100
101    pub fn instance_ids_free(&self) -> arc_swap::Guard<Arc<Vec<u64>>> {
102        self.instance_free.load()
103    }
104
105    /// Get a watcher for available instance IDs
106    pub fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
107        self.instance_avail_rx.clone()
108    }
109
110    /// Wait for at least one Instance to be available for this Endpoint
111    pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
112        tracing::trace!(
113            "wait_for_instances: Starting wait for endpoint: {}",
114            self.endpoint.id()
115        );
116        let mut rx = self.instance_source.as_ref().clone();
117        // wait for there to be 1 or more endpoints
118        let mut instances: Vec<Instance>;
119        loop {
120            instances = rx.borrow_and_update().to_vec();
121            if instances.is_empty() {
122                rx.changed().await?;
123            } else {
124                tracing::info!(
125                    "wait_for_instances: Found {} instance(s) for endpoint: {}",
126                    instances.len(),
127                    self.endpoint.id()
128                );
129                break;
130            }
131        }
132        Ok(instances)
133    }
134
135    /// Mark an instance as down/unavailable
136    pub fn report_instance_down(&self, instance_id: u64) {
137        let filtered = self
138            .instance_ids_avail()
139            .iter()
140            .filter_map(|&id| if id == instance_id { None } else { Some(id) })
141            .collect::<Vec<_>>();
142        self.instance_avail.store(Arc::new(filtered.clone()));
143
144        // Notify watch channel subscribers about the change
145        let _ = self.instance_avail_tx.send(filtered);
146
147        tracing::debug!("inhibiting instance {instance_id}");
148    }
149
150    /// Update the set of free instances based on busy instance IDs
151    pub fn update_free_instances(&self, busy_instance_ids: &[u64]) {
152        let all_instance_ids = self.instance_ids();
153        let free_ids: Vec<u64> = all_instance_ids
154            .into_iter()
155            .filter(|id| !busy_instance_ids.contains(id))
156            .collect();
157        self.instance_free.store(Arc::new(free_ids));
158    }
159
160    /// Monitor the key-value instance source and update instance_avail.
161    ///
162    /// This function also performs periodic reconciliation: if `instance_source` hasn't
163    /// changed for `reconcile_interval`, we reset `instance_avail` to match
164    /// `instance_source`. This ensures instances removed via `report_instance_down`
165    /// are eventually restored even if the discovery source doesn't emit updates.
166    fn monitor_instance_source(&self) {
167        let reconcile_interval = self.reconcile_interval;
168        let cancel_token = self.endpoint.drt().primary_token();
169        let client = self.clone();
170        let endpoint_id = self.endpoint.id();
171        tokio::task::spawn(async move {
172            let mut rx = client.instance_source.as_ref().clone();
173            while !cancel_token.is_cancelled() {
174                let instance_ids: Vec<u64> = rx
175                    .borrow_and_update()
176                    .iter()
177                    .map(|instance| instance.id())
178                    .collect();
179
180                // TODO: this resets both tracked available and free instances
181                client.instance_avail.store(Arc::new(instance_ids.clone()));
182                client.instance_free.store(Arc::new(instance_ids.clone()));
183
184                // Send update to watch channel subscribers
185                let _ = client.instance_avail_tx.send(instance_ids);
186
187                tokio::select! {
188                    result = rx.changed() => {
189                        if let Err(err) = result {
190                            tracing::error!(
191                                "monitor_instance_source: The Sender is dropped: {err}, endpoint={endpoint_id}",
192                            );
193                            cancel_token.cancel();
194                        }
195                    }
196                    _ = tokio::time::sleep(reconcile_interval) => {
197                        tracing::trace!(
198                            "monitor_instance_source: periodic reconciliation for endpoint={endpoint_id}",
199                        );
200                    }
201                }
202            }
203        });
204    }
205
206    async fn get_or_create_dynamic_instance_source(
207        endpoint: &Endpoint,
208    ) -> Result<Arc<tokio::sync::watch::Receiver<Vec<Instance>>>> {
209        let drt = endpoint.drt();
210        let instance_sources = drt.instance_sources();
211        let mut instance_sources = instance_sources.lock().await;
212
213        if let Some(instance_source) = instance_sources.get(endpoint) {
214            if let Some(instance_source) = instance_source.upgrade() {
215                return Ok(instance_source);
216            } else {
217                instance_sources.remove(endpoint);
218            }
219        }
220
221        let discovery = drt.discovery();
222        let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
223            namespace: endpoint.component.namespace.name.clone(),
224            component: endpoint.component.name.clone(),
225            endpoint: endpoint.name.clone(),
226        };
227
228        let mut discovery_stream = discovery
229            .list_and_watch(discovery_query.clone(), None)
230            .await?;
231        let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
232
233        let secondary = endpoint.component.drt.runtime().secondary().clone();
234
235        secondary.spawn(async move {
236            tracing::trace!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
237            let mut map: HashMap<u64, Instance> = HashMap::new();
238
239            loop {
240                let discovery_event = tokio::select! {
241                    _ = watch_tx.closed() => {
242                        break;
243                    }
244                    discovery_event = discovery_stream.next() => {
245                        match discovery_event {
246                            Some(Ok(event)) => {
247                                event
248                            },
249                            Some(Err(e)) => {
250                                tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
251                                break;
252                            }
253                            None => {
254                                break;
255                            }
256                        }
257                    }
258                };
259
260                match discovery_event {
261                    DiscoveryEvent::Added(discovery_instance) => {
262                        if let DiscoveryInstance::Endpoint(instance) = discovery_instance {
263
264                                map.insert(instance.instance_id, instance);
265                        }
266                    }
267                    DiscoveryEvent::Removed(id) => {
268                        map.remove(&id.instance_id());
269                    }
270                }
271
272                let instances: Vec<Instance> = map.values().cloned().collect();
273                if watch_tx.send(instances).is_err() {
274                    break;
275                }
276            }
277            let _ = watch_tx.send(vec![]);
278        });
279
280        let instance_source = Arc::new(watch_rx);
281        instance_sources.insert(endpoint.clone(), Arc::downgrade(&instance_source));
282        Ok(instance_source)
283    }
284}
285
286#[cfg(test)]
287mod tests {
288    use super::*;
289    use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
290
291    /// Test that instances removed via report_instance_down are restored after
292    /// the reconciliation interval elapses.
293    #[tokio::test]
294    async fn test_instance_reconciliation() {
295        const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(100);
296
297        let rt = Runtime::from_current().unwrap();
298        // Use process_local config to avoid needing etcd/nats
299        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
300            .await
301            .unwrap();
302        let ns = drt.namespace("test_reconciliation".to_string()).unwrap();
303        let component = ns.component("test_component".to_string()).unwrap();
304        let endpoint = component.endpoint("test_endpoint".to_string());
305
306        // Use a short reconcile interval for faster tests
307        let client = Client::with_reconcile_interval(endpoint, TEST_RECONCILE_INTERVAL)
308            .await
309            .unwrap();
310
311        // Initially, instance_avail should be empty (no registered instances)
312        assert!(client.instance_ids_avail().is_empty());
313
314        // For this test, we'll directly manipulate instance_avail and verify reconciliation
315        // Store some test IDs
316        client.instance_avail.store(Arc::new(vec![1, 2, 3]));
317
318        assert_eq!(**client.instance_ids_avail(), vec![1u64, 2, 3]);
319
320        // Simulate report_instance_down removing instance 2
321        client.report_instance_down(2);
322        assert_eq!(**client.instance_ids_avail(), vec![1u64, 3]);
323
324        // Wait for reconciliation interval + buffer
325        // The monitor_instance_source will reset instance_avail to match instance_source
326        // Since instance_source is empty, after reconciliation instance_avail should be empty
327        tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
328
329        // After reconciliation, instance_avail should match instance_source (which is empty)
330        assert!(
331            client.instance_ids_avail().is_empty(),
332            "After reconciliation, instance_avail should match instance_source"
333        );
334
335        rt.shutdown();
336    }
337
338    /// Test that report_instance_down correctly removes an instance from instance_avail.
339    #[tokio::test]
340    async fn test_report_instance_down() {
341        let rt = Runtime::from_current().unwrap();
342        // Use process_local config to avoid needing etcd/nats
343        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
344            .await
345            .unwrap();
346        let ns = drt.namespace("test_report_down".to_string()).unwrap();
347        let component = ns.component("test_component".to_string()).unwrap();
348        let endpoint = component.endpoint("test_endpoint".to_string());
349
350        let client = endpoint.client().await.unwrap();
351
352        // Manually set up instance_avail with test instances
353        client.instance_avail.store(Arc::new(vec![1, 2, 3]));
354        assert_eq!(**client.instance_ids_avail(), vec![1u64, 2, 3]);
355
356        // Report instance 2 as down
357        client.report_instance_down(2);
358
359        // Verify instance 2 is removed
360        let avail = client.instance_ids_avail();
361        assert!(avail.contains(&1), "Instance 1 should still be available");
362        assert!(
363            !avail.contains(&2),
364            "Instance 2 should be removed after report_instance_down"
365        );
366        assert!(avail.contains(&3), "Instance 3 should still be available");
367
368        rt.shutdown();
369    }
370
371    /// Test that instance_avail_watcher receives updates when instances change.
372    #[tokio::test]
373    async fn test_instance_avail_watcher() {
374        let rt = Runtime::from_current().unwrap();
375        // Use process_local config to avoid needing etcd/nats
376        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
377            .await
378            .unwrap();
379        let ns = drt.namespace("test_watcher".to_string()).unwrap();
380        let component = ns.component("test_component".to_string()).unwrap();
381        let endpoint = component.endpoint("test_endpoint".to_string());
382
383        let client = endpoint.client().await.unwrap();
384        let watcher = client.instance_avail_watcher();
385
386        // Set initial instances
387        client.instance_avail.store(Arc::new(vec![1, 2, 3]));
388
389        // Report instance down - this should notify the watcher
390        client.report_instance_down(2);
391
392        // The watcher should receive the update
393        // Note: We need to check if changed() was signaled
394        let current = watcher.borrow().clone();
395        assert_eq!(current, vec![1, 3]);
396
397        rt.shutdown();
398    }
399}