use std::sync::atomic::{AtomicU64, Ordering};
use std::{
collections::{HashMap, HashSet},
sync::{Arc, Mutex as StdMutex},
time::Duration,
};
use anyhow::Result;
use arc_swap::ArcSwap;
use dashmap::DashMap;
use futures::StreamExt;
use rand::Rng;
use crate::component::{Endpoint, Instance};
use crate::discovery::{DiscoveryEvent, DiscoveryInstance, DiscoveryInstanceId};
use crate::traits::DistributedRuntimeProvider;
#[derive(Debug, Default)]
pub(crate) struct RoutingOccupancyState {
counts: DashMap<u64, AtomicU64>,
exact_selection_lock: tokio::sync::Mutex<()>,
}
impl RoutingOccupancyState {
pub(crate) fn increment(&self, instance_id: u64) {
self.counts
.entry(instance_id)
.or_insert_with(|| AtomicU64::new(0))
.fetch_add(1, Ordering::Relaxed);
}
pub(crate) async fn select_exact_min_and_increment(&self, instance_ids: &[u64]) -> Option<u64> {
let _guard = self.exact_selection_lock.lock().await;
let mut min_load = u64::MAX;
let mut selected = None;
let mut tie_count = 0usize;
let mut rng = rand::rng();
for &id in instance_ids {
let load = self.load(id);
if load < min_load {
min_load = load;
selected = Some(id);
tie_count = 1;
continue;
}
if load == min_load {
tie_count += 1;
if rng.random_range(0..tie_count) == 0 {
selected = Some(id);
}
}
}
let id = selected?;
self.increment(id);
Some(id)
}
pub(crate) fn decrement(&self, instance_id: u64) {
if let Some(count) = self.counts.get(&instance_id) {
let _ = count.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(1))
});
}
}
pub(crate) fn load(&self, instance_id: u64) -> u64 {
self.counts
.get(&instance_id)
.map(|c| c.load(Ordering::Relaxed))
.unwrap_or(0)
}
pub(crate) fn retain(&self, instance_ids: &[u64]) {
let live: HashSet<u64> = instance_ids.iter().copied().collect();
self.counts.retain(|id, _| live.contains(id));
}
}
pub(crate) async fn get_or_create_routing_occupancy_state(
endpoint: &Endpoint,
) -> Arc<RoutingOccupancyState> {
let drt = endpoint.drt();
let registry = drt.routing_occupancy_states();
let mut registry = registry.lock().await;
if let Some(weak) = registry.get(endpoint) {
if let Some(state) = weak.upgrade() {
return state;
} else {
registry.remove(endpoint);
}
}
let state = Arc::new(RoutingOccupancyState::default());
registry.insert(endpoint.clone(), Arc::downgrade(&state));
state
}
const DEFAULT_RECONCILE_INTERVAL: Duration = Duration::from_secs(5);
#[derive(Debug)]
pub(crate) struct EndpointDiscoverySource {
instance_source: tokio::sync::watch::Receiver<Vec<Instance>>,
event_subscribers: StdMutex<Vec<tokio::sync::mpsc::UnboundedSender<DiscoveryEvent>>>,
}
impl EndpointDiscoverySource {
fn new(instance_source: tokio::sync::watch::Receiver<Vec<Instance>>) -> Self {
Self {
instance_source,
event_subscribers: StdMutex::new(Vec::new()),
}
}
fn instance_receiver(&self) -> tokio::sync::watch::Receiver<Vec<Instance>> {
self.instance_source.clone()
}
fn subscribe_events(&self) -> tokio::sync::mpsc::UnboundedReceiver<DiscoveryEvent> {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
self.event_subscribers.lock().unwrap().push(tx);
rx
}
fn broadcast_event(&self, event: &DiscoveryEvent) {
let subscribers = &mut *self.event_subscribers.lock().unwrap();
subscribers.retain(|tx| tx.send(event.clone()).is_ok());
}
}
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct RoutingInstanceCounts {
pub discovered: usize,
pub routable: usize,
pub overloaded: usize,
pub free: usize,
}
#[derive(Clone, Debug)]
pub(crate) struct RoutingInstances {
discovered_ids: Vec<u64>,
routable_ids: Vec<u64>,
overloaded_ids: HashSet<u64>,
free_ids: Vec<u64>,
}
impl RoutingInstances {
fn new(discovered_ids: Vec<u64>) -> Self {
Self::from_parts(discovered_ids.clone(), discovered_ids, HashSet::new())
}
fn from_parts(
discovered_ids: Vec<u64>,
routable_ids: Vec<u64>,
overloaded_ids: HashSet<u64>,
) -> Self {
let free_ids = Self::derive_free_ids(&routable_ids, &overloaded_ids);
Self {
discovered_ids,
routable_ids,
overloaded_ids,
free_ids,
}
}
pub(crate) fn discovered_ids(&self) -> &[u64] {
&self.discovered_ids
}
pub(crate) fn routable_ids(&self) -> &[u64] {
&self.routable_ids
}
pub(crate) fn free_ids(&self) -> &[u64] {
&self.free_ids
}
pub(crate) fn counts(&self) -> RoutingInstanceCounts {
RoutingInstanceCounts {
discovered: self.discovered_ids.len(),
routable: self.routable_ids.len(),
overloaded: self.overloaded_ids.len(),
free: self.free_ids.len(),
}
}
pub(crate) fn is_overloaded(&self, instance_id: u64) -> bool {
self.overloaded_ids.contains(&instance_id)
}
fn overloaded_ids(&self) -> Option<HashSet<u64>> {
if self.overloaded_ids.is_empty() {
return None;
}
Some(self.overloaded_ids.clone())
}
fn reconcile_discovered(&self, discovered_ids: Vec<u64>) -> Self {
let old_discovered_ids = self.discovered_ids.iter().copied().collect::<HashSet<_>>();
let new_discovered_ids = discovered_ids.iter().copied().collect::<HashSet<_>>();
let mut overloaded_ids = self.overloaded_ids.clone();
overloaded_ids
.retain(|id| !old_discovered_ids.contains(id) || new_discovered_ids.contains(id));
Self::from_parts(discovered_ids.clone(), discovered_ids, overloaded_ids)
}
fn report_instance_down(&self, instance_id: u64) -> Self {
let routable_ids: Vec<u64> = self
.routable_ids
.iter()
.copied()
.filter(|id| *id != instance_id)
.collect();
Self::from_parts(
self.discovered_ids.clone(),
routable_ids,
self.overloaded_ids.clone(),
)
}
#[cfg(test)]
fn override_routable_ids(&self, routable_ids: Vec<u64>) -> Self {
Self::from_parts(
self.discovered_ids.clone(),
routable_ids,
self.overloaded_ids.clone(),
)
}
fn set_overloaded(&self, overloaded_ids: HashSet<u64>) -> Self {
Self::from_parts(
self.discovered_ids.clone(),
self.routable_ids.clone(),
overloaded_ids,
)
}
fn clear_overloaded_for_removed(&self, removed_ids: &HashSet<u64>) -> Self {
let mut overloaded_ids = self.overloaded_ids.clone();
overloaded_ids.retain(|id| !removed_ids.contains(id));
Self::from_parts(
self.discovered_ids.clone(),
self.routable_ids.clone(),
overloaded_ids,
)
}
fn derive_free_ids(routable_ids: &[u64], overloaded_ids: &HashSet<u64>) -> Vec<u64> {
if overloaded_ids.is_empty() {
return routable_ids.to_vec();
}
routable_ids
.iter()
.copied()
.filter(|id| !overloaded_ids.contains(id))
.collect()
}
}
#[derive(Debug)]
struct RoutingInstancesState {
snapshot: ArcSwap<RoutingInstances>,
update_lock: StdMutex<()>,
instance_avail_tx: tokio::sync::watch::Sender<Vec<u64>>,
instance_avail_rx: tokio::sync::watch::Receiver<Vec<u64>>,
}
impl RoutingInstancesState {
fn new(discovered_ids: Vec<u64>) -> Self {
let snapshot = RoutingInstances::new(discovered_ids);
let (instance_avail_tx, instance_avail_rx) =
tokio::sync::watch::channel(snapshot.routable_ids().to_vec());
Self {
snapshot: ArcSwap::from_pointee(snapshot),
update_lock: StdMutex::new(()),
instance_avail_tx,
instance_avail_rx,
}
}
fn snapshot(&self) -> arc_swap::Guard<Arc<RoutingInstances>> {
self.snapshot.load()
}
fn update(
&self,
update: impl FnOnce(&RoutingInstances) -> RoutingInstances,
publish_routable_ids: bool,
) -> Arc<RoutingInstances> {
let _guard = self.update_lock.lock().unwrap();
let current = self.snapshot.load();
let next = Arc::new(update(¤t));
self.snapshot.store(next.clone());
if publish_routable_ids {
self.publish_routable_ids(&next);
}
next
}
fn publish_routable_ids(&self, routing_instances: &RoutingInstances) {
let _ = self
.instance_avail_tx
.send(routing_instances.routable_ids().to_vec());
}
fn routable_ids(&self) -> Vec<u64> {
self.snapshot().routable_ids().to_vec()
}
#[cfg(test)]
fn free_ids(&self) -> Vec<u64> {
self.snapshot().free_ids.clone()
}
fn counts(&self) -> RoutingInstanceCounts {
self.snapshot().counts()
}
fn overloaded_ids(&self) -> Option<HashSet<u64>> {
self.snapshot().overloaded_ids()
}
fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
self.instance_avail_rx.clone()
}
fn report_instance_down(&self, instance_id: u64) {
self.update(|current| current.report_instance_down(instance_id), true);
}
fn set_overloaded_instances(&self, overloaded_instance_ids: &[u64]) -> bool {
let overloaded_ids = overloaded_instance_ids
.iter()
.copied()
.collect::<HashSet<_>>();
let _guard = self.update_lock.lock().unwrap();
let current = self.snapshot.load();
if current.overloaded_ids == overloaded_ids {
return false;
}
let next = Arc::new(current.set_overloaded(overloaded_ids));
self.snapshot.store(next);
true
}
fn clear_overloaded_for_removed(&self, removed_instance_ids: &[u64]) {
if removed_instance_ids.is_empty() {
return;
}
let removed_ids = removed_instance_ids.iter().copied().collect::<HashSet<_>>();
self.update(
move |current| current.clear_overloaded_for_removed(&removed_ids),
false,
);
}
fn reconcile_discovered(&self, discovered_ids: Vec<u64>) -> Arc<RoutingInstances> {
self.update(
move |current| current.reconcile_discovered(discovered_ids),
true,
)
}
#[cfg(test)]
fn override_routable_ids(&self, ids: Vec<u64>) {
self.update(move |current| current.override_routable_ids(ids), true);
}
}
#[derive(Clone, Debug)]
pub struct Client {
pub endpoint: Endpoint,
endpoint_discovery_source: Arc<EndpointDiscoverySource>,
pub instance_source: Arc<tokio::sync::watch::Receiver<Vec<Instance>>>,
routing_instances: Arc<RoutingInstancesState>,
reconcile_interval: Duration,
}
impl Client {
pub(crate) async fn new(endpoint: Endpoint) -> Result<Self> {
Self::with_reconcile_interval(endpoint, DEFAULT_RECONCILE_INTERVAL).await
}
pub(crate) async fn with_reconcile_interval(
endpoint: Endpoint,
reconcile_interval: Duration,
) -> Result<Self> {
tracing::trace!(
"Client::new_dynamic: Creating dynamic client for endpoint: {}",
endpoint.id()
);
let endpoint_discovery_source =
Self::get_or_create_dynamic_discovery_source(&endpoint).await?;
let instance_source = Arc::new(endpoint_discovery_source.instance_receiver());
let initial_ids: Vec<u64> = instance_source
.borrow()
.iter()
.map(|instance| instance.id())
.collect();
let client = Client {
endpoint: endpoint.clone(),
endpoint_discovery_source,
instance_source: instance_source.clone(),
routing_instances: Arc::new(RoutingInstancesState::new(initial_ids)),
reconcile_interval,
};
client.monitor_instance_source();
Ok(client)
}
pub fn instances(&self) -> Vec<Instance> {
self.instance_source.borrow().clone()
}
pub fn instance_ids(&self) -> Vec<u64> {
self.instances().into_iter().map(|ep| ep.id()).collect()
}
pub fn instance_ids_avail(&self) -> Vec<u64> {
self.routing_instances.routable_ids()
}
#[cfg(test)]
pub(crate) fn instance_ids_free(&self) -> Vec<u64> {
self.routing_instances.free_ids()
}
pub(crate) fn routing_instances(&self) -> arc_swap::Guard<Arc<RoutingInstances>> {
self.routing_instances.snapshot()
}
pub fn routing_instance_counts(&self) -> RoutingInstanceCounts {
self.routing_instances.counts()
}
pub fn instance_avail_watcher(&self) -> tokio::sync::watch::Receiver<Vec<u64>> {
self.routing_instances.instance_avail_watcher()
}
pub(crate) fn subscribe_discovery_events(
&self,
) -> tokio::sync::mpsc::UnboundedReceiver<DiscoveryEvent> {
self.endpoint_discovery_source.subscribe_events()
}
pub async fn wait_for_instances(&self) -> Result<Vec<Instance>> {
tracing::trace!(
"wait_for_instances: Starting wait for endpoint: {}",
self.endpoint.id()
);
let mut rx = self.instance_source.as_ref().clone();
let mut instances: Vec<Instance>;
loop {
instances = rx.borrow_and_update().to_vec();
if instances.is_empty() {
rx.changed().await?;
} else {
tracing::info!(
"wait_for_instances: Found {} instance(s) for endpoint: {}",
instances.len(),
self.endpoint.id()
);
break;
}
}
Ok(instances)
}
pub fn report_instance_down(&self, instance_id: u64) {
self.routing_instances.report_instance_down(instance_id);
tracing::debug!("inhibiting instance {instance_id}");
}
pub fn set_overloaded_instances(&self, overloaded_instance_ids: &[u64]) -> bool {
self.routing_instances
.set_overloaded_instances(overloaded_instance_ids)
}
pub fn clear_overloaded_instances_for_removed(&self, removed_instance_ids: &[u64]) {
self.routing_instances
.clear_overloaded_for_removed(removed_instance_ids);
}
pub fn overloaded_instance_ids(&self) -> Option<HashSet<u64>> {
self.routing_instances.overloaded_ids()
}
fn monitor_instance_source(&self) {
let reconcile_interval = self.reconcile_interval;
let cancel_token = self.endpoint.drt().primary_token();
let client = self.clone();
let endpoint_id = self.endpoint.id();
tokio::task::spawn(async move {
let mut rx = client.instance_source.as_ref().clone();
while !cancel_token.is_cancelled() {
let instance_ids: Vec<u64> = rx
.borrow_and_update()
.iter()
.map(|instance| instance.id())
.collect();
let routing_instances = client.reconcile_discovered_instances(instance_ids);
let registry = client.endpoint.drt().routing_occupancy_states();
if let Ok(registry) = registry.try_lock()
&& let Some(weak) = registry.get(&client.endpoint)
&& let Some(state) = weak.upgrade()
{
state.retain(routing_instances.discovered_ids());
}
tokio::select! {
result = rx.changed() => {
if let Err(err) = result {
tracing::error!(
"monitor_instance_source: The Sender is dropped: {err}, endpoint={endpoint_id}",
);
cancel_token.cancel();
}
}
_ = tokio::time::sleep(reconcile_interval) => {
tracing::trace!(
"monitor_instance_source: periodic reconciliation for endpoint={endpoint_id}",
);
}
}
}
});
}
#[cfg(test)]
pub(crate) fn override_instance_avail(&self, ids: Vec<u64>) {
self.routing_instances.override_routable_ids(ids);
}
fn reconcile_discovered_instances(&self, discovered_ids: Vec<u64>) -> Arc<RoutingInstances> {
self.routing_instances.reconcile_discovered(discovered_ids)
}
async fn get_or_create_dynamic_discovery_source(
endpoint: &Endpoint,
) -> Result<Arc<EndpointDiscoverySource>> {
let drt = endpoint.drt();
let sources = drt.endpoint_discovery_sources();
let mut sources = sources.lock().await;
if let Some(source) = sources.get(endpoint) {
if let Some(source) = source.upgrade() {
return Ok(source);
} else {
sources.remove(endpoint);
}
}
let discovery = drt.discovery();
let discovery_query = crate::discovery::DiscoveryQuery::Endpoint {
namespace: endpoint.component.namespace.name.clone(),
component: endpoint.component.name.clone(),
endpoint: endpoint.name.clone(),
};
let mut discovery_stream = discovery
.list_and_watch(discovery_query.clone(), None)
.await?;
let (watch_tx, watch_rx) = tokio::sync::watch::channel(vec![]);
let discovery_source = Arc::new(EndpointDiscoverySource::new(watch_rx));
let secondary = endpoint.component.drt.runtime().secondary().clone();
let discovery_source_task = discovery_source.clone();
secondary.spawn(async move {
tracing::trace!("endpoint_watcher: Starting for discovery query: {:?}", discovery_query);
let mut map: HashMap<u64, Instance> = HashMap::new();
loop {
let discovery_event = tokio::select! {
_ = watch_tx.closed() => {
break;
}
discovery_event = discovery_stream.next() => {
match discovery_event {
Some(Ok(event)) => {
event
},
Some(Err(e)) => {
tracing::error!("endpoint_watcher: discovery stream error: {}; shutting down for discovery query: {:?}", e, discovery_query);
break;
}
None => {
break;
}
}
}
};
discovery_source_task.broadcast_event(&discovery_event);
match discovery_event {
DiscoveryEvent::Added(DiscoveryInstance::Endpoint(instance)) => {
map.insert(instance.instance_id, instance);
}
DiscoveryEvent::Added(_) => {}
DiscoveryEvent::Removed(id) => {
if let DiscoveryInstanceId::Endpoint(endpoint_id) = id {
map.remove(&endpoint_id.instance_id);
}
}
}
let instances: Vec<Instance> = map.values().cloned().collect();
if watch_tx.send(instances).is_err() {
break;
}
}
let _ = watch_tx.send(vec![]);
});
sources.insert(endpoint.clone(), Arc::downgrade(&discovery_source));
Ok(discovery_source)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{DistributedRuntime, Runtime, distributed::DistributedConfig};
#[tokio::test]
async fn test_instance_reconciliation() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(100);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_reconciliation".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint, TEST_RECONCILE_INTERVAL)
.await
.unwrap();
assert!(client.instance_ids_avail().is_empty());
client.override_instance_avail(vec![1, 2, 3]);
assert_eq!(client.instance_ids_avail(), vec![1u64, 2, 3]);
client.report_instance_down(2);
assert_eq!(client.instance_ids_avail(), vec![1u64, 3]);
tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
assert!(
client.instance_ids_avail().is_empty(),
"After reconciliation, instance_avail should match instance_source"
);
rt.shutdown();
}
#[tokio::test]
async fn test_report_instance_down() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_report_down".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
client.override_instance_avail(vec![1, 2, 3]);
assert_eq!(client.instance_ids_avail(), vec![1u64, 2, 3]);
client.report_instance_down(2);
let avail = client.instance_ids_avail();
assert!(avail.contains(&1), "Instance 1 should still be available");
assert!(
!avail.contains(&2),
"Instance 2 should be removed after report_instance_down"
);
assert!(avail.contains(&3), "Instance 3 should still be available");
rt.shutdown();
}
#[tokio::test]
async fn test_overloaded_instance_ids_returns_none_when_empty() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_overloaded_ids".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
assert_eq!(client.overloaded_instance_ids(), None);
assert!(client.set_overloaded_instances(&[7]));
assert_eq!(client.overloaded_instance_ids(), Some(HashSet::from([7])));
assert!(!client.set_overloaded_instances(&[7]));
assert!(client.set_overloaded_instances(&[]));
assert_eq!(client.overloaded_instance_ids(), None);
assert!(!client.set_overloaded_instances(&[]));
rt.shutdown();
}
#[tokio::test]
async fn test_instance_reconciliation_preserves_overloaded_existing_instances() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_overloaded_reconciliation".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
let instances = client.wait_for_instances().await.unwrap();
let worker_id = instances[0].id();
for _ in 0..10 {
if client.instance_ids_free().contains(&worker_id) {
break;
}
tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
}
assert!(
client.instance_ids_free().contains(&worker_id),
"worker should be free after initial discovery reconciliation"
);
client.set_overloaded_instances(&[worker_id]);
assert!(
client.instance_ids_free().is_empty(),
"worker should be overloaded before periodic reconciliation"
);
tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
assert!(
client.instance_ids_free().is_empty(),
"periodic reconciliation should not mark an existing overloaded worker free"
);
rt.shutdown();
}
#[tokio::test]
async fn test_report_instance_down_preserves_overloaded_state() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_report_down_preserves_overloaded".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
let instances = client.wait_for_instances().await.unwrap();
let worker_id = instances[0].id();
for _ in 0..10 {
if client.instance_ids_avail().contains(&worker_id) {
break;
}
tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
}
client.set_overloaded_instances(&[worker_id]);
client.report_instance_down(worker_id);
assert!(
!client.instance_ids_avail().contains(&worker_id),
"reported-down worker should leave routable availability"
);
assert_eq!(
client.routing_instance_counts().overloaded,
1,
"reported-down worker should remain overloaded while still discovered"
);
assert!(
client.instance_ids_free().is_empty(),
"reported-down overloaded worker should not become free"
);
endpoint.unregister_endpoint_instance().await.unwrap();
for _ in 0..10 {
if client.routing_instance_counts().overloaded == 0 {
break;
}
tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
}
assert_eq!(
client.routing_instance_counts().overloaded,
0,
"stable discovery removal should clear overloaded state"
);
rt.shutdown();
}
#[tokio::test]
async fn test_instance_reconciliation_prunes_removed_overloaded_instances() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_removed_overloaded_cleanup".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
let instances = client.wait_for_instances().await.unwrap();
let worker_id = instances[0].id();
client.set_overloaded_instances(&[worker_id]);
assert_eq!(client.routing_instance_counts().overloaded, 1);
assert!(client.instance_ids_free().is_empty());
endpoint.unregister_endpoint_instance().await.unwrap();
for _ in 0..10 {
if client.routing_instance_counts().overloaded == 0 {
break;
}
tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
}
assert_eq!(
client.routing_instance_counts().overloaded,
0,
"removed discovered workers should not remain in overloaded state"
);
rt.shutdown();
}
#[tokio::test]
async fn test_instance_ids_free_excludes_overloaded_new_instances() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let worker_id = drt.connection_id();
let ns = drt
.namespace("test_new_overloaded_reconciliation".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
client.set_overloaded_instances(&[worker_id]);
endpoint.register_endpoint_instance().await.unwrap();
let instances = client.wait_for_instances().await.unwrap();
assert_eq!(instances[0].id(), worker_id);
assert!(
client.instance_ids_free().is_empty(),
"newly discovered overloaded worker should not be free"
);
tokio::time::sleep(TEST_RECONCILE_INTERVAL + Duration::from_millis(50)).await;
assert!(
client.instance_ids_free().is_empty(),
"discovery reconciliation should not affect recomputed free workers"
);
rt.shutdown();
}
#[tokio::test]
async fn test_discovery_add_updates_free_without_overloaded_publish() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt
.namespace("test_free_updates_on_discovery_add".to_string())
.unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
let instances = client.wait_for_instances().await.unwrap();
let worker_id = instances[0].id();
for _ in 0..10 {
if client.instance_ids_free().contains(&worker_id) {
break;
}
tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
}
assert_eq!(
client.instance_ids_free(),
vec![worker_id],
"newly discovered non-overloaded workers should appear free without an overload update"
);
rt.shutdown();
}
#[tokio::test]
async fn test_instance_avail_watcher() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_watcher".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = endpoint.client().await.unwrap();
let watcher = client.instance_avail_watcher();
client.override_instance_avail(vec![1, 2, 3]);
client.report_instance_down(2);
let current = watcher.borrow().clone();
assert_eq!(current, vec![1, 3]);
rt.shutdown();
}
#[tokio::test]
async fn test_concurrent_select_and_increment() {
let state = Arc::new(RoutingOccupancyState::default());
let instance_ids: Vec<u64> = vec![100, 200, 300];
let num_requests = 90;
let mut handles = Vec::new();
for _ in 0..num_requests {
let state = state.clone();
let ids = instance_ids.clone();
handles.push(tokio::spawn(async move {
state.select_exact_min_and_increment(&ids).await
}));
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(state.load(100), 30);
assert_eq!(state.load(200), 30);
assert_eq!(state.load(300), 30);
}
#[tokio::test]
async fn test_select_exact_min_and_increment_randomizes_ties() {
let mut selected = [false; 3];
for _ in 0..120 {
let state = RoutingOccupancyState::default();
let picked = state
.select_exact_min_and_increment(&[10, 20, 30])
.await
.unwrap();
match picked {
10 => selected[0] = true,
20 => selected[1] = true,
30 => selected[2] = true,
_ => panic!("unexpected worker id: {picked}"),
}
}
let selected_count = selected.into_iter().filter(|seen| *seen).count();
assert!(
selected_count > 1,
"tie-breaking should not always select the first minimum-load worker"
);
}
#[tokio::test]
async fn test_connection_counts() {
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_ll_counts".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let state1 = get_or_create_routing_occupancy_state(&endpoint).await;
let state2 = get_or_create_routing_occupancy_state(&endpoint).await;
let picked1 = state1
.select_exact_min_and_increment(&[10, 20, 30])
.await
.unwrap();
assert_eq!(state1.load(picked1), 1);
let picked2 = state1
.select_exact_min_and_increment(&[10, 20, 30])
.await
.unwrap();
assert_ne!(picked1, picked2);
assert_eq!(state2.load(10), state1.load(10));
assert_eq!(state2.load(20), state1.load(20));
assert_eq!(state2.load(30), state1.load(30));
state2.decrement(picked1);
assert_eq!(state1.load(picked1), if picked1 == picked2 { 1 } else { 0 });
rt.shutdown();
}
#[tokio::test]
async fn test_least_loaded_state_retain() {
let state = RoutingOccupancyState::default();
state.select_exact_min_and_increment(&[1, 2, 3]).await;
state.select_exact_min_and_increment(&[1, 2, 3]).await;
state.select_exact_min_and_increment(&[1, 2, 3]).await;
assert_eq!(state.load(1), 1);
assert_eq!(state.load(2), 1);
assert_eq!(state.load(3), 1);
state.retain(&[1, 3]);
assert_eq!(state.load(1), 1);
assert_eq!(state.load(2), 0);
assert_eq!(state.load(3), 1);
}
#[tokio::test]
async fn test_monitor_instance_source_cleans_up_removed_worker_counts() {
const TEST_RECONCILE_INTERVAL: Duration = Duration::from_millis(50);
let rt = Runtime::from_current().unwrap();
let drt = DistributedRuntime::new(rt.clone(), DistributedConfig::process_local())
.await
.unwrap();
let ns = drt.namespace("test_occupancy_cleanup".to_string()).unwrap();
let component = ns.component("test_component".to_string()).unwrap();
let endpoint = component.endpoint("test_endpoint".to_string());
let client = Client::with_reconcile_interval(endpoint.clone(), TEST_RECONCILE_INTERVAL)
.await
.unwrap();
endpoint.register_endpoint_instance().await.unwrap();
client.wait_for_instances().await.unwrap();
let worker_id = client.instance_ids_avail()[0];
let state = get_or_create_routing_occupancy_state(&endpoint).await;
state.increment(worker_id);
assert_eq!(state.load(worker_id), 1);
endpoint.unregister_endpoint_instance().await.unwrap();
for _ in 0..10 {
if state.load(worker_id) == 0 {
break;
}
tokio::time::sleep(TEST_RECONCILE_INTERVAL).await;
}
assert_eq!(state.load(worker_id), 0);
rt.shutdown();
}
}