Skip to main content

dynamo_runtime/component/
client.rs

1// SPDX-FileCopyrightText: Copyright (c) 2024-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4use std::sync::atomic::{AtomicU64, Ordering};
5use std::{
6    collections::{HashMap, HashSet},
7    sync::{Arc, Mutex as StdMutex},
8    time::Duration,
9};
10
11use anyhow::Result;
12use arc_swap::ArcSwap;
13use dashmap::DashMap;
14use futures::StreamExt;
15use rand::Rng;
16
17use crate::component::{Endpoint, Instance};
18use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
19use crate::traits::DistributedRuntimeProvider;
20
21/// Shared occupancy state for routing modes that track per-worker in-flight requests.
22#[derive(Debug, Default)]
23pub(crate) struct RoutingOccupancyState {
24    counts: DashMap<u64, AtomicU64>,
25    exact_selection_lock: tokio::sync::Mutex<()>,
26}
27
28impl RoutingOccupancyState {
29    pub(crate) fn increment(&self, instance_id: u64) {
30        self.counts
31            .entry(instance_id)
32            .or_insert_with(|| AtomicU64::new(0))
33            .fetch_add(1, Ordering::Relaxed);
34    }
35
36    pub(crate) async fn select_exact_min_and_increment(&self, instance_ids: &[u64]) -> Option<u64> {
37        let _guard = self.exact_selection_lock.lock().await;
38
39        let mut min_load = u64::MAX;
40        let mut selected = None;
41        let mut tie_count = 0usize;
42        let mut rng = rand::rng();
43        for &id in instance_ids {
44            let load = self.load(id);
45            if load < min_load {
46                min_load = load;
47                selected = Some(id);
48                tie_count = 1;
49                continue;
50            }
51
52            if load == min_load {
53                tie_count += 1;
54                // Reservoir sampling keeps tied minima uniform without allocating in this locked hot path.
55                if rng.random_range(0..tie_count) == 0 {
56                    selected = Some(id);
57                }
58            }
59        }
60
61        let id = selected?;
62        self.increment(id);
63        Some(id)
64    }
65
66    pub(crate) fn decrement(&self, instance_id: u64) {
67        if let Some(count) = self.counts.get(&instance_id) {
68            let _ = count.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
69                Some(current.saturating_sub(1))
70            });
71        }
72    }
73
74    pub(crate) fn load(&self, instance_id: u64) -> u64 {
75        self.counts
76            .get(&instance_id)
77            .map(|c| c.load(Ordering::Relaxed))
78            .unwrap_or(0)
79    }
80
81    pub(crate) fn retain(&self, instance_ids: &[u64]) {
82        let live: HashSet<u64> = instance_ids.iter().copied().collect();
83        self.counts.retain(|id, _| live.contains(id));
84    }
85}
86
87/// Get or create the shared routing occupancy state for an endpoint.
88pub(crate) async fn get_or_create_routing_occupancy_state(
89    endpoint: &Endpoint,
90) -> Arc<RoutingOccupancyState> {
91    let drt = endpoint.drt();
92    let registry = drt.routing_occupancy_states();
93    let mut registry = registry.lock().await;
94
95    if let Some(weak) = registry.get(endpoint) {
96        if let Some(state) = weak.upgrade() {
97            return state;
98        } else {
99            registry.remove(endpoint);
100        }
101    }
102
103    let state = Arc::new(RoutingOccupancyState::default());
104    registry.insert(endpoint.clone(), Arc::downgrade(&state));
105    state
106}
107
108/// Default interval for periodic reconciliation of instance_avail with instance_source
109const DEFAULT_RECONCILE_INTERVAL: Duration = Duration::from_secs(5);
110
111/// Shared endpoint discovery state for a single endpoint query.
112///
113/// This wraps both the coalesced instance snapshot used for routing decisions
114/// and a raw, lossless per-subscriber event feed used by the response-stream
115/// cancellation watcher. Both outputs are driven by a single underlying
116/// discovery `list_and_watch` task so clients do not multiply control-plane
117/// watches.
118#[derive(Debug)]
119pub(crate) struct EndpointDiscoverySource {
120    instance_source: tokio::sync::watch::Receiver<Vec<Instance>>,
121    event_subscribers: StdMutex<Vec<tokio::sync::mpsc::UnboundedSender<DiscoveryEvent>>>,
122}
123
124impl EndpointDiscoverySource {
125    fn new(instance_source: tokio::sync::watch::Receiver<Vec<Instance>>) -> Self {
126        Self {
127            instance_source,
128            event_subscribers: StdMutex::new(Vec::new()),
129        }
130    }
131
132    fn instance_receiver(&self) -> tokio::sync::watch::Receiver<Vec<Instance>> {
133        self.instance_source.clone()
134    }
135
136    fn subscribe_events(&self) -> tokio::sync::mpsc::UnboundedReceiver<DiscoveryEvent> {
137        let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
138        self.event_subscribers.lock().unwrap().push(tx);
139        rx
140    }
141
142    fn broadcast_event(&self, event: &DiscoveryEvent) {
143        let subscribers = &mut *self.event_subscribers.lock().unwrap();
144        subscribers.retain(|tx| tx.send(event.clone()).is_ok());
145    }
146}
147
148#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
149pub struct RoutingInstanceCounts {
150    pub discovered: usize,
151    pub routable: usize,
152    pub overloaded: usize,
153    /// IDs not currently reported overloaded, derived from `discovered - overloaded`.
154    pub free: usize,
155}
156
157#[derive(Clone, Debug)]
158pub(crate) struct RoutingInstances {
159    discovered_ids: Vec<u64>,
160    routable_ids: Vec<u64>,
161    overloaded_ids: HashSet<u64>,
162    free_ids: Vec<u64>,
163}
164
165impl RoutingInstances {
166    fn new(discovered_ids: Vec<u64>) -> Self {
167        Self::from_parts(discovered_ids.clone(), discovered_ids, HashSet::new())
168    }
169
170    fn from_parts(
171        discovered_ids: Vec<u64>,
172        routable_ids: Vec<u64>,
173        overloaded_ids: HashSet<u64>,
174    ) -> Self {
175        let free_ids = Self::derive_free_ids(&routable_ids, &overloaded_ids);
176        Self {
177            discovered_ids,
178            routable_ids,
179            overloaded_ids,
180            free_ids,
181        }
182    }
183
184    pub(crate) fn discovered_ids(&self) -> &[u64] {
185        &self.discovered_ids
186    }
187
188    pub(crate) fn routable_ids(&self) -> &[u64] {
189        &self.routable_ids
190    }
191
192    pub(crate) fn free_ids(&self) -> &[u64] {
193        &self.free_ids
194    }
195
196    pub(crate) fn counts(&self) -> RoutingInstanceCounts {
197        RoutingInstanceCounts {
198            discovered: self.discovered_ids.len(),
199            routable: self.routable_ids.len(),
200            overloaded: self.overloaded_ids.len(),
201            free: self.free_ids.len(),
202        }
203    }
204
205    pub(crate) fn is_overloaded(&self, instance_id: u64) -> bool {
206        self.overloaded_ids.contains(&instance_id)
207    }
208
209    fn overloaded_ids(&self) -> Option<HashSet<u64>> {
210        if self.overloaded_ids.is_empty() {
211            return None;
212        }
213
214        Some(self.overloaded_ids.clone())
215    }
216
217    fn reconcile_discovered(&self, discovered_ids: Vec<u64>) -> Self {
218        let old_discovered_ids = self.discovered_ids.iter().copied().collect::<HashSet<_>>();
219        let new_discovered_ids = discovered_ids.iter().copied().collect::<HashSet<_>>();
220        let mut overloaded_ids = self.overloaded_ids.clone();
221        overloaded_ids
222            .retain(|id| !old_discovered_ids.contains(id) || new_discovered_ids.contains(id));
223
224        Self::from_parts(discovered_ids.clone(), discovered_ids, overloaded_ids)
225    }
226
227    fn report_instance_down(&self, instance_id: u64) -> Self {
228        let routable_ids: Vec<u64> = self
229            .routable_ids
230            .iter()
231            .copied()
232            .filter(|id| *id != instance_id)
233            .collect();
234
235        Self::from_parts(
236            self.discovered_ids.clone(),
237            routable_ids,
238            self.overloaded_ids.clone(),
239        )
240    }
241
242    #[cfg(test)]
243    fn override_routable_ids(&self, routable_ids: Vec<u64>) -> Self {
244        // Route through from_parts so `free_ids` is recomputed from the new
245        // routable set instead of carrying the stale value forward.
246        Self::from_parts(
247            self.discovered_ids.clone(),
248            routable_ids,
249            self.overloaded_ids.clone(),
250        )
251    }
252
253    fn set_overloaded(&self, overloaded_ids: HashSet<u64>) -> Self {
254        Self::from_parts(
255            self.discovered_ids.clone(),
256            self.routable_ids.clone(),
257            overloaded_ids,
258        )
259    }
260
261    fn clear_overloaded_for_removed(&self, removed_ids: &HashSet<u64>) -> Self {
262        let mut overloaded_ids = self.overloaded_ids.clone();
263        overloaded_ids.retain(|id| !removed_ids.contains(id));
264        Self::from_parts(
265            self.discovered_ids.clone(),
266            self.routable_ids.clone(),
267            overloaded_ids,
268        )
269    }
270
271    fn derive_free_ids(routable_ids: &[u64], overloaded_ids: &HashSet<u64>) -> Vec<u64> {
272        if overloaded_ids.is_empty() {
273            return routable_ids.to_vec();
274        }
275
276        routable_ids
277            .iter()
278            .copied()
279            .filter(|id| !overloaded_ids.contains(id))
280            .collect()
281    }
282}
283
284#[derive(Debug)]
285struct RoutingInstancesState {
286    snapshot: ArcSwap<RoutingInstances>,
287    update_lock: StdMutex<()>,
288    instance_avail_tx: tokio::sync::watch::Sender<Vec<u64>>,
289    instance_avail_rx: tokio::sync::watch::Receiver<Vec<u64>>,
290}
291
292impl RoutingInstancesState {
293    fn new(discovered_ids: Vec<u64>) -> Self {
294        let snapshot = RoutingInstances::new(discovered_ids);
295        let (instance_avail_tx, instance_avail_rx) =
296            tokio::sync::watch::channel(snapshot.routable_ids().to_vec());
297        Self {
298            snapshot: ArcSwap::from_pointee(snapshot),
299            update_lock: StdMutex::new(()),
300            instance_avail_tx,
301            instance_avail_rx,
302        }
303    }
304
305    fn snapshot(&self) -> arc_swap::Guard<Arc<RoutingInstances>> {
306        self.snapshot.load()
307    }
308
309    fn update(
310        &self,
311        update: impl FnOnce(&RoutingInstances) -> RoutingInstances,
312        publish_routable_ids: bool,
313    ) -> Arc<RoutingInstances> {
314        let _guard = self.update_lock.lock().unwrap();
315        let current = self.snapshot.load();
316        let next = Arc::new(update(&current));
317        self.snapshot.store(next.clone());
318        if publish_routable_ids {
319            self.publish_routable_ids(&next);
320        }
321        next
322    }
323
324    fn publish_routable_ids(&self, routing_instances: &RoutingInstances) {
325        let _ = self
326            .instance_avail_tx
327            .send(routing_instances.routable_ids().to_vec());
328    }
329
330    fn routable_ids(&self) -> Vec<u64> {
331        self.snapshot().routable_ids().to_vec()
332    }
333
334    #[cfg(test)]
335    fn free_ids(&self) -> Vec<u64> {
336        self.snapshot().free_ids.clone()
337    }
338
339    fn counts(&self) -> RoutingInstanceCounts {
340        self.snapshot().counts()
341    }
342
343    fn overloaded_ids(&self) -> Option<HashSet<u64>> {
344        self.snapshot().overloaded_ids()
345    }
346
347    fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
348        self.instance_avail_rx.clone()
349    }
350
351    fn report_instance_down(&self, instance_id: u64) {
352        self.update(|current| current.report_instance_down(instance_id), true);
353    }
354
355    fn set_overloaded_instances(&self, overloaded_instance_ids: &[u64]) -> bool {
356        let overloaded_ids = overloaded_instance_ids
357            .iter()
358            .copied()
359            .collect::<HashSet<_>>();
360        let _guard = self.update_lock.lock().unwrap();
361        let current = self.snapshot.load();
362        if current.overloaded_ids == overloaded_ids {
363            return false;
364        }
365
366        let next = Arc::new(current.set_overloaded(overloaded_ids));
367        self.snapshot.store(next);
368        true
369    }
370
371    fn clear_overloaded_for_removed(&self, removed_instance_ids: &[u64]) {
372        if removed_instance_ids.is_empty() {
373            return;
374        }
375
376        let removed_ids = removed_instance_ids.iter().copied().collect::<HashSet<_>>();
377        self.update(
378            move |current| current.clear_overloaded_for_removed(&removed_ids),
379            false,
380        );
381    }
382
383    fn reconcile_discovered(&self, discovered_ids: Vec<u64>) -> Arc<RoutingInstances> {
384        self.update(
385            move |current| current.reconcile_discovered(discovered_ids),
386            true,
387        )
388    }
389
390    #[cfg(test)]
391    fn override_routable_ids(&self, ids: Vec<u64>) {
392        self.update(move |current| current.override_routable_ids(ids), true);
393    }
394}
395
396#[derive(Clone, Debug)]
397pub struct Client {
398    // This is me
399    pub endpoint: Endpoint,
400    // Shared endpoint discovery source backing both snapshots and raw events.
401    endpoint_discovery_source: Arc<EndpointDiscoverySource>,
402    // These are the remotes I know about from watching key-value store
403    pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
404    // Immutable routing snapshot. Free IDs are derived from discovered IDs and overloaded IDs.
405    routing_instances: Arc<RoutingInstancesState>,
406    /// Interval for periodic reconciliation of instance_avail with instance_source.
407    /// This ensures instances removed via `report_instance_down` are eventually restored.
408    reconcile_interval: Duration,
409}
410
411impl Client {
412    // Client with auto-discover instances using key-value store
413    pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
414        Self::with_reconcile_interval(endpoint, DEFAULT_RECONCILE_INTERVAL).await
415    }
416
417    /// Create a client with a custom reconcile interval.
418    /// The reconcile interval controls how often `instance_avail` is reset to match
419    /// `instance_source`, restoring any instances removed via `report_instance_down`.
420    pub(crate) async fn with_reconcile_interval(
421        endpoint: Endpoint,
422        reconcile_interval: Duration,
423    ) -> Result<Self> {
424        tracing::trace!(
425            "Client::new_dynamic: Creating dynamic client for endpoint: {}",
426            endpoint.id()
427        );
428        let endpoint_discovery_source =
429            Self::get_or_create_dynamic_discovery_source(&endpoint).await?;
430        let instance_source = Arc::new(endpoint_discovery_source.instance_receiver());
431
432        // Seed instance_avail from the current instance_source snapshot so that
433        // callers who proceed immediately after wait_for_instances (which reads
434        // instance_source directly) will also find instances in instance_avail
435        // (which is read by the routing methods like random/round_robin).
436        let initial_ids: Vec<u64> = instance_source
437            .borrow()
438            .iter()
439            .map(|instance| instance.id())
440            .collect();
441        let client = Client {
442            endpoint: endpoint.clone(),
443            endpoint_discovery_source,
444            instance_source: instance_source.clone(),
445            routing_instances: Arc::new(RoutingInstancesState::new(initial_ids)),
446            reconcile_interval,
447        };
448        client.monitor_instance_source();
449        Ok(client)
450    }
451
452    /// Instances available from watching key-value store
453    pub fn instances(&self) -> Vec<Instance> {
454        self.instance_source.borrow().clone()
455    }
456
457    pub fn instance_ids(&self) -> Vec<u64> {
458        self.instances().into_iter().map(|ep| ep.id()).collect()
459    }
460
461    pub fn instance_ids_avail(&self) -> Vec<u64> {
462        self.routing_instances.routable_ids()
463    }
464
465    #[cfg(test)]
466    pub(crate) fn instance_ids_free(&self) -> Vec<u64> {
467        self.routing_instances.free_ids()
468    }
469
470    pub(crate) fn routing_instances(&self) -> arc_swap::Guard<Arc<RoutingInstances>> {
471        self.routing_instances.snapshot()
472    }
473
474    pub fn routing_instance_counts(&self) -> RoutingInstanceCounts {
475        self.routing_instances.counts()
476    }
477
478    /// Get a watcher for available instance IDs
479    pub fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
480        self.routing_instances.instance_avail_watcher()
481    }
482
483    /// Subscribe to raw discovery events for this endpoint.
484    ///
485    /// Unlike `instance_source`, this feed does not coalesce remove→add pairs,
486    /// so consumers can react to every removal event exactly once.
487    pub(crate) fn subscribe_discovery_events(
488        &self,
489    ) -> tokio::sync::mpsc::UnboundedReceiver<DiscoveryEvent> {
490        self.endpoint_discovery_source.subscribe_events()
491    }
492
493    /// Wait for at least one Instance to be available for this Endpoint
494    pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
495        tracing::trace!(
496            "wait_for_instances: Starting wait for endpoint: {}",
497            self.endpoint.id()
498        );
499        let mut rx = self.instance_source.as_ref().clone();
500        // wait for there to be 1 or more endpoints
501        let mut instances: Vec<Instance>;
502        loop {
503            instances = rx.borrow_and_update().to_vec();
504            if instances.is_empty() {
505                rx.changed().await?;
506            } else {
507                tracing::info!(
508                    "wait_for_instances: Found {} instance(s) for endpoint: {}",
509                    instances.len(),
510                    self.endpoint.id()
511                );
512                break;
513            }
514        }
515        Ok(instances)
516    }
517
518    /// Mark an instance as down/unavailable
519    pub fn report_instance_down(&self, instance_id: u64) {
520        self.routing_instances.report_instance_down(instance_id);
521        tracing::debug!("inhibiting instance {instance_id}");
522    }
523
524    /// Replace the set of overloaded instances reported by the worker monitor.
525    /// Returns true when this changes the routing snapshot.
526    pub fn set_overloaded_instances(&self, overloaded_instance_ids: &[u64]) -> bool {
527        self.routing_instances
528            .set_overloaded_instances(overloaded_instance_ids)
529    }
530
531    pub fn clear_overloaded_instances_for_removed(&self, removed_instance_ids: &[u64]) {
532        self.routing_instances
533            .clear_overloaded_for_removed(removed_instance_ids);
534    }
535
536    pub fn overloaded_instance_ids(&self) -> Option<HashSet<u64>> {
537        self.routing_instances.overloaded_ids()
538    }
539
540    /// Monitor the key-value instance source and update instance_avail.
541    ///
542    /// This function also performs periodic reconciliation: if `instance_source` hasn't
543    /// changed for `reconcile_interval`, we reset `instance_avail` to match
544    /// `instance_source`. This ensures instances removed via `report_instance_down`
545    /// are eventually restored even if the discovery source doesn't emit updates.
546    fn monitor_instance_source(&self) {
547        let reconcile_interval = self.reconcile_interval;
548        let cancel_token = self.endpoint.drt().primary_token();
549        let client = self.clone();
550        let endpoint_id = self.endpoint.id();
551        tokio::task::spawn(async move {
552            let mut rx = client.instance_source.as_ref().clone();
553            while !cancel_token.is_cancelled() {
554                let instance_ids: Vec<u64> = rx
555                    .borrow_and_update()
556                    .iter()
557                    .map(|instance| instance.id())
558                    .collect();
559
560                let routing_instances = client.reconcile_discovered_instances(instance_ids);
561
562                // Clean up stale occupancy counters for instances that no longer exist.
563                let registry = client.endpoint.drt().routing_occupancy_states();
564                if let Ok(registry) = registry.try_lock()
565                    && let Some(weak) = registry.get(&client.endpoint)
566                    && let Some(state) = weak.upgrade()
567                {
568                    state.retain(routing_instances.discovered_ids());
569                }
570
571                tokio::select! {
572                    result = rx.changed() => {
573                        if let Err(err) = result {
574                            tracing::error!(
575                                "monitor_instance_source: The Sender is dropped: {err}, endpoint={endpoint_id}",
576                            );
577                            cancel_token.cancel();
578                        }
579                    }
580                    _ = tokio::time::sleep(reconcile_interval) => {
581                        tracing::trace!(
582                            "monitor_instance_source: periodic reconciliation for endpoint={endpoint_id}",
583                        );
584                    }
585                }
586            }
587        });
588    }
589
590    /// Override routable IDs for testing. This allows creating an inconsistency
591    /// between `instance_ids_avail()` and `instances()` to simulate downed workers.
592    #[cfg(test)]
593    pub(crate) fn override_instance_avail(&self, ids: Vec<u64>) {
594        self.routing_instances.override_routable_ids(ids);
595    }
596
597    fn reconcile_discovered_instances(&self, discovered_ids: Vec<u64>) -> Arc<RoutingInstances> {
598        self.routing_instances.reconcile_discovered(discovered_ids)
599    }
600
601    async fn get_or_create_dynamic_discovery_source(
602        endpoint: &Endpoint,
603    ) -> Result<Arc<EndpointDiscoverySource>> {
604        let drt = endpoint.drt();
605        let sources = drt.endpoint_discovery_sources();
606        let mut sources = sources.lock().await;
607
608        if let Some(source) = sources.get(endpoint) {
609            if let Some(source) = source.upgrade() {
610                return Ok(source);
611            } else {
612                sources.remove(endpoint);
613            }
614        }
615
616        let discovery = drt.discovery();
617        let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
618            namespace: endpoint.component.namespace.name.clone(),
619            component: endpoint.component.name.clone(),
620            endpoint: endpoint.name.clone(),
621        };
622
623        let mut discovery_stream = discovery
624            .list_and_watch(discovery_query.clone(), None)
625            .await?;
626        let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
627        let discovery_source = Arc::new(EndpointDiscoverySource::new(watch_rx));
628
629        let secondary = endpoint.component.drt.runtime().secondary().clone();
630        let discovery_source_task = discovery_source.clone();
631
632        secondary.spawn(async move {
633            tracing::trace!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
634            let mut map: HashMap<u64, Instance> = HashMap::new();
635
636            loop {
637                let discovery_event = tokio::select! {
638                    _ = watch_tx.closed() => {
639                        break;
640                    }
641                    discovery_event = discovery_stream.next() => {
642                        match discovery_event {
643                            Some(Ok(event)) => {
644                                event
645                            },
646                            Some(Err(e)) => {
647                                tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
648                                break;
649                            }
650                            None => {
651                                break;
652                            }
653                        }
654                    }
655                };
656
657                discovery_source_task.broadcast_event(&discovery_event);
658
659                match discovery_event {
660                    DiscoveryEvent::Added(DiscoveryInstance::Endpoint(instance)) => {
661                        map.insert(instance.instance_id, instance);
662                    }
663                    DiscoveryEvent::Added(_) => {}
664                    DiscoveryEvent::Removed(id) => {
665                        if let DiscoveryInstanceId::Endpoint(endpoint_id) = id {
666                            map.remove(&endpoint_id.instance_id);
667                        }
668                    }
669                }
670
671                let instances: Vec<Instance> = map.values().cloned().collect();
672                if watch_tx.send(instances).is_err() {
673                    break;
674                }
675            }
676            let _ = watch_tx.send(vec![]);
677        });
678
679        sources.insert(endpoint.clone(), Arc::downgrade(&discovery_source));
680        Ok(discovery_source)
681    }
682}
683
684#[cfg(test)]
685mod tests {
686    use super::*;
687    use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
688
689    /// Test that instances removed via report_instance_down are restored after
690    /// the reconciliation interval elapses.
691    #[tokio::test]
692    async fn test_instance_reconciliation() {
693        const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(100);
694
695        let rt = Runtime::from_current().unwrap();
696        // Use process_local config to avoid needing etcd/nats
697        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
698            .await
699            .unwrap();
700        let ns = drt.namespace("test_reconciliation".to_string()).unwrap();
701        let component = ns.component("test_component".to_string()).unwrap();
702        let endpoint = component.endpoint("test_endpoint".to_string());
703
704        // Use a short reconcile interval for faster tests
705        let client = Client::with_reconcile_interval(endpoint, TEST_RECONCILE_INTERVAL)
706            .await
707            .unwrap();
708
709        // Initially, instance_avail should be empty (no registered instances)
710        assert!(client.instance_ids_avail().is_empty());
711
712        // For this test, we'll directly manipulate instance_avail and verify reconciliation
713        // Store some test IDs
714        client.override_instance_avail(vec![1, 2, 3]);
715
716        assert_eq!(client.instance_ids_avail(), vec![1u64, 2, 3]);
717
718        // Simulate report_instance_down removing instance 2
719        client.report_instance_down(2);
720        assert_eq!(client.instance_ids_avail(), vec![1u64, 3]);
721
722        // Wait for reconciliation interval + buffer
723        // The monitor_instance_source will reset instance_avail to match instance_source
724        // Since instance_source is empty, after reconciliation instance_avail should be empty
725        tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
726
727        // After reconciliation, instance_avail should match instance_source (which is empty)
728        assert!(
729            client.instance_ids_avail().is_empty(),
730            "After reconciliation, instance_avail should match instance_source"
731        );
732
733        rt.shutdown();
734    }
735
736    /// Test that report_instance_down correctly removes an instance from instance_avail.
737    #[tokio::test]
738    async fn test_report_instance_down() {
739        let rt = Runtime::from_current().unwrap();
740        // Use process_local config to avoid needing etcd/nats
741        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
742            .await
743            .unwrap();
744        let ns = drt.namespace("test_report_down".to_string()).unwrap();
745        let component = ns.component("test_component".to_string()).unwrap();
746        let endpoint = component.endpoint("test_endpoint".to_string());
747
748        let client = endpoint.client().await.unwrap();
749
750        // Manually set up instance_avail with test instances
751        client.override_instance_avail(vec![1, 2, 3]);
752        assert_eq!(client.instance_ids_avail(), vec![1u64, 2, 3]);
753
754        // Report instance 2 as down
755        client.report_instance_down(2);
756
757        // Verify instance 2 is removed
758        let avail = client.instance_ids_avail();
759        assert!(avail.contains(&1), "Instance 1 should still be available");
760        assert!(
761            !avail.contains(&2),
762            "Instance 2 should be removed after report_instance_down"
763        );
764        assert!(avail.contains(&3), "Instance 3 should still be available");
765
766        rt.shutdown();
767    }
768
769    #[tokio::test]
770    async fn test_overloaded_instance_ids_returns_none_when_empty() {
771        let rt = Runtime::from_current().unwrap();
772        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
773            .await
774            .unwrap();
775        let ns = drt.namespace("test_overloaded_ids".to_string()).unwrap();
776        let component = ns.component("test_component".to_string()).unwrap();
777        let endpoint = component.endpoint("test_endpoint".to_string());
778        let client = endpoint.client().await.unwrap();
779
780        assert_eq!(client.overloaded_instance_ids(), None);
781
782        assert!(client.set_overloaded_instances(&[7]));
783        assert_eq!(client.overloaded_instance_ids(), Some(HashSet::from([7])));
784        assert!(!client.set_overloaded_instances(&[7]));
785
786        assert!(client.set_overloaded_instances(&[]));
787        assert_eq!(client.overloaded_instance_ids(), None);
788        assert!(!client.set_overloaded_instances(&[]));
789
790        rt.shutdown();
791    }
792
793    #[tokio::test]
794    async fn test_instance_reconciliation_preserves_overloaded_existing_instances() {
795        const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
796
797        let rt = Runtime::from_current().unwrap();
798        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
799            .await
800            .unwrap();
801        let ns = drt
802            .namespace("test_overloaded_reconciliation".to_string())
803            .unwrap();
804        let component = ns.component("test_component".to_string()).unwrap();
805        let endpoint = component.endpoint("test_endpoint".to_string());
806
807        let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
808            .await
809            .unwrap();
810        endpoint.register_endpoint_instance().await.unwrap();
811        let instances = client.wait_for_instances().await.unwrap();
812        let worker_id = instances[0].id();
813
814        for _ in 0..10 {
815            if client.instance_ids_free().contains(&worker_id) {
816                break;
817            }
818            tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
819        }
820        assert!(
821            client.instance_ids_free().contains(&worker_id),
822            "worker should be free after initial discovery reconciliation"
823        );
824
825        client.set_overloaded_instances(&[worker_id]);
826        assert!(
827            client.instance_ids_free().is_empty(),
828            "worker should be overloaded before periodic reconciliation"
829        );
830
831        tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
832
833        assert!(
834            client.instance_ids_free().is_empty(),
835            "periodic reconciliation should not mark an existing overloaded worker free"
836        );
837
838        rt.shutdown();
839    }
840
841    #[tokio::test]
842    async fn test_report_instance_down_preserves_overloaded_state() {
843        const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
844
845        let rt = Runtime::from_current().unwrap();
846        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
847            .await
848            .unwrap();
849        let ns = drt
850            .namespace("test_report_down_preserves_overloaded".to_string())
851            .unwrap();
852        let component = ns.component("test_component".to_string()).unwrap();
853        let endpoint = component.endpoint("test_endpoint".to_string());
854
855        let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
856            .await
857            .unwrap();
858        endpoint.register_endpoint_instance().await.unwrap();
859        let instances = client.wait_for_instances().await.unwrap();
860        let worker_id = instances[0].id();
861
862        for _ in 0..10 {
863            if client.instance_ids_avail().contains(&worker_id) {
864                break;
865            }
866            tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
867        }
868
869        client.set_overloaded_instances(&[worker_id]);
870        client.report_instance_down(worker_id);
871
872        assert!(
873            !client.instance_ids_avail().contains(&worker_id),
874            "reported-down worker should leave routable availability"
875        );
876        assert_eq!(
877            client.routing_instance_counts().overloaded,
878            1,
879            "reported-down worker should remain overloaded while still discovered"
880        );
881        assert!(
882            client.instance_ids_free().is_empty(),
883            "reported-down overloaded worker should not become free"
884        );
885
886        endpoint.unregister_endpoint_instance().await.unwrap();
887        for _ in 0..10 {
888            if client.routing_instance_counts().overloaded == 0 {
889                break;
890            }
891            tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
892        }
893
894        assert_eq!(
895            client.routing_instance_counts().overloaded,
896            0,
897            "stable discovery removal should clear overloaded state"
898        );
899
900        rt.shutdown();
901    }
902
903    #[tokio::test]
904    async fn test_instance_reconciliation_prunes_removed_overloaded_instances() {
905        const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
906
907        let rt = Runtime::from_current().unwrap();
908        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
909            .await
910            .unwrap();
911        let ns = drt
912            .namespace("test_removed_overloaded_cleanup".to_string())
913            .unwrap();
914        let component = ns.component("test_component".to_string()).unwrap();
915        let endpoint = component.endpoint("test_endpoint".to_string());
916
917        let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
918            .await
919            .unwrap();
920        endpoint.register_endpoint_instance().await.unwrap();
921        let instances = client.wait_for_instances().await.unwrap();
922        let worker_id = instances[0].id();
923
924        client.set_overloaded_instances(&[worker_id]);
925        assert_eq!(client.routing_instance_counts().overloaded, 1);
926        assert!(client.instance_ids_free().is_empty());
927
928        endpoint.unregister_endpoint_instance().await.unwrap();
929        for _ in 0..10 {
930            if client.routing_instance_counts().overloaded == 0 {
931                break;
932            }
933            tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
934        }
935
936        assert_eq!(
937            client.routing_instance_counts().overloaded,
938            0,
939            "removed discovered workers should not remain in overloaded state"
940        );
941
942        rt.shutdown();
943    }
944
945    #[tokio::test]
946    async fn test_instance_ids_free_excludes_overloaded_new_instances() {
947        const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
948
949        let rt = Runtime::from_current().unwrap();
950        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
951            .await
952            .unwrap();
953        let worker_id = drt.connection_id();
954        let ns = drt
955            .namespace("test_new_overloaded_reconciliation".to_string())
956            .unwrap();
957        let component = ns.component("test_component".to_string()).unwrap();
958        let endpoint = component.endpoint("test_endpoint".to_string());
959
960        let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
961            .await
962            .unwrap();
963        client.set_overloaded_instances(&[worker_id]);
964
965        endpoint.register_endpoint_instance().await.unwrap();
966        let instances = client.wait_for_instances().await.unwrap();
967        assert_eq!(instances[0].id(), worker_id);
968        assert!(
969            client.instance_ids_free().is_empty(),
970            "newly discovered overloaded worker should not be free"
971        );
972
973        tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
974
975        assert!(
976            client.instance_ids_free().is_empty(),
977            "discovery reconciliation should not affect recomputed free workers"
978        );
979
980        rt.shutdown();
981    }
982
983    #[tokio::test]
984    async fn test_discovery_add_updates_free_without_overloaded_publish() {
985        const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
986
987        let rt = Runtime::from_current().unwrap();
988        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
989            .await
990            .unwrap();
991        let ns = drt
992            .namespace("test_free_updates_on_discovery_add".to_string())
993            .unwrap();
994        let component = ns.component("test_component".to_string()).unwrap();
995        let endpoint = component.endpoint("test_endpoint".to_string());
996
997        let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
998            .await
999            .unwrap();
1000        endpoint.register_endpoint_instance().await.unwrap();
1001        let instances = client.wait_for_instances().await.unwrap();
1002        let worker_id = instances[0].id();
1003
1004        for _ in 0..10 {
1005            if client.instance_ids_free().contains(&worker_id) {
1006                break;
1007            }
1008            tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
1009        }
1010
1011        assert_eq!(
1012            client.instance_ids_free(),
1013            vec![worker_id],
1014            "newly discovered non-overloaded workers should appear free without an overload update"
1015        );
1016
1017        rt.shutdown();
1018    }
1019
1020    /// Test that instance_avail_watcher receives updates when instances change.
1021    #[tokio::test]
1022    async fn test_instance_avail_watcher() {
1023        let rt = Runtime::from_current().unwrap();
1024        // Use process_local config to avoid needing etcd/nats
1025        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1026            .await
1027            .unwrap();
1028        let ns = drt.namespace("test_watcher".to_string()).unwrap();
1029        let component = ns.component("test_component".to_string()).unwrap();
1030        let endpoint = component.endpoint("test_endpoint".to_string());
1031
1032        let client = endpoint.client().await.unwrap();
1033        let watcher = client.instance_avail_watcher();
1034
1035        // Set initial instances
1036        client.override_instance_avail(vec![1, 2, 3]);
1037
1038        // Report instance down - this should notify the watcher
1039        client.report_instance_down(2);
1040
1041        // The watcher should receive the update
1042        // Note: We need to check if changed() was signaled
1043        let current = watcher.borrow().clone();
1044        assert_eq!(current, vec![1, 3]);
1045
1046        rt.shutdown();
1047    }
1048
1049    /// Test that concurrent select_and_increment distributes load correctly.
1050    #[tokio::test]
1051    async fn test_concurrent_select_and_increment() {
1052        let state = Arc::new(RoutingOccupancyState::default());
1053        let instance_ids: Vec<u64> = vec![100, 200, 300];
1054        let num_requests = 90;
1055
1056        let mut handles = Vec::new();
1057        for _ in 0..num_requests {
1058            let state = state.clone();
1059            let ids = instance_ids.clone();
1060            handles.push(tokio::spawn(async move {
1061                state.select_exact_min_and_increment(&ids).await
1062            }));
1063        }
1064
1065        for handle in handles {
1066            handle.await.unwrap();
1067        }
1068
1069        assert_eq!(state.load(100), 30);
1070        assert_eq!(state.load(200), 30);
1071        assert_eq!(state.load(300), 30);
1072    }
1073
1074    #[tokio::test]
1075    async fn test_select_exact_min_and_increment_randomizes_ties() {
1076        let mut selected = [false; 3];
1077
1078        for _ in 0..120 {
1079            let state = RoutingOccupancyState::default();
1080            let picked = state
1081                .select_exact_min_and_increment(&[10, 20, 30])
1082                .await
1083                .unwrap();
1084            match picked {
1085                10 => selected[0] = true,
1086                20 => selected[1] = true,
1087                30 => selected[2] = true,
1088                _ => panic!("unexpected worker id: {picked}"),
1089            }
1090        }
1091
1092        let selected_count = selected.into_iter().filter(|seen| *seen).count();
1093        assert!(
1094            selected_count > 1,
1095            "tie-breaking should not always select the first minimum-load worker"
1096        );
1097    }
1098
1099    #[tokio::test]
1100    async fn test_connection_counts() {
1101        let rt = Runtime::from_current().unwrap();
1102        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1103            .await
1104            .unwrap();
1105        let ns = drt.namespace("test_ll_counts".to_string()).unwrap();
1106        let component = ns.component("test_component".to_string()).unwrap();
1107        let endpoint = component.endpoint("test_endpoint".to_string());
1108
1109        let state1 = get_or_create_routing_occupancy_state(&endpoint).await;
1110        let state2 = get_or_create_routing_occupancy_state(&endpoint).await;
1111
1112        let picked1 = state1
1113            .select_exact_min_and_increment(&[10, 20, 30])
1114            .await
1115            .unwrap();
1116        assert_eq!(state1.load(picked1), 1);
1117
1118        let picked2 = state1
1119            .select_exact_min_and_increment(&[10, 20, 30])
1120            .await
1121            .unwrap();
1122        assert_ne!(picked1, picked2);
1123
1124        // state2 should see the same counts (same underlying Arc)
1125        assert_eq!(state2.load(10), state1.load(10));
1126        assert_eq!(state2.load(20), state1.load(20));
1127        assert_eq!(state2.load(30), state1.load(30));
1128
1129        state2.decrement(picked1);
1130        assert_eq!(state1.load(picked1), if picked1 == picked2 { 1 } else { 0 });
1131
1132        rt.shutdown();
1133    }
1134
1135    #[tokio::test]
1136    async fn test_least_loaded_state_retain() {
1137        let state = RoutingOccupancyState::default();
1138
1139        // Add some connections
1140        state.select_exact_min_and_increment(&[1, 2, 3]).await;
1141        state.select_exact_min_and_increment(&[1, 2, 3]).await;
1142        state.select_exact_min_and_increment(&[1, 2, 3]).await;
1143        // Each instance should have 1 connection
1144        assert_eq!(state.load(1), 1);
1145        assert_eq!(state.load(2), 1);
1146        assert_eq!(state.load(3), 1);
1147
1148        // Retain only instances 1 and 3 (instance 2 was removed)
1149        state.retain(&[1, 3]);
1150
1151        assert_eq!(state.load(1), 1);
1152        assert_eq!(state.load(2), 0);
1153        assert_eq!(state.load(3), 1);
1154    }
1155
1156    #[tokio::test]
1157    async fn test_monitor_instance_source_cleans_up_removed_worker_counts() {
1158        const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
1159
1160        let rt = Runtime::from_current().unwrap();
1161        let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1162            .await
1163            .unwrap();
1164        let ns = drt.namespace("test_occupancy_cleanup".to_string()).unwrap();
1165        let component = ns.component("test_component".to_string()).unwrap();
1166        let endpoint = component.endpoint("test_endpoint".to_string());
1167
1168        let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
1169            .await
1170            .unwrap();
1171        endpoint.register_endpoint_instance().await.unwrap();
1172        client.wait_for_instances().await.unwrap();
1173
1174        let worker_id = client.instance_ids_avail()[0];
1175        let state = get_or_create_routing_occupancy_state(&endpoint).await;
1176        state.increment(worker_id);
1177        assert_eq!(state.load(worker_id), 1);
1178
1179        endpoint.unregister_endpoint_instance().await.unwrap();
1180
1181        for _ in 0..10 {
1182            if state.load(worker_id) == 0 {
1183                break;
1184            }
1185            tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
1186        }
1187
1188        assert_eq!(state.load(worker_id), 0);
1189
1190        rt.shutdown();
1191    }
1192}