1mod strategies;
39
40use anyhow::Result;
41use derive_more::{AsRef, Deref, Display, From};
42use nutype::nutype;
43use std::cmp::Ordering as CmpOrdering;
44use std::sync::Arc;
45use std::sync::atomic::{AtomicUsize, Ordering};
46use tracing::{debug, info};
47
48use crate::config::BackendSelectionStrategy;
49use crate::pool::DeadpoolConnectionProvider;
50use crate::types::{BackendId, ClientId, ServerName};
51use strategies::{LeastLoaded, WeightedRoundRobin};
52
53#[derive(Debug)]
55enum SelectionStrategy {
56 WeightedRoundRobin(WeightedRoundRobin),
57 LeastLoaded(LeastLoaded),
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Display, From, AsRef, Deref)]
64pub struct LoadRatio(f64);
65
66impl LoadRatio {
67 pub const MAX: Self = Self(f64::MAX);
69
70 pub const MIN: Self = Self(0.0);
72
73 #[inline]
75 #[must_use]
76 pub const fn new(ratio: f64) -> Self {
77 Self(ratio)
78 }
79
80 #[inline]
82 #[must_use]
83 pub const fn get(&self) -> f64 {
84 self.0
85 }
86}
87
88impl PartialOrd for LoadRatio {
89 fn partial_cmp(&self, other: &Self) -> Option<CmpOrdering> {
90 self.0.partial_cmp(&other.0)
91 }
92}
93
94#[derive(Debug, Clone, Display, From, AsRef, Deref)]
96#[display("PendingCount({})", "_0.load(Ordering::Relaxed)")]
97pub struct PendingCount(Arc<AtomicUsize>);
98
99impl PartialEq for PendingCount {
101 fn eq(&self, other: &Self) -> bool {
102 self.get() == other.get()
103 }
104}
105
106impl PartialEq<usize> for PendingCount {
107 fn eq(&self, other: &usize) -> bool {
108 self.get() == *other
109 }
110}
111
112impl Eq for PendingCount {}
113
114impl PendingCount {
115 #[inline]
117 #[must_use]
118 pub fn new() -> Self {
119 Self(Arc::new(AtomicUsize::new(0)))
120 }
121
122 #[inline]
124 pub fn increment(&self) {
125 self.0.fetch_add(1, Ordering::Relaxed);
126 }
127
128 #[inline]
130 pub fn decrement(&self) {
131 self.0.fetch_sub(1, Ordering::Relaxed);
132 }
133
134 #[inline]
136 #[must_use]
137 pub fn get(&self) -> usize {
138 self.0.load(Ordering::Relaxed)
139 }
140}
141
142impl Default for PendingCount {
143 fn default() -> Self {
144 Self::new()
145 }
146}
147
148#[derive(Debug, Clone, Display, From, AsRef, Deref)]
150#[display("StatefulCount({})", "_0.load(Ordering::Relaxed)")]
151pub struct StatefulCount(Arc<AtomicUsize>);
152
153impl PartialEq for StatefulCount {
155 fn eq(&self, other: &Self) -> bool {
156 self.get() == other.get()
157 }
158}
159
160impl PartialEq<usize> for StatefulCount {
161 fn eq(&self, other: &usize) -> bool {
162 self.get() == *other
163 }
164}
165
166impl Eq for StatefulCount {}
167
168impl StatefulCount {
169 #[inline]
171 #[must_use]
172 pub fn new() -> Self {
173 Self(Arc::new(AtomicUsize::new(0)))
174 }
175
176 #[inline]
178 #[must_use]
179 pub fn get(&self) -> usize {
180 self.0.load(Ordering::Relaxed)
181 }
182
183 pub fn try_acquire(&self, max_stateful: usize) -> bool {
187 let mut current = self.0.load(Ordering::Acquire);
188 loop {
189 if current >= max_stateful {
190 return false;
191 }
192
193 match self.0.compare_exchange_weak(
194 current,
195 current + 1,
196 Ordering::AcqRel,
197 Ordering::Acquire,
198 ) {
199 Ok(_) => return true,
200 Err(actual) => current = actual,
201 }
202 }
203 }
204
205 pub fn release(&self) -> Result<usize, usize> {
209 self.0
210 .fetch_update(Ordering::AcqRel, Ordering::Acquire, |current| {
211 if current == 0 {
212 None
213 } else {
214 Some(current - 1)
215 }
216 })
217 }
218}
219
220impl Default for StatefulCount {
221 fn default() -> Self {
222 Self::new()
223 }
224}
225
226#[nutype(derive(
228 Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Display, From, AsRef
229))]
230pub struct BackendCount(usize);
231
232impl PartialEq<usize> for BackendCount {
233 fn eq(&self, other: &usize) -> bool {
234 self.into_inner() == *other
235 }
236}
237
238impl PartialOrd<usize> for BackendCount {
239 fn partial_cmp(&self, other: &usize) -> Option<CmpOrdering> {
240 self.into_inner().partial_cmp(other)
241 }
242}
243
244impl BackendCount {
245 pub fn zero() -> Self {
247 Self::new(0)
248 }
249
250 #[inline]
252 #[must_use]
253 pub fn get(&self) -> usize {
254 self.into_inner()
255 }
256}
257
258#[nutype(derive(
260 Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Display, From, AsRef
261))]
262pub struct TotalWeight(usize);
263
264impl PartialEq<usize> for TotalWeight {
265 fn eq(&self, other: &usize) -> bool {
266 self.into_inner() == *other
267 }
268}
269
270impl PartialOrd<usize> for TotalWeight {
271 fn partial_cmp(&self, other: &usize) -> Option<CmpOrdering> {
272 self.into_inner().partial_cmp(other)
273 }
274}
275
276impl TotalWeight {
277 pub fn zero() -> Self {
279 Self::new(0)
280 }
281
282 #[inline]
284 #[must_use]
285 pub fn get(&self) -> usize {
286 self.into_inner()
287 }
288}
289
290#[nutype(derive(Debug, Clone, Copy, PartialEq, Display, From, AsRef))]
292pub struct TrafficShare(f64);
293
294impl TrafficShare {
295 #[inline]
297 #[must_use]
298 pub fn get(&self) -> f64 {
299 self.into_inner()
300 }
301
302 #[inline]
304 #[must_use]
305 pub fn from_weight(max_connections: usize, total_weight: TotalWeight) -> Self {
306 if total_weight.get() > 0 {
307 Self::new((max_connections as f64 / total_weight.get() as f64) * 100.0)
308 } else {
309 Self::new(0.0)
310 }
311 }
312}
313
314#[derive(Debug, Clone)]
316struct BackendInfo {
317 id: BackendId,
319 name: ServerName,
321 provider: DeadpoolConnectionProvider,
323 pending_count: PendingCount,
325 stateful_count: StatefulCount,
327 tier: u8,
329}
330
331impl BackendInfo {
332 #[must_use]
336 fn load_ratio(&self) -> LoadRatio {
337 let max_conns = self.provider.max_size() as f64;
338 if max_conns > 0.0 {
339 let pending = self.pending_count.get() as f64;
340 LoadRatio::new(pending / max_conns)
341 } else {
342 LoadRatio::MAX
343 }
344 }
345}
346
347#[derive(Debug)]
385pub struct BackendSelector {
386 backends: Vec<BackendInfo>,
388 strategy: SelectionStrategy,
390}
391
392impl Default for BackendSelector {
393 fn default() -> Self {
394 Self::new()
395 }
396}
397
398impl BackendSelector {
399 #[inline]
403 fn find_backend(&self, backend_id: BackendId) -> Option<&BackendInfo> {
404 self.backends.iter().find(|b| b.id == backend_id)
405 }
406
407 #[inline]
412 #[must_use]
413 pub fn get_tier(&self, backend_id: BackendId) -> Option<u8> {
414 self.find_backend(backend_id).map(|b| b.tier)
415 }
416
417 #[must_use]
419 pub fn new() -> Self {
420 Self::with_strategy(BackendSelectionStrategy::WeightedRoundRobin)
421 }
422
423 #[must_use]
425 pub fn with_strategy(strategy: BackendSelectionStrategy) -> Self {
426 let selection_strategy = match strategy {
427 BackendSelectionStrategy::WeightedRoundRobin => {
428 SelectionStrategy::WeightedRoundRobin(WeightedRoundRobin::new(0))
429 }
430 BackendSelectionStrategy::LeastLoaded => {
431 SelectionStrategy::LeastLoaded(LeastLoaded::new())
432 }
433 };
434
435 Self {
436 backends: Vec::with_capacity(4),
438 strategy: selection_strategy,
439 }
440 }
441
442 pub fn add_backend(
450 &mut self,
451 backend_id: BackendId,
452 name: ServerName,
453 provider: DeadpoolConnectionProvider,
454 tier: u8,
455 ) {
456 let max_connections = provider.max_size();
457
458 match &mut self.strategy {
460 SelectionStrategy::WeightedRoundRobin(wrr) => {
461 let old_weight = TotalWeight::new(wrr.total_weight());
462 let new_weight = TotalWeight::new(old_weight.get() + max_connections);
463 wrr.set_total_weight(new_weight.get());
464
465 let traffic_share = TrafficShare::from_weight(max_connections, new_weight);
467
468 info!(
469 "Added backend {:?} ({}) tier {} with {} connections - will receive {:.1}% of traffic (total weight: {} -> {}) [weighted round-robin]",
470 backend_id,
471 name,
472 tier,
473 max_connections,
474 traffic_share.get(),
475 old_weight,
476 new_weight
477 );
478 }
479 SelectionStrategy::LeastLoaded(_) => {
480 info!(
481 "Added backend {:?} ({}) tier {} with {} connections [least-loaded strategy]",
482 backend_id, name, tier, max_connections
483 );
484 }
485 }
486
487 self.backends.push(BackendInfo {
488 id: backend_id,
489 name,
490 provider,
491 pending_count: PendingCount::new(),
492 stateful_count: StatefulCount::new(),
493 tier,
494 });
495 }
496
497 fn select_backend(
507 &self,
508 availability: Option<&crate::cache::ArticleAvailability>,
509 ) -> Option<&BackendInfo> {
510 if self.backends.is_empty() {
511 return None;
512 }
513
514 let is_available =
516 |backend: &&BackendInfo| availability.is_none_or(|avail| avail.should_try(backend.id));
517
518 let should_apply_tier_filtering = availability.is_some();
521
522 let lowest_available_tier = if should_apply_tier_filtering {
524 self.backends
525 .iter()
526 .filter(|b| is_available(b))
527 .map(|b| b.tier)
528 .min()?
529 } else {
530 0 };
532
533 let tier_filter = |backend: &&BackendInfo| {
535 if should_apply_tier_filtering {
536 backend.tier == lowest_available_tier && is_available(backend)
537 } else {
538 is_available(backend)
539 }
540 };
541
542 match &self.strategy {
543 SelectionStrategy::WeightedRoundRobin(wrr) => {
544 let tier_total_weight: usize = self
546 .backends
547 .iter()
548 .filter(tier_filter)
549 .map(|b| b.provider.max_size())
550 .sum();
551
552 if tier_total_weight == 0 {
553 return None;
555 }
556
557 let tier_position = wrr.select_with_weight(tier_total_weight)?;
560
561 self.backends
563 .iter()
564 .filter(tier_filter)
565 .scan(0, |cumulative, backend| {
566 *cumulative += backend.provider.max_size();
567 Some((*cumulative, backend))
568 })
569 .find(|(cumulative_weight, _)| tier_position < *cumulative_weight)
570 .map(|(_, backend)| backend)
571 .or_else(|| {
572 self.backends.iter().find(tier_filter)
574 })
575 }
576 SelectionStrategy::LeastLoaded(_) => {
577 self.backends.iter().filter(tier_filter).min_by(|a, b| {
579 a.load_ratio()
580 .partial_cmp(&b.load_ratio())
581 .unwrap_or(std::cmp::Ordering::Equal)
582 })
583 }
584 }
585 }
586
587 pub fn route_command(&self, _client_id: ClientId, _command: &str) -> Result<BackendId> {
590 self.route_command_with_availability(_client_id, _command, None)
591 }
592
593 pub fn route_command_with_availability(
595 &self,
596 _client_id: ClientId,
597 _command: &str,
598 availability: Option<&crate::cache::ArticleAvailability>,
599 ) -> Result<BackendId> {
600 let backend = self.select_backend(availability).ok_or_else(|| {
601 anyhow::anyhow!(
602 "No backends available for routing (total backends: {})",
603 self.backends.len()
604 )
605 })?;
606
607 backend.pending_count.increment();
609
610 debug!(
611 "Selected backend {:?} ({}) for command",
612 backend.id, backend.name
613 );
614
615 Ok(backend.id)
616 }
617
618 pub fn complete_command(&self, backend_id: BackendId) {
620 if let Some(backend) = self.find_backend(backend_id) {
621 backend.pending_count.decrement();
622 }
623 }
624
625 pub fn mark_backend_pending(&self, backend_id: BackendId) {
628 if let Some(backend) = self.find_backend(backend_id) {
629 backend.pending_count.increment();
630 }
631 }
632
633 #[must_use]
635 pub fn backend_provider(&self, backend_id: BackendId) -> Option<&DeadpoolConnectionProvider> {
636 self.find_backend(backend_id).map(|b| &b.provider)
637 }
638
639 #[must_use]
641 #[inline]
642 pub fn backend_count(&self) -> BackendCount {
643 BackendCount::new(self.backends.len())
644 }
645
646 #[must_use]
649 #[inline]
650 pub fn total_weight(&self) -> TotalWeight {
651 match &self.strategy {
652 SelectionStrategy::WeightedRoundRobin(wrr) => TotalWeight::new(wrr.total_weight()),
653 SelectionStrategy::LeastLoaded(_) => {
654 TotalWeight::new(self.backends.iter().map(|b| b.provider.max_size()).sum())
656 }
657 }
658 }
659
660 #[must_use]
665 pub fn backend_load(&self, backend_id: BackendId) -> Option<PendingCount> {
666 self.find_backend(backend_id)
667 .map(|b| b.pending_count.clone())
668 }
669
670 pub fn try_acquire_stateful(&self, backend_id: BackendId) -> bool {
674 if let Some(backend) = self.find_backend(backend_id) {
675 let max_connections = backend.provider.max_size();
677
678 let max_stateful = max_connections.saturating_sub(1);
680
681 let acquired = backend.stateful_count.try_acquire(max_stateful);
683
684 if acquired {
685 debug!(
686 "Backend {:?} ({}) acquired stateful slot: {}/{}",
687 backend_id,
688 backend.name,
689 backend.stateful_count.get(),
690 max_stateful
691 );
692 } else {
693 debug!(
694 "Backend {:?} ({}) stateful limit reached: {}/{}",
695 backend_id,
696 backend.name,
697 backend.stateful_count.get(),
698 max_stateful
699 );
700 }
701
702 acquired
703 } else {
704 false
705 }
706 }
707
708 pub fn release_stateful(&self, backend_id: BackendId) {
710 if let Some(backend) = self.find_backend(backend_id) {
711 match backend.stateful_count.release() {
713 Ok(prev) => {
714 debug!(
715 "Backend {:?} ({}) released stateful slot: {}/{}",
716 backend_id,
717 backend.name,
718 prev - 1,
719 backend.provider.max_size().saturating_sub(1)
720 );
721 }
722 Err(0) => {
723 debug!(
724 "Backend {:?} ({}) release_stateful called when count already 0",
725 backend_id, backend.name
726 );
727 }
728 Err(other) => unreachable!(
729 "Unexpected error in release: got Err({other}), expected only Err(0)"
730 ),
731 }
732 }
733 }
734
735 #[must_use]
740 pub fn stateful_count(&self, backend_id: BackendId) -> Option<StatefulCount> {
741 self.find_backend(backend_id)
742 .map(|b| b.stateful_count.clone())
743 }
744
745 #[must_use]
749 pub fn backend_load_ratio(&self, backend_id: BackendId) -> Option<LoadRatio> {
750 self.find_backend(backend_id).map(|b| b.load_ratio())
751 }
752
753 #[must_use]
758 pub fn backend_traffic_share(&self, backend_id: BackendId) -> Option<TrafficShare> {
759 self.find_backend(backend_id).map(|b| {
760 let total = self.total_weight();
761 TrafficShare::from_weight(b.provider.max_size(), total)
762 })
763 }
764}