use crate::cx::Cx;
use crate::error::{Error, ErrorKind};
use crate::security::authenticated::AuthenticatedSymbol;
use crate::sync::Mutex;
use crate::sync::OwnedMutexGuard;
use crate::transport::sink::{SymbolSink, SymbolSinkExt};
use crate::types::symbol::{ObjectId, Symbol};
use crate::types::{RegionId, Time};
use parking_lot::RwLock;
use smallvec::{SmallVec, smallvec};
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::{AtomicU8, AtomicU32, AtomicU64, Ordering};
type EndpointSinkMap = HashMap<EndpointId, Arc<Mutex<Box<dyn SymbolSink>>>>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct EndpointId(pub u64);
impl EndpointId {
#[must_use]
pub const fn new(id: u64) -> Self {
Self(id)
}
}
impl std::fmt::Display for EndpointId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Endpoint({})", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[repr(u8)]
pub enum EndpointState {
Healthy,
Degraded,
Unhealthy,
Draining,
Removed,
}
impl EndpointState {
const fn as_u8(self) -> u8 {
self as u8
}
fn from_u8(value: u8) -> Self {
match value {
x if x == Self::Healthy as u8 => Self::Healthy,
x if x == Self::Degraded as u8 => Self::Degraded,
x if x == Self::Unhealthy as u8 => Self::Unhealthy,
x if x == Self::Draining as u8 => Self::Draining,
_ => Self::Removed,
}
}
#[must_use]
pub const fn can_receive(&self) -> bool {
matches!(self, Self::Healthy | Self::Degraded)
}
#[must_use]
pub const fn is_available(&self) -> bool {
!matches!(self, Self::Removed)
}
}
#[derive(Debug)]
pub struct Endpoint {
pub id: EndpointId,
pub address: String,
state: AtomicU8,
pub weight: u32,
pub region: Option<RegionId>,
pub active_connections: AtomicU32,
pub symbols_sent: AtomicU64,
pub failures: AtomicU64,
pub last_success: AtomicU64,
pub last_failure: AtomicU64,
pub metadata: HashMap<String, String>,
}
impl Endpoint {
pub fn new(id: EndpointId, address: impl Into<String>) -> Self {
Self {
id,
address: address.into(),
state: AtomicU8::new(EndpointState::Healthy.as_u8()),
weight: 100,
region: None,
active_connections: AtomicU32::new(0),
symbols_sent: AtomicU64::new(0),
failures: AtomicU64::new(0),
last_success: AtomicU64::new(0),
last_failure: AtomicU64::new(0),
metadata: HashMap::new(),
}
}
#[must_use]
pub fn with_weight(mut self, weight: u32) -> Self {
self.weight = weight;
self
}
#[must_use]
pub fn with_region(mut self, region: RegionId) -> Self {
self.region = Some(region);
self
}
#[must_use]
pub fn with_state(self, state: EndpointState) -> Self {
self.state.store(state.as_u8(), Ordering::Relaxed);
self
}
#[must_use]
pub fn state(&self) -> EndpointState {
EndpointState::from_u8(self.state.load(Ordering::Relaxed))
}
pub fn set_state(&self, state: EndpointState) {
self.state.store(state.as_u8(), Ordering::Relaxed);
}
pub fn record_success(&self, now: Time) {
self.symbols_sent.fetch_add(1, Ordering::Relaxed);
self.last_success.store(now.as_nanos(), Ordering::Relaxed);
}
pub fn record_failure(&self, now: Time) {
self.failures.fetch_add(1, Ordering::Relaxed);
self.last_failure.store(now.as_nanos(), Ordering::Relaxed);
}
pub fn acquire_connection(&self) {
self.active_connections.fetch_add(1, Ordering::Relaxed);
}
pub fn release_connection(&self) {
let _ =
self.active_connections
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
Some(current.saturating_sub(1))
});
}
#[must_use]
pub fn connection_count(&self) -> u32 {
self.active_connections.load(Ordering::Relaxed)
}
#[must_use]
#[allow(clippy::cast_precision_loss)]
pub fn failure_rate(&self) -> f64 {
let sent = self.symbols_sent.load(Ordering::Relaxed);
let failures = self.failures.load(Ordering::Relaxed);
let total = sent + failures;
if total == 0 {
0.0
} else {
failures as f64 / total as f64
}
}
pub fn acquire_connection_guard(&self) -> ConnectionGuard<'_> {
self.acquire_connection();
ConnectionGuard { endpoint: self }
}
}
pub struct ConnectionGuard<'a> {
endpoint: &'a Endpoint,
}
impl Drop for ConnectionGuard<'_> {
fn drop(&mut self) {
self.endpoint.release_connection();
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum LoadBalanceStrategy {
#[default]
RoundRobin,
WeightedRoundRobin,
LeastConnections,
WeightedLeastConnections,
Random,
HashBased,
FirstAvailable,
}
#[derive(Debug)]
pub struct LoadBalancer {
strategy: LoadBalanceStrategy,
rr_counter: AtomicU64,
random_seed: AtomicU64,
}
impl LoadBalancer {
const LCG_MULTIPLIER: u64 = 6_364_136_223_846_793_005;
const LCG_INCREMENT: u64 = 1;
const RANDOM_FLOYD_SMALL_N_MAX: usize = 8;
#[inline]
fn next_lcg(seed: u64) -> u64 {
seed.wrapping_mul(Self::LCG_MULTIPLIER)
.wrapping_add(Self::LCG_INCREMENT)
}
#[inline]
fn compare_weighted_load(a: &Endpoint, b: &Endpoint) -> std::cmp::Ordering {
let a_conn = u64::from(a.connection_count());
let b_conn = u64::from(b.connection_count());
let a_weight = u64::from(a.weight.max(1));
let b_weight = u64::from(b.weight.max(1));
(a_conn * b_weight).cmp(&(b_conn * a_weight))
}
#[inline]
fn select_ranked_prefix<'a, F>(
available: Vec<&'a Arc<Endpoint>>,
n: usize,
mut cmp: F,
) -> Vec<&'a Arc<Endpoint>>
where
F: FnMut(&(usize, &'a Arc<Endpoint>), &(usize, &'a Arc<Endpoint>)) -> std::cmp::Ordering,
{
if n == 0 || available.is_empty() {
return Vec::new();
}
if n == 1 {
let mut best_idx = 0;
let mut best_ep = available[0];
for (i, ep) in available.into_iter().enumerate().skip(1) {
if cmp(&(i, ep), &(best_idx, best_ep)) == std::cmp::Ordering::Less {
best_idx = i;
best_ep = ep;
}
}
return vec![best_ep];
}
let mut ranked: Vec<(usize, &Arc<Endpoint>)> = available.into_iter().enumerate().collect();
if n < ranked.len() {
ranked.select_nth_unstable_by(n, |a, b| cmp(a, b));
ranked.truncate(n);
}
ranked.sort_by(|a, b| cmp(a, b));
ranked.into_iter().map(|(_, endpoint)| endpoint).collect()
}
#[inline]
fn weighted_endpoint_span_for_slot(available: &[&Arc<Endpoint>], slot: u64) -> (usize, u64) {
let mut cumulative = 0u64;
for (idx, endpoint) in available.iter().enumerate() {
cumulative += u64::from(endpoint.weight);
if slot < cumulative {
return (idx, cumulative);
}
}
let last_index = available.len().saturating_sub(1);
(last_index, cumulative)
}
fn select_n_weighted_round_robin<'a>(
&self,
available: &[&'a Arc<Endpoint>],
n: usize,
) -> Vec<&'a Arc<Endpoint>> {
let len = available.len();
let total_weight: u64 = available
.iter()
.map(|endpoint| u64::from(endpoint.weight))
.sum();
if total_weight == 0 {
let counter = self.rr_counter.fetch_add(n as u64, Ordering::Relaxed);
let start = counter as usize;
return (0..n).map(|i| available[(start + i) % len]).collect();
}
loop {
let counter = self.rr_counter.load(Ordering::Relaxed);
let mut selected = Vec::with_capacity(n);
let mut selected_indices = SmallVec::<[usize; 16]>::new();
let mut slot = counter % total_weight;
let mut consumed_slots = 0u64;
while consumed_slots < total_weight {
let (idx, block_end) = Self::weighted_endpoint_span_for_slot(available, slot);
let span = block_end - slot;
if !selected_indices.contains(&idx) {
selected_indices.push(idx);
selected.push(available[idx]);
if selected.len() == n {
consumed_slots += 1;
break;
}
}
consumed_slots += span;
slot = if block_end == total_weight {
0
} else {
block_end
};
}
if selected.len() < n {
let fallback_start = counter as usize % len;
for offset in 0..len {
let idx = (fallback_start + offset) % len;
if selected_indices.contains(&idx) {
continue;
}
selected.push(available[idx]);
if selected.len() >= n {
break;
}
}
}
let next_counter = counter.saturating_add(consumed_slots.max(1));
if self
.rr_counter
.compare_exchange_weak(counter, next_counter, Ordering::Relaxed, Ordering::Relaxed)
.is_ok()
{
return selected;
}
}
}
#[must_use]
pub fn new(strategy: LoadBalanceStrategy) -> Self {
Self {
strategy,
rr_counter: AtomicU64::new(0),
random_seed: AtomicU64::new(0),
}
}
#[allow(clippy::too_many_lines)]
pub fn select<'a>(
&self,
endpoints: &'a [Arc<Endpoint>],
object_id: Option<ObjectId>,
) -> Option<&'a Arc<Endpoint>> {
if endpoints.is_empty() {
return None;
}
match self.strategy {
LoadBalanceStrategy::Random => {
self.select_random_single_without_materializing(endpoints)
}
LoadBalanceStrategy::LeastConnections => {
let mut best = None;
let mut best_count = u32::MAX;
for ep in endpoints {
if ep.state().can_receive() {
let count = ep.connection_count();
if count < best_count {
best_count = count;
best = Some(ep);
if count == 0 {
break;
}
}
}
}
best
}
LoadBalanceStrategy::WeightedLeastConnections => {
let mut best = None;
let mut best_score = None;
for ep in endpoints {
if ep.state().can_receive() {
let count = u64::from(ep.connection_count());
let weight = u64::from(ep.weight.max(1));
let is_better = match best_score {
None => true,
Some((best_count_u64, best_weight_u64)) => {
(count * best_weight_u64) < (best_count_u64 * weight)
}
};
if is_better {
best_score = Some((count, weight));
best = Some(ep);
if count == 0 {
break;
}
}
}
}
best
}
LoadBalanceStrategy::RoundRobin => {
let count = endpoints.iter().filter(|e| e.state().can_receive()).count();
if count == 0 {
return None;
}
let target = (self.rr_counter.fetch_add(1, Ordering::Relaxed) as usize) % count;
endpoints
.iter()
.filter(|e| e.state().can_receive())
.nth(target)
.or_else(|| endpoints.iter().find(|e| e.state().can_receive()))
}
LoadBalanceStrategy::WeightedRoundRobin => {
let total_weight: u64 = endpoints
.iter()
.filter(|e| e.state().can_receive())
.map(|e| u64::from(e.weight))
.sum();
if total_weight == 0 {
return endpoints.iter().find(|e| e.state().can_receive());
}
let counter = self.rr_counter.fetch_add(1, Ordering::Relaxed);
let target = counter % total_weight;
let mut cumulative = 0u64;
for endpoint in endpoints {
if endpoint.state().can_receive() {
cumulative += u64::from(endpoint.weight);
if target < cumulative {
return Some(endpoint);
}
}
}
endpoints.iter().rfind(|e| e.state().can_receive())
}
LoadBalanceStrategy::HashBased => {
let count = endpoints.iter().filter(|e| e.state().can_receive()).count();
if count == 0 {
return None;
}
object_id.map_or_else(
|| {
let idx =
(self.rr_counter.fetch_add(1, Ordering::Relaxed) as usize) % count;
endpoints
.iter()
.filter(|e| e.state().can_receive())
.nth(idx)
.or_else(|| endpoints.iter().find(|e| e.state().can_receive()))
},
|oid| {
let hash = oid.as_u128() as usize;
let idx = hash % count;
endpoints
.iter()
.filter(|e| e.state().can_receive())
.nth(idx)
.or_else(|| endpoints.iter().find(|e| e.state().can_receive()))
},
)
}
LoadBalanceStrategy::FirstAvailable => {
endpoints.iter().find(|e| e.state().can_receive())
}
}
}
#[allow(clippy::too_many_lines)]
pub fn select_n<'a>(
&self,
endpoints: &'a [Arc<Endpoint>],
n: usize,
object_id: Option<ObjectId>,
) -> Vec<&'a Arc<Endpoint>> {
if n == 0 {
return Vec::new();
}
if n == 1 {
match self.strategy {
LoadBalanceStrategy::Random => {
return self
.select_random_single_without_materializing(endpoints)
.into_iter()
.collect();
}
LoadBalanceStrategy::LeastConnections => {
let mut best = None;
let mut best_count = u32::MAX;
for ep in endpoints {
if ep.state().can_receive() {
let count = ep.connection_count();
if count < best_count {
best_count = count;
best = Some(ep);
if count == 0 {
break;
}
}
}
}
return best.into_iter().collect();
}
LoadBalanceStrategy::WeightedLeastConnections => {
let mut best = None;
let mut best_score = None;
for ep in endpoints {
if ep.state().can_receive() {
let count = u64::from(ep.connection_count());
let weight = u64::from(ep.weight.max(1));
let is_better = match best_score {
None => true,
Some((best_count_u64, best_weight_u64)) => {
(count * best_weight_u64) < (best_count_u64 * weight)
}
};
if is_better {
best_score = Some((count, weight));
best = Some(ep);
if count == 0 {
break;
}
}
}
}
return best.into_iter().collect();
}
_ => {}
}
}
if matches!(self.strategy, LoadBalanceStrategy::Random)
&& n <= Self::RANDOM_FLOYD_SMALL_N_MAX
{
if let Some(selected) = self.select_n_random_small_without_materializing(endpoints, n) {
return selected;
}
}
if n <= 16 {
match self.strategy {
LoadBalanceStrategy::LeastConnections => {
let mut top_n =
smallvec::SmallVec::<[(usize, u32, &'a Arc<Endpoint>); 16]>::new();
for (idx, ep) in endpoints.iter().enumerate() {
if ep.state().can_receive() {
let count = ep.connection_count();
if top_n.len() == n {
let last = &top_n[n - 1];
if last.1 == 0 {
break;
}
if count > last.1 || (count == last.1 && idx > last.0) {
continue;
}
}
let mut insert_pos = top_n.len();
for i in 0..top_n.len() {
if count < top_n[i].1 || (count == top_n[i].1 && idx < top_n[i].0) {
insert_pos = i;
break;
}
}
if insert_pos < n {
top_n.insert(insert_pos, (idx, count, ep));
if top_n.len() > n {
top_n.pop();
}
}
}
}
return top_n.into_iter().map(|(_, _, ep)| ep).collect();
}
LoadBalanceStrategy::WeightedLeastConnections => {
let mut top_n =
smallvec::SmallVec::<[(usize, u64, u64, &'a Arc<Endpoint>); 16]>::new();
for (idx, ep) in endpoints.iter().enumerate() {
if ep.state().can_receive() {
let count = u64::from(ep.connection_count());
let weight = u64::from(ep.weight.max(1));
if top_n.len() == n {
let last = &top_n[n - 1];
if last.1 == 0 {
break;
}
let (other_idx, other_count, other_weight, _) = *last;
let is_better = (count * other_weight) < (other_count * weight)
|| ((count * other_weight) == (other_count * weight)
&& idx < other_idx);
if !is_better {
continue;
}
}
let mut insert_pos = top_n.len();
for i in 0..top_n.len() {
let (other_idx, other_count, other_weight, _) = top_n[i];
let is_better = (count * other_weight) < (other_count * weight)
|| ((count * other_weight) == (other_count * weight)
&& idx < other_idx);
if is_better {
insert_pos = i;
break;
}
}
if insert_pos < n {
top_n.insert(insert_pos, (idx, count, weight, ep));
if top_n.len() > n {
top_n.pop();
}
}
}
}
return top_n.into_iter().map(|(_, _, _, ep)| ep).collect();
}
_ => {}
}
}
let mut available: Vec<&Arc<Endpoint>> = Vec::with_capacity(endpoints.len());
for endpoint in endpoints {
if endpoint.state().can_receive() {
available.push(endpoint);
}
}
if available.is_empty() {
return Vec::new();
}
let count = n.min(available.len());
match self.strategy {
LoadBalanceStrategy::RoundRobin => {
let start = self.rr_counter.fetch_add(count as u64, Ordering::Relaxed) as usize;
let len = available.len();
(0..count).map(|i| available[(start + i) % len]).collect()
}
LoadBalanceStrategy::Random => {
let mut seed = self.random_seed.fetch_add(count as u64, Ordering::Relaxed);
let len = available.len();
for i in 0..count {
seed = Self::next_lcg(seed);
let range = len - i;
let offset = (seed as usize) % range;
let swap_idx = i + offset;
available.swap(i, swap_idx);
}
available.truncate(count);
available
}
LoadBalanceStrategy::LeastConnections => {
Self::select_ranked_prefix(available, count, |a, b| {
a.1.connection_count()
.cmp(&b.1.connection_count())
.then(a.0.cmp(&b.0))
})
}
LoadBalanceStrategy::WeightedLeastConnections => {
Self::select_ranked_prefix(available, count, |a, b| {
Self::compare_weighted_load(a.1, b.1).then(a.0.cmp(&b.0))
})
}
LoadBalanceStrategy::HashBased => {
let start_idx = object_id.map_or_else(
|| self.rr_counter.fetch_add(count as u64, Ordering::Relaxed) as usize,
|oid| oid.as_u128() as usize,
);
let len = available.len();
(0..count)
.map(|i| available[(start_idx + i) % len])
.collect()
}
LoadBalanceStrategy::WeightedRoundRobin => {
self.select_n_weighted_round_robin(&available, count)
}
LoadBalanceStrategy::FirstAvailable => available.into_iter().take(count).collect(),
}
}
fn select_random_single_without_materializing<'a>(
&self,
endpoints: &'a [Arc<Endpoint>],
) -> Option<&'a Arc<Endpoint>> {
if endpoints.is_empty() {
return None;
}
let mut seed = self.random_seed.fetch_add(1, Ordering::Relaxed);
let total = endpoints.len();
let max_attempts = total.min(64);
for _ in 0..max_attempts {
seed = Self::next_lcg(seed);
let idx = (seed as usize) % total;
if endpoints[idx].state().can_receive() {
return Some(&endpoints[idx]);
}
}
endpoints.iter().find(|ep| ep.state().can_receive())
}
fn select_n_random_small_without_materializing<'a>(
&self,
endpoints: &'a [Arc<Endpoint>],
n: usize,
) -> Option<Vec<&'a Arc<Endpoint>>> {
if n == 0 {
return Some(Vec::new());
}
let total = endpoints.len();
if total == 0 {
return None;
}
let mut seed = self.random_seed.fetch_add(n as u64, Ordering::Relaxed);
let mut selected = SmallVec::<[usize; Self::RANDOM_FLOYD_SMALL_N_MAX]>::new();
let max_attempts = n * 4 + 16;
let mut attempts = 0;
while selected.len() < n {
if attempts >= max_attempts {
return None; }
attempts += 1;
seed = Self::next_lcg(seed);
let idx = (seed as usize) % total;
if !endpoints[idx].state().can_receive() {
continue;
}
if selected.contains(&idx) {
continue;
}
selected.push(idx);
}
Some(selected.into_iter().map(|i| &endpoints[i]).collect())
}
}
#[derive(Debug, Clone)]
pub struct RoutingEntry {
pub endpoints: Vec<Arc<Endpoint>>,
pub load_balancer: Arc<LoadBalancer>,
pub priority: u32,
pub ttl: Option<Time>,
pub created_at: Time,
}
impl RoutingEntry {
#[must_use]
pub fn new(endpoints: Vec<Arc<Endpoint>>, created_at: Time) -> Self {
Self {
endpoints,
load_balancer: Arc::new(LoadBalancer::new(LoadBalanceStrategy::RoundRobin)),
priority: 100,
ttl: None,
created_at,
}
}
#[must_use]
pub fn with_strategy(mut self, strategy: LoadBalanceStrategy) -> Self {
self.load_balancer = Arc::new(LoadBalancer::new(strategy));
self
}
#[must_use]
pub fn with_priority(mut self, priority: u32) -> Self {
self.priority = priority;
self
}
#[must_use]
pub fn with_ttl(mut self, ttl: Time) -> Self {
self.ttl = Some(ttl);
self
}
#[must_use]
pub fn is_expired(&self, now: Time) -> bool {
self.ttl.is_some_and(|ttl| {
let expiry = self.created_at.saturating_add_nanos(ttl.as_nanos());
now >= expiry
})
}
#[must_use]
pub fn select_endpoint(&self, object_id: Option<ObjectId>) -> Option<Arc<Endpoint>> {
self.load_balancer
.select(&self.endpoints, object_id)
.cloned()
}
#[must_use]
pub fn select_endpoints(&self, n: usize, object_id: Option<ObjectId>) -> Vec<Arc<Endpoint>> {
self.load_balancer
.select_n(&self.endpoints, n, object_id)
.into_iter()
.cloned()
.collect()
}
}
#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub enum RouteKey {
Object(ObjectId),
Region(RegionId),
ObjectAndRegion(ObjectId, RegionId),
Default,
}
impl RouteKey {
#[must_use]
pub fn object(oid: ObjectId) -> Self {
Self::Object(oid)
}
#[must_use]
pub fn region(rid: RegionId) -> Self {
Self::Region(rid)
}
}
#[derive(Debug)]
pub struct RoutingTable {
routes: RwLock<HashMap<RouteKey, RoutingEntry>>,
default_route: RwLock<Option<RoutingEntry>>,
endpoints: RwLock<HashMap<EndpointId, Arc<Endpoint>>>,
}
impl RoutingTable {
#[must_use]
pub fn new() -> Self {
Self {
routes: RwLock::new(HashMap::new()),
default_route: RwLock::new(None),
endpoints: RwLock::new(HashMap::new()),
}
}
pub fn register_endpoint(&self, endpoint: Endpoint) -> Arc<Endpoint> {
let id = endpoint.id;
let arc = Arc::new(endpoint);
self.endpoints.write().insert(id, arc.clone());
arc
}
#[must_use]
pub fn get_endpoint(&self, id: EndpointId) -> Option<Arc<Endpoint>> {
self.endpoints.read().get(&id).cloned()
}
pub fn update_endpoint_state(&self, id: EndpointId, state: EndpointState) -> bool {
self.endpoints.read().get(&id).is_some_and(|endpoint| {
endpoint.set_state(state);
true
})
}
pub fn add_route(&self, key: RouteKey, entry: RoutingEntry) {
if key == RouteKey::Default {
*self.default_route.write() = Some(entry);
} else {
self.routes.write().insert(key, entry);
}
}
pub fn remove_route(&self, key: &RouteKey) -> bool {
if *key == RouteKey::Default {
let mut default = self.default_route.write();
let had_route = default.is_some();
*default = None;
had_route
} else {
self.routes.write().remove(key).is_some()
}
}
#[must_use]
pub fn lookup(&self, key: &RouteKey) -> Option<RoutingEntry> {
if let Some(entry) = self.routes.read().get(key) {
return Some(entry.clone());
}
if let RouteKey::ObjectAndRegion(oid, rid) = key {
if let Some(entry) = self.routes.read().get(&RouteKey::Object(*oid)) {
return Some(entry.clone());
}
if let Some(entry) = self.routes.read().get(&RouteKey::Region(*rid)) {
return Some(entry.clone());
}
}
self.default_route.read().clone()
}
#[must_use]
pub fn lookup_without_default(&self, key: &RouteKey) -> Option<RoutingEntry> {
if let Some(entry) = self.routes.read().get(key) {
return Some(entry.clone());
}
if let RouteKey::ObjectAndRegion(oid, rid) = key {
if let Some(entry) = self.routes.read().get(&RouteKey::Object(*oid)) {
return Some(entry.clone());
}
if let Some(entry) = self.routes.read().get(&RouteKey::Region(*rid)) {
return Some(entry.clone());
}
}
None
}
pub fn prune_expired(&self, now: Time) -> usize {
let mut routes = self.routes.write();
let before = routes.len();
routes.retain(|_, entry| !entry.is_expired(now));
let mut pruned = before - routes.len();
drop(routes);
let mut default = self.default_route.write();
if default.as_ref().is_some_and(|entry| entry.is_expired(now)) {
*default = None;
pruned += 1;
}
drop(default);
pruned
}
#[must_use]
pub fn dispatchable_endpoints(&self) -> Vec<Arc<Endpoint>> {
let mut endpoints = self
.endpoints
.read()
.values()
.filter(|endpoint| endpoint.state().can_receive())
.cloned()
.collect::<Vec<_>>();
endpoints.sort_unstable_by_key(|endpoint| endpoint.id);
endpoints
}
#[must_use]
pub fn route_count(&self) -> usize {
let routes = self.routes.read().len();
let default = usize::from(self.default_route.read().is_some());
routes + default
}
}
impl Default for RoutingTable {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct RouteResult {
pub endpoint: Arc<Endpoint>,
pub matched_key: RouteKey,
pub is_fallback: bool,
}
#[derive(Debug)]
pub struct SymbolRouter {
table: Arc<RoutingTable>,
allow_fallback: bool,
prefer_local: bool,
local_region: Option<RegionId>,
}
impl SymbolRouter {
pub fn new(table: Arc<RoutingTable>) -> Self {
Self {
table,
allow_fallback: true,
prefer_local: false,
local_region: None,
}
}
#[must_use]
pub fn without_fallback(mut self) -> Self {
self.allow_fallback = false;
self
}
#[must_use]
pub fn with_local_preference(mut self, region: RegionId) -> Self {
self.prefer_local = true;
self.local_region = Some(region);
self
}
fn local_candidates(&self, entry: &RoutingEntry) -> Vec<Arc<Endpoint>> {
if !self.prefer_local {
return Vec::new();
}
let Some(local) = self.local_region else {
return Vec::new();
};
entry
.endpoints
.iter()
.filter(|endpoint| endpoint.region == Some(local) && endpoint.state().can_receive())
.cloned()
.collect()
}
fn select_preferred_endpoint(
&self,
entry: &RoutingEntry,
object_id: ObjectId,
) -> Option<Arc<Endpoint>> {
let local = self.local_candidates(entry);
if !local.is_empty() {
return entry.load_balancer.select(&local, Some(object_id)).cloned();
}
entry.select_endpoint(Some(object_id))
}
fn select_preferred_endpoints(
&self,
entry: &RoutingEntry,
object_id: ObjectId,
count: usize,
) -> Vec<Arc<Endpoint>> {
let local = self.local_candidates(entry);
if local.is_empty() {
return entry.select_endpoints(count, Some(object_id));
}
let local_take = local.len().min(count);
let mut selected = entry
.load_balancer
.select_n(&local, local_take, Some(object_id))
.into_iter()
.cloned()
.collect::<Vec<_>>();
if selected.len() >= count {
return selected;
}
let Some(local_region) = self.local_region else {
return entry.select_endpoints(count, Some(object_id));
};
let non_local = entry
.endpoints
.iter()
.filter(|endpoint| {
endpoint.region != Some(local_region) && endpoint.state().can_receive()
})
.cloned()
.collect::<Vec<_>>();
let remaining = count - selected.len();
let mut tail = entry
.load_balancer
.select_n(&non_local, remaining, Some(object_id))
.into_iter()
.cloned()
.collect::<Vec<_>>();
selected.append(&mut tail);
selected
}
pub fn route(&self, symbol: &Symbol) -> Result<RouteResult, RoutingError> {
let object_id = symbol.object_id();
let primary_key = RouteKey::Object(object_id);
let primary_entry = self.table.lookup_without_default(&primary_key);
if let Some(entry) = primary_entry.as_ref() {
if let Some(endpoint) = self.select_preferred_endpoint(entry, object_id) {
return Ok(RouteResult {
endpoint,
matched_key: primary_key,
is_fallback: false,
});
}
}
if self.allow_fallback {
let fallback_key = RouteKey::Default;
if let Some(entry) = self.table.lookup(&fallback_key) {
if let Some(endpoint) = entry.select_endpoint(Some(object_id)) {
return Ok(RouteResult {
endpoint,
matched_key: fallback_key,
is_fallback: true,
});
}
return Err(RoutingError::NoHealthyEndpoints { object_id });
}
}
if primary_entry.is_some() {
return Err(RoutingError::NoHealthyEndpoints { object_id });
}
Err(RoutingError::NoRoute {
object_id,
reason: "No matching route and no default route configured".into(),
})
}
pub fn route_multicast(
&self,
symbol: &Symbol,
count: usize,
) -> Result<Vec<RouteResult>, RoutingError> {
let object_id = symbol.object_id();
let key = RouteKey::Object(object_id);
let (entry, matched_key, is_fallback) =
if let Some(entry) = self.table.lookup_without_default(&key) {
(entry, key, false)
} else if self.allow_fallback {
let fallback_key = RouteKey::Default;
let fallback =
self.table
.lookup(&fallback_key)
.ok_or_else(|| RoutingError::NoRoute {
object_id,
reason: "No route for multicast".into(),
})?;
(fallback, fallback_key, true)
} else {
return Err(RoutingError::NoRoute {
object_id,
reason: "No route for multicast".into(),
});
};
let endpoints = self.select_preferred_endpoints(&entry, object_id, count);
if endpoints.is_empty() {
return Err(RoutingError::NoHealthyEndpoints { object_id });
}
let results: Vec<_> = endpoints
.into_iter()
.map(|endpoint| RouteResult {
endpoint,
matched_key: matched_key.clone(),
is_fallback,
})
.collect();
Ok(results)
}
#[must_use]
pub fn table(&self) -> &Arc<RoutingTable> {
&self.table
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum DispatchStrategy {
#[default]
Unicast,
Multicast {
count: usize,
},
Broadcast,
QuorumCast {
required: usize,
},
}
#[derive(Debug)]
pub struct DispatchResult {
pub successes: usize,
pub failures: usize,
pub sent_to: SmallVec<[EndpointId; 4]>,
pub failed_endpoints: SmallVec<[(EndpointId, DispatchError); 4]>,
pub duration: Time,
}
impl DispatchResult {
#[must_use]
pub fn all_succeeded(&self) -> bool {
self.failures == 0 && self.successes > 0
}
#[must_use]
pub fn any_succeeded(&self) -> bool {
self.successes > 0
}
#[must_use]
pub fn quorum_reached(&self, required: usize) -> bool {
self.successes >= required
}
}
#[derive(Debug, Clone)]
pub struct DispatchConfig {
pub default_strategy: DispatchStrategy,
pub timeout: Time,
pub max_retries: u32,
pub retry_delay: Time,
pub fail_fast: bool,
pub max_concurrent: u32,
}
impl Default for DispatchConfig {
fn default() -> Self {
Self {
default_strategy: DispatchStrategy::Unicast,
timeout: Time::from_secs(5),
max_retries: 3,
retry_delay: Time::from_millis(100),
fail_fast: false,
max_concurrent: 100,
}
}
}
pub struct SymbolDispatcher {
router: Arc<SymbolRouter>,
config: DispatchConfig,
active_dispatches: AtomicU32,
total_dispatched: AtomicU64,
total_failures: AtomicU64,
sinks: RwLock<EndpointSinkMap>,
}
impl std::fmt::Debug for SymbolDispatcher {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("SymbolDispatcher")
.field("router", &self.router)
.field("config", &self.config)
.field("active_dispatches", &self.active_dispatches)
.field("total_dispatched", &self.total_dispatched)
.field("total_failures", &self.total_failures)
.field(
"sinks",
&format_args!("<{} sinks>", self.sinks.read().len()),
)
.finish()
}
}
struct DispatchGuard<'a> {
dispatcher: &'a SymbolDispatcher,
}
impl Drop for DispatchGuard<'_> {
fn drop(&mut self) {
self.dispatcher
.active_dispatches
.fetch_sub(1, Ordering::Release);
}
}
impl SymbolDispatcher {
#[must_use]
pub fn new(router: Arc<SymbolRouter>, config: DispatchConfig) -> Self {
Self {
router,
config,
active_dispatches: AtomicU32::new(0),
total_dispatched: AtomicU64::new(0),
total_failures: AtomicU64::new(0),
sinks: RwLock::new(HashMap::new()),
}
}
pub fn add_sink(&self, endpoint: EndpointId, sink: Box<dyn SymbolSink>) {
self.sinks
.write()
.insert(endpoint, Arc::new(Mutex::new(sink)));
}
fn send_failed(endpoint: EndpointId) -> DispatchError {
DispatchError::SendFailed {
endpoint,
reason: "Send failed".into(),
}
}
async fn send_to_endpoint(
&self,
cx: &Cx,
endpoint: EndpointId,
symbol: AuthenticatedSymbol,
) -> Result<(), DispatchError> {
let sink = {
let sinks = self.sinks.read();
sinks.get(&endpoint).cloned()
};
let Some(sink) = sink else {
return Ok(());
};
if cx.checkpoint().is_err() {
return Err(DispatchError::Cancelled);
}
match OwnedMutexGuard::lock(sink, cx).await {
Ok(mut guard) => {
let guard: &mut Box<dyn SymbolSink> = &mut guard;
match guard.send(symbol).await {
Ok(()) => Ok(()),
Err(crate::transport::error::SinkError::Cancelled) => {
Err(DispatchError::Cancelled)
}
Err(crate::transport::error::SinkError::Io { source })
if source.kind() == std::io::ErrorKind::Interrupted
&& cx.checkpoint().is_err() =>
{
Err(DispatchError::Cancelled)
}
Err(_) => Err(Self::send_failed(endpoint)),
}
}
Err(crate::sync::LockError::Cancelled) => Err(DispatchError::Cancelled),
Err(_) => Err(DispatchError::Timeout),
}
}
pub async fn dispatch(
&self,
cx: &Cx,
symbol: AuthenticatedSymbol,
) -> Result<DispatchResult, DispatchError> {
self.dispatch_with_strategy(cx, symbol, self.config.default_strategy)
.await
}
pub async fn dispatch_with_strategy(
&self,
cx: &Cx,
symbol: AuthenticatedSymbol,
strategy: DispatchStrategy,
) -> Result<DispatchResult, DispatchError> {
let active = self.active_dispatches.fetch_add(1, Ordering::AcqRel);
if active >= self.config.max_concurrent {
self.active_dispatches.fetch_sub(1, Ordering::Release);
return Err(DispatchError::Overloaded);
}
let _guard = DispatchGuard { dispatcher: self };
let result = match strategy {
DispatchStrategy::Unicast => self.dispatch_unicast(cx, symbol).await,
DispatchStrategy::Multicast { count } => {
self.dispatch_multicast(cx, &symbol, count).await
}
DispatchStrategy::Broadcast => self.dispatch_broadcast(cx, &symbol).await,
DispatchStrategy::QuorumCast { required } => {
self.dispatch_quorum(cx, &symbol, required).await
}
};
match &result {
Ok(r) => {
self.total_dispatched
.fetch_add(r.successes as u64, Ordering::Relaxed);
self.total_failures
.fetch_add(r.failures as u64, Ordering::Relaxed);
}
Err(_) => {
self.total_failures.fetch_add(1, Ordering::Relaxed);
}
}
result
}
#[allow(clippy::unused_async)]
async fn dispatch_unicast(
&self,
cx: &Cx,
symbol: AuthenticatedSymbol,
) -> Result<DispatchResult, DispatchError> {
let route = self.router.route(symbol.symbol())?;
let _guard = route.endpoint.acquire_connection_guard();
match self.send_to_endpoint(cx, route.endpoint.id, symbol).await {
Ok(()) => {
route.endpoint.record_success(Time::ZERO);
Ok(DispatchResult {
successes: 1,
failures: 0,
sent_to: smallvec![route.endpoint.id],
failed_endpoints: SmallVec::new(),
duration: Time::ZERO,
})
}
Err(DispatchError::Cancelled) => Err(DispatchError::Cancelled),
Err(err) => {
route.endpoint.record_failure(Time::ZERO);
Err(err)
}
}
}
#[allow(clippy::unused_async)]
async fn dispatch_multicast(
&self,
cx: &Cx,
symbol: &AuthenticatedSymbol,
count: usize,
) -> Result<DispatchResult, DispatchError> {
if count == 0 {
return Ok(DispatchResult {
successes: 0,
failures: 0,
sent_to: SmallVec::new(),
failed_endpoints: SmallVec::new(),
duration: Time::ZERO,
});
}
let routes = match self.router.route_multicast(symbol.symbol(), count) {
Ok(routes) => routes,
Err(RoutingError::NoHealthyEndpoints { object_id }) => {
return Err(DispatchError::RoutingFailed(
RoutingError::NoHealthyEndpoints { object_id },
));
}
Err(e) => return Err(DispatchError::RoutingFailed(e)),
};
let mut successes = 0;
let mut failures = 0;
let mut sent_to = SmallVec::<[EndpointId; 4]>::new();
let mut failed = SmallVec::<[(EndpointId, DispatchError); 4]>::new();
for route in routes {
if cx.checkpoint().is_err() {
return Err(DispatchError::Cancelled);
}
let endpoint = route.endpoint;
let _guard = endpoint.acquire_connection_guard();
match self.send_to_endpoint(cx, endpoint.id, symbol.clone()).await {
Ok(()) => {
endpoint.record_success(Time::ZERO);
successes += 1;
sent_to.push(endpoint.id);
}
Err(DispatchError::Cancelled) => return Err(DispatchError::Cancelled),
Err(err) => {
endpoint.record_failure(Time::ZERO);
failures += 1;
failed.push((endpoint.id, err));
}
}
}
Ok(DispatchResult {
successes,
failures,
sent_to,
failed_endpoints: failed,
duration: Time::ZERO,
})
}
#[allow(clippy::unused_async)]
async fn dispatch_broadcast(
&self,
cx: &Cx,
symbol: &AuthenticatedSymbol,
) -> Result<DispatchResult, DispatchError> {
let endpoints = self.router.table().dispatchable_endpoints();
if endpoints.is_empty() {
return Err(DispatchError::NoEndpoints);
}
let mut successes = 0;
let mut failures = 0;
let mut sent_to = SmallVec::<[EndpointId; 4]>::new();
let mut failed = SmallVec::<[(EndpointId, DispatchError); 4]>::new();
for route in endpoints {
if cx.checkpoint().is_err() {
return Err(DispatchError::Cancelled);
}
let _guard = route.acquire_connection_guard();
match self.send_to_endpoint(cx, route.id, symbol.clone()).await {
Ok(()) => {
route.record_success(Time::ZERO);
successes += 1;
sent_to.push(route.id);
}
Err(DispatchError::Cancelled) => return Err(DispatchError::Cancelled),
Err(err) => {
route.record_failure(Time::ZERO);
failures += 1;
failed.push((route.id, err));
}
}
}
Ok(DispatchResult {
successes,
failures,
sent_to,
failed_endpoints: failed,
duration: Time::ZERO,
})
}
#[allow(clippy::unused_async)]
async fn dispatch_quorum(
&self,
cx: &Cx,
symbol: &AuthenticatedSymbol,
required: usize,
) -> Result<DispatchResult, DispatchError> {
let endpoints = self.router.table().dispatchable_endpoints();
if endpoints.len() < required {
return Err(DispatchError::InsufficientEndpoints {
available: endpoints.len(),
required,
});
}
let mut successes = 0;
let mut failures = 0;
let mut sent_to = SmallVec::<[EndpointId; 4]>::new();
let mut failed = SmallVec::<[(EndpointId, DispatchError); 4]>::new();
for route in endpoints {
if cx.checkpoint().is_err() {
return Err(DispatchError::Cancelled);
}
if successes >= required {
break;
}
let _guard = route.acquire_connection_guard();
match self.send_to_endpoint(cx, route.id, symbol.clone()).await {
Ok(()) => {
route.record_success(Time::ZERO);
successes += 1;
sent_to.push(route.id);
}
Err(DispatchError::Cancelled) => return Err(DispatchError::Cancelled),
Err(err) => {
route.record_failure(Time::ZERO);
failures += 1;
failed.push((route.id, err));
}
}
}
if successes < required {
return Err(DispatchError::QuorumNotReached {
achieved: successes,
required,
});
}
Ok(DispatchResult {
successes,
failures,
sent_to,
failed_endpoints: failed,
duration: Time::ZERO,
})
}
#[must_use]
pub fn stats(&self) -> DispatcherStats {
DispatcherStats {
active_dispatches: self.active_dispatches.load(Ordering::Relaxed),
total_dispatched: self.total_dispatched.load(Ordering::Relaxed),
total_failures: self.total_failures.load(Ordering::Relaxed),
}
}
}
#[derive(Debug, Clone)]
pub struct DispatcherStats {
pub active_dispatches: u32,
pub total_dispatched: u64,
pub total_failures: u64,
}
#[derive(Debug, Clone)]
pub enum RoutingError {
NoRoute {
object_id: ObjectId,
reason: String,
},
NoHealthyEndpoints {
object_id: ObjectId,
},
EmptyTable,
}
impl std::fmt::Display for RoutingError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoRoute { object_id, reason } => {
write!(f, "no route for object {object_id:?}: {reason}")
}
Self::NoHealthyEndpoints { object_id } => {
write!(f, "no healthy endpoints for object {object_id:?}")
}
Self::EmptyTable => write!(f, "routing table is empty"),
}
}
}
impl std::error::Error for RoutingError {}
impl From<RoutingError> for Error {
fn from(e: RoutingError) -> Self {
Self::new(ErrorKind::RoutingFailed).with_message(e.to_string())
}
}
#[derive(Debug, Clone)]
pub enum DispatchError {
RoutingFailed(RoutingError),
SendFailed {
endpoint: EndpointId,
reason: String,
},
Overloaded,
NoEndpoints,
InsufficientEndpoints {
available: usize,
required: usize,
},
QuorumNotReached {
achieved: usize,
required: usize,
},
Timeout,
Cancelled,
}
impl std::fmt::Display for DispatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::RoutingFailed(e) => write!(f, "routing failed: {e}"),
Self::SendFailed { endpoint, reason } => {
write!(f, "send to {endpoint} failed: {reason}")
}
Self::Overloaded => write!(f, "dispatcher overloaded"),
Self::NoEndpoints => write!(f, "no endpoints available"),
Self::InsufficientEndpoints {
available,
required,
} => {
write!(
f,
"insufficient endpoints: {available} available, {required} required"
)
}
Self::QuorumNotReached { achieved, required } => {
write!(f, "quorum not reached: {achieved} of {required} required")
}
Self::Timeout => write!(f, "dispatch timeout"),
Self::Cancelled => write!(f, "dispatch cancelled"),
}
}
}
impl std::error::Error for DispatchError {}
impl From<RoutingError> for DispatchError {
fn from(e: RoutingError) -> Self {
Self::RoutingFailed(e)
}
}
impl From<DispatchError> for Error {
fn from(e: DispatchError) -> Self {
match e {
DispatchError::RoutingFailed(_) => {
Self::new(ErrorKind::RoutingFailed).with_message(e.to_string())
}
DispatchError::QuorumNotReached { .. } => {
Self::new(ErrorKind::QuorumNotReached).with_message(e.to_string())
}
_ => Self::new(ErrorKind::DispatchFailed).with_message(e.to_string()),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::Cx;
use crate::security::authenticated::AuthenticatedSymbol;
use crate::security::tag::AuthenticationTag;
use crate::transport::error::SinkError;
use crate::types::{Symbol, SymbolId, SymbolKind};
use futures_lite::future;
use serde_json::json;
use std::collections::HashSet;
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};
fn test_endpoint(id: u64) -> Endpoint {
Endpoint::new(EndpointId(id), format!("node-{id}:8080"))
}
fn test_authenticated_symbol(esi: u32) -> AuthenticatedSymbol {
let id = SymbolId::new_for_test(1, 0, esi);
let symbol = Symbol::new(id, vec![esi as u8], SymbolKind::Source);
AuthenticatedSymbol::new_verified(symbol, AuthenticationTag::zero())
}
fn scrub_endpoint_region(region: Option<RegionId>) -> Option<&'static str> {
let _ = region?;
Some("<region>")
}
fn scrub_route_key(key: &RouteKey) -> &'static str {
match key {
RouteKey::Object(_) => "object:<object>",
RouteKey::Region(_) => "region:<region>",
RouteKey::ObjectAndRegion(_, _) => "object+region:<object>:<region>",
RouteKey::Default => "default",
}
}
fn routing_entry_snapshot(entry: &RoutingEntry) -> serde_json::Value {
json!({
"strategy": format!("{:?}", entry.load_balancer.strategy),
"priority": entry.priority,
"ttl_ms": entry.ttl.map(Time::as_millis),
"endpoint_ids": entry
.endpoints
.iter()
.map(|endpoint| endpoint.id.to_string())
.collect::<Vec<_>>(),
})
}
fn routing_table_snapshot(table: &RoutingTable) -> serde_json::Value {
let mut endpoints = table.endpoints.read().values().cloned().collect::<Vec<_>>();
endpoints.sort_unstable_by_key(|endpoint| endpoint.id);
let mut routes = table
.routes
.read()
.iter()
.map(|(key, entry)| (key.clone(), entry.clone()))
.collect::<Vec<_>>();
routes.sort_unstable_by(|(left, _), (right, _)| left.cmp(right));
json!({
"route_count": table.route_count(),
"dispatchable_endpoint_ids": table
.dispatchable_endpoints()
.into_iter()
.map(|endpoint| endpoint.id.to_string())
.collect::<Vec<_>>(),
"endpoints": endpoints
.into_iter()
.map(|endpoint| json!({
"id": endpoint.id.to_string(),
"address": endpoint.address,
"state": format!("{:?}", endpoint.state()),
"weight": endpoint.weight,
"region": scrub_endpoint_region(endpoint.region),
}))
.collect::<Vec<_>>(),
"default_route": table
.default_route
.read()
.as_ref()
.map(routing_entry_snapshot),
"routes": routes
.into_iter()
.map(|(key, entry)| json!({
"key": scrub_route_key(&key),
"entry": routing_entry_snapshot(&entry),
}))
.collect::<Vec<_>>(),
})
}
struct InterruptedSink;
impl SymbolSink for InterruptedSink {
fn poll_send(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_symbol: AuthenticatedSymbol,
) -> Poll<Result<(), SinkError>> {
Poll::Ready(Err(SinkError::Io {
source: io::Error::new(io::ErrorKind::Interrupted, "synthetic interrupt"),
}))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), SinkError>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), SinkError>> {
Poll::Ready(Ok(()))
}
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), SinkError>> {
Poll::Ready(Ok(()))
}
}
struct CancellingInterruptedSink {
cancel_cx: Cx,
}
impl SymbolSink for CancellingInterruptedSink {
fn poll_send(
self: Pin<&mut Self>,
_cx: &mut Context<'_>,
_symbol: AuthenticatedSymbol,
) -> Poll<Result<(), SinkError>> {
self.cancel_cx.set_cancel_requested(true);
Poll::Ready(Err(SinkError::Io {
source: io::Error::new(io::ErrorKind::Interrupted, "cancelled"),
}))
}
fn poll_flush(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), SinkError>> {
Poll::Ready(Ok(()))
}
fn poll_close(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), SinkError>> {
Poll::Ready(Ok(()))
}
fn poll_ready(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Result<(), SinkError>> {
Poll::Ready(Ok(()))
}
}
#[test]
fn test_endpoint_state() {
assert!(EndpointState::Healthy.can_receive());
assert!(EndpointState::Degraded.can_receive());
assert!(!EndpointState::Unhealthy.can_receive());
assert!(!EndpointState::Draining.can_receive());
assert!(!EndpointState::Removed.can_receive());
assert!(EndpointState::Healthy.is_available());
assert!(!EndpointState::Removed.is_available());
}
#[test]
fn test_endpoint_statistics() {
let endpoint = test_endpoint(1);
endpoint.record_success(Time::from_secs(1));
endpoint.record_success(Time::from_secs(2));
endpoint.record_failure(Time::from_secs(3));
assert_eq!(endpoint.symbols_sent.load(Ordering::Relaxed), 2);
assert_eq!(endpoint.failures.load(Ordering::Relaxed), 1);
let rate = endpoint.failure_rate();
assert!(rate > 0.3 && rate < 0.34);
}
#[test]
fn test_load_balancer_round_robin() {
let lb = LoadBalancer::new(LoadBalanceStrategy::RoundRobin);
let endpoints: Vec<Arc<Endpoint>> = (1..=3).map(|i| Arc::new(test_endpoint(i))).collect();
let e1 = lb.select(&endpoints, None);
let e2 = lb.select(&endpoints, None);
let e3 = lb.select(&endpoints, None);
let e4 = lb.select(&endpoints, None);
assert_eq!(e1.unwrap().id, EndpointId(1));
assert_eq!(e2.unwrap().id, EndpointId(2));
assert_eq!(e3.unwrap().id, EndpointId(3));
assert_eq!(e4.unwrap().id, EndpointId(1));
}
#[test]
fn test_load_balancer_least_connections() {
let lb = LoadBalancer::new(LoadBalanceStrategy::LeastConnections);
let e1 = Arc::new(test_endpoint(1));
let e2 = Arc::new(test_endpoint(2));
let e3 = Arc::new(test_endpoint(3));
e1.active_connections.store(5, Ordering::Relaxed);
e2.active_connections.store(2, Ordering::Relaxed);
e3.active_connections.store(10, Ordering::Relaxed);
let endpoints = vec![e1, e2.clone(), e3];
let selected = lb.select(&endpoints, None).unwrap();
assert_eq!(selected.id, e2.id); }
#[test]
fn test_load_balancer_weighted_least_connections() {
let lb = LoadBalancer::new(LoadBalanceStrategy::WeightedLeastConnections);
let e1 = Arc::new(test_endpoint(1).with_weight(1));
let e2 = Arc::new(test_endpoint(2).with_weight(4));
let e3 = Arc::new(test_endpoint(3).with_weight(2));
e1.active_connections.store(2, Ordering::Relaxed); e2.active_connections.store(4, Ordering::Relaxed); e3.active_connections.store(3, Ordering::Relaxed);
let endpoints = vec![e1, e2.clone(), e3];
let selected = lb.select(&endpoints, None).unwrap();
assert_eq!(selected.id, e2.id);
}
#[test]
fn test_load_balancer_hash_based() {
let lb = LoadBalancer::new(LoadBalanceStrategy::HashBased);
let endpoints: Vec<Arc<Endpoint>> = (1..=3).map(|i| Arc::new(test_endpoint(i))).collect();
let oid = ObjectId::new_for_test(42);
let s1 = lb.select(&endpoints, Some(oid));
let s2 = lb.select(&endpoints, Some(oid));
assert_eq!(s1.unwrap().id, s2.unwrap().id);
}
#[test]
fn test_load_balancer_random_select_n_returns_unique_healthy() {
let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
let endpoints: Vec<Arc<Endpoint>> = (0..10)
.map(|i| {
let endpoint = test_endpoint(i);
if i % 3 == 0 {
Arc::new(endpoint.with_state(EndpointState::Unhealthy))
} else {
Arc::new(endpoint)
}
})
.collect();
let selected = lb.select_n(&endpoints, 3, None);
assert_eq!(selected.len(), 3);
assert!(
selected
.iter()
.all(|endpoint| endpoint.state().can_receive())
);
let unique_ids: HashSet<_> = selected.iter().map(|endpoint| endpoint.id).collect();
assert_eq!(unique_ids.len(), selected.len());
}
#[test]
fn test_load_balancer_random_select_n_returns_all_healthy_when_n_large() {
let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
let endpoints = vec![
Arc::new(test_endpoint(1).with_state(EndpointState::Healthy)),
Arc::new(test_endpoint(2).with_state(EndpointState::Unhealthy)),
Arc::new(test_endpoint(3).with_state(EndpointState::Degraded)),
Arc::new(test_endpoint(4).with_state(EndpointState::Draining)),
Arc::new(test_endpoint(5).with_state(EndpointState::Healthy)),
];
let selected = lb.select_n(&endpoints, 16, None);
let mut selected_ids: Vec<_> = selected.iter().map(|endpoint| endpoint.id).collect();
selected_ids.sort();
assert_eq!(
selected_ids,
vec![EndpointId::new(1), EndpointId::new(3), EndpointId::new(5)]
);
}
#[test]
fn test_load_balancer_random_select_n_single_matches_select_sequence() {
let lb_select = LoadBalancer::new(LoadBalanceStrategy::Random);
let lb_select_n = LoadBalancer::new(LoadBalanceStrategy::Random);
let endpoints: Vec<Arc<Endpoint>> = (0..8)
.map(|i| {
let endpoint = test_endpoint(i);
if i % 4 == 0 {
Arc::new(endpoint.with_state(EndpointState::Unhealthy))
} else {
Arc::new(endpoint)
}
})
.collect();
for _ in 0..64 {
let selected = lb_select
.select(&endpoints, None)
.map(|endpoint| endpoint.id);
let selected_n = lb_select_n
.select_n(&endpoints, 1, None)
.first()
.map(|endpoint| endpoint.id);
assert_eq!(selected, selected_n);
}
}
#[test]
fn test_load_balancer_random_select_single_is_uniform_over_healthy() {
let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
let endpoints = vec![
Arc::new(test_endpoint(0).with_state(EndpointState::Healthy)),
Arc::new(test_endpoint(100).with_state(EndpointState::Unhealthy)),
Arc::new(test_endpoint(1).with_state(EndpointState::Healthy)),
Arc::new(test_endpoint(101).with_state(EndpointState::Draining)),
Arc::new(test_endpoint(2).with_state(EndpointState::Healthy)),
];
let mut counts = [0usize; 3];
for _ in 0..3000 {
let selected = lb.select_n(&endpoints, 1, None);
assert_eq!(selected.len(), 1);
let id = selected[0].id;
if id == EndpointId::new(0) {
counts[0] += 1;
} else if id == EndpointId::new(1) {
counts[1] += 1;
} else if id == EndpointId::new(2) {
counts[2] += 1;
} else {
panic!("selected unhealthy endpoint: {id:?}"); }
}
assert_eq!(counts.iter().sum::<usize>(), 3000);
for count in counts {
assert!((900..=1100).contains(&count), "non-uniform count: {count}");
}
}
#[test]
fn test_load_balancer_random_select_n_small_all_healthy_is_unique() {
let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
let endpoints: Vec<Arc<Endpoint>> = (0..16).map(|i| Arc::new(test_endpoint(i))).collect();
for _ in 0..64 {
let selected = lb.select_n(&endpoints, 3, None);
assert_eq!(selected.len(), 3);
assert!(
selected
.iter()
.all(|endpoint| endpoint.state().can_receive())
);
let unique_ids: HashSet<_> = selected.iter().map(|endpoint| endpoint.id).collect();
assert_eq!(unique_ids.len(), selected.len());
}
}
#[test]
fn test_load_balancer_weighted_least_connections_select_n_uses_weights() {
let lb = LoadBalancer::new(LoadBalanceStrategy::WeightedLeastConnections);
let e1 = Arc::new(test_endpoint(1).with_weight(1));
let e2 = Arc::new(test_endpoint(2).with_weight(4));
let e3 = Arc::new(test_endpoint(3).with_weight(2));
let e4 = Arc::new(test_endpoint(4).with_weight(2));
e1.active_connections.store(4, Ordering::Relaxed); e2.active_connections.store(4, Ordering::Relaxed); e3.active_connections.store(4, Ordering::Relaxed); e4.active_connections.store(1, Ordering::Relaxed);
let endpoints = vec![e1, e2.clone(), e3, e4.clone()];
let selected = lb.select_n(&endpoints, 2, None);
let selected_ids: Vec<_> = selected.iter().map(|endpoint| endpoint.id).collect();
assert_eq!(selected_ids, vec![e4.id, e2.id]);
}
#[test]
fn test_load_balancer_least_connections_select_n_preserves_input_order_on_ties() {
let lb = LoadBalancer::new(LoadBalanceStrategy::LeastConnections);
let e1 = Arc::new(test_endpoint(1));
let e2 = Arc::new(test_endpoint(2));
let e3 = Arc::new(test_endpoint(3));
let e4 = Arc::new(test_endpoint(4));
e1.active_connections.store(2, Ordering::Relaxed);
e2.active_connections.store(2, Ordering::Relaxed);
e3.active_connections.store(2, Ordering::Relaxed);
e4.active_connections.store(5, Ordering::Relaxed);
let endpoints = vec![e1.clone(), e2.clone(), e3.clone(), e4];
let selected = lb.select_n(&endpoints, 3, None);
let selected_ids: Vec<_> = selected.iter().map(|endpoint| endpoint.id).collect();
assert_eq!(selected_ids, vec![e1.id, e2.id, e3.id]);
}
#[test]
fn test_load_balancer_weighted_least_connections_select_n_preserves_input_order_on_ties() {
let lb = LoadBalancer::new(LoadBalanceStrategy::WeightedLeastConnections);
let e1 = Arc::new(test_endpoint(1).with_weight(1));
let e2 = Arc::new(test_endpoint(2).with_weight(2));
let e3 = Arc::new(test_endpoint(3).with_weight(3));
let e4 = Arc::new(test_endpoint(4).with_weight(1));
e1.active_connections.store(3, Ordering::Relaxed); e2.active_connections.store(6, Ordering::Relaxed); e3.active_connections.store(9, Ordering::Relaxed); e4.active_connections.store(7, Ordering::Relaxed);
let endpoints = vec![e1.clone(), e2.clone(), e3.clone(), e4];
let selected = lb.select_n(&endpoints, 3, None);
let selected_ids: Vec<_> = selected.iter().map(|endpoint| endpoint.id).collect();
assert_eq!(selected_ids, vec![e1.id, e2.id, e3.id]);
}
#[test]
fn test_load_balancer_weighted_round_robin_select_n_honors_weight_ring() {
let lb = LoadBalancer::new(LoadBalanceStrategy::WeightedRoundRobin);
let heavy = Arc::new(test_endpoint(1).with_weight(5));
let medium = Arc::new(test_endpoint(2).with_weight(1));
let light = Arc::new(test_endpoint(3).with_weight(1));
let endpoints = vec![heavy.clone(), medium.clone(), light.clone()];
let first: Vec<_> = lb
.select_n(&endpoints, 2, None)
.into_iter()
.map(|endpoint| endpoint.id)
.collect();
let second: Vec<_> = lb
.select_n(&endpoints, 2, None)
.into_iter()
.map(|endpoint| endpoint.id)
.collect();
let third: Vec<_> = lb
.select_n(&endpoints, 2, None)
.into_iter()
.map(|endpoint| endpoint.id)
.collect();
assert_eq!(first, vec![heavy.id, medium.id]);
assert_eq!(second, vec![light.id, heavy.id]);
assert_eq!(third, vec![heavy.id, medium.id]);
}
#[test]
fn test_load_balancer_weighted_round_robin_select_n_handles_extreme_weight_skew() {
let lb = LoadBalancer::new(LoadBalanceStrategy::WeightedRoundRobin);
let heavy = Arc::new(test_endpoint(1).with_weight(u32::MAX));
let light = Arc::new(test_endpoint(2).with_weight(1));
let endpoints = vec![heavy.clone(), light.clone()];
let selected: Vec<_> = lb
.select_n(&endpoints, 2, None)
.into_iter()
.map(|endpoint| endpoint.id)
.collect();
assert_eq!(selected, vec![heavy.id, light.id]);
}
#[test]
fn test_routing_table_basic() {
let table = RoutingTable::new();
let _e1 = table.register_endpoint(test_endpoint(1));
let e2 = table.register_endpoint(test_endpoint(2));
assert!(table.get_endpoint(EndpointId(1)).is_some());
assert!(table.get_endpoint(EndpointId(999)).is_none());
let entry = RoutingEntry::new(vec![e2], Time::ZERO);
table.add_route(RouteKey::Default, entry);
assert_eq!(table.route_count(), 1);
}
#[test]
fn test_routing_table_lookup() {
let table = RoutingTable::new();
let e1 = table.register_endpoint(test_endpoint(1));
let e2 = table.register_endpoint(test_endpoint(2));
let default = RoutingEntry::new(vec![e1], Time::ZERO);
table.add_route(RouteKey::Default, default);
let oid = ObjectId::new_for_test(42);
let specific = RoutingEntry::new(vec![e2], Time::ZERO);
table.add_route(RouteKey::Object(oid), specific);
let found = table.lookup(&RouteKey::Object(oid));
assert!(found.is_some());
let other_oid = ObjectId::new_for_test(999);
let found = table.lookup(&RouteKey::Object(other_oid));
assert!(found.is_some()); }
#[test]
fn test_routing_entry_ttl() {
let entry = RoutingEntry::new(vec![], Time::from_secs(100)).with_ttl(Time::from_secs(60));
assert!(!entry.is_expired(Time::from_secs(150)));
assert!(entry.is_expired(Time::from_secs(160)));
assert!(entry.is_expired(Time::from_secs(170)));
}
#[test]
fn test_routing_table_prune() {
let table = RoutingTable::new();
let e1 = table.register_endpoint(test_endpoint(1));
let entry1 =
RoutingEntry::new(vec![e1.clone()], Time::from_secs(0)).with_ttl(Time::from_secs(10));
let entry2 = RoutingEntry::new(vec![e1], Time::from_secs(0)).with_ttl(Time::from_secs(100));
table.add_route(RouteKey::Object(ObjectId::new_for_test(1)), entry1);
table.add_route(RouteKey::Object(ObjectId::new_for_test(2)), entry2);
assert_eq!(table.route_count(), 2);
let pruned = table.prune_expired(Time::from_secs(50));
assert_eq!(pruned, 1);
assert_eq!(table.route_count(), 1);
}
#[test]
fn test_routing_table_prune_includes_default_route() {
let table = RoutingTable::new();
let e1 = table.register_endpoint(test_endpoint(1));
let default_entry =
RoutingEntry::new(vec![e1], Time::from_secs(0)).with_ttl(Time::from_secs(10));
table.add_route(RouteKey::Default, default_entry);
assert_eq!(table.route_count(), 1);
let pruned = table.prune_expired(Time::from_secs(50));
assert_eq!(pruned, 1);
assert_eq!(table.route_count(), 0);
}
#[test]
fn test_symbol_router() {
let table = Arc::new(RoutingTable::new());
let e1 = table.register_endpoint(test_endpoint(1));
let entry = RoutingEntry::new(vec![e1], Time::ZERO);
table.add_route(RouteKey::Default, entry);
let router = SymbolRouter::new(table);
let symbol = Symbol::new_for_test(1, 0, 0, &[1, 2, 3]);
let result = router.route(&symbol);
assert!(result.is_ok());
assert_eq!(result.unwrap().endpoint.id, EndpointId(1));
}
#[test]
fn test_symbol_router_without_fallback() {
let table = Arc::new(RoutingTable::new());
let e1 = table.register_endpoint(test_endpoint(1));
let entry = RoutingEntry::new(vec![e1], Time::ZERO);
table.add_route(RouteKey::Default, entry);
let router = SymbolRouter::new(table).without_fallback();
let symbol = Symbol::new_for_test(1, 0, 0, &[1, 2, 3]);
let result = router.route(&symbol);
assert!(
result.is_err(),
"without_fallback should reject default-only route"
);
}
#[test]
fn test_symbol_router_failover() {
let table = Arc::new(RoutingTable::new());
let primary =
table.register_endpoint(test_endpoint(1).with_state(EndpointState::Unhealthy));
let backup = table.register_endpoint(test_endpoint(2).with_state(EndpointState::Healthy));
let entry = RoutingEntry::new(vec![primary, backup.clone()], Time::ZERO)
.with_strategy(LoadBalanceStrategy::FirstAvailable);
table.add_route(RouteKey::Default, entry);
let router = SymbolRouter::new(table);
let symbol = Symbol::new_for_test(1, 0, 0, &[1, 2, 3]);
let result = router.route(&symbol).expect("route");
assert_eq!(result.endpoint.id, backup.id);
}
#[test]
fn test_symbol_router_object_route_with_only_unhealthy_endpoints_returns_no_healthy() {
let table = Arc::new(RoutingTable::new());
let object_id = ObjectId::new_for_test(77);
let unhealthy =
table.register_endpoint(test_endpoint(1).with_state(EndpointState::Unhealthy));
let entry = RoutingEntry::new(vec![unhealthy], Time::ZERO)
.with_strategy(LoadBalanceStrategy::FirstAvailable);
table.add_route(RouteKey::Object(object_id), entry);
let router = SymbolRouter::new(table);
let symbol = Symbol::new_for_test(77, 0, 0, &[1, 2, 3]);
let result = router.route(&symbol);
assert!(matches!(
result,
Err(RoutingError::NoHealthyEndpoints { object_id: oid }) if oid == object_id
));
}
#[test]
fn test_symbol_router_unhealthy_default_route_returns_no_healthy() {
let table = Arc::new(RoutingTable::new());
let object_id = ObjectId::new_for_test(88);
let unhealthy =
table.register_endpoint(test_endpoint(1).with_state(EndpointState::Unhealthy));
let entry = RoutingEntry::new(vec![unhealthy], Time::ZERO)
.with_strategy(LoadBalanceStrategy::FirstAvailable);
table.add_route(RouteKey::Default, entry);
let router = SymbolRouter::new(table);
let symbol = Symbol::new_for_test(88, 0, 0, &[1, 2, 3]);
let result = router.route(&symbol);
assert!(matches!(
result,
Err(RoutingError::NoHealthyEndpoints { object_id: oid }) if oid == object_id
));
}
#[test]
fn test_symbol_router_without_any_route_still_returns_no_route() {
let table = Arc::new(RoutingTable::new());
let router = SymbolRouter::new(table);
let object_id = ObjectId::new_for_test(99);
let symbol = Symbol::new_for_test(99, 0, 0, &[1, 2, 3]);
let result = router.route(&symbol);
assert!(matches!(
result,
Err(RoutingError::NoRoute { object_id: oid, .. }) if oid == object_id
));
}
#[test]
fn test_symbol_router_local_preference_unicast() {
let table = Arc::new(RoutingTable::new());
let local_region = RegionId::new_for_test(7, 0);
let remote_region = RegionId::new_for_test(8, 0);
let remote = table.register_endpoint(
test_endpoint(1)
.with_region(remote_region)
.with_state(EndpointState::Healthy),
);
let local = table.register_endpoint(
test_endpoint(2)
.with_region(local_region)
.with_state(EndpointState::Healthy),
);
let object_id = ObjectId::new_for_test(42);
let entry = RoutingEntry::new(vec![remote, local.clone()], Time::ZERO)
.with_strategy(LoadBalanceStrategy::FirstAvailable);
table.add_route(RouteKey::Object(object_id), entry);
let router = SymbolRouter::new(table).with_local_preference(local_region);
let symbol = Symbol::new_for_test(42, 0, 0, &[1, 2, 3]);
let result = router.route(&symbol).expect("route with local preference");
assert_eq!(result.endpoint.id, local.id);
assert!(!result.is_fallback);
}
#[test]
fn test_symbol_router_multicast() {
let table = Arc::new(RoutingTable::new());
let e1 = table.register_endpoint(test_endpoint(1));
let e2 = table.register_endpoint(test_endpoint(2));
let e3 = table.register_endpoint(test_endpoint(3));
let entry = RoutingEntry::new(vec![e1, e2, e3], Time::ZERO);
table.add_route(RouteKey::Default, entry);
let router = SymbolRouter::new(table);
let symbol = Symbol::new_for_test(1, 0, 0, &[1, 2, 3]);
let results = router.route_multicast(&symbol, 2);
assert!(results.is_ok());
assert_eq!(results.unwrap().len(), 2);
}
#[test]
fn test_symbol_router_multicast_weighted_round_robin_respects_weights_across_calls() {
let table = Arc::new(RoutingTable::new());
let heavy = table.register_endpoint(test_endpoint(1).with_weight(5));
let medium = table.register_endpoint(test_endpoint(2).with_weight(1));
let light = table.register_endpoint(test_endpoint(3).with_weight(1));
let object_id = ObjectId::new_for_test(77);
let entry = RoutingEntry::new(
vec![heavy.clone(), medium.clone(), light.clone()],
Time::ZERO,
)
.with_strategy(LoadBalanceStrategy::WeightedRoundRobin);
table.add_route(RouteKey::Object(object_id), entry);
let router = SymbolRouter::new(table);
let symbol = Symbol::new_for_test(77, 0, 0, &[7, 7]);
let first: Vec<_> = router
.route_multicast(&symbol, 2)
.expect("first weighted multicast")
.into_iter()
.map(|route| route.endpoint.id)
.collect();
let second: Vec<_> = router
.route_multicast(&symbol, 2)
.expect("second weighted multicast")
.into_iter()
.map(|route| route.endpoint.id)
.collect();
let third: Vec<_> = router
.route_multicast(&symbol, 2)
.expect("third weighted multicast")
.into_iter()
.map(|route| route.endpoint.id)
.collect();
assert_eq!(first, vec![heavy.id, medium.id]);
assert_eq!(second, vec![light.id, heavy.id]);
assert_eq!(third, vec![heavy.id, medium.id]);
}
#[test]
fn test_symbol_router_local_preference_multicast_fills_local_first() {
let table = Arc::new(RoutingTable::new());
let local_region = RegionId::new_for_test(11, 0);
let remote_region = RegionId::new_for_test(12, 0);
let local_a = table.register_endpoint(
test_endpoint(1)
.with_region(local_region)
.with_state(EndpointState::Healthy),
);
let remote = table.register_endpoint(
test_endpoint(2)
.with_region(remote_region)
.with_state(EndpointState::Healthy),
);
let local_b = table.register_endpoint(
test_endpoint(3)
.with_region(local_region)
.with_state(EndpointState::Healthy),
);
let object_id = ObjectId::new_for_test(9);
let entry = RoutingEntry::new(vec![local_a.clone(), remote, local_b.clone()], Time::ZERO)
.with_strategy(LoadBalanceStrategy::RoundRobin);
table.add_route(RouteKey::Object(object_id), entry);
let router = SymbolRouter::new(table).with_local_preference(local_region);
let symbol = Symbol::new_for_test(9, 0, 0, &[9]);
let multicast_routes = router
.route_multicast(&symbol, 2)
.expect("multicast with local preference");
let selected: HashSet<_> = multicast_routes
.into_iter()
.map(|route| route.endpoint.id)
.collect();
let expected: HashSet<_> = [local_a.id, local_b.id].into_iter().collect();
assert_eq!(selected, expected);
}
#[test]
fn test_dispatch_result_quorum() {
let result = DispatchResult {
successes: 3,
failures: 1,
sent_to: smallvec![EndpointId(1), EndpointId(2), EndpointId(3)],
failed_endpoints: SmallVec::new(),
duration: Time::ZERO,
};
assert!(result.quorum_reached(2));
assert!(result.quorum_reached(3));
assert!(!result.quorum_reached(4));
assert!(result.any_succeeded());
assert!(!result.all_succeeded()); }
#[test]
fn dispatch_result_unicast_stays_inline() {
let result = DispatchResult {
successes: 1,
failures: 0,
sent_to: smallvec![EndpointId(7)],
failed_endpoints: SmallVec::new(),
duration: Time::ZERO,
};
assert!(!result.sent_to.spilled());
assert!(!result.failed_endpoints.spilled());
}
#[test]
fn test_endpoint_connections() {
let endpoint = test_endpoint(1);
assert_eq!(endpoint.connection_count(), 0);
endpoint.acquire_connection();
endpoint.acquire_connection();
assert_eq!(endpoint.connection_count(), 2);
endpoint.release_connection();
assert_eq!(endpoint.connection_count(), 1);
}
#[test]
fn test_endpoint_release_connection_saturates() {
let endpoint = test_endpoint(1);
endpoint.release_connection();
assert_eq!(endpoint.connection_count(), 0);
}
#[test]
fn test_routing_table_updates_endpoint_state() {
let table = RoutingTable::new();
let endpoint = table.register_endpoint(test_endpoint(9));
assert_eq!(endpoint.state(), EndpointState::Healthy);
assert!(table.update_endpoint_state(EndpointId(9), EndpointState::Draining));
assert_eq!(endpoint.state(), EndpointState::Draining);
assert!(!table.update_endpoint_state(EndpointId(999), EndpointState::Healthy));
}
#[test]
fn test_routing_table_dispatchable_endpoints_include_degraded_in_id_order() {
let table = RoutingTable::new();
let degraded =
table.register_endpoint(test_endpoint(3).with_state(EndpointState::Degraded));
let healthy = table.register_endpoint(test_endpoint(1).with_state(EndpointState::Healthy));
let _unhealthy =
table.register_endpoint(test_endpoint(2).with_state(EndpointState::Unhealthy));
let ids: Vec<_> = table
.dispatchable_endpoints()
.into_iter()
.map(|endpoint| endpoint.id)
.collect();
assert_eq!(ids, vec![healthy.id, degraded.id]);
}
#[test]
fn test_symbol_dispatcher_broadcast_uses_dispatchable_endpoints_in_id_order() {
let table = Arc::new(RoutingTable::new());
let degraded =
table.register_endpoint(test_endpoint(3).with_state(EndpointState::Degraded));
let healthy_a =
table.register_endpoint(test_endpoint(1).with_state(EndpointState::Healthy));
let healthy_b =
table.register_endpoint(test_endpoint(2).with_state(EndpointState::Healthy));
let router = Arc::new(SymbolRouter::new(table));
let dispatcher = SymbolDispatcher::new(router, DispatchConfig::default());
let cx: Cx = Cx::for_testing();
let result = future::block_on(dispatcher.dispatch_with_strategy(
&cx,
test_authenticated_symbol(7),
DispatchStrategy::Broadcast,
))
.expect("broadcast dispatch should succeed");
let sent_to: Vec<_> = result.sent_to.into_iter().collect();
assert_eq!(sent_to, vec![healthy_a.id, healthy_b.id, degraded.id]);
}
#[test]
fn test_symbol_dispatcher_quorum_uses_lowest_dispatchable_ids_first() {
let table = Arc::new(RoutingTable::new());
let degraded =
table.register_endpoint(test_endpoint(3).with_state(EndpointState::Degraded));
let healthy_a =
table.register_endpoint(test_endpoint(1).with_state(EndpointState::Healthy));
let healthy_b =
table.register_endpoint(test_endpoint(2).with_state(EndpointState::Healthy));
let router = Arc::new(SymbolRouter::new(table));
let dispatcher = SymbolDispatcher::new(router, DispatchConfig::default());
let cx: Cx = Cx::for_testing();
let result = future::block_on(dispatcher.dispatch_with_strategy(
&cx,
test_authenticated_symbol(8),
DispatchStrategy::QuorumCast { required: 2 },
))
.expect("quorum dispatch should succeed");
let sent_to: Vec<_> = result.sent_to.iter().copied().collect();
assert_eq!(sent_to, vec![healthy_a.id, healthy_b.id]);
assert_eq!(result.successes, 2);
assert_eq!(result.failures, 0);
assert!(result.quorum_reached(2));
assert!(!sent_to.contains(°raded.id));
}
#[test]
fn test_symbol_dispatcher_unicast_interrupted_io_without_cancel_stays_send_failure() {
let table = Arc::new(RoutingTable::new());
let endpoint = table.register_endpoint(test_endpoint(41));
table.add_route(
RouteKey::Default,
RoutingEntry::new(vec![endpoint.clone()], Time::ZERO),
);
let router = Arc::new(SymbolRouter::new(table));
let dispatcher = SymbolDispatcher::new(router, DispatchConfig::default());
dispatcher.add_sink(endpoint.id, Box::new(InterruptedSink));
let cx: Cx = Cx::for_testing();
let result = future::block_on(dispatcher.dispatch_with_strategy(
&cx,
test_authenticated_symbol(41),
DispatchStrategy::Unicast,
));
assert!(matches!(
result,
Err(DispatchError::SendFailed {
endpoint: failed_endpoint,
..
}) if failed_endpoint == endpoint.id
));
assert_eq!(endpoint.failures.load(Ordering::Relaxed), 1);
assert!(!cx.is_cancel_requested());
}
#[test]
fn test_symbol_dispatcher_broadcast_mid_send_cancel_returns_cancelled() {
let table = Arc::new(RoutingTable::new());
let endpoint = table.register_endpoint(test_endpoint(52));
let router = Arc::new(SymbolRouter::new(table));
let dispatcher = SymbolDispatcher::new(router, DispatchConfig::default());
let cx: Cx = Cx::for_testing();
dispatcher.add_sink(
endpoint.id,
Box::new(CancellingInterruptedSink {
cancel_cx: cx.clone(),
}),
);
let result = future::block_on(dispatcher.dispatch_with_strategy(
&cx,
test_authenticated_symbol(52),
DispatchStrategy::Broadcast,
));
assert!(matches!(result, Err(DispatchError::Cancelled)));
assert_eq!(endpoint.failures.load(Ordering::Relaxed), 0);
assert!(cx.is_cancel_requested());
}
#[test]
fn test_routing_error_display() {
let oid = ObjectId::new_for_test(42);
let no_route = RoutingError::NoRoute {
object_id: oid,
reason: "test".into(),
};
assert!(no_route.to_string().contains("no route"));
let no_healthy = RoutingError::NoHealthyEndpoints { object_id: oid };
assert!(no_healthy.to_string().contains("healthy"));
}
#[test]
fn test_dispatch_error_display() {
let overloaded = DispatchError::Overloaded;
assert!(overloaded.to_string().contains("overloaded"));
let quorum = DispatchError::QuorumNotReached {
achieved: 2,
required: 3,
};
assert!(quorum.to_string().contains("quorum"));
assert!(quorum.to_string().contains('2'));
assert!(quorum.to_string().contains('3'));
}
#[test]
fn endpoint_id_debug_display() {
let id = EndpointId::new(42);
assert!(format!("{id:?}").contains("42"));
assert_eq!(id.to_string(), "Endpoint(42)");
}
#[test]
fn endpoint_id_clone_copy_eq() {
let id = EndpointId::new(7);
let id2 = id;
assert_eq!(id, id2);
}
#[test]
fn endpoint_id_ord_hash() {
let a = EndpointId::new(1);
let b = EndpointId::new(2);
assert!(a < b);
let mut set = HashSet::new();
set.insert(a);
set.insert(b);
assert_eq!(set.len(), 2);
}
#[test]
fn endpoint_state_debug_clone_copy_eq() {
let s = EndpointState::Healthy;
let s2 = s;
assert_eq!(s, s2);
assert!(format!("{s:?}").contains("Healthy"));
}
#[test]
fn endpoint_state_as_u8_roundtrip() {
let states = [
EndpointState::Healthy,
EndpointState::Degraded,
EndpointState::Unhealthy,
EndpointState::Draining,
EndpointState::Removed,
];
for &s in &states {
assert_eq!(EndpointState::from_u8(s.as_u8()), s);
}
}
#[test]
fn endpoint_state_from_u8_invalid() {
let s = EndpointState::from_u8(255);
assert_eq!(s, EndpointState::Removed);
}
#[test]
fn endpoint_debug() {
let ep = Endpoint::new(EndpointId::new(1), "addr:80");
let dbg = format!("{ep:?}");
assert!(dbg.contains("Endpoint"));
}
#[test]
fn endpoint_with_weight_region() {
let region = RegionId::new_for_test(1, 0);
let ep = Endpoint::new(EndpointId::new(5), "host:80")
.with_weight(200)
.with_region(region);
assert_eq!(ep.weight, 200);
assert_eq!(ep.region, Some(region));
}
#[test]
fn endpoint_with_state_setter() {
let ep = Endpoint::new(EndpointId::new(1), "h:80").with_state(EndpointState::Draining);
assert_eq!(ep.state(), EndpointState::Draining);
ep.set_state(EndpointState::Healthy);
assert_eq!(ep.state(), EndpointState::Healthy);
}
#[test]
fn endpoint_failure_rate_zero() {
let ep = Endpoint::new(EndpointId::new(1), "h:80");
assert!((ep.failure_rate() - 0.0).abs() < f64::EPSILON);
}
#[test]
fn endpoint_connection_guard_drops() {
let ep = Endpoint::new(EndpointId::new(1), "h:80");
{
let _guard = ep.acquire_connection_guard();
assert_eq!(ep.connection_count(), 1);
}
assert_eq!(ep.connection_count(), 0);
}
#[test]
fn load_balance_strategy_debug_clone_copy_eq_default() {
let s = LoadBalanceStrategy::default();
assert_eq!(s, LoadBalanceStrategy::RoundRobin);
let s2 = s;
assert_eq!(s, s2);
assert!(format!("{s:?}").contains("RoundRobin"));
}
#[test]
fn route_key_debug_clone_eq_ord_hash() {
let oid = ObjectId::new_for_test(1);
let k1 = RouteKey::Object(oid);
let k2 = k1.clone();
assert_eq!(k1, k2);
assert!(format!("{k1:?}").contains("Object"));
assert!(k1 <= k2);
let mut set = HashSet::new();
set.insert(k1);
set.insert(RouteKey::Default);
assert_eq!(set.len(), 2);
}
#[test]
fn route_key_constructors() {
let oid = ObjectId::new_for_test(1);
let rid = RegionId::new_for_test(2, 0);
assert_eq!(RouteKey::object(oid), RouteKey::Object(oid));
assert_eq!(RouteKey::region(rid), RouteKey::Region(rid));
}
#[test]
fn dispatch_strategy_debug_clone_copy_eq_default() {
let s = DispatchStrategy::default();
assert_eq!(s, DispatchStrategy::Unicast);
let s2 = s;
assert_eq!(s, s2);
assert!(format!("{s:?}").contains("Unicast"));
}
#[test]
fn dispatch_config_debug_clone_default() {
let cfg = DispatchConfig::default();
let cfg2 = cfg;
assert_eq!(cfg2.max_retries, 3);
assert!(format!("{cfg2:?}").contains("DispatchConfig"));
}
#[test]
fn dispatcher_stats_debug() {
let stats = DispatcherStats {
active_dispatches: 0,
total_dispatched: 100,
total_failures: 5,
};
let dbg = format!("{stats:?}");
assert!(dbg.contains("100"));
}
#[test]
fn routing_error_debug_clone() {
let err = RoutingError::EmptyTable;
let err2 = err;
assert!(format!("{err2:?}").contains("EmptyTable"));
}
#[test]
fn routing_error_display_all_variants() {
let oid = ObjectId::new_for_test(1);
let e1 = RoutingError::NoRoute {
object_id: oid,
reason: "gone".into(),
};
assert!(e1.to_string().contains("no route"));
assert!(e1.to_string().contains("gone"));
let e2 = RoutingError::NoHealthyEndpoints { object_id: oid };
assert!(e2.to_string().contains("healthy"));
let e3 = RoutingError::EmptyTable;
assert!(e3.to_string().contains("empty"));
}
#[test]
fn routing_error_trait() {
let err: Box<dyn std::error::Error> = Box::new(RoutingError::EmptyTable);
assert!(!err.to_string().is_empty());
}
#[test]
fn dispatch_error_debug_clone() {
let err = DispatchError::Timeout;
let err2 = err;
assert!(format!("{err2:?}").contains("Timeout"));
}
#[test]
fn dispatch_error_display_all_variants() {
let e1 = DispatchError::RoutingFailed(RoutingError::EmptyTable);
assert!(e1.to_string().contains("routing failed"));
let e2 = DispatchError::SendFailed {
endpoint: EndpointId::new(3),
reason: "down".into(),
};
assert!(e2.to_string().contains("send"));
let e3 = DispatchError::NoEndpoints;
assert!(e3.to_string().contains("no endpoints"));
let e4 = DispatchError::InsufficientEndpoints {
available: 1,
required: 3,
};
assert!(e4.to_string().contains("insufficient"));
let e5 = DispatchError::Timeout;
assert!(e5.to_string().contains("timeout"));
}
#[test]
fn dispatch_error_from_routing_error() {
let re = RoutingError::EmptyTable;
let de = DispatchError::from(re);
assert!(matches!(de, DispatchError::RoutingFailed(_)));
}
#[test]
fn dispatch_error_trait() {
let err: Box<dyn std::error::Error> = Box::new(DispatchError::Timeout);
assert!(!err.to_string().is_empty());
}
#[test]
fn routing_entry_with_priority() {
let entry = RoutingEntry::new(vec![], Time::ZERO).with_priority(10);
assert_eq!(entry.priority, 10);
}
#[test]
fn routing_entry_select_endpoint_empty() {
let entry = RoutingEntry::new(vec![], Time::ZERO);
assert!(entry.select_endpoint(None).is_none());
}
#[test]
fn load_balancer_debug() {
let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
assert!(format!("{lb:?}").contains("Random"));
}
#[test]
fn routing_table_debug() {
let table = RoutingTable::new();
assert!(format!("{table:?}").contains("RoutingTable"));
}
#[test]
fn routing_table_state_snapshot_scrubbed() {
let table = RoutingTable::new();
let region = RegionId::new_for_test(9, 2);
let object_id = ObjectId::new_for_test(44);
let primary = table.register_endpoint(
test_endpoint(1)
.with_weight(200)
.with_region(region)
.with_state(EndpointState::Healthy),
);
let backup = table.register_endpoint(
test_endpoint(2)
.with_weight(50)
.with_state(EndpointState::Degraded),
);
let draining = table.register_endpoint(
test_endpoint(3)
.with_weight(10)
.with_state(EndpointState::Draining),
);
table.add_route(
RouteKey::Default,
RoutingEntry::new(vec![backup.clone()], Time::ZERO)
.with_priority(90)
.with_strategy(LoadBalanceStrategy::FirstAvailable),
);
table.add_route(
RouteKey::Object(object_id),
RoutingEntry::new(vec![primary, backup], Time::ZERO)
.with_priority(10)
.with_ttl(Time::from_secs(30))
.with_strategy(LoadBalanceStrategy::WeightedRoundRobin),
);
table.add_route(
RouteKey::Region(region),
RoutingEntry::new(vec![draining], Time::ZERO)
.with_priority(40)
.with_strategy(LoadBalanceStrategy::RoundRobin),
);
insta::assert_json_snapshot!(
"routing_table_state_scrubbed",
routing_table_snapshot(&table)
);
}
}