1use crate::cx::Cx;
10use crate::error::{Error, ErrorKind};
11use crate::security::authenticated::AuthenticatedSymbol;
12use crate::sync::Mutex;
13use crate::sync::OwnedMutexGuard;
14use crate::transport::sink::{SymbolSink, SymbolSinkExt};
15use crate::types::symbol::{ObjectId, Symbol};
16use crate::types::{RegionId, Time};
17use parking_lot::RwLock;
18use smallvec::{SmallVec, smallvec};
19use std::collections::HashMap;
20use std::sync::Arc;
21use std::sync::atomic::{AtomicU8, AtomicU32, AtomicU64, Ordering};
22
23type EndpointSinkMap = HashMap<EndpointId, Arc<Mutex<Box<dyn SymbolSink>>>>;
24
25#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
31pub struct EndpointId(pub u64);
32
33impl EndpointId {
34 #[must_use]
36 pub const fn new(id: u64) -> Self {
37 Self(id)
38 }
39}
40
41impl std::fmt::Display for EndpointId {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 write!(f, "Endpoint({})", self.0)
44 }
45}
46
47#[derive(Debug, Clone, Copy, PartialEq, Eq)]
49#[repr(u8)]
50pub enum EndpointState {
51 Healthy,
53
54 Degraded,
56
57 Unhealthy,
59
60 Draining,
62
63 Removed,
65}
66
67impl EndpointState {
68 const fn as_u8(self) -> u8 {
69 self as u8
70 }
71
72 fn from_u8(value: u8) -> Self {
73 match value {
74 x if x == Self::Healthy as u8 => Self::Healthy,
75 x if x == Self::Degraded as u8 => Self::Degraded,
76 x if x == Self::Unhealthy as u8 => Self::Unhealthy,
77 x if x == Self::Draining as u8 => Self::Draining,
78 _ => Self::Removed,
79 }
80 }
81
82 #[must_use]
84 pub const fn can_receive(&self) -> bool {
85 matches!(self, Self::Healthy | Self::Degraded)
86 }
87
88 #[must_use]
90 pub const fn is_available(&self) -> bool {
91 !matches!(self, Self::Removed)
92 }
93}
94
95#[derive(Debug)]
97pub struct Endpoint {
98 pub id: EndpointId,
100
101 pub address: String,
103
104 state: AtomicU8,
106
107 pub weight: u32,
109
110 pub region: Option<RegionId>,
112
113 pub active_connections: AtomicU32,
115
116 pub symbols_sent: AtomicU64,
118
119 pub failures: AtomicU64,
121
122 pub last_success: AtomicU64,
124
125 pub last_failure: AtomicU64,
127
128 pub metadata: HashMap<String, String>,
130}
131
132impl Endpoint {
133 pub fn new(id: EndpointId, address: impl Into<String>) -> Self {
135 Self {
136 id,
137 address: address.into(),
138 state: AtomicU8::new(EndpointState::Healthy.as_u8()),
139 weight: 100,
140 region: None,
141 active_connections: AtomicU32::new(0),
142 symbols_sent: AtomicU64::new(0),
143 failures: AtomicU64::new(0),
144 last_success: AtomicU64::new(0),
145 last_failure: AtomicU64::new(0),
146 metadata: HashMap::new(),
147 }
148 }
149
150 #[must_use]
152 pub fn with_weight(mut self, weight: u32) -> Self {
153 self.weight = weight;
154 self
155 }
156
157 #[must_use]
159 pub fn with_region(mut self, region: RegionId) -> Self {
160 self.region = Some(region);
161 self
162 }
163
164 #[must_use]
166 pub fn with_state(self, state: EndpointState) -> Self {
167 self.state.store(state.as_u8(), Ordering::Relaxed);
168 self
169 }
170
171 #[must_use]
173 pub fn state(&self) -> EndpointState {
174 EndpointState::from_u8(self.state.load(Ordering::Relaxed))
175 }
176
177 pub fn set_state(&self, state: EndpointState) {
179 self.state.store(state.as_u8(), Ordering::Relaxed);
180 }
181
182 pub fn record_success(&self, now: Time) {
184 self.symbols_sent.fetch_add(1, Ordering::Relaxed);
185 self.last_success.store(now.as_nanos(), Ordering::Relaxed);
186 }
187
188 pub fn record_failure(&self, now: Time) {
190 self.failures.fetch_add(1, Ordering::Relaxed);
191 self.last_failure.store(now.as_nanos(), Ordering::Relaxed);
192 }
193
194 pub fn acquire_connection(&self) {
196 self.active_connections.fetch_add(1, Ordering::Relaxed);
197 }
198
199 pub fn release_connection(&self) {
201 let _ =
202 self.active_connections
203 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| {
204 Some(current.saturating_sub(1))
205 });
206 }
207
208 #[must_use]
210 pub fn connection_count(&self) -> u32 {
211 self.active_connections.load(Ordering::Relaxed)
212 }
213
214 #[must_use]
216 #[allow(clippy::cast_precision_loss)]
217 pub fn failure_rate(&self) -> f64 {
218 let sent = self.symbols_sent.load(Ordering::Relaxed);
219 let failures = self.failures.load(Ordering::Relaxed);
220 let total = sent + failures;
221 if total == 0 {
222 0.0
223 } else {
224 failures as f64 / total as f64
225 }
226 }
227
228 pub fn acquire_connection_guard(&self) -> ConnectionGuard<'_> {
232 self.acquire_connection();
233 ConnectionGuard { endpoint: self }
234 }
235}
236
237pub struct ConnectionGuard<'a> {
239 endpoint: &'a Endpoint,
240}
241
242impl Drop for ConnectionGuard<'_> {
243 fn drop(&mut self) {
244 self.endpoint.release_connection();
245 }
246}
247
248#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
254pub enum LoadBalanceStrategy {
255 #[default]
257 RoundRobin,
258
259 WeightedRoundRobin,
261
262 LeastConnections,
264
265 WeightedLeastConnections,
267
268 Random,
270
271 HashBased,
273
274 FirstAvailable,
276}
277
278#[derive(Debug)]
280pub struct LoadBalancer {
281 strategy: LoadBalanceStrategy,
283
284 rr_counter: AtomicU64,
286
287 random_seed: AtomicU64,
289}
290
291impl LoadBalancer {
292 const LCG_MULTIPLIER: u64 = 6_364_136_223_846_793_005;
293 const LCG_INCREMENT: u64 = 1;
294 const RANDOM_FLOYD_SMALL_N_MAX: usize = 8;
295
296 #[inline]
297 fn next_lcg(seed: u64) -> u64 {
298 seed.wrapping_mul(Self::LCG_MULTIPLIER)
299 .wrapping_add(Self::LCG_INCREMENT)
300 }
301
302 #[inline]
303 fn compare_weighted_load(a: &Endpoint, b: &Endpoint) -> std::cmp::Ordering {
304 let a_conn = u128::from(a.connection_count());
305 let b_conn = u128::from(b.connection_count());
306 let a_weight = u128::from(a.weight.max(1));
307 let b_weight = u128::from(b.weight.max(1));
308 (a_conn * b_weight).cmp(&(b_conn * a_weight))
309 }
310
311 #[must_use]
313 pub fn new(strategy: LoadBalanceStrategy) -> Self {
314 Self {
315 strategy,
316 rr_counter: AtomicU64::new(0),
317 random_seed: AtomicU64::new(0),
318 }
319 }
320
321 pub fn select<'a>(
323 &self,
324 endpoints: &'a [Arc<Endpoint>],
325 object_id: Option<ObjectId>,
326 ) -> Option<&'a Arc<Endpoint>> {
327 if matches!(self.strategy, LoadBalanceStrategy::Random) {
328 return self.select_random_single_without_materializing(endpoints);
329 }
330
331 let available: SmallVec<[&Arc<Endpoint>; 8]> = endpoints
332 .iter()
333 .filter(|e| e.state().can_receive())
334 .collect();
335
336 if available.is_empty() {
337 return None;
338 }
339
340 match self.strategy {
341 LoadBalanceStrategy::RoundRobin => {
342 let idx =
343 (self.rr_counter.fetch_add(1, Ordering::Relaxed) as usize) % available.len();
344 Some(available[idx])
345 }
346
347 LoadBalanceStrategy::WeightedRoundRobin => {
348 let total_weight: u64 = available.iter().map(|e| u64::from(e.weight)).sum();
349 if total_weight == 0 {
350 return Some(available[0]);
351 }
352
353 let counter = self.rr_counter.fetch_add(1, Ordering::Relaxed);
354 let target = counter % total_weight;
355
356 let mut cumulative = 0u64;
357 for endpoint in &available {
358 cumulative += u64::from(endpoint.weight);
359 if target < cumulative {
360 return Some(endpoint);
361 }
362 }
363 Some(*available.last().unwrap())
364 }
365
366 LoadBalanceStrategy::LeastConnections => {
367 available.into_iter().min_by_key(|e| e.connection_count())
368 }
369
370 LoadBalanceStrategy::WeightedLeastConnections => available
371 .into_iter()
372 .min_by(|a, b| Self::compare_weighted_load(a, b)),
373
374 LoadBalanceStrategy::Random => {
375 self.select_random_single_without_materializing(endpoints)
378 }
379
380 LoadBalanceStrategy::HashBased => object_id.map_or_else(
381 || {
382 let idx = (self.rr_counter.fetch_add(1, Ordering::Relaxed) as usize)
384 % available.len();
385 Some(available[idx])
386 },
387 |oid| {
388 let hash = oid.as_u128() as usize;
389 Some(available[hash % available.len()])
390 },
391 ),
392 LoadBalanceStrategy::FirstAvailable => Some(available[0]),
393 }
394 }
395
396 pub fn select_n<'a>(
398 &self,
399 endpoints: &'a [Arc<Endpoint>],
400 n: usize,
401 _object_id: Option<ObjectId>,
402 ) -> Vec<&'a Arc<Endpoint>> {
403 if n == 0 {
404 return Vec::new();
405 }
406
407 if matches!(self.strategy, LoadBalanceStrategy::Random) && n == 1 {
408 return self
409 .select_random_single_without_materializing(endpoints)
410 .into_iter()
411 .collect();
412 }
413
414 if matches!(self.strategy, LoadBalanceStrategy::Random)
415 && n <= Self::RANDOM_FLOYD_SMALL_N_MAX
416 {
417 if let Some(selected) = self.select_n_random_small_without_materializing(endpoints, n) {
418 return selected;
419 }
420 }
421
422 let mut available: Vec<&Arc<Endpoint>> = Vec::with_capacity(endpoints.len());
425 for endpoint in endpoints {
426 if endpoint.state().can_receive() {
427 available.push(endpoint);
428 }
429 }
430
431 if available.is_empty() {
432 return Vec::new();
433 }
434
435 if n >= available.len() {
436 return available;
437 }
438
439 match self.strategy {
440 LoadBalanceStrategy::RoundRobin => {
441 let start = self.rr_counter.fetch_add(n as u64, Ordering::Relaxed) as usize;
442 let len = available.len();
443 (0..n).map(|i| available[(start + i) % len]).collect()
444 }
445
446 LoadBalanceStrategy::Random => {
447 let mut seed = self.random_seed.fetch_add(n as u64, Ordering::Relaxed);
450 let len = available.len();
451 let count = n.min(len);
452
453 for i in 0..count {
454 seed = Self::next_lcg(seed);
456 let range = len - i;
458 let offset = (seed as usize) % range;
459 let swap_idx = i + offset;
460 available.swap(i, swap_idx);
461 }
462 available.truncate(count);
463 available
464 }
465 LoadBalanceStrategy::LeastConnections => {
466 let mut candidates = available;
467 candidates.sort_by_key(|e| e.connection_count());
470 candidates.truncate(n);
471 candidates
472 }
473 LoadBalanceStrategy::WeightedLeastConnections => {
474 let mut candidates = available;
475 candidates.sort_by(|a, b| Self::compare_weighted_load(a, b));
476 candidates.truncate(n);
477 candidates
478 }
479
480 _ => available.into_iter().take(n).collect(),
482 }
483 }
484
485 fn select_random_single_without_materializing<'a>(
491 &self,
492 endpoints: &'a [Arc<Endpoint>],
493 ) -> Option<&'a Arc<Endpoint>> {
494 let mut seed = self.random_seed.fetch_add(1, Ordering::Relaxed);
495 let mut selected: Option<&Arc<Endpoint>> = None;
496 let mut healthy_seen = 0usize;
497
498 for endpoint in endpoints {
499 if !endpoint.state().can_receive() {
500 continue;
501 }
502
503 healthy_seen += 1;
504 seed = Self::next_lcg(seed);
506 if (seed as usize).is_multiple_of(healthy_seen) {
507 selected = Some(endpoint);
508 }
509 }
510
511 selected
512 }
513
514 fn select_n_random_small_without_materializing<'a>(
523 &self,
524 endpoints: &'a [Arc<Endpoint>],
525 n: usize,
526 ) -> Option<Vec<&'a Arc<Endpoint>>> {
527 if n == 0 {
528 return Some(Vec::new());
529 }
530 if endpoints.is_empty() {
531 return None;
532 }
533
534 let mut healthy_count = 0usize;
535 for endpoint in endpoints {
536 if endpoint.state().can_receive() {
537 healthy_count += 1;
538 }
539 }
540 if healthy_count == 0 {
541 return Some(Vec::new());
542 }
543 if healthy_count <= n {
544 let mut all_healthy = Vec::with_capacity(healthy_count);
545 for endpoint in endpoints {
546 if endpoint.state().can_receive() {
547 all_healthy.push(endpoint);
548 }
549 }
550 return Some(all_healthy);
551 }
552
553 let mut seed = self.random_seed.fetch_add(n as u64, Ordering::Relaxed);
554 let mut selected_indices = SmallVec::<[usize; Self::RANDOM_FLOYD_SMALL_N_MAX]>::new();
555 selected_indices.reserve_exact(n);
556
557 for j in (healthy_count - n)..healthy_count {
558 seed = Self::next_lcg(seed);
559 let candidate = (seed as usize) % (j + 1);
560 if selected_indices.contains(&candidate) {
561 selected_indices.push(j);
562 } else {
563 selected_indices.push(candidate);
564 }
565 }
566
567 for i in 0..n {
568 seed = Self::next_lcg(seed);
569 let swap = i + ((seed as usize) % (n - i));
570 selected_indices.swap(i, swap);
571 }
572
573 let mut selected =
574 SmallVec::<[Option<&Arc<Endpoint>>; Self::RANDOM_FLOYD_SMALL_N_MAX]>::new();
575 selected.resize(n, None);
576
577 let mut rank_to_output_pos =
578 SmallVec::<[(usize, usize); Self::RANDOM_FLOYD_SMALL_N_MAX]>::with_capacity(n);
579 for (output_pos, &rank) in selected_indices.iter().enumerate() {
580 rank_to_output_pos.push((rank, output_pos));
581 }
582 rank_to_output_pos.sort_unstable_by_key(|&(rank, _)| rank);
583
584 let mut healthy_rank = 0usize;
585 let mut rank_cursor = 0usize;
586 let mut next_target_rank = rank_to_output_pos.first().map(|&(rank, _)| rank);
587
588 for endpoint in endpoints {
589 if !endpoint.state().can_receive() {
590 continue;
591 }
592 while next_target_rank == Some(healthy_rank) {
593 let output_pos = rank_to_output_pos[rank_cursor].1;
594 selected[output_pos] = Some(endpoint);
595 rank_cursor += 1;
596 if rank_cursor == n {
597 break;
598 }
599 next_target_rank = Some(rank_to_output_pos[rank_cursor].0);
600 }
601 if rank_cursor == n {
602 break;
603 }
604 healthy_rank += 1;
605 }
606
607 if rank_cursor != n {
608 return None;
609 }
610
611 Some(selected.into_iter().flatten().collect())
612 }
613}
614
615#[derive(Debug, Clone)]
621pub struct RoutingEntry {
622 pub endpoints: Vec<Arc<Endpoint>>,
624
625 pub load_balancer: Arc<LoadBalancer>,
627
628 pub priority: u32,
630
631 pub ttl: Option<Time>,
633
634 pub created_at: Time,
636}
637
638impl RoutingEntry {
639 #[must_use]
641 pub fn new(endpoints: Vec<Arc<Endpoint>>, created_at: Time) -> Self {
642 Self {
643 endpoints,
644 load_balancer: Arc::new(LoadBalancer::new(LoadBalanceStrategy::RoundRobin)),
645 priority: 100,
646 ttl: None,
647 created_at,
648 }
649 }
650
651 #[must_use]
653 pub fn with_strategy(mut self, strategy: LoadBalanceStrategy) -> Self {
654 self.load_balancer = Arc::new(LoadBalancer::new(strategy));
655 self
656 }
657
658 #[must_use]
660 pub fn with_priority(mut self, priority: u32) -> Self {
661 self.priority = priority;
662 self
663 }
664
665 #[must_use]
667 pub fn with_ttl(mut self, ttl: Time) -> Self {
668 self.ttl = Some(ttl);
669 self
670 }
671
672 #[must_use]
674 pub fn is_expired(&self, now: Time) -> bool {
675 self.ttl.is_some_and(|ttl| {
676 let expiry = self.created_at.saturating_add_nanos(ttl.as_nanos());
677 now > expiry
678 })
679 }
680
681 #[must_use]
683 pub fn select_endpoint(&self, object_id: Option<ObjectId>) -> Option<Arc<Endpoint>> {
684 self.load_balancer
685 .select(&self.endpoints, object_id)
686 .cloned()
687 }
688
689 #[must_use]
691 pub fn select_endpoints(&self, n: usize, object_id: Option<ObjectId>) -> Vec<Arc<Endpoint>> {
692 self.load_balancer
693 .select_n(&self.endpoints, n, object_id)
694 .into_iter()
695 .cloned()
696 .collect()
697 }
698}
699
700#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
702pub enum RouteKey {
703 Object(ObjectId),
705
706 Region(RegionId),
708
709 ObjectAndRegion(ObjectId, RegionId),
711
712 Default,
714}
715
716impl RouteKey {
717 #[must_use]
719 pub fn object(oid: ObjectId) -> Self {
720 Self::Object(oid)
721 }
722
723 #[must_use]
725 pub fn region(rid: RegionId) -> Self {
726 Self::Region(rid)
727 }
728}
729
730#[derive(Debug)]
732pub struct RoutingTable {
733 routes: RwLock<HashMap<RouteKey, RoutingEntry>>,
735
736 default_route: RwLock<Option<RoutingEntry>>,
738
739 endpoints: RwLock<HashMap<EndpointId, Arc<Endpoint>>>,
741}
742
743impl RoutingTable {
744 #[must_use]
746 pub fn new() -> Self {
747 Self {
748 routes: RwLock::new(HashMap::new()),
749 default_route: RwLock::new(None),
750 endpoints: RwLock::new(HashMap::new()),
751 }
752 }
753
754 pub fn register_endpoint(&self, endpoint: Endpoint) -> Arc<Endpoint> {
756 let id = endpoint.id;
757 let arc = Arc::new(endpoint);
758 self.endpoints.write().insert(id, arc.clone());
759 arc
760 }
761
762 #[must_use]
764 pub fn get_endpoint(&self, id: EndpointId) -> Option<Arc<Endpoint>> {
765 self.endpoints.read().get(&id).cloned()
766 }
767
768 pub fn update_endpoint_state(&self, id: EndpointId, state: EndpointState) -> bool {
770 self.endpoints.read().get(&id).is_some_and(|endpoint| {
771 endpoint.set_state(state);
772 true
773 })
774 }
775
776 pub fn add_route(&self, key: RouteKey, entry: RoutingEntry) {
778 if key == RouteKey::Default {
779 *self.default_route.write() = Some(entry);
780 } else {
781 self.routes.write().insert(key, entry);
782 }
783 }
784
785 pub fn remove_route(&self, key: &RouteKey) -> bool {
787 if *key == RouteKey::Default {
788 let mut default = self.default_route.write();
789 let had_route = default.is_some();
790 *default = None;
791 had_route
792 } else {
793 self.routes.write().remove(key).is_some()
794 }
795 }
796
797 #[must_use]
799 pub fn lookup(&self, key: &RouteKey) -> Option<RoutingEntry> {
800 if let Some(entry) = self.routes.read().get(key) {
802 return Some(entry.clone());
803 }
804
805 if let RouteKey::ObjectAndRegion(oid, rid) = key {
807 if let Some(entry) = self.routes.read().get(&RouteKey::Object(*oid)) {
809 return Some(entry.clone());
810 }
811 if let Some(entry) = self.routes.read().get(&RouteKey::Region(*rid)) {
813 return Some(entry.clone());
814 }
815 }
816
817 self.default_route.read().clone()
819 }
820
821 #[must_use]
826 pub fn lookup_without_default(&self, key: &RouteKey) -> Option<RoutingEntry> {
827 if let Some(entry) = self.routes.read().get(key) {
828 return Some(entry.clone());
829 }
830
831 if let RouteKey::ObjectAndRegion(oid, rid) = key {
832 if let Some(entry) = self.routes.read().get(&RouteKey::Object(*oid)) {
833 return Some(entry.clone());
834 }
835 if let Some(entry) = self.routes.read().get(&RouteKey::Region(*rid)) {
836 return Some(entry.clone());
837 }
838 }
839
840 None
841 }
842
843 pub fn prune_expired(&self, now: Time) -> usize {
845 let mut routes = self.routes.write();
846 let before = routes.len();
847 routes.retain(|_, entry| !entry.is_expired(now));
848 before - routes.len()
849 }
850
851 #[must_use]
853 pub fn healthy_endpoints(&self) -> Vec<Arc<Endpoint>> {
854 self.endpoints
855 .read()
856 .values()
857 .filter(|e| e.state() == EndpointState::Healthy)
858 .cloned()
859 .collect()
860 }
861
862 #[must_use]
864 pub fn route_count(&self) -> usize {
865 let routes = self.routes.read().len();
866 let default = usize::from(self.default_route.read().is_some());
867 routes + default
868 }
869}
870
871impl Default for RoutingTable {
872 fn default() -> Self {
873 Self::new()
874 }
875}
876
877#[derive(Debug, Clone)]
883pub struct RouteResult {
884 pub endpoint: Arc<Endpoint>,
886
887 pub matched_key: RouteKey,
889
890 pub is_fallback: bool,
892}
893
894#[derive(Debug)]
896pub struct SymbolRouter {
897 table: Arc<RoutingTable>,
899
900 allow_fallback: bool,
902
903 prefer_local: bool,
905
906 local_region: Option<RegionId>,
908}
909
910impl SymbolRouter {
911 pub fn new(table: Arc<RoutingTable>) -> Self {
913 Self {
914 table,
915 allow_fallback: true,
916 prefer_local: false,
917 local_region: None,
918 }
919 }
920
921 #[must_use]
923 pub fn without_fallback(mut self) -> Self {
924 self.allow_fallback = false;
925 self
926 }
927
928 #[must_use]
930 pub fn with_local_preference(mut self, region: RegionId) -> Self {
931 self.prefer_local = true;
932 self.local_region = Some(region);
933 self
934 }
935
936 pub fn route(&self, symbol: &Symbol) -> Result<RouteResult, RoutingError> {
938 let object_id = symbol.object_id();
939 let primary_key = RouteKey::Object(object_id);
940
941 if let Some(entry) = self.table.lookup_without_default(&primary_key) {
942 if let Some(endpoint) = entry.select_endpoint(Some(object_id)) {
943 if self.prefer_local {
945 if let Some(local) = self.local_region {
946 if endpoint.region == Some(local) {
947 }
949 }
950 }
951
952 return Ok(RouteResult {
953 endpoint,
954 matched_key: primary_key,
955 is_fallback: false,
956 });
957 }
958 }
959
960 if self.allow_fallback {
961 let fallback_key = RouteKey::Default;
962 if let Some(entry) = self.table.lookup(&fallback_key) {
963 if let Some(endpoint) = entry.select_endpoint(Some(object_id)) {
964 return Ok(RouteResult {
965 endpoint,
966 matched_key: fallback_key,
967 is_fallback: true,
968 });
969 }
970 }
971 }
972
973 Err(RoutingError::NoRoute {
974 object_id,
975 reason: "No matching route and no default route configured".into(),
976 })
977 }
978
979 pub fn route_multicast(
981 &self,
982 symbol: &Symbol,
983 count: usize,
984 ) -> Result<Vec<RouteResult>, RoutingError> {
985 let object_id = symbol.object_id();
986
987 let key = RouteKey::Object(object_id);
988 let (entry, matched_key, is_fallback) =
989 if let Some(entry) = self.table.lookup_without_default(&key) {
990 (entry, key, false)
991 } else if self.allow_fallback {
992 let fallback_key = RouteKey::Default;
993 let fallback =
994 self.table
995 .lookup(&fallback_key)
996 .ok_or_else(|| RoutingError::NoRoute {
997 object_id,
998 reason: "No route for multicast".into(),
999 })?;
1000 (fallback, fallback_key, true)
1001 } else {
1002 return Err(RoutingError::NoRoute {
1003 object_id,
1004 reason: "No route for multicast".into(),
1005 });
1006 };
1007
1008 let endpoints = entry.select_endpoints(count, Some(object_id));
1010
1011 if endpoints.is_empty() {
1012 return Err(RoutingError::NoHealthyEndpoints { object_id });
1013 }
1014
1015 let results: Vec<_> = endpoints
1016 .into_iter()
1017 .map(|endpoint| RouteResult {
1018 endpoint,
1019 matched_key: matched_key.clone(),
1020 is_fallback,
1021 })
1022 .collect();
1023
1024 Ok(results)
1025 }
1026
1027 #[must_use]
1029 pub fn table(&self) -> &Arc<RoutingTable> {
1030 &self.table
1031 }
1032}
1033
1034#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
1040pub enum DispatchStrategy {
1041 #[default]
1043 Unicast,
1044
1045 Multicast {
1047 count: usize,
1049 },
1050
1051 Broadcast,
1053
1054 QuorumCast {
1056 required: usize,
1058 },
1059}
1060
1061#[derive(Debug)]
1063pub struct DispatchResult {
1064 pub successes: usize,
1066
1067 pub failures: usize,
1069
1070 pub sent_to: SmallVec<[EndpointId; 4]>,
1072
1073 pub failed_endpoints: SmallVec<[(EndpointId, DispatchError); 4]>,
1075
1076 pub duration: Time,
1078}
1079
1080impl DispatchResult {
1081 #[must_use]
1083 pub fn all_succeeded(&self) -> bool {
1084 self.failures == 0 && self.successes > 0
1085 }
1086
1087 #[must_use]
1089 pub fn any_succeeded(&self) -> bool {
1090 self.successes > 0
1091 }
1092
1093 #[must_use]
1095 pub fn quorum_reached(&self, required: usize) -> bool {
1096 self.successes >= required
1097 }
1098}
1099
1100#[derive(Debug, Clone)]
1106pub struct DispatchConfig {
1107 pub default_strategy: DispatchStrategy,
1109
1110 pub timeout: Time,
1112
1113 pub max_retries: u32,
1115
1116 pub retry_delay: Time,
1118
1119 pub fail_fast: bool,
1121
1122 pub max_concurrent: u32,
1124}
1125
1126impl Default for DispatchConfig {
1127 fn default() -> Self {
1128 Self {
1129 default_strategy: DispatchStrategy::Unicast,
1130 timeout: Time::from_secs(5),
1131 max_retries: 3,
1132 retry_delay: Time::from_millis(100),
1133 fail_fast: false,
1134 max_concurrent: 100,
1135 }
1136 }
1137}
1138
1139pub struct SymbolDispatcher {
1141 router: Arc<SymbolRouter>,
1143
1144 config: DispatchConfig,
1146
1147 active_dispatches: AtomicU32,
1149
1150 total_dispatched: AtomicU64,
1152
1153 total_failures: AtomicU64,
1155
1156 sinks: RwLock<EndpointSinkMap>,
1158}
1159
1160impl std::fmt::Debug for SymbolDispatcher {
1161 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1162 f.debug_struct("SymbolDispatcher")
1163 .field("router", &self.router)
1164 .field("config", &self.config)
1165 .field("active_dispatches", &self.active_dispatches)
1166 .field("total_dispatched", &self.total_dispatched)
1167 .field("total_failures", &self.total_failures)
1168 .field(
1169 "sinks",
1170 &format_args!("<{} sinks>", self.sinks.read().len()),
1171 )
1172 .finish()
1173 }
1174}
1175
1176struct DispatchGuard<'a> {
1178 dispatcher: &'a SymbolDispatcher,
1179}
1180
1181impl Drop for DispatchGuard<'_> {
1182 fn drop(&mut self) {
1183 self.dispatcher
1184 .active_dispatches
1185 .fetch_sub(1, Ordering::Release);
1186 }
1187}
1188
1189impl SymbolDispatcher {
1190 #[must_use]
1192 pub fn new(router: Arc<SymbolRouter>, config: DispatchConfig) -> Self {
1193 Self {
1194 router,
1195 config,
1196 active_dispatches: AtomicU32::new(0),
1197 total_dispatched: AtomicU64::new(0),
1198 total_failures: AtomicU64::new(0),
1199 sinks: RwLock::new(HashMap::new()),
1200 }
1201 }
1202
1203 pub fn add_sink(&self, endpoint: EndpointId, sink: Box<dyn SymbolSink>) {
1205 self.sinks
1206 .write()
1207 .insert(endpoint, Arc::new(Mutex::new(sink)));
1208 }
1209
1210 pub async fn dispatch(
1212 &self,
1213 cx: &Cx,
1214 symbol: AuthenticatedSymbol,
1215 ) -> Result<DispatchResult, DispatchError> {
1216 self.dispatch_with_strategy(cx, symbol, self.config.default_strategy)
1217 .await
1218 }
1219
1220 pub async fn dispatch_with_strategy(
1222 &self,
1223 cx: &Cx,
1224 symbol: AuthenticatedSymbol,
1225 strategy: DispatchStrategy,
1226 ) -> Result<DispatchResult, DispatchError> {
1227 let active = self.active_dispatches.fetch_add(1, Ordering::AcqRel);
1229 if active >= self.config.max_concurrent {
1230 self.active_dispatches.fetch_sub(1, Ordering::Release);
1231 return Err(DispatchError::Overloaded);
1232 }
1233
1234 let _guard = DispatchGuard { dispatcher: self };
1236
1237 let result = match strategy {
1238 DispatchStrategy::Unicast => self.dispatch_unicast(cx, symbol).await,
1239 DispatchStrategy::Multicast { count } => {
1240 self.dispatch_multicast(cx, &symbol, count).await
1241 }
1242 DispatchStrategy::Broadcast => self.dispatch_broadcast(cx, &symbol).await,
1243 DispatchStrategy::QuorumCast { required } => {
1244 self.dispatch_quorum(cx, &symbol, required).await
1245 }
1246 };
1247
1248 match &result {
1252 Ok(r) => {
1253 self.total_dispatched
1254 .fetch_add(r.successes as u64, Ordering::Relaxed);
1255 self.total_failures
1256 .fetch_add(r.failures as u64, Ordering::Relaxed);
1257 }
1258 Err(_) => {
1259 self.total_failures.fetch_add(1, Ordering::Relaxed);
1260 }
1261 }
1262
1263 result
1264 }
1265
1266 #[allow(clippy::unused_async)]
1268 async fn dispatch_unicast(
1269 &self,
1270 cx: &Cx,
1271 symbol: AuthenticatedSymbol,
1272 ) -> Result<DispatchResult, DispatchError> {
1273 let route = self.router.route(symbol.symbol())?;
1274
1275 let sink = {
1277 let sinks = self.sinks.read();
1278 sinks.get(&route.endpoint.id).cloned()
1279 };
1280
1281 let _guard = route.endpoint.acquire_connection_guard();
1282
1283 let result = if let Some(sink) = sink {
1284 let send_result = match OwnedMutexGuard::lock(sink, cx).await {
1285 Ok(mut guard) => guard
1286 .send(symbol)
1287 .await
1288 .map_err(|_| DispatchError::SendFailed {
1289 endpoint: route.endpoint.id,
1290 reason: "Send failed".into(),
1291 }),
1292 Err(_) => Err(DispatchError::Timeout),
1293 };
1294
1295 match send_result {
1296 Ok(()) => {
1297 route.endpoint.record_success(Time::ZERO);
1298 Ok(DispatchResult {
1299 successes: 1,
1300 failures: 0,
1301 sent_to: smallvec![route.endpoint.id],
1302 failed_endpoints: SmallVec::new(),
1303 duration: Time::ZERO,
1304 })
1305 }
1306 Err(err) => {
1307 route.endpoint.record_failure(Time::ZERO);
1308 Err(err)
1309 }
1310 }
1311 } else {
1312 route.endpoint.record_success(Time::ZERO);
1314 Ok(DispatchResult {
1315 successes: 1,
1316 failures: 0,
1317 sent_to: smallvec![route.endpoint.id],
1318 failed_endpoints: SmallVec::new(),
1319 duration: Time::ZERO,
1320 })
1321 };
1322
1323 result
1325 }
1326
1327 #[allow(clippy::unused_async)]
1329 async fn dispatch_multicast(
1330 &self,
1331 cx: &Cx,
1332 symbol: &AuthenticatedSymbol,
1333 count: usize,
1334 ) -> Result<DispatchResult, DispatchError> {
1335 if count == 0 {
1336 return Ok(DispatchResult {
1337 successes: 0,
1338 failures: 0,
1339 sent_to: SmallVec::new(),
1340 failed_endpoints: SmallVec::new(),
1341 duration: Time::ZERO,
1342 });
1343 }
1344
1345 let routes = match self.router.route_multicast(symbol.symbol(), count) {
1347 Ok(routes) => routes,
1348 Err(RoutingError::NoHealthyEndpoints { object_id }) => {
1349 return Err(DispatchError::RoutingFailed(
1350 RoutingError::NoHealthyEndpoints { object_id },
1351 ));
1352 }
1353 Err(e) => return Err(DispatchError::RoutingFailed(e)),
1354 };
1355
1356 let mut successes = 0;
1358 let mut failures = 0;
1359 let mut sent_to = SmallVec::<[EndpointId; 4]>::new();
1360 let mut failed = SmallVec::<[(EndpointId, DispatchError); 4]>::new();
1361
1362 for route in routes {
1363 let endpoint = route.endpoint;
1364 let _guard = endpoint.acquire_connection_guard();
1365
1366 let success = if let Some(sink) = {
1368 let sinks = self.sinks.read();
1369 sinks.get(&endpoint.id).cloned()
1370 } {
1371 match OwnedMutexGuard::lock(sink, cx).await {
1372 Ok(mut guard) => {
1373 let guard: &mut Box<dyn SymbolSink> = &mut guard;
1374 guard.send(symbol.clone()).await.is_ok()
1375 }
1376 Err(_) => false,
1377 }
1378 } else {
1379 true
1381 };
1382
1383 if success {
1386 endpoint.record_success(Time::ZERO);
1387 successes += 1;
1388 sent_to.push(endpoint.id);
1389 } else {
1390 endpoint.record_failure(Time::ZERO);
1391 failures += 1;
1392 failed.push((
1393 endpoint.id,
1394 DispatchError::SendFailed {
1395 endpoint: endpoint.id,
1396 reason: "Send failed".into(),
1397 },
1398 ));
1399 }
1400 }
1401
1402 Ok(DispatchResult {
1403 successes,
1404 failures,
1405 sent_to,
1406 failed_endpoints: failed,
1407 duration: Time::ZERO,
1408 })
1409 }
1410
1411 #[allow(clippy::unused_async)]
1413 async fn dispatch_broadcast(
1414 &self,
1415 cx: &Cx,
1416 symbol: &AuthenticatedSymbol,
1417 ) -> Result<DispatchResult, DispatchError> {
1418 let endpoints = self.router.table().healthy_endpoints();
1419
1420 if endpoints.is_empty() {
1421 return Err(DispatchError::NoEndpoints);
1422 }
1423
1424 let mut successes = 0;
1425 let mut failures = 0;
1426 let mut sent_to = SmallVec::<[EndpointId; 4]>::new();
1427 let mut failed = SmallVec::<[(EndpointId, DispatchError); 4]>::new();
1428
1429 for route in endpoints {
1430 let _guard = route.acquire_connection_guard();
1431
1432 let success = if let Some(sink) = {
1434 let sinks = self.sinks.read();
1435 sinks.get(&route.id).cloned()
1436 } {
1437 match OwnedMutexGuard::lock(sink, cx).await {
1438 Ok(mut guard) => {
1439 let guard: &mut Box<dyn SymbolSink> = &mut guard;
1440 guard.send(symbol.clone()).await.is_ok()
1441 }
1442 Err(_) => false,
1443 }
1444 } else {
1445 true
1447 };
1448
1449 if success {
1450 route.record_success(Time::ZERO);
1451 successes += 1;
1452 sent_to.push(route.id);
1453 } else {
1454 route.record_failure(Time::ZERO);
1455 failures += 1;
1456 failed.push((
1457 route.id,
1458 DispatchError::SendFailed {
1459 endpoint: route.id,
1460 reason: "Send failed".into(),
1461 },
1462 ));
1463 }
1464 }
1465
1466 Ok(DispatchResult {
1467 successes,
1468 failures,
1469 sent_to,
1470 failed_endpoints: failed,
1471 duration: Time::ZERO,
1472 })
1473 }
1474
1475 #[allow(clippy::unused_async)]
1477 async fn dispatch_quorum(
1478 &self,
1479 cx: &Cx,
1480 symbol: &AuthenticatedSymbol,
1481 required: usize,
1482 ) -> Result<DispatchResult, DispatchError> {
1483 let endpoints = self.router.table().healthy_endpoints();
1484
1485 if endpoints.len() < required {
1486 return Err(DispatchError::InsufficientEndpoints {
1487 available: endpoints.len(),
1488 required,
1489 });
1490 }
1491
1492 let mut successes = 0;
1493 let mut failures = 0;
1494 let mut sent_to = SmallVec::<[EndpointId; 4]>::new();
1495 let mut failed = SmallVec::<[(EndpointId, DispatchError); 4]>::new();
1496
1497 for route in endpoints {
1498 if successes >= required {
1499 break;
1500 }
1501
1502 let _guard = route.acquire_connection_guard();
1503
1504 let success = if let Some(sink) = {
1505 let sinks = self.sinks.read();
1506 sinks.get(&route.id).cloned()
1507 } {
1508 match OwnedMutexGuard::lock(sink, cx).await {
1509 Ok(mut guard) => {
1510 let guard: &mut Box<dyn SymbolSink> = &mut guard;
1511 guard.send(symbol.clone()).await.is_ok()
1512 }
1513 Err(_) => false,
1514 }
1515 } else {
1516 true
1517 };
1518
1519 if success {
1520 route.record_success(Time::ZERO);
1521 successes += 1;
1522 sent_to.push(route.id);
1523 } else {
1524 route.record_failure(Time::ZERO);
1525 failures += 1;
1526 failed.push((
1527 route.id,
1528 DispatchError::SendFailed {
1529 endpoint: route.id,
1530 reason: "Send failed".into(),
1531 },
1532 ));
1533 }
1534 }
1535
1536 if successes < required {
1537 return Err(DispatchError::QuorumNotReached {
1538 achieved: successes,
1539 required,
1540 });
1541 }
1542
1543 Ok(DispatchResult {
1544 successes,
1545 failures,
1546 sent_to,
1547 failed_endpoints: failed,
1548 duration: Time::ZERO,
1549 })
1550 }
1551
1552 #[must_use]
1554 pub fn stats(&self) -> DispatcherStats {
1555 DispatcherStats {
1556 active_dispatches: self.active_dispatches.load(Ordering::Relaxed),
1557 total_dispatched: self.total_dispatched.load(Ordering::Relaxed),
1558 total_failures: self.total_failures.load(Ordering::Relaxed),
1559 }
1560 }
1561}
1562
1563#[derive(Debug, Clone)]
1565pub struct DispatcherStats {
1566 pub active_dispatches: u32,
1568
1569 pub total_dispatched: u64,
1571
1572 pub total_failures: u64,
1574}
1575
1576#[derive(Debug, Clone)]
1582pub enum RoutingError {
1583 NoRoute {
1585 object_id: ObjectId,
1587 reason: String,
1589 },
1590
1591 NoHealthyEndpoints {
1593 object_id: ObjectId,
1595 },
1596
1597 EmptyTable,
1599}
1600
1601impl std::fmt::Display for RoutingError {
1602 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1603 match self {
1604 Self::NoRoute { object_id, reason } => {
1605 write!(f, "no route for object {object_id:?}: {reason}")
1606 }
1607 Self::NoHealthyEndpoints { object_id } => {
1608 write!(f, "no healthy endpoints for object {object_id:?}")
1609 }
1610 Self::EmptyTable => write!(f, "routing table is empty"),
1611 }
1612 }
1613}
1614
1615impl std::error::Error for RoutingError {}
1616
1617impl From<RoutingError> for Error {
1618 fn from(e: RoutingError) -> Self {
1619 Self::new(ErrorKind::RoutingFailed).with_message(e.to_string())
1620 }
1621}
1622#[derive(Debug, Clone)]
1624pub enum DispatchError {
1625 RoutingFailed(RoutingError),
1627
1628 SendFailed {
1630 endpoint: EndpointId,
1632 reason: String,
1634 },
1635
1636 Overloaded,
1638
1639 NoEndpoints,
1641
1642 InsufficientEndpoints {
1644 available: usize,
1646 required: usize,
1648 },
1649
1650 QuorumNotReached {
1652 achieved: usize,
1654 required: usize,
1656 },
1657
1658 Timeout,
1660}
1661
1662impl std::fmt::Display for DispatchError {
1663 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
1664 match self {
1665 Self::RoutingFailed(e) => write!(f, "routing failed: {e}"),
1666 Self::SendFailed { endpoint, reason } => {
1667 write!(f, "send to {endpoint} failed: {reason}")
1668 }
1669 Self::Overloaded => write!(f, "dispatcher overloaded"),
1670 Self::NoEndpoints => write!(f, "no endpoints available"),
1671 Self::InsufficientEndpoints {
1672 available,
1673 required,
1674 } => {
1675 write!(
1676 f,
1677 "insufficient endpoints: {available} available, {required} required"
1678 )
1679 }
1680 Self::QuorumNotReached { achieved, required } => {
1681 write!(f, "quorum not reached: {achieved} of {required} required")
1682 }
1683 Self::Timeout => write!(f, "dispatch timeout"),
1684 }
1685 }
1686}
1687
1688impl std::error::Error for DispatchError {}
1689
1690impl From<RoutingError> for DispatchError {
1691 fn from(e: RoutingError) -> Self {
1692 Self::RoutingFailed(e)
1693 }
1694}
1695
1696impl From<DispatchError> for Error {
1697 fn from(e: DispatchError) -> Self {
1698 match e {
1699 DispatchError::RoutingFailed(_) => {
1700 Self::new(ErrorKind::RoutingFailed).with_message(e.to_string())
1701 }
1702 DispatchError::QuorumNotReached { .. } => {
1703 Self::new(ErrorKind::QuorumNotReached).with_message(e.to_string())
1704 }
1705 _ => Self::new(ErrorKind::DispatchFailed).with_message(e.to_string()),
1706 }
1707 }
1708}
1709#[cfg(test)]
1710mod tests {
1711 use super::*;
1712 use std::collections::HashSet;
1713
1714 fn test_endpoint(id: u64) -> Endpoint {
1715 Endpoint::new(EndpointId(id), format!("node-{id}:8080"))
1716 }
1717
1718 #[test]
1720 fn test_endpoint_state() {
1721 assert!(EndpointState::Healthy.can_receive());
1722 assert!(EndpointState::Degraded.can_receive());
1723 assert!(!EndpointState::Unhealthy.can_receive());
1724 assert!(!EndpointState::Draining.can_receive());
1725 assert!(!EndpointState::Removed.can_receive());
1726
1727 assert!(EndpointState::Healthy.is_available());
1728 assert!(!EndpointState::Removed.is_available());
1729 }
1730
1731 #[test]
1733 fn test_endpoint_statistics() {
1734 let endpoint = test_endpoint(1);
1735
1736 endpoint.record_success(Time::from_secs(1));
1737 endpoint.record_success(Time::from_secs(2));
1738 endpoint.record_failure(Time::from_secs(3));
1739
1740 assert_eq!(endpoint.symbols_sent.load(Ordering::Relaxed), 2);
1741 assert_eq!(endpoint.failures.load(Ordering::Relaxed), 1);
1742
1743 let rate = endpoint.failure_rate();
1745 assert!(rate > 0.3 && rate < 0.34);
1746 }
1747
1748 #[test]
1750 fn test_load_balancer_round_robin() {
1751 let lb = LoadBalancer::new(LoadBalanceStrategy::RoundRobin);
1752
1753 let endpoints: Vec<Arc<Endpoint>> = (1..=3).map(|i| Arc::new(test_endpoint(i))).collect();
1754
1755 let e1 = lb.select(&endpoints, None);
1756 let e2 = lb.select(&endpoints, None);
1757 let e3 = lb.select(&endpoints, None);
1758 let e4 = lb.select(&endpoints, None); assert_eq!(e1.unwrap().id, EndpointId(1));
1761 assert_eq!(e2.unwrap().id, EndpointId(2));
1762 assert_eq!(e3.unwrap().id, EndpointId(3));
1763 assert_eq!(e4.unwrap().id, EndpointId(1));
1764 }
1765
1766 #[test]
1768 fn test_load_balancer_least_connections() {
1769 let lb = LoadBalancer::new(LoadBalanceStrategy::LeastConnections);
1770
1771 let e1 = Arc::new(test_endpoint(1));
1772 let e2 = Arc::new(test_endpoint(2));
1773 let e3 = Arc::new(test_endpoint(3));
1774
1775 e1.active_connections.store(5, Ordering::Relaxed);
1776 e2.active_connections.store(2, Ordering::Relaxed);
1777 e3.active_connections.store(10, Ordering::Relaxed);
1778
1779 let endpoints = vec![e1, e2.clone(), e3];
1780
1781 let selected = lb.select(&endpoints, None).unwrap();
1782 assert_eq!(selected.id, e2.id); }
1784
1785 #[test]
1786 fn test_load_balancer_weighted_least_connections() {
1787 let lb = LoadBalancer::new(LoadBalanceStrategy::WeightedLeastConnections);
1788
1789 let e1 = Arc::new(test_endpoint(1).with_weight(1));
1790 let e2 = Arc::new(test_endpoint(2).with_weight(4));
1791 let e3 = Arc::new(test_endpoint(3).with_weight(2));
1792
1793 e1.active_connections.store(2, Ordering::Relaxed); e2.active_connections.store(4, Ordering::Relaxed); e3.active_connections.store(3, Ordering::Relaxed); let endpoints = vec![e1, e2.clone(), e3];
1798 let selected = lb.select(&endpoints, None).unwrap();
1799 assert_eq!(selected.id, e2.id);
1800 }
1801
1802 #[test]
1804 fn test_load_balancer_hash_based() {
1805 let lb = LoadBalancer::new(LoadBalanceStrategy::HashBased);
1806
1807 let endpoints: Vec<Arc<Endpoint>> = (1..=3).map(|i| Arc::new(test_endpoint(i))).collect();
1808
1809 let oid = ObjectId::new_for_test(42);
1810
1811 let s1 = lb.select(&endpoints, Some(oid));
1813 let s2 = lb.select(&endpoints, Some(oid));
1814 assert_eq!(s1.unwrap().id, s2.unwrap().id);
1815 }
1816
1817 #[test]
1818 fn test_load_balancer_random_select_n_returns_unique_healthy() {
1819 let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
1820 let endpoints: Vec<Arc<Endpoint>> = (0..10)
1821 .map(|i| {
1822 let endpoint = test_endpoint(i);
1823 if i % 3 == 0 {
1824 Arc::new(endpoint.with_state(EndpointState::Unhealthy))
1825 } else {
1826 Arc::new(endpoint)
1827 }
1828 })
1829 .collect();
1830
1831 let selected = lb.select_n(&endpoints, 3, None);
1832 assert_eq!(selected.len(), 3);
1833 assert!(
1834 selected
1835 .iter()
1836 .all(|endpoint| endpoint.state().can_receive())
1837 );
1838
1839 let unique_ids: HashSet<_> = selected.iter().map(|endpoint| endpoint.id).collect();
1840 assert_eq!(unique_ids.len(), selected.len());
1841 }
1842
1843 #[test]
1844 fn test_load_balancer_random_select_n_returns_all_healthy_when_n_large() {
1845 let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
1846 let endpoints = vec![
1847 Arc::new(test_endpoint(1).with_state(EndpointState::Healthy)),
1848 Arc::new(test_endpoint(2).with_state(EndpointState::Unhealthy)),
1849 Arc::new(test_endpoint(3).with_state(EndpointState::Degraded)),
1850 Arc::new(test_endpoint(4).with_state(EndpointState::Draining)),
1851 Arc::new(test_endpoint(5).with_state(EndpointState::Healthy)),
1852 ];
1853
1854 let selected = lb.select_n(&endpoints, 16, None);
1855 let selected_ids: Vec<_> = selected.iter().map(|endpoint| endpoint.id).collect();
1856 assert_eq!(
1857 selected_ids,
1858 vec![EndpointId::new(1), EndpointId::new(3), EndpointId::new(5)]
1859 );
1860 }
1861
1862 #[test]
1863 fn test_load_balancer_random_select_n_single_matches_select_sequence() {
1864 let lb_select = LoadBalancer::new(LoadBalanceStrategy::Random);
1865 let lb_select_n = LoadBalancer::new(LoadBalanceStrategy::Random);
1866 let endpoints: Vec<Arc<Endpoint>> = (0..8)
1867 .map(|i| {
1868 let endpoint = test_endpoint(i);
1869 if i % 4 == 0 {
1870 Arc::new(endpoint.with_state(EndpointState::Unhealthy))
1871 } else {
1872 Arc::new(endpoint)
1873 }
1874 })
1875 .collect();
1876
1877 for _ in 0..64 {
1878 let selected = lb_select
1879 .select(&endpoints, None)
1880 .map(|endpoint| endpoint.id);
1881 let selected_n = lb_select_n
1882 .select_n(&endpoints, 1, None)
1883 .first()
1884 .map(|endpoint| endpoint.id);
1885 assert_eq!(selected, selected_n);
1886 }
1887 }
1888
1889 #[test]
1890 fn test_load_balancer_random_select_single_is_uniform_over_healthy() {
1891 let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
1892 let endpoints = vec![
1893 Arc::new(test_endpoint(0).with_state(EndpointState::Healthy)),
1894 Arc::new(test_endpoint(100).with_state(EndpointState::Unhealthy)),
1895 Arc::new(test_endpoint(1).with_state(EndpointState::Healthy)),
1896 Arc::new(test_endpoint(101).with_state(EndpointState::Draining)),
1897 Arc::new(test_endpoint(2).with_state(EndpointState::Healthy)),
1898 ];
1899
1900 let mut counts = [0usize; 3];
1901 for _ in 0..3000 {
1902 let selected = lb.select_n(&endpoints, 1, None);
1903 assert_eq!(selected.len(), 1);
1904 let id = selected[0].id;
1905 if id == EndpointId::new(0) {
1906 counts[0] += 1;
1907 } else if id == EndpointId::new(1) {
1908 counts[1] += 1;
1909 } else if id == EndpointId::new(2) {
1910 counts[2] += 1;
1911 } else {
1912 panic!("selected unhealthy endpoint: {id:?}");
1913 }
1914 }
1915
1916 assert_eq!(counts.iter().sum::<usize>(), 3000);
1917 for count in counts {
1919 assert!((900..=1100).contains(&count), "non-uniform count: {count}");
1920 }
1921 }
1922
1923 #[test]
1924 fn test_load_balancer_random_select_n_small_all_healthy_is_unique() {
1925 let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
1926 let endpoints: Vec<Arc<Endpoint>> = (0..16).map(|i| Arc::new(test_endpoint(i))).collect();
1927
1928 for _ in 0..64 {
1929 let selected = lb.select_n(&endpoints, 3, None);
1930 assert_eq!(selected.len(), 3);
1931 assert!(
1932 selected
1933 .iter()
1934 .all(|endpoint| endpoint.state().can_receive())
1935 );
1936 let unique_ids: HashSet<_> = selected.iter().map(|endpoint| endpoint.id).collect();
1937 assert_eq!(unique_ids.len(), selected.len());
1938 }
1939 }
1940
1941 #[test]
1942 fn test_load_balancer_weighted_least_connections_select_n_uses_weights() {
1943 let lb = LoadBalancer::new(LoadBalanceStrategy::WeightedLeastConnections);
1944
1945 let e1 = Arc::new(test_endpoint(1).with_weight(1));
1946 let e2 = Arc::new(test_endpoint(2).with_weight(4));
1947 let e3 = Arc::new(test_endpoint(3).with_weight(2));
1948 let e4 = Arc::new(test_endpoint(4).with_weight(2));
1949
1950 e1.active_connections.store(4, Ordering::Relaxed); e2.active_connections.store(4, Ordering::Relaxed); e3.active_connections.store(4, Ordering::Relaxed); e4.active_connections.store(1, Ordering::Relaxed); let endpoints = vec![e1, e2.clone(), e3, e4.clone()];
1956 let selected = lb.select_n(&endpoints, 2, None);
1957 let selected_ids: Vec<_> = selected.iter().map(|endpoint| endpoint.id).collect();
1958 assert_eq!(selected_ids, vec![e4.id, e2.id]);
1959 }
1960
1961 #[test]
1963 fn test_routing_table_basic() {
1964 let table = RoutingTable::new();
1965
1966 let _e1 = table.register_endpoint(test_endpoint(1));
1967 let e2 = table.register_endpoint(test_endpoint(2));
1968
1969 assert!(table.get_endpoint(EndpointId(1)).is_some());
1970 assert!(table.get_endpoint(EndpointId(999)).is_none());
1971
1972 let entry = RoutingEntry::new(vec![e2], Time::ZERO);
1973 table.add_route(RouteKey::Default, entry);
1974
1975 assert_eq!(table.route_count(), 1);
1976 }
1977
1978 #[test]
1980 fn test_routing_table_lookup() {
1981 let table = RoutingTable::new();
1982
1983 let e1 = table.register_endpoint(test_endpoint(1));
1984 let e2 = table.register_endpoint(test_endpoint(2));
1985
1986 let default = RoutingEntry::new(vec![e1], Time::ZERO);
1988 table.add_route(RouteKey::Default, default);
1989
1990 let oid = ObjectId::new_for_test(42);
1992 let specific = RoutingEntry::new(vec![e2], Time::ZERO);
1993 table.add_route(RouteKey::Object(oid), specific);
1994
1995 let found = table.lookup(&RouteKey::Object(oid));
1997 assert!(found.is_some());
1998
1999 let other_oid = ObjectId::new_for_test(999);
2001 let found = table.lookup(&RouteKey::Object(other_oid));
2002 assert!(found.is_some()); }
2004
2005 #[test]
2007 fn test_routing_entry_ttl() {
2008 let entry = RoutingEntry::new(vec![], Time::from_secs(100)).with_ttl(Time::from_secs(60));
2009
2010 assert!(!entry.is_expired(Time::from_secs(150)));
2011 assert!(entry.is_expired(Time::from_secs(170)));
2012 }
2013
2014 #[test]
2016 fn test_routing_table_prune() {
2017 let table = RoutingTable::new();
2018
2019 let e1 = table.register_endpoint(test_endpoint(1));
2020
2021 let entry1 =
2023 RoutingEntry::new(vec![e1.clone()], Time::from_secs(0)).with_ttl(Time::from_secs(10));
2024 let entry2 = RoutingEntry::new(vec![e1], Time::from_secs(0)).with_ttl(Time::from_secs(100));
2025
2026 table.add_route(RouteKey::Object(ObjectId::new_for_test(1)), entry1);
2027 table.add_route(RouteKey::Object(ObjectId::new_for_test(2)), entry2);
2028
2029 assert_eq!(table.route_count(), 2);
2030
2031 let pruned = table.prune_expired(Time::from_secs(50));
2033 assert_eq!(pruned, 1);
2034 assert_eq!(table.route_count(), 1);
2035 }
2036
2037 #[test]
2039 fn test_symbol_router() {
2040 let table = Arc::new(RoutingTable::new());
2041 let e1 = table.register_endpoint(test_endpoint(1));
2042
2043 let entry = RoutingEntry::new(vec![e1], Time::ZERO);
2044 table.add_route(RouteKey::Default, entry);
2045
2046 let router = SymbolRouter::new(table);
2047
2048 let symbol = Symbol::new_for_test(1, 0, 0, &[1, 2, 3]);
2049 let result = router.route(&symbol);
2050
2051 assert!(result.is_ok());
2052 assert_eq!(result.unwrap().endpoint.id, EndpointId(1));
2053 }
2054
2055 #[test]
2057 fn test_symbol_router_without_fallback() {
2058 let table = Arc::new(RoutingTable::new());
2059 let e1 = table.register_endpoint(test_endpoint(1));
2060
2061 let entry = RoutingEntry::new(vec![e1], Time::ZERO);
2063 table.add_route(RouteKey::Default, entry);
2064
2065 let router = SymbolRouter::new(table).without_fallback();
2066
2067 let symbol = Symbol::new_for_test(1, 0, 0, &[1, 2, 3]);
2068 let result = router.route(&symbol);
2069
2070 assert!(
2071 result.is_err(),
2072 "without_fallback should reject default-only route"
2073 );
2074 }
2075
2076 #[test]
2078 fn test_symbol_router_failover() {
2079 let table = Arc::new(RoutingTable::new());
2080
2081 let primary =
2082 table.register_endpoint(test_endpoint(1).with_state(EndpointState::Unhealthy));
2083 let backup = table.register_endpoint(test_endpoint(2).with_state(EndpointState::Healthy));
2084
2085 let entry = RoutingEntry::new(vec![primary, backup.clone()], Time::ZERO)
2086 .with_strategy(LoadBalanceStrategy::FirstAvailable);
2087 table.add_route(RouteKey::Default, entry);
2088
2089 let router = SymbolRouter::new(table);
2090 let symbol = Symbol::new_for_test(1, 0, 0, &[1, 2, 3]);
2091 let result = router.route(&symbol).expect("route");
2092
2093 assert_eq!(result.endpoint.id, backup.id);
2094 }
2095
2096 #[test]
2098 fn test_symbol_router_multicast() {
2099 let table = Arc::new(RoutingTable::new());
2100 let e1 = table.register_endpoint(test_endpoint(1));
2101 let e2 = table.register_endpoint(test_endpoint(2));
2102 let e3 = table.register_endpoint(test_endpoint(3));
2103
2104 let entry = RoutingEntry::new(vec![e1, e2, e3], Time::ZERO);
2105 table.add_route(RouteKey::Default, entry);
2106
2107 let router = SymbolRouter::new(table);
2108
2109 let symbol = Symbol::new_for_test(1, 0, 0, &[1, 2, 3]);
2110 let results = router.route_multicast(&symbol, 2);
2111
2112 assert!(results.is_ok());
2113 assert_eq!(results.unwrap().len(), 2);
2114 }
2115
2116 #[test]
2118 fn test_dispatch_result_quorum() {
2119 let result = DispatchResult {
2120 successes: 3,
2121 failures: 1,
2122 sent_to: smallvec![EndpointId(1), EndpointId(2), EndpointId(3)],
2123 failed_endpoints: SmallVec::new(),
2124 duration: Time::ZERO,
2125 };
2126
2127 assert!(result.quorum_reached(2));
2128 assert!(result.quorum_reached(3));
2129 assert!(!result.quorum_reached(4));
2130 assert!(result.any_succeeded());
2131 assert!(!result.all_succeeded()); }
2133
2134 #[test]
2135 fn dispatch_result_unicast_stays_inline() {
2136 let result = DispatchResult {
2137 successes: 1,
2138 failures: 0,
2139 sent_to: smallvec![EndpointId(7)],
2140 failed_endpoints: SmallVec::new(),
2141 duration: Time::ZERO,
2142 };
2143
2144 assert!(!result.sent_to.spilled());
2145 assert!(!result.failed_endpoints.spilled());
2146 }
2147
2148 #[test]
2150 fn test_endpoint_connections() {
2151 let endpoint = test_endpoint(1);
2152
2153 assert_eq!(endpoint.connection_count(), 0);
2154
2155 endpoint.acquire_connection();
2156 endpoint.acquire_connection();
2157 assert_eq!(endpoint.connection_count(), 2);
2158
2159 endpoint.release_connection();
2160 assert_eq!(endpoint.connection_count(), 1);
2161 }
2162
2163 #[test]
2164 fn test_endpoint_release_connection_saturates() {
2165 let endpoint = test_endpoint(1);
2166 endpoint.release_connection();
2167 assert_eq!(endpoint.connection_count(), 0);
2168 }
2169
2170 #[test]
2171 fn test_routing_table_updates_endpoint_state() {
2172 let table = RoutingTable::new();
2173 let endpoint = table.register_endpoint(test_endpoint(9));
2174 assert_eq!(endpoint.state(), EndpointState::Healthy);
2175 assert!(table.update_endpoint_state(EndpointId(9), EndpointState::Draining));
2176 assert_eq!(endpoint.state(), EndpointState::Draining);
2177 assert!(!table.update_endpoint_state(EndpointId(999), EndpointState::Healthy));
2178 }
2179
2180 #[test]
2182 fn test_routing_error_display() {
2183 let oid = ObjectId::new_for_test(42);
2184
2185 let no_route = RoutingError::NoRoute {
2186 object_id: oid,
2187 reason: "test".into(),
2188 };
2189 assert!(no_route.to_string().contains("no route"));
2190
2191 let no_healthy = RoutingError::NoHealthyEndpoints { object_id: oid };
2192 assert!(no_healthy.to_string().contains("healthy"));
2193 }
2194
2195 #[test]
2197 fn test_dispatch_error_display() {
2198 let overloaded = DispatchError::Overloaded;
2199 assert!(overloaded.to_string().contains("overloaded"));
2200
2201 let quorum = DispatchError::QuorumNotReached {
2202 achieved: 2,
2203 required: 3,
2204 };
2205 assert!(quorum.to_string().contains("quorum"));
2206 assert!(quorum.to_string().contains('2'));
2207 assert!(quorum.to_string().contains('3'));
2208 }
2209
2210 #[test]
2213 fn endpoint_id_debug_display() {
2214 let id = EndpointId::new(42);
2215 assert!(format!("{id:?}").contains("42"));
2216 assert_eq!(id.to_string(), "Endpoint(42)");
2217 }
2218
2219 #[test]
2220 fn endpoint_id_clone_copy_eq() {
2221 let id = EndpointId::new(7);
2222 let id2 = id;
2223 assert_eq!(id, id2);
2224 }
2225
2226 #[test]
2227 fn endpoint_id_ord_hash() {
2228 let a = EndpointId::new(1);
2229 let b = EndpointId::new(2);
2230 assert!(a < b);
2231
2232 let mut set = HashSet::new();
2233 set.insert(a);
2234 set.insert(b);
2235 assert_eq!(set.len(), 2);
2236 }
2237
2238 #[test]
2239 fn endpoint_state_debug_clone_copy_eq() {
2240 let s = EndpointState::Healthy;
2241 let s2 = s;
2242 assert_eq!(s, s2);
2243 assert!(format!("{s:?}").contains("Healthy"));
2244 }
2245
2246 #[test]
2247 fn endpoint_state_as_u8_roundtrip() {
2248 let states = [
2249 EndpointState::Healthy,
2250 EndpointState::Degraded,
2251 EndpointState::Unhealthy,
2252 EndpointState::Draining,
2253 EndpointState::Removed,
2254 ];
2255 for &s in &states {
2256 assert_eq!(EndpointState::from_u8(s.as_u8()), s);
2257 }
2258 }
2259
2260 #[test]
2261 fn endpoint_state_from_u8_invalid() {
2262 let s = EndpointState::from_u8(255);
2263 assert_eq!(s, EndpointState::Removed);
2264 }
2265
2266 #[test]
2267 fn endpoint_debug() {
2268 let ep = Endpoint::new(EndpointId::new(1), "addr:80");
2269 let dbg = format!("{ep:?}");
2270 assert!(dbg.contains("Endpoint"));
2271 }
2272
2273 #[test]
2274 fn endpoint_with_weight_region() {
2275 let region = RegionId::new_for_test(1, 0);
2276 let ep = Endpoint::new(EndpointId::new(5), "host:80")
2277 .with_weight(200)
2278 .with_region(region);
2279 assert_eq!(ep.weight, 200);
2280 assert_eq!(ep.region, Some(region));
2281 }
2282
2283 #[test]
2284 fn endpoint_with_state_setter() {
2285 let ep = Endpoint::new(EndpointId::new(1), "h:80").with_state(EndpointState::Draining);
2286 assert_eq!(ep.state(), EndpointState::Draining);
2287 ep.set_state(EndpointState::Healthy);
2288 assert_eq!(ep.state(), EndpointState::Healthy);
2289 }
2290
2291 #[test]
2292 fn endpoint_failure_rate_zero() {
2293 let ep = Endpoint::new(EndpointId::new(1), "h:80");
2294 assert!((ep.failure_rate() - 0.0).abs() < f64::EPSILON);
2295 }
2296
2297 #[test]
2298 fn endpoint_connection_guard_drops() {
2299 let ep = Endpoint::new(EndpointId::new(1), "h:80");
2300 {
2301 let _guard = ep.acquire_connection_guard();
2302 assert_eq!(ep.connection_count(), 1);
2303 }
2304 assert_eq!(ep.connection_count(), 0);
2305 }
2306
2307 #[test]
2308 fn load_balance_strategy_debug_clone_copy_eq_default() {
2309 let s = LoadBalanceStrategy::default();
2310 assert_eq!(s, LoadBalanceStrategy::RoundRobin);
2311 let s2 = s;
2312 assert_eq!(s, s2);
2313 assert!(format!("{s:?}").contains("RoundRobin"));
2314 }
2315
2316 #[test]
2317 fn route_key_debug_clone_eq_ord_hash() {
2318 let oid = ObjectId::new_for_test(1);
2319 let k1 = RouteKey::Object(oid);
2320 let k2 = k1.clone();
2321 assert_eq!(k1, k2);
2322 assert!(format!("{k1:?}").contains("Object"));
2323 assert!(k1 <= k2);
2324
2325 let mut set = HashSet::new();
2326 set.insert(k1);
2327 set.insert(RouteKey::Default);
2328 assert_eq!(set.len(), 2);
2329 }
2330
2331 #[test]
2332 fn route_key_constructors() {
2333 let oid = ObjectId::new_for_test(1);
2334 let rid = RegionId::new_for_test(2, 0);
2335 assert_eq!(RouteKey::object(oid), RouteKey::Object(oid));
2336 assert_eq!(RouteKey::region(rid), RouteKey::Region(rid));
2337 }
2338
2339 #[test]
2340 fn dispatch_strategy_debug_clone_copy_eq_default() {
2341 let s = DispatchStrategy::default();
2342 assert_eq!(s, DispatchStrategy::Unicast);
2343 let s2 = s;
2344 assert_eq!(s, s2);
2345 assert!(format!("{s:?}").contains("Unicast"));
2346 }
2347
2348 #[test]
2349 fn dispatch_config_debug_clone_default() {
2350 let cfg = DispatchConfig::default();
2351 let cfg2 = cfg;
2352 assert_eq!(cfg2.max_retries, 3);
2353 assert!(format!("{cfg2:?}").contains("DispatchConfig"));
2354 }
2355
2356 #[test]
2357 fn dispatcher_stats_debug() {
2358 let stats = DispatcherStats {
2359 active_dispatches: 0,
2360 total_dispatched: 100,
2361 total_failures: 5,
2362 };
2363 let dbg = format!("{stats:?}");
2364 assert!(dbg.contains("100"));
2365 }
2366
2367 #[test]
2368 fn routing_error_debug_clone() {
2369 let err = RoutingError::EmptyTable;
2370 let err2 = err;
2371 assert!(format!("{err2:?}").contains("EmptyTable"));
2372 }
2373
2374 #[test]
2375 fn routing_error_display_all_variants() {
2376 let oid = ObjectId::new_for_test(1);
2377 let e1 = RoutingError::NoRoute {
2378 object_id: oid,
2379 reason: "gone".into(),
2380 };
2381 assert!(e1.to_string().contains("no route"));
2382 assert!(e1.to_string().contains("gone"));
2383
2384 let e2 = RoutingError::NoHealthyEndpoints { object_id: oid };
2385 assert!(e2.to_string().contains("healthy"));
2386
2387 let e3 = RoutingError::EmptyTable;
2388 assert!(e3.to_string().contains("empty"));
2389 }
2390
2391 #[test]
2392 fn routing_error_trait() {
2393 let err: Box<dyn std::error::Error> = Box::new(RoutingError::EmptyTable);
2394 assert!(!err.to_string().is_empty());
2395 }
2396
2397 #[test]
2398 fn dispatch_error_debug_clone() {
2399 let err = DispatchError::Timeout;
2400 let err2 = err;
2401 assert!(format!("{err2:?}").contains("Timeout"));
2402 }
2403
2404 #[test]
2405 fn dispatch_error_display_all_variants() {
2406 let e1 = DispatchError::RoutingFailed(RoutingError::EmptyTable);
2407 assert!(e1.to_string().contains("routing failed"));
2408
2409 let e2 = DispatchError::SendFailed {
2410 endpoint: EndpointId::new(3),
2411 reason: "down".into(),
2412 };
2413 assert!(e2.to_string().contains("send"));
2414
2415 let e3 = DispatchError::NoEndpoints;
2416 assert!(e3.to_string().contains("no endpoints"));
2417
2418 let e4 = DispatchError::InsufficientEndpoints {
2419 available: 1,
2420 required: 3,
2421 };
2422 assert!(e4.to_string().contains("insufficient"));
2423
2424 let e5 = DispatchError::Timeout;
2425 assert!(e5.to_string().contains("timeout"));
2426 }
2427
2428 #[test]
2429 fn dispatch_error_from_routing_error() {
2430 let re = RoutingError::EmptyTable;
2431 let de = DispatchError::from(re);
2432 assert!(matches!(de, DispatchError::RoutingFailed(_)));
2433 }
2434
2435 #[test]
2436 fn dispatch_error_trait() {
2437 let err: Box<dyn std::error::Error> = Box::new(DispatchError::Timeout);
2438 assert!(!err.to_string().is_empty());
2439 }
2440
2441 #[test]
2442 fn routing_entry_with_priority() {
2443 let entry = RoutingEntry::new(vec![], Time::ZERO).with_priority(10);
2444 assert_eq!(entry.priority, 10);
2445 }
2446
2447 #[test]
2448 fn routing_entry_select_endpoint_empty() {
2449 let entry = RoutingEntry::new(vec![], Time::ZERO);
2450 assert!(entry.select_endpoint(None).is_none());
2451 }
2452
2453 #[test]
2454 fn load_balancer_debug() {
2455 let lb = LoadBalancer::new(LoadBalanceStrategy::Random);
2456 assert!(format!("{lb:?}").contains("Random"));
2457 }
2458
2459 #[test]
2460 fn routing_table_debug() {
2461 let table = RoutingTable::new();
2462 assert!(format!("{table:?}").contains("RoutingTable"));
2463 }
2464}