1use std::sync::atomic::{AtomicU64, Ordering};
5use std::{
6 collections::{HashMap, HashSet},
7 sync::{Arc, Mutex as StdMutex},
8 time::Duration,
9};
10
11use anyhow::Result;
12use arc_swap::ArcSwap;
13use dashmap::DashMap;
14use futures::StreamExt;
15use rand::Rng;
16
17use crate::component::{Endpoint, Instance};
18use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
19use crate::traits::DistributedRuntimeProvider;
20
21#[derive(Debug, Default)]
23pub(crate) struct RoutingOccupancyState {
24 counts: DashMap<u64, AtomicU64>,
25 exact_selection_lock: tokio::sync::Mutex<()>,
26}
27
28impl RoutingOccupancyState {
29 pub(crate) fn increment(&self, instance_id: u64) {
30 self.counts
31 .entry(instance_id)
32 .or_insert_with(|| AtomicU64::new(0))
33 .fetch_add(1, Ordering::Relaxed);
34 }
35
36 pub(crate) async fn select_exact_min_and_increment(&self, instance_ids: &[u64]) -> Option<u64> {
37 let _guard = self.exact_selection_lock.lock().await;
38
39 let mut min_load = u64::MAX;
40 let mut selected = None;
41 let mut tie_count = 0usize;
42 let mut rng = rand::rng();
43 for &id in instance_ids {
44 let load = self.load(id);
45 if load < min_load {
46 min_load = load;
47 selected = Some(id);
48 tie_count = 1;
49 continue;
50 }
51
52 if load == min_load {
53 tie_count += 1;
54 if rng.random_range(0..tie_count) == 0 {
56 selected = Some(id);
57 }
58 }
59 }
60
61 let id = selected?;
62 self.increment(id);
63 Some(id)
64 }
65
66 pub(crate) fn decrement(&self, instance_id: u64) {
67 if let Some(count) = self.counts.get(&instance_id) {
68 let _ = count.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
69 Some(current.saturating_sub(1))
70 });
71 }
72 }
73
74 pub(crate) fn load(&self, instance_id: u64) -> u64 {
75 self.counts
76 .get(&instance_id)
77 .map(|c| c.load(Ordering::Relaxed))
78 .unwrap_or(0)
79 }
80
81 pub(crate) fn retain(&self, instance_ids: &[u64]) {
82 let live: HashSet<u64> = instance_ids.iter().copied().collect();
83 self.counts.retain(|id, _| live.contains(id));
84 }
85}
86
87pub(crate) async fn get_or_create_routing_occupancy_state(
89 endpoint: &Endpoint,
90) -> Arc<RoutingOccupancyState> {
91 let drt = endpoint.drt();
92 let registry = drt.routing_occupancy_states();
93 let mut registry = registry.lock().await;
94
95 if let Some(weak) = registry.get(endpoint) {
96 if let Some(state) = weak.upgrade() {
97 return state;
98 } else {
99 registry.remove(endpoint);
100 }
101 }
102
103 let state = Arc::new(RoutingOccupancyState::default());
104 registry.insert(endpoint.clone(), Arc::downgrade(&state));
105 state
106}
107
108const DEFAULT_RECONCILE_INTERVAL: Duration = Duration::from_secs(5);
110
111#[derive(Debug)]
119pub(crate) struct EndpointDiscoverySource {
120 instance_source: tokio::sync::watch::Receiver<Vec<Instance>>,
121 event_subscribers: StdMutex<Vec<tokio::sync::mpsc::UnboundedSender<DiscoveryEvent>>>,
122}
123
124impl EndpointDiscoverySource {
125 fn new(instance_source: tokio::sync::watch::Receiver<Vec<Instance>>) -> Self {
126 Self {
127 instance_source,
128 event_subscribers: StdMutex::new(Vec::new()),
129 }
130 }
131
132 fn instance_receiver(&self) -> tokio::sync::watch::Receiver<Vec<Instance>> {
133 self.instance_source.clone()
134 }
135
136 fn subscribe_events(&self) -> tokio::sync::mpsc::UnboundedReceiver<DiscoveryEvent> {
137 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
138 self.event_subscribers.lock().unwrap().push(tx);
139 rx
140 }
141
142 fn broadcast_event(&self, event: &DiscoveryEvent) {
143 let subscribers = &mut *self.event_subscribers.lock().unwrap();
144 subscribers.retain(|tx| tx.send(event.clone()).is_ok());
145 }
146}
147
148#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
149pub struct RoutingInstanceCounts {
150 pub discovered: usize,
151 pub routable: usize,
152 pub overloaded: usize,
153 pub free: usize,
155}
156
157#[derive(Clone, Debug)]
158pub(crate) struct RoutingInstances {
159 discovered_ids: Vec<u64>,
160 routable_ids: Vec<u64>,
161 overloaded_ids: HashSet<u64>,
162 free_ids: Vec<u64>,
163}
164
165impl RoutingInstances {
166 fn new(discovered_ids: Vec<u64>) -> Self {
167 Self::from_parts(discovered_ids.clone(), discovered_ids, HashSet::new())
168 }
169
170 fn from_parts(
171 discovered_ids: Vec<u64>,
172 routable_ids: Vec<u64>,
173 overloaded_ids: HashSet<u64>,
174 ) -> Self {
175 let free_ids = Self::derive_free_ids(&routable_ids, &overloaded_ids);
176 Self {
177 discovered_ids,
178 routable_ids,
179 overloaded_ids,
180 free_ids,
181 }
182 }
183
184 pub(crate) fn discovered_ids(&self) -> &[u64] {
185 &self.discovered_ids
186 }
187
188 pub(crate) fn routable_ids(&self) -> &[u64] {
189 &self.routable_ids
190 }
191
192 pub(crate) fn free_ids(&self) -> &[u64] {
193 &self.free_ids
194 }
195
196 pub(crate) fn counts(&self) -> RoutingInstanceCounts {
197 RoutingInstanceCounts {
198 discovered: self.discovered_ids.len(),
199 routable: self.routable_ids.len(),
200 overloaded: self.overloaded_ids.len(),
201 free: self.free_ids.len(),
202 }
203 }
204
205 pub(crate) fn is_overloaded(&self, instance_id: u64) -> bool {
206 self.overloaded_ids.contains(&instance_id)
207 }
208
209 fn overloaded_ids(&self) -> Option<HashSet<u64>> {
210 if self.overloaded_ids.is_empty() {
211 return None;
212 }
213
214 Some(self.overloaded_ids.clone())
215 }
216
217 fn reconcile_discovered(&self, discovered_ids: Vec<u64>) -> Self {
218 let old_discovered_ids = self.discovered_ids.iter().copied().collect::<HashSet<_>>();
219 let new_discovered_ids = discovered_ids.iter().copied().collect::<HashSet<_>>();
220 let mut overloaded_ids = self.overloaded_ids.clone();
221 overloaded_ids
222 .retain(|id| !old_discovered_ids.contains(id) || new_discovered_ids.contains(id));
223
224 Self::from_parts(discovered_ids.clone(), discovered_ids, overloaded_ids)
225 }
226
227 fn report_instance_down(&self, instance_id: u64) -> Self {
228 let routable_ids: Vec<u64> = self
229 .routable_ids
230 .iter()
231 .copied()
232 .filter(|id| *id != instance_id)
233 .collect();
234
235 Self::from_parts(
236 self.discovered_ids.clone(),
237 routable_ids,
238 self.overloaded_ids.clone(),
239 )
240 }
241
242 #[cfg(test)]
243 fn override_routable_ids(&self, routable_ids: Vec<u64>) -> Self {
244 Self::from_parts(
247 self.discovered_ids.clone(),
248 routable_ids,
249 self.overloaded_ids.clone(),
250 )
251 }
252
253 fn set_overloaded(&self, overloaded_ids: HashSet<u64>) -> Self {
254 Self::from_parts(
255 self.discovered_ids.clone(),
256 self.routable_ids.clone(),
257 overloaded_ids,
258 )
259 }
260
261 fn clear_overloaded_for_removed(&self, removed_ids: &HashSet<u64>) -> Self {
262 let mut overloaded_ids = self.overloaded_ids.clone();
263 overloaded_ids.retain(|id| !removed_ids.contains(id));
264 Self::from_parts(
265 self.discovered_ids.clone(),
266 self.routable_ids.clone(),
267 overloaded_ids,
268 )
269 }
270
271 fn derive_free_ids(routable_ids: &[u64], overloaded_ids: &HashSet<u64>) -> Vec<u64> {
272 if overloaded_ids.is_empty() {
273 return routable_ids.to_vec();
274 }
275
276 routable_ids
277 .iter()
278 .copied()
279 .filter(|id| !overloaded_ids.contains(id))
280 .collect()
281 }
282}
283
284#[derive(Debug)]
285struct RoutingInstancesState {
286 snapshot: ArcSwap<RoutingInstances>,
287 update_lock: StdMutex<()>,
288 instance_avail_tx: tokio::sync::watch::Sender<Vec<u64>>,
289 instance_avail_rx: tokio::sync::watch::Receiver<Vec<u64>>,
290}
291
292impl RoutingInstancesState {
293 fn new(discovered_ids: Vec<u64>) -> Self {
294 let snapshot = RoutingInstances::new(discovered_ids);
295 let (instance_avail_tx, instance_avail_rx) =
296 tokio::sync::watch::channel(snapshot.routable_ids().to_vec());
297 Self {
298 snapshot: ArcSwap::from_pointee(snapshot),
299 update_lock: StdMutex::new(()),
300 instance_avail_tx,
301 instance_avail_rx,
302 }
303 }
304
305 fn snapshot(&self) -> arc_swap::Guard<Arc<RoutingInstances>> {
306 self.snapshot.load()
307 }
308
309 fn update(
310 &self,
311 update: impl FnOnce(&RoutingInstances) -> RoutingInstances,
312 publish_routable_ids: bool,
313 ) -> Arc<RoutingInstances> {
314 let _guard = self.update_lock.lock().unwrap();
315 let current = self.snapshot.load();
316 let next = Arc::new(update(¤t));
317 self.snapshot.store(next.clone());
318 if publish_routable_ids {
319 self.publish_routable_ids(&next);
320 }
321 next
322 }
323
324 fn publish_routable_ids(&self, routing_instances: &RoutingInstances) {
325 let _ = self
326 .instance_avail_tx
327 .send(routing_instances.routable_ids().to_vec());
328 }
329
330 fn routable_ids(&self) -> Vec<u64> {
331 self.snapshot().routable_ids().to_vec()
332 }
333
334 #[cfg(test)]
335 fn free_ids(&self) -> Vec<u64> {
336 self.snapshot().free_ids.clone()
337 }
338
339 fn counts(&self) -> RoutingInstanceCounts {
340 self.snapshot().counts()
341 }
342
343 fn overloaded_ids(&self) -> Option<HashSet<u64>> {
344 self.snapshot().overloaded_ids()
345 }
346
347 fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
348 self.instance_avail_rx.clone()
349 }
350
351 fn report_instance_down(&self, instance_id: u64) {
352 self.update(|current| current.report_instance_down(instance_id), true);
353 }
354
355 fn set_overloaded_instances(&self, overloaded_instance_ids: &[u64]) -> bool {
356 let overloaded_ids = overloaded_instance_ids
357 .iter()
358 .copied()
359 .collect::<HashSet<_>>();
360 let _guard = self.update_lock.lock().unwrap();
361 let current = self.snapshot.load();
362 if current.overloaded_ids == overloaded_ids {
363 return false;
364 }
365
366 let next = Arc::new(current.set_overloaded(overloaded_ids));
367 self.snapshot.store(next);
368 true
369 }
370
371 fn clear_overloaded_for_removed(&self, removed_instance_ids: &[u64]) {
372 if removed_instance_ids.is_empty() {
373 return;
374 }
375
376 let removed_ids = removed_instance_ids.iter().copied().collect::<HashSet<_>>();
377 self.update(
378 move |current| current.clear_overloaded_for_removed(&removed_ids),
379 false,
380 );
381 }
382
383 fn reconcile_discovered(&self, discovered_ids: Vec<u64>) -> Arc<RoutingInstances> {
384 self.update(
385 move |current| current.reconcile_discovered(discovered_ids),
386 true,
387 )
388 }
389
390 #[cfg(test)]
391 fn override_routable_ids(&self, ids: Vec<u64>) {
392 self.update(move |current| current.override_routable_ids(ids), true);
393 }
394}
395
396#[derive(Clone, Debug)]
397pub struct Client {
398 pub endpoint: Endpoint,
400 endpoint_discovery_source: Arc<EndpointDiscoverySource>,
402 pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
404 routing_instances: Arc<RoutingInstancesState>,
406 reconcile_interval: Duration,
409}
410
411impl Client {
412 pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
414 Self::with_reconcile_interval(endpoint, DEFAULT_RECONCILE_INTERVAL).await
415 }
416
417 pub(crate) async fn with_reconcile_interval(
421 endpoint: Endpoint,
422 reconcile_interval: Duration,
423 ) -> Result<Self> {
424 tracing::trace!(
425 "Client::new_dynamic: Creating dynamic client for endpoint: {}",
426 endpoint.id()
427 );
428 let endpoint_discovery_source =
429 Self::get_or_create_dynamic_discovery_source(&endpoint).await?;
430 let instance_source = Arc::new(endpoint_discovery_source.instance_receiver());
431
432 let initial_ids: Vec<u64> = instance_source
437 .borrow()
438 .iter()
439 .map(|instance| instance.id())
440 .collect();
441 let client = Client {
442 endpoint: endpoint.clone(),
443 endpoint_discovery_source,
444 instance_source: instance_source.clone(),
445 routing_instances: Arc::new(RoutingInstancesState::new(initial_ids)),
446 reconcile_interval,
447 };
448 client.monitor_instance_source();
449 Ok(client)
450 }
451
452 pub fn instances(&self) -> Vec<Instance> {
454 self.instance_source.borrow().clone()
455 }
456
457 pub fn instance_ids(&self) -> Vec<u64> {
458 self.instances().into_iter().map(|ep| ep.id()).collect()
459 }
460
461 pub fn instance_ids_avail(&self) -> Vec<u64> {
462 self.routing_instances.routable_ids()
463 }
464
465 #[cfg(test)]
466 pub(crate) fn instance_ids_free(&self) -> Vec<u64> {
467 self.routing_instances.free_ids()
468 }
469
470 pub(crate) fn routing_instances(&self) -> arc_swap::Guard<Arc<RoutingInstances>> {
471 self.routing_instances.snapshot()
472 }
473
474 pub fn routing_instance_counts(&self) -> RoutingInstanceCounts {
475 self.routing_instances.counts()
476 }
477
478 pub fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
480 self.routing_instances.instance_avail_watcher()
481 }
482
483 pub(crate) fn subscribe_discovery_events(
488 &self,
489 ) -> tokio::sync::mpsc::UnboundedReceiver<DiscoveryEvent> {
490 self.endpoint_discovery_source.subscribe_events()
491 }
492
493 pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
495 tracing::trace!(
496 "wait_for_instances: Starting wait for endpoint: {}",
497 self.endpoint.id()
498 );
499 let mut rx = self.instance_source.as_ref().clone();
500 let mut instances: Vec<Instance>;
502 loop {
503 instances = rx.borrow_and_update().to_vec();
504 if instances.is_empty() {
505 rx.changed().await?;
506 } else {
507 tracing::info!(
508 "wait_for_instances: Found {} instance(s) for endpoint: {}",
509 instances.len(),
510 self.endpoint.id()
511 );
512 break;
513 }
514 }
515 Ok(instances)
516 }
517
518 pub fn report_instance_down(&self, instance_id: u64) {
520 self.routing_instances.report_instance_down(instance_id);
521 tracing::debug!("inhibiting instance {instance_id}");
522 }
523
524 pub fn set_overloaded_instances(&self, overloaded_instance_ids: &[u64]) -> bool {
527 self.routing_instances
528 .set_overloaded_instances(overloaded_instance_ids)
529 }
530
531 pub fn clear_overloaded_instances_for_removed(&self, removed_instance_ids: &[u64]) {
532 self.routing_instances
533 .clear_overloaded_for_removed(removed_instance_ids);
534 }
535
536 pub fn overloaded_instance_ids(&self) -> Option<HashSet<u64>> {
537 self.routing_instances.overloaded_ids()
538 }
539
540 fn monitor_instance_source(&self) {
547 let reconcile_interval = self.reconcile_interval;
548 let cancel_token = self.endpoint.drt().primary_token();
549 let client = self.clone();
550 let endpoint_id = self.endpoint.id();
551 tokio::task::spawn(async move {
552 let mut rx = client.instance_source.as_ref().clone();
553 while !cancel_token.is_cancelled() {
554 let instance_ids: Vec<u64> = rx
555 .borrow_and_update()
556 .iter()
557 .map(|instance| instance.id())
558 .collect();
559
560 let routing_instances = client.reconcile_discovered_instances(instance_ids);
561
562 let registry = client.endpoint.drt().routing_occupancy_states();
564 if let Ok(registry) = registry.try_lock()
565 && let Some(weak) = registry.get(&client.endpoint)
566 && let Some(state) = weak.upgrade()
567 {
568 state.retain(routing_instances.discovered_ids());
569 }
570
571 tokio::select! {
572 result = rx.changed() => {
573 if let Err(err) = result {
574 tracing::error!(
575 "monitor_instance_source: The Sender is dropped: {err}, endpoint={endpoint_id}",
576 );
577 cancel_token.cancel();
578 }
579 }
580 _ = tokio::time::sleep(reconcile_interval) => {
581 tracing::trace!(
582 "monitor_instance_source: periodic reconciliation for endpoint={endpoint_id}",
583 );
584 }
585 }
586 }
587 });
588 }
589
590 #[cfg(test)]
593 pub(crate) fn override_instance_avail(&self, ids: Vec<u64>) {
594 self.routing_instances.override_routable_ids(ids);
595 }
596
597 fn reconcile_discovered_instances(&self, discovered_ids: Vec<u64>) -> Arc<RoutingInstances> {
598 self.routing_instances.reconcile_discovered(discovered_ids)
599 }
600
601 async fn get_or_create_dynamic_discovery_source(
602 endpoint: &Endpoint,
603 ) -> Result<Arc<EndpointDiscoverySource>> {
604 let drt = endpoint.drt();
605 let sources = drt.endpoint_discovery_sources();
606 let mut sources = sources.lock().await;
607
608 if let Some(source) = sources.get(endpoint) {
609 if let Some(source) = source.upgrade() {
610 return Ok(source);
611 } else {
612 sources.remove(endpoint);
613 }
614 }
615
616 let discovery = drt.discovery();
617 let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
618 namespace: endpoint.component.namespace.name.clone(),
619 component: endpoint.component.name.clone(),
620 endpoint: endpoint.name.clone(),
621 };
622
623 let mut discovery_stream = discovery
624 .list_and_watch(discovery_query.clone(), None)
625 .await?;
626 let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
627 let discovery_source = Arc::new(EndpointDiscoverySource::new(watch_rx));
628
629 let secondary = endpoint.component.drt.runtime().secondary().clone();
630 let discovery_source_task = discovery_source.clone();
631
632 secondary.spawn(async move {
633 tracing::trace!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
634 let mut map: HashMap<u64, Instance> = HashMap::new();
635
636 loop {
637 let discovery_event = tokio::select! {
638 _ = watch_tx.closed() => {
639 break;
640 }
641 discovery_event = discovery_stream.next() => {
642 match discovery_event {
643 Some(Ok(event)) => {
644 event
645 },
646 Some(Err(e)) => {
647 tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
648 break;
649 }
650 None => {
651 break;
652 }
653 }
654 }
655 };
656
657 discovery_source_task.broadcast_event(&discovery_event);
658
659 match discovery_event {
660 DiscoveryEvent::Added(DiscoveryInstance::Endpoint(instance)) => {
661 map.insert(instance.instance_id, instance);
662 }
663 DiscoveryEvent::Added(_) => {}
664 DiscoveryEvent::Removed(id) => {
665 if let DiscoveryInstanceId::Endpoint(endpoint_id) = id {
666 map.remove(&endpoint_id.instance_id);
667 }
668 }
669 }
670
671 let instances: Vec<Instance> = map.values().cloned().collect();
672 if watch_tx.send(instances).is_err() {
673 break;
674 }
675 }
676 let _ = watch_tx.send(vec![]);
677 });
678
679 sources.insert(endpoint.clone(), Arc::downgrade(&discovery_source));
680 Ok(discovery_source)
681 }
682}
683
684#[cfg(test)]
685mod tests {
686 use super::*;
687 use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
688
689 #[tokio::test]
692 async fn test_instance_reconciliation() {
693 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(100);
694
695 let rt = Runtime::from_current().unwrap();
696 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
698 .await
699 .unwrap();
700 let ns = drt.namespace("test_reconciliation".to_string()).unwrap();
701 let component = ns.component("test_component".to_string()).unwrap();
702 let endpoint = component.endpoint("test_endpoint".to_string());
703
704 let client = Client::with_reconcile_interval(endpoint, TEST_RECONCILE_INTERVAL)
706 .await
707 .unwrap();
708
709 assert!(client.instance_ids_avail().is_empty());
711
712 client.override_instance_avail(vec![1, 2, 3]);
715
716 assert_eq!(client.instance_ids_avail(), vec![1u64, 2, 3]);
717
718 client.report_instance_down(2);
720 assert_eq!(client.instance_ids_avail(), vec![1u64, 3]);
721
722 tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
726
727 assert!(
729 client.instance_ids_avail().is_empty(),
730 "After reconciliation, instance_avail should match instance_source"
731 );
732
733 rt.shutdown();
734 }
735
736 #[tokio::test]
738 async fn test_report_instance_down() {
739 let rt = Runtime::from_current().unwrap();
740 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
742 .await
743 .unwrap();
744 let ns = drt.namespace("test_report_down".to_string()).unwrap();
745 let component = ns.component("test_component".to_string()).unwrap();
746 let endpoint = component.endpoint("test_endpoint".to_string());
747
748 let client = endpoint.client().await.unwrap();
749
750 client.override_instance_avail(vec![1, 2, 3]);
752 assert_eq!(client.instance_ids_avail(), vec![1u64, 2, 3]);
753
754 client.report_instance_down(2);
756
757 let avail = client.instance_ids_avail();
759 assert!(avail.contains(&1), "Instance 1 should still be available");
760 assert!(
761 !avail.contains(&2),
762 "Instance 2 should be removed after report_instance_down"
763 );
764 assert!(avail.contains(&3), "Instance 3 should still be available");
765
766 rt.shutdown();
767 }
768
769 #[tokio::test]
770 async fn test_overloaded_instance_ids_returns_none_when_empty() {
771 let rt = Runtime::from_current().unwrap();
772 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
773 .await
774 .unwrap();
775 let ns = drt.namespace("test_overloaded_ids".to_string()).unwrap();
776 let component = ns.component("test_component".to_string()).unwrap();
777 let endpoint = component.endpoint("test_endpoint".to_string());
778 let client = endpoint.client().await.unwrap();
779
780 assert_eq!(client.overloaded_instance_ids(), None);
781
782 assert!(client.set_overloaded_instances(&[7]));
783 assert_eq!(client.overloaded_instance_ids(), Some(HashSet::from([7])));
784 assert!(!client.set_overloaded_instances(&[7]));
785
786 assert!(client.set_overloaded_instances(&[]));
787 assert_eq!(client.overloaded_instance_ids(), None);
788 assert!(!client.set_overloaded_instances(&[]));
789
790 rt.shutdown();
791 }
792
793 #[tokio::test]
794 async fn test_instance_reconciliation_preserves_overloaded_existing_instances() {
795 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
796
797 let rt = Runtime::from_current().unwrap();
798 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
799 .await
800 .unwrap();
801 let ns = drt
802 .namespace("test_overloaded_reconciliation".to_string())
803 .unwrap();
804 let component = ns.component("test_component".to_string()).unwrap();
805 let endpoint = component.endpoint("test_endpoint".to_string());
806
807 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
808 .await
809 .unwrap();
810 endpoint.register_endpoint_instance().await.unwrap();
811 let instances = client.wait_for_instances().await.unwrap();
812 let worker_id = instances[0].id();
813
814 for _ in 0..10 {
815 if client.instance_ids_free().contains(&worker_id) {
816 break;
817 }
818 tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
819 }
820 assert!(
821 client.instance_ids_free().contains(&worker_id),
822 "worker should be free after initial discovery reconciliation"
823 );
824
825 client.set_overloaded_instances(&[worker_id]);
826 assert!(
827 client.instance_ids_free().is_empty(),
828 "worker should be overloaded before periodic reconciliation"
829 );
830
831 tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
832
833 assert!(
834 client.instance_ids_free().is_empty(),
835 "periodic reconciliation should not mark an existing overloaded worker free"
836 );
837
838 rt.shutdown();
839 }
840
841 #[tokio::test]
842 async fn test_report_instance_down_preserves_overloaded_state() {
843 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
844
845 let rt = Runtime::from_current().unwrap();
846 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
847 .await
848 .unwrap();
849 let ns = drt
850 .namespace("test_report_down_preserves_overloaded".to_string())
851 .unwrap();
852 let component = ns.component("test_component".to_string()).unwrap();
853 let endpoint = component.endpoint("test_endpoint".to_string());
854
855 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
856 .await
857 .unwrap();
858 endpoint.register_endpoint_instance().await.unwrap();
859 let instances = client.wait_for_instances().await.unwrap();
860 let worker_id = instances[0].id();
861
862 for _ in 0..10 {
863 if client.instance_ids_avail().contains(&worker_id) {
864 break;
865 }
866 tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
867 }
868
869 client.set_overloaded_instances(&[worker_id]);
870 client.report_instance_down(worker_id);
871
872 assert!(
873 !client.instance_ids_avail().contains(&worker_id),
874 "reported-down worker should leave routable availability"
875 );
876 assert_eq!(
877 client.routing_instance_counts().overloaded,
878 1,
879 "reported-down worker should remain overloaded while still discovered"
880 );
881 assert!(
882 client.instance_ids_free().is_empty(),
883 "reported-down overloaded worker should not become free"
884 );
885
886 endpoint.unregister_endpoint_instance().await.unwrap();
887 for _ in 0..10 {
888 if client.routing_instance_counts().overloaded == 0 {
889 break;
890 }
891 tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
892 }
893
894 assert_eq!(
895 client.routing_instance_counts().overloaded,
896 0,
897 "stable discovery removal should clear overloaded state"
898 );
899
900 rt.shutdown();
901 }
902
903 #[tokio::test]
904 async fn test_instance_reconciliation_prunes_removed_overloaded_instances() {
905 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
906
907 let rt = Runtime::from_current().unwrap();
908 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
909 .await
910 .unwrap();
911 let ns = drt
912 .namespace("test_removed_overloaded_cleanup".to_string())
913 .unwrap();
914 let component = ns.component("test_component".to_string()).unwrap();
915 let endpoint = component.endpoint("test_endpoint".to_string());
916
917 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
918 .await
919 .unwrap();
920 endpoint.register_endpoint_instance().await.unwrap();
921 let instances = client.wait_for_instances().await.unwrap();
922 let worker_id = instances[0].id();
923
924 client.set_overloaded_instances(&[worker_id]);
925 assert_eq!(client.routing_instance_counts().overloaded, 1);
926 assert!(client.instance_ids_free().is_empty());
927
928 endpoint.unregister_endpoint_instance().await.unwrap();
929 for _ in 0..10 {
930 if client.routing_instance_counts().overloaded == 0 {
931 break;
932 }
933 tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
934 }
935
936 assert_eq!(
937 client.routing_instance_counts().overloaded,
938 0,
939 "removed discovered workers should not remain in overloaded state"
940 );
941
942 rt.shutdown();
943 }
944
945 #[tokio::test]
946 async fn test_instance_ids_free_excludes_overloaded_new_instances() {
947 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
948
949 let rt = Runtime::from_current().unwrap();
950 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
951 .await
952 .unwrap();
953 let worker_id = drt.connection_id();
954 let ns = drt
955 .namespace("test_new_overloaded_reconciliation".to_string())
956 .unwrap();
957 let component = ns.component("test_component".to_string()).unwrap();
958 let endpoint = component.endpoint("test_endpoint".to_string());
959
960 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
961 .await
962 .unwrap();
963 client.set_overloaded_instances(&[worker_id]);
964
965 endpoint.register_endpoint_instance().await.unwrap();
966 let instances = client.wait_for_instances().await.unwrap();
967 assert_eq!(instances[0].id(), worker_id);
968 assert!(
969 client.instance_ids_free().is_empty(),
970 "newly discovered overloaded worker should not be free"
971 );
972
973 tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
974
975 assert!(
976 client.instance_ids_free().is_empty(),
977 "discovery reconciliation should not affect recomputed free workers"
978 );
979
980 rt.shutdown();
981 }
982
983 #[tokio::test]
984 async fn test_discovery_add_updates_free_without_overloaded_publish() {
985 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
986
987 let rt = Runtime::from_current().unwrap();
988 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
989 .await
990 .unwrap();
991 let ns = drt
992 .namespace("test_free_updates_on_discovery_add".to_string())
993 .unwrap();
994 let component = ns.component("test_component".to_string()).unwrap();
995 let endpoint = component.endpoint("test_endpoint".to_string());
996
997 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
998 .await
999 .unwrap();
1000 endpoint.register_endpoint_instance().await.unwrap();
1001 let instances = client.wait_for_instances().await.unwrap();
1002 let worker_id = instances[0].id();
1003
1004 for _ in 0..10 {
1005 if client.instance_ids_free().contains(&worker_id) {
1006 break;
1007 }
1008 tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
1009 }
1010
1011 assert_eq!(
1012 client.instance_ids_free(),
1013 vec![worker_id],
1014 "newly discovered non-overloaded workers should appear free without an overload update"
1015 );
1016
1017 rt.shutdown();
1018 }
1019
1020 #[tokio::test]
1022 async fn test_instance_avail_watcher() {
1023 let rt = Runtime::from_current().unwrap();
1024 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1026 .await
1027 .unwrap();
1028 let ns = drt.namespace("test_watcher".to_string()).unwrap();
1029 let component = ns.component("test_component".to_string()).unwrap();
1030 let endpoint = component.endpoint("test_endpoint".to_string());
1031
1032 let client = endpoint.client().await.unwrap();
1033 let watcher = client.instance_avail_watcher();
1034
1035 client.override_instance_avail(vec![1, 2, 3]);
1037
1038 client.report_instance_down(2);
1040
1041 let current = watcher.borrow().clone();
1044 assert_eq!(current, vec![1, 3]);
1045
1046 rt.shutdown();
1047 }
1048
1049 #[tokio::test]
1051 async fn test_concurrent_select_and_increment() {
1052 let state = Arc::new(RoutingOccupancyState::default());
1053 let instance_ids: Vec<u64> = vec![100, 200, 300];
1054 let num_requests = 90;
1055
1056 let mut handles = Vec::new();
1057 for _ in 0..num_requests {
1058 let state = state.clone();
1059 let ids = instance_ids.clone();
1060 handles.push(tokio::spawn(async move {
1061 state.select_exact_min_and_increment(&ids).await
1062 }));
1063 }
1064
1065 for handle in handles {
1066 handle.await.unwrap();
1067 }
1068
1069 assert_eq!(state.load(100), 30);
1070 assert_eq!(state.load(200), 30);
1071 assert_eq!(state.load(300), 30);
1072 }
1073
1074 #[tokio::test]
1075 async fn test_select_exact_min_and_increment_randomizes_ties() {
1076 let mut selected = [false; 3];
1077
1078 for _ in 0..120 {
1079 let state = RoutingOccupancyState::default();
1080 let picked = state
1081 .select_exact_min_and_increment(&[10, 20, 30])
1082 .await
1083 .unwrap();
1084 match picked {
1085 10 => selected[0] = true,
1086 20 => selected[1] = true,
1087 30 => selected[2] = true,
1088 _ => panic!("unexpected worker id: {picked}"),
1089 }
1090 }
1091
1092 let selected_count = selected.into_iter().filter(|seen| *seen).count();
1093 assert!(
1094 selected_count > 1,
1095 "tie-breaking should not always select the first minimum-load worker"
1096 );
1097 }
1098
1099 #[tokio::test]
1100 async fn test_connection_counts() {
1101 let rt = Runtime::from_current().unwrap();
1102 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1103 .await
1104 .unwrap();
1105 let ns = drt.namespace("test_ll_counts".to_string()).unwrap();
1106 let component = ns.component("test_component".to_string()).unwrap();
1107 let endpoint = component.endpoint("test_endpoint".to_string());
1108
1109 let state1 = get_or_create_routing_occupancy_state(&endpoint).await;
1110 let state2 = get_or_create_routing_occupancy_state(&endpoint).await;
1111
1112 let picked1 = state1
1113 .select_exact_min_and_increment(&[10, 20, 30])
1114 .await
1115 .unwrap();
1116 assert_eq!(state1.load(picked1), 1);
1117
1118 let picked2 = state1
1119 .select_exact_min_and_increment(&[10, 20, 30])
1120 .await
1121 .unwrap();
1122 assert_ne!(picked1, picked2);
1123
1124 assert_eq!(state2.load(10), state1.load(10));
1126 assert_eq!(state2.load(20), state1.load(20));
1127 assert_eq!(state2.load(30), state1.load(30));
1128
1129 state2.decrement(picked1);
1130 assert_eq!(state1.load(picked1), if picked1 == picked2 { 1 } else { 0 });
1131
1132 rt.shutdown();
1133 }
1134
1135 #[tokio::test]
1136 async fn test_least_loaded_state_retain() {
1137 let state = RoutingOccupancyState::default();
1138
1139 state.select_exact_min_and_increment(&[1, 2, 3]).await;
1141 state.select_exact_min_and_increment(&[1, 2, 3]).await;
1142 state.select_exact_min_and_increment(&[1, 2, 3]).await;
1143 assert_eq!(state.load(1), 1);
1145 assert_eq!(state.load(2), 1);
1146 assert_eq!(state.load(3), 1);
1147
1148 state.retain(&[1, 3]);
1150
1151 assert_eq!(state.load(1), 1);
1152 assert_eq!(state.load(2), 0);
1153 assert_eq!(state.load(3), 1);
1154 }
1155
1156 #[tokio::test]
1157 async fn test_monitor_instance_source_cleans_up_removed_worker_counts() {
1158 const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
1159
1160 let rt = Runtime::from_current().unwrap();
1161 let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
1162 .await
1163 .unwrap();
1164 let ns = drt.namespace("test_occupancy_cleanup".to_string()).unwrap();
1165 let component = ns.component("test_component".to_string()).unwrap();
1166 let endpoint = component.endpoint("test_endpoint".to_string());
1167
1168 let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
1169 .await
1170 .unwrap();
1171 endpoint.register_endpoint_instance().await.unwrap();
1172 client.wait_for_instances().await.unwrap();
1173
1174 let worker_id = client.instance_ids_avail()[0];
1175 let state = get_or_create_routing_occupancy_state(&endpoint).await;
1176 state.increment(worker_id);
1177 assert_eq!(state.load(worker_id), 1);
1178
1179 endpoint.unregister_endpoint_instance().await.unwrap();
1180
1181 for _ in 0..10 {
1182 if state.load(worker_id) == 0 {
1183 break;
1184 }
1185 tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
1186 }
1187
1188 assert_eq!(state.load(worker_id), 0);
1189
1190 rt.shutdown();
1191 }
1192}