1use async_trait::async_trait;
7use pingora::upstreams::peer::HttpPeer;
8use rand::seq::IndexedRandom;
9use std::collections::HashMap;
10use std::net::ToSocketAddrs;
11use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
12use std::sync::Arc;
13use std::time::Duration;
14use tokio::sync::RwLock;
15use tracing::{debug, error, info, trace, warn};
16
17use grapsus_common::{
18 errors::{GrapsusError, GrapsusResult},
19 types::{CircuitBreakerConfig, LoadBalancingAlgorithm},
20 CircuitBreaker, UpstreamId,
21};
22use grapsus_config::UpstreamConfig;
23
24#[derive(Debug, Clone)]
33pub struct UpstreamTarget {
34 pub address: String,
36 pub port: u16,
38 pub weight: u32,
40}
41
42impl UpstreamTarget {
43 pub fn new(address: impl Into<String>, port: u16, weight: u32) -> Self {
45 Self {
46 address: address.into(),
47 port,
48 weight,
49 }
50 }
51
52 pub fn from_address(addr: &str) -> Option<Self> {
54 let parts: Vec<&str> = addr.rsplitn(2, ':').collect();
55 if parts.len() == 2 {
56 let port = parts[0].parse().ok()?;
57 let address = parts[1].to_string();
58 Some(Self {
59 address,
60 port,
61 weight: 100,
62 })
63 } else {
64 None
65 }
66 }
67
68 pub fn from_config(config: &grapsus_config::UpstreamTarget) -> Option<Self> {
70 Self::from_address(&config.address).map(|mut t| {
71 t.weight = config.weight;
72 t
73 })
74 }
75
76 pub fn full_address(&self) -> String {
78 format!("{}:{}", self.address, self.port)
79 }
80}
81
82pub mod adaptive;
88pub mod consistent_hash;
89pub mod health;
90pub mod inference_health;
91pub mod least_tokens;
92pub mod locality;
93pub mod maglev;
94pub mod p2c;
95pub mod peak_ewma;
96pub mod sticky_session;
97pub mod subset;
98pub mod weighted_least_conn;
99
100pub use adaptive::{AdaptiveBalancer, AdaptiveConfig};
102pub use consistent_hash::{ConsistentHashBalancer, ConsistentHashConfig};
103pub use health::{ActiveHealthChecker, HealthCheckRunner};
104pub use inference_health::InferenceHealthCheck;
105pub use least_tokens::{
106 LeastTokensQueuedBalancer, LeastTokensQueuedConfig, LeastTokensQueuedTargetStats,
107};
108pub use locality::{LocalityAwareBalancer, LocalityAwareConfig};
109pub use maglev::{MaglevBalancer, MaglevConfig};
110pub use p2c::{P2cBalancer, P2cConfig};
111pub use peak_ewma::{PeakEwmaBalancer, PeakEwmaConfig};
112pub use sticky_session::{StickySessionBalancer, StickySessionRuntimeConfig};
113pub use subset::{SubsetBalancer, SubsetConfig};
114pub use weighted_least_conn::{WeightedLeastConnBalancer, WeightedLeastConnConfig};
115
116#[derive(Debug, Clone)]
118pub struct RequestContext {
119 pub client_ip: Option<std::net::SocketAddr>,
120 pub headers: HashMap<String, String>,
121 pub path: String,
122 pub method: String,
123}
124
125#[async_trait]
127pub trait LoadBalancer: Send + Sync {
128 async fn select(&self, context: Option<&RequestContext>) -> GrapsusResult<TargetSelection>;
130
131 async fn report_health(&self, address: &str, healthy: bool);
133
134 async fn healthy_targets(&self) -> Vec<String>;
136
137 async fn release(&self, _selection: &TargetSelection) {
139 }
141
142 async fn report_result(
144 &self,
145 _selection: &TargetSelection,
146 _success: bool,
147 _latency: Option<Duration>,
148 ) {
149 }
151
152 async fn report_result_with_latency(
159 &self,
160 address: &str,
161 success: bool,
162 _latency: Option<Duration>,
163 ) {
164 self.report_health(address, success).await;
166 }
167}
168
169#[derive(Debug, Clone)]
171pub struct TargetSelection {
172 pub address: String,
174 pub weight: u32,
176 pub metadata: HashMap<String, String>,
178}
179
180pub struct UpstreamPool {
182 id: UpstreamId,
184 targets: Vec<UpstreamTarget>,
186 load_balancer: Arc<dyn LoadBalancer>,
188 pool_config: ConnectionPoolConfig,
190 http_version: HttpVersionOptions,
192 tls_enabled: bool,
194 tls_sni: Option<String>,
196 tls_config: Option<grapsus_config::UpstreamTlsConfig>,
198 circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
200 stats: Arc<PoolStats>,
202}
203
204pub struct ConnectionPoolConfig {
213 pub max_connections: usize,
215 pub max_idle: usize,
217 pub idle_timeout: Duration,
219 pub max_lifetime: Option<Duration>,
221 pub connection_timeout: Duration,
223 pub read_timeout: Duration,
225 pub write_timeout: Duration,
227}
228
229pub struct HttpVersionOptions {
231 pub min_version: u8,
233 pub max_version: u8,
235 pub h2_ping_interval: Duration,
237 pub max_h2_streams: usize,
239}
240
241impl ConnectionPoolConfig {
242 pub fn from_config(
244 pool_config: &grapsus_config::ConnectionPoolConfig,
245 timeouts: &grapsus_config::UpstreamTimeouts,
246 ) -> Self {
247 Self {
248 max_connections: pool_config.max_connections,
249 max_idle: pool_config.max_idle,
250 idle_timeout: Duration::from_secs(pool_config.idle_timeout_secs),
251 max_lifetime: pool_config.max_lifetime_secs.map(Duration::from_secs),
252 connection_timeout: Duration::from_secs(timeouts.connect_secs),
253 read_timeout: Duration::from_secs(timeouts.read_secs),
254 write_timeout: Duration::from_secs(timeouts.write_secs),
255 }
256 }
257}
258
259#[derive(Default)]
263pub struct PoolStats {
264 pub requests: AtomicU64,
266 pub successes: AtomicU64,
268 pub failures: AtomicU64,
270 pub retries: AtomicU64,
272 pub circuit_breaker_trips: AtomicU64,
274}
275
276#[derive(Debug, Clone)]
278pub struct ShadowTarget {
279 pub scheme: String,
281 pub host: String,
283 pub port: u16,
285 pub sni: Option<String>,
287}
288
289impl ShadowTarget {
290 pub fn build_url(&self, path: &str) -> String {
292 let port_suffix = match (self.scheme.as_str(), self.port) {
293 ("http", 80) | ("https", 443) => String::new(),
294 _ => format!(":{}", self.port),
295 };
296 format!("{}://{}{}{}", self.scheme, self.host, port_suffix, path)
297 }
298}
299
300#[derive(Debug, Clone)]
302pub struct PoolConfigSnapshot {
303 pub max_connections: usize,
305 pub max_idle: usize,
307 pub idle_timeout_secs: u64,
309 pub max_lifetime_secs: Option<u64>,
311 pub connection_timeout_secs: u64,
313 pub read_timeout_secs: u64,
315 pub write_timeout_secs: u64,
317}
318
319struct RoundRobinBalancer {
321 targets: Vec<UpstreamTarget>,
322 current: AtomicUsize,
323 health_status: Arc<RwLock<HashMap<String, bool>>>,
324}
325
326impl RoundRobinBalancer {
327 fn new(targets: Vec<UpstreamTarget>) -> Self {
328 let mut health_status = HashMap::new();
329 for target in &targets {
330 health_status.insert(target.full_address(), true);
331 }
332
333 Self {
334 targets,
335 current: AtomicUsize::new(0),
336 health_status: Arc::new(RwLock::new(health_status)),
337 }
338 }
339}
340
341#[async_trait]
342impl LoadBalancer for RoundRobinBalancer {
343 async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
344 trace!(
345 total_targets = self.targets.len(),
346 algorithm = "round_robin",
347 "Selecting upstream target"
348 );
349
350 let health = self.health_status.read().await;
351 let healthy_targets: Vec<_> = self
352 .targets
353 .iter()
354 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
355 .collect();
356
357 if healthy_targets.is_empty() {
358 warn!(
359 total_targets = self.targets.len(),
360 algorithm = "round_robin",
361 "No healthy upstream targets available"
362 );
363 return Err(GrapsusError::NoHealthyUpstream);
364 }
365
366 let index = self.current.fetch_add(1, Ordering::Relaxed) % healthy_targets.len();
367 let target = healthy_targets[index];
368
369 trace!(
370 selected_target = %target.full_address(),
371 healthy_count = healthy_targets.len(),
372 index = index,
373 algorithm = "round_robin",
374 "Selected target via round robin"
375 );
376
377 Ok(TargetSelection {
378 address: target.full_address(),
379 weight: target.weight,
380 metadata: HashMap::new(),
381 })
382 }
383
384 async fn report_health(&self, address: &str, healthy: bool) {
385 trace!(
386 target = %address,
387 healthy = healthy,
388 algorithm = "round_robin",
389 "Updating target health status"
390 );
391 self.health_status
392 .write()
393 .await
394 .insert(address.to_string(), healthy);
395 }
396
397 async fn healthy_targets(&self) -> Vec<String> {
398 self.health_status
399 .read()
400 .await
401 .iter()
402 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
403 .collect()
404 }
405}
406
407struct RandomBalancer {
409 targets: Vec<UpstreamTarget>,
410 health_status: Arc<RwLock<HashMap<String, bool>>>,
411}
412
413impl RandomBalancer {
414 fn new(targets: Vec<UpstreamTarget>) -> Self {
415 let mut health_status = HashMap::new();
416 for target in &targets {
417 health_status.insert(target.full_address(), true);
418 }
419
420 Self {
421 targets,
422 health_status: Arc::new(RwLock::new(health_status)),
423 }
424 }
425}
426
427#[async_trait]
428impl LoadBalancer for RandomBalancer {
429 async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
430 use rand::seq::SliceRandom;
431
432 trace!(
433 total_targets = self.targets.len(),
434 algorithm = "random",
435 "Selecting upstream target"
436 );
437
438 let health = self.health_status.read().await;
439 let healthy_targets: Vec<_> = self
440 .targets
441 .iter()
442 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
443 .collect();
444
445 if healthy_targets.is_empty() {
446 warn!(
447 total_targets = self.targets.len(),
448 algorithm = "random",
449 "No healthy upstream targets available"
450 );
451 return Err(GrapsusError::NoHealthyUpstream);
452 }
453
454 let mut rng = rand::rng();
455 let target = healthy_targets
456 .choose(&mut rng)
457 .ok_or(GrapsusError::NoHealthyUpstream)?;
458
459 trace!(
460 selected_target = %target.full_address(),
461 healthy_count = healthy_targets.len(),
462 algorithm = "random",
463 "Selected target via random selection"
464 );
465
466 Ok(TargetSelection {
467 address: target.full_address(),
468 weight: target.weight,
469 metadata: HashMap::new(),
470 })
471 }
472
473 async fn report_health(&self, address: &str, healthy: bool) {
474 trace!(
475 target = %address,
476 healthy = healthy,
477 algorithm = "random",
478 "Updating target health status"
479 );
480 self.health_status
481 .write()
482 .await
483 .insert(address.to_string(), healthy);
484 }
485
486 async fn healthy_targets(&self) -> Vec<String> {
487 self.health_status
488 .read()
489 .await
490 .iter()
491 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
492 .collect()
493 }
494}
495
496struct LeastConnectionsBalancer {
498 targets: Vec<UpstreamTarget>,
499 connections: Arc<RwLock<HashMap<String, usize>>>,
500 health_status: Arc<RwLock<HashMap<String, bool>>>,
501}
502
503impl LeastConnectionsBalancer {
504 fn new(targets: Vec<UpstreamTarget>) -> Self {
505 let mut health_status = HashMap::new();
506 let mut connections = HashMap::new();
507
508 for target in &targets {
509 let addr = target.full_address();
510 health_status.insert(addr.clone(), true);
511 connections.insert(addr, 0);
512 }
513
514 Self {
515 targets,
516 connections: Arc::new(RwLock::new(connections)),
517 health_status: Arc::new(RwLock::new(health_status)),
518 }
519 }
520}
521
522#[async_trait]
523impl LoadBalancer for LeastConnectionsBalancer {
524 async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
525 trace!(
526 total_targets = self.targets.len(),
527 algorithm = "least_connections",
528 "Selecting upstream target"
529 );
530
531 let health = self.health_status.read().await;
532 let conns = self.connections.read().await;
533
534 let mut best_target = None;
535 let mut min_connections = usize::MAX;
536
537 for target in &self.targets {
538 let addr = target.full_address();
539 if !*health.get(&addr).unwrap_or(&true) {
540 trace!(
541 target = %addr,
542 algorithm = "least_connections",
543 "Skipping unhealthy target"
544 );
545 continue;
546 }
547
548 let conn_count = *conns.get(&addr).unwrap_or(&0);
549 trace!(
550 target = %addr,
551 connections = conn_count,
552 "Evaluating target connection count"
553 );
554 if conn_count < min_connections {
555 min_connections = conn_count;
556 best_target = Some(target);
557 }
558 }
559
560 match best_target {
561 Some(target) => {
562 trace!(
563 selected_target = %target.full_address(),
564 connections = min_connections,
565 algorithm = "least_connections",
566 "Selected target with fewest connections"
567 );
568 Ok(TargetSelection {
569 address: target.full_address(),
570 weight: target.weight,
571 metadata: HashMap::new(),
572 })
573 }
574 None => {
575 warn!(
576 total_targets = self.targets.len(),
577 algorithm = "least_connections",
578 "No healthy upstream targets available"
579 );
580 Err(GrapsusError::NoHealthyUpstream)
581 }
582 }
583 }
584
585 async fn report_health(&self, address: &str, healthy: bool) {
586 trace!(
587 target = %address,
588 healthy = healthy,
589 algorithm = "least_connections",
590 "Updating target health status"
591 );
592 self.health_status
593 .write()
594 .await
595 .insert(address.to_string(), healthy);
596 }
597
598 async fn healthy_targets(&self) -> Vec<String> {
599 self.health_status
600 .read()
601 .await
602 .iter()
603 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
604 .collect()
605 }
606}
607
608struct WeightedBalancer {
610 targets: Vec<UpstreamTarget>,
611 weights: Vec<u32>,
612 current_index: AtomicUsize,
613 health_status: Arc<RwLock<HashMap<String, bool>>>,
614}
615
616#[async_trait]
617impl LoadBalancer for WeightedBalancer {
618 async fn select(&self, _context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
619 trace!(
620 total_targets = self.targets.len(),
621 algorithm = "weighted",
622 "Selecting upstream target"
623 );
624
625 let health = self.health_status.read().await;
626 let healthy_indices: Vec<_> = self
627 .targets
628 .iter()
629 .enumerate()
630 .filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
631 .map(|(i, _)| i)
632 .collect();
633
634 if healthy_indices.is_empty() {
635 warn!(
636 total_targets = self.targets.len(),
637 algorithm = "weighted",
638 "No healthy upstream targets available"
639 );
640 return Err(GrapsusError::NoHealthyUpstream);
641 }
642
643 let idx = self.current_index.fetch_add(1, Ordering::Relaxed) % healthy_indices.len();
644 let target_idx = healthy_indices[idx];
645 let target = &self.targets[target_idx];
646 let weight = self.weights.get(target_idx).copied().unwrap_or(1);
647
648 trace!(
649 selected_target = %target.full_address(),
650 weight = weight,
651 healthy_count = healthy_indices.len(),
652 algorithm = "weighted",
653 "Selected target via weighted round robin"
654 );
655
656 Ok(TargetSelection {
657 address: target.full_address(),
658 weight,
659 metadata: HashMap::new(),
660 })
661 }
662
663 async fn report_health(&self, address: &str, healthy: bool) {
664 trace!(
665 target = %address,
666 healthy = healthy,
667 algorithm = "weighted",
668 "Updating target health status"
669 );
670 self.health_status
671 .write()
672 .await
673 .insert(address.to_string(), healthy);
674 }
675
676 async fn healthy_targets(&self) -> Vec<String> {
677 self.health_status
678 .read()
679 .await
680 .iter()
681 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
682 .collect()
683 }
684}
685
686struct IpHashBalancer {
688 targets: Vec<UpstreamTarget>,
689 health_status: Arc<RwLock<HashMap<String, bool>>>,
690}
691
692#[async_trait]
693impl LoadBalancer for IpHashBalancer {
694 async fn select(&self, context: Option<&RequestContext>) -> GrapsusResult<TargetSelection> {
695 trace!(
696 total_targets = self.targets.len(),
697 algorithm = "ip_hash",
698 "Selecting upstream target"
699 );
700
701 let health = self.health_status.read().await;
702 let healthy_targets: Vec<_> = self
703 .targets
704 .iter()
705 .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
706 .collect();
707
708 if healthy_targets.is_empty() {
709 warn!(
710 total_targets = self.targets.len(),
711 algorithm = "ip_hash",
712 "No healthy upstream targets available"
713 );
714 return Err(GrapsusError::NoHealthyUpstream);
715 }
716
717 let (hash, client_ip_str) = if let Some(ctx) = context {
719 if let Some(ip) = &ctx.client_ip {
720 use std::hash::{Hash, Hasher};
721 let mut hasher = std::collections::hash_map::DefaultHasher::new();
722 ip.hash(&mut hasher);
723 (hasher.finish(), Some(ip.to_string()))
724 } else {
725 (0, None)
726 }
727 } else {
728 (0, None)
729 };
730
731 let idx = (hash as usize) % healthy_targets.len();
732 let target = healthy_targets[idx];
733
734 trace!(
735 selected_target = %target.full_address(),
736 client_ip = client_ip_str.as_deref().unwrap_or("unknown"),
737 hash = hash,
738 index = idx,
739 healthy_count = healthy_targets.len(),
740 algorithm = "ip_hash",
741 "Selected target via IP hash"
742 );
743
744 Ok(TargetSelection {
745 address: target.full_address(),
746 weight: target.weight,
747 metadata: HashMap::new(),
748 })
749 }
750
751 async fn report_health(&self, address: &str, healthy: bool) {
752 trace!(
753 target = %address,
754 healthy = healthy,
755 algorithm = "ip_hash",
756 "Updating target health status"
757 );
758 self.health_status
759 .write()
760 .await
761 .insert(address.to_string(), healthy);
762 }
763
764 async fn healthy_targets(&self) -> Vec<String> {
765 self.health_status
766 .read()
767 .await
768 .iter()
769 .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
770 .collect()
771 }
772}
773
774impl UpstreamPool {
775 pub async fn new(config: UpstreamConfig) -> GrapsusResult<Self> {
777 let id = UpstreamId::new(&config.id);
778
779 info!(
780 upstream_id = %config.id,
781 target_count = config.targets.len(),
782 algorithm = ?config.load_balancing,
783 "Creating upstream pool"
784 );
785
786 let targets: Vec<UpstreamTarget> = config
788 .targets
789 .iter()
790 .filter_map(UpstreamTarget::from_config)
791 .collect();
792
793 if targets.is_empty() {
794 error!(
795 upstream_id = %config.id,
796 "No valid upstream targets configured"
797 );
798 return Err(GrapsusError::Config {
799 message: "No valid upstream targets".to_string(),
800 source: None,
801 });
802 }
803
804 for target in &targets {
805 debug!(
806 upstream_id = %config.id,
807 target = %target.full_address(),
808 weight = target.weight,
809 "Registered upstream target"
810 );
811 }
812
813 debug!(
815 upstream_id = %config.id,
816 algorithm = ?config.load_balancing,
817 "Creating load balancer"
818 );
819 let load_balancer = Self::create_load_balancer(&config.load_balancing, &targets, &config)?;
820
821 debug!(
823 upstream_id = %config.id,
824 max_connections = config.connection_pool.max_connections,
825 max_idle = config.connection_pool.max_idle,
826 idle_timeout_secs = config.connection_pool.idle_timeout_secs,
827 connect_timeout_secs = config.timeouts.connect_secs,
828 read_timeout_secs = config.timeouts.read_secs,
829 write_timeout_secs = config.timeouts.write_secs,
830 "Creating connection pool configuration"
831 );
832 let pool_config =
833 ConnectionPoolConfig::from_config(&config.connection_pool, &config.timeouts);
834
835 let http_version = HttpVersionOptions {
837 min_version: config.http_version.min_version,
838 max_version: config.http_version.max_version,
839 h2_ping_interval: if config.http_version.h2_ping_interval_secs > 0 {
840 Duration::from_secs(config.http_version.h2_ping_interval_secs)
841 } else {
842 Duration::ZERO
843 },
844 max_h2_streams: config.http_version.max_h2_streams,
845 };
846
847 let tls_enabled = config.tls.is_some();
849 let tls_sni = config.tls.as_ref().and_then(|t| t.sni.clone());
850 let tls_config = config.tls.clone();
851
852 if let Some(ref tls) = tls_config {
854 if tls.client_cert.is_some() {
855 info!(
856 upstream_id = %config.id,
857 "mTLS enabled for upstream (client certificate configured)"
858 );
859 }
860 }
861
862 if http_version.max_version >= 2 && tls_enabled {
863 info!(
864 upstream_id = %config.id,
865 "HTTP/2 enabled for upstream (via ALPN)"
866 );
867 }
868
869 let mut circuit_breakers = HashMap::new();
871 for target in &targets {
872 trace!(
873 upstream_id = %config.id,
874 target = %target.full_address(),
875 "Initializing circuit breaker for target"
876 );
877 circuit_breakers.insert(
878 target.full_address(),
879 CircuitBreaker::new(CircuitBreakerConfig::default()),
880 );
881 }
882
883 let pool = Self {
884 id: id.clone(),
885 targets,
886 load_balancer,
887 pool_config,
888 http_version,
889 tls_enabled,
890 tls_sni,
891 tls_config,
892 circuit_breakers: Arc::new(RwLock::new(circuit_breakers)),
893 stats: Arc::new(PoolStats::default()),
894 };
895
896 info!(
897 upstream_id = %id,
898 target_count = pool.targets.len(),
899 "Upstream pool created successfully"
900 );
901
902 Ok(pool)
903 }
904
905 fn create_load_balancer(
907 algorithm: &LoadBalancingAlgorithm,
908 targets: &[UpstreamTarget],
909 config: &UpstreamConfig,
910 ) -> GrapsusResult<Arc<dyn LoadBalancer>> {
911 let balancer: Arc<dyn LoadBalancer> = match algorithm {
912 LoadBalancingAlgorithm::RoundRobin => {
913 Arc::new(RoundRobinBalancer::new(targets.to_vec()))
914 }
915 LoadBalancingAlgorithm::LeastConnections => {
916 Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
917 }
918 LoadBalancingAlgorithm::Weighted => {
919 let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
920 Arc::new(WeightedBalancer {
921 targets: targets.to_vec(),
922 weights,
923 current_index: AtomicUsize::new(0),
924 health_status: Arc::new(RwLock::new(HashMap::new())),
925 })
926 }
927 LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
928 targets: targets.to_vec(),
929 health_status: Arc::new(RwLock::new(HashMap::new())),
930 }),
931 LoadBalancingAlgorithm::Random => Arc::new(RandomBalancer::new(targets.to_vec())),
932 LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
933 targets.to_vec(),
934 ConsistentHashConfig::default(),
935 )),
936 LoadBalancingAlgorithm::PowerOfTwoChoices => {
937 Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
938 }
939 LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
940 targets.to_vec(),
941 AdaptiveConfig::default(),
942 )),
943 LoadBalancingAlgorithm::LeastTokensQueued => Arc::new(LeastTokensQueuedBalancer::new(
944 targets.to_vec(),
945 LeastTokensQueuedConfig::default(),
946 )),
947 LoadBalancingAlgorithm::Maglev => Arc::new(MaglevBalancer::new(
948 targets.to_vec(),
949 MaglevConfig::default(),
950 )),
951 LoadBalancingAlgorithm::LocalityAware => Arc::new(LocalityAwareBalancer::new(
952 targets.to_vec(),
953 LocalityAwareConfig::default(),
954 )),
955 LoadBalancingAlgorithm::PeakEwma => Arc::new(PeakEwmaBalancer::new(
956 targets.to_vec(),
957 PeakEwmaConfig::default(),
958 )),
959 LoadBalancingAlgorithm::DeterministicSubset => Arc::new(SubsetBalancer::new(
960 targets.to_vec(),
961 SubsetConfig::default(),
962 )),
963 LoadBalancingAlgorithm::WeightedLeastConnections => {
964 Arc::new(WeightedLeastConnBalancer::new(
965 targets.to_vec(),
966 WeightedLeastConnConfig::default(),
967 ))
968 }
969 LoadBalancingAlgorithm::Sticky => {
970 let sticky_config = config.sticky_session.as_ref().ok_or_else(|| {
972 GrapsusError::Config {
973 message: format!(
974 "Upstream '{}' uses Sticky algorithm but no sticky_session config provided",
975 config.id
976 ),
977 source: None,
978 }
979 })?;
980
981 let runtime_config = StickySessionRuntimeConfig::from_config(sticky_config);
983
984 let fallback = Self::create_load_balancer_inner(&sticky_config.fallback, targets)?;
986
987 info!(
988 upstream_id = %config.id,
989 cookie_name = %runtime_config.cookie_name,
990 cookie_ttl_secs = runtime_config.cookie_ttl_secs,
991 fallback_algorithm = ?sticky_config.fallback,
992 "Creating sticky session balancer"
993 );
994
995 Arc::new(StickySessionBalancer::new(
996 targets.to_vec(),
997 runtime_config,
998 fallback,
999 ))
1000 }
1001 };
1002 Ok(balancer)
1003 }
1004
1005 fn create_load_balancer_inner(
1007 algorithm: &LoadBalancingAlgorithm,
1008 targets: &[UpstreamTarget],
1009 ) -> GrapsusResult<Arc<dyn LoadBalancer>> {
1010 let balancer: Arc<dyn LoadBalancer> = match algorithm {
1011 LoadBalancingAlgorithm::RoundRobin => {
1012 Arc::new(RoundRobinBalancer::new(targets.to_vec()))
1013 }
1014 LoadBalancingAlgorithm::LeastConnections => {
1015 Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
1016 }
1017 LoadBalancingAlgorithm::Weighted => {
1018 let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
1019 Arc::new(WeightedBalancer {
1020 targets: targets.to_vec(),
1021 weights,
1022 current_index: AtomicUsize::new(0),
1023 health_status: Arc::new(RwLock::new(HashMap::new())),
1024 })
1025 }
1026 LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
1027 targets: targets.to_vec(),
1028 health_status: Arc::new(RwLock::new(HashMap::new())),
1029 }),
1030 LoadBalancingAlgorithm::Random => Arc::new(RandomBalancer::new(targets.to_vec())),
1031 LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
1032 targets.to_vec(),
1033 ConsistentHashConfig::default(),
1034 )),
1035 LoadBalancingAlgorithm::PowerOfTwoChoices => {
1036 Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
1037 }
1038 LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
1039 targets.to_vec(),
1040 AdaptiveConfig::default(),
1041 )),
1042 LoadBalancingAlgorithm::LeastTokensQueued => Arc::new(LeastTokensQueuedBalancer::new(
1043 targets.to_vec(),
1044 LeastTokensQueuedConfig::default(),
1045 )),
1046 LoadBalancingAlgorithm::Maglev => Arc::new(MaglevBalancer::new(
1047 targets.to_vec(),
1048 MaglevConfig::default(),
1049 )),
1050 LoadBalancingAlgorithm::LocalityAware => Arc::new(LocalityAwareBalancer::new(
1051 targets.to_vec(),
1052 LocalityAwareConfig::default(),
1053 )),
1054 LoadBalancingAlgorithm::PeakEwma => Arc::new(PeakEwmaBalancer::new(
1055 targets.to_vec(),
1056 PeakEwmaConfig::default(),
1057 )),
1058 LoadBalancingAlgorithm::DeterministicSubset => Arc::new(SubsetBalancer::new(
1059 targets.to_vec(),
1060 SubsetConfig::default(),
1061 )),
1062 LoadBalancingAlgorithm::WeightedLeastConnections => {
1063 Arc::new(WeightedLeastConnBalancer::new(
1064 targets.to_vec(),
1065 WeightedLeastConnConfig::default(),
1066 ))
1067 }
1068 LoadBalancingAlgorithm::Sticky => {
1069 return Err(GrapsusError::Config {
1071 message: "Sticky algorithm cannot be used as fallback for sticky sessions"
1072 .to_string(),
1073 source: None,
1074 });
1075 }
1076 };
1077 Ok(balancer)
1078 }
1079
1080 pub async fn select_peer_with_metadata(
1086 &self,
1087 context: Option<&RequestContext>,
1088 ) -> GrapsusResult<(HttpPeer, HashMap<String, String>)> {
1089 let request_num = self.stats.requests.fetch_add(1, Ordering::Relaxed) + 1;
1090
1091 trace!(
1092 upstream_id = %self.id,
1093 request_num = request_num,
1094 target_count = self.targets.len(),
1095 "Starting peer selection with metadata"
1096 );
1097
1098 let mut attempts = 0;
1099 let max_attempts = self.targets.len() * 2;
1100
1101 while attempts < max_attempts {
1102 attempts += 1;
1103
1104 trace!(
1105 upstream_id = %self.id,
1106 attempt = attempts,
1107 max_attempts = max_attempts,
1108 "Attempting to select peer"
1109 );
1110
1111 let selection = match self.load_balancer.select(context).await {
1112 Ok(s) => s,
1113 Err(e) => {
1114 warn!(
1115 upstream_id = %self.id,
1116 attempt = attempts,
1117 error = %e,
1118 "Load balancer selection failed"
1119 );
1120 continue;
1121 }
1122 };
1123
1124 trace!(
1125 upstream_id = %self.id,
1126 target = %selection.address,
1127 attempt = attempts,
1128 "Load balancer selected target"
1129 );
1130
1131 let breakers = self.circuit_breakers.read().await;
1133 if let Some(breaker) = breakers.get(&selection.address) {
1134 if !breaker.is_closed() {
1135 debug!(
1136 upstream_id = %self.id,
1137 target = %selection.address,
1138 attempt = attempts,
1139 "Circuit breaker is open, skipping target"
1140 );
1141 self.stats
1142 .circuit_breaker_trips
1143 .fetch_add(1, Ordering::Relaxed);
1144 continue;
1145 }
1146 }
1147
1148 trace!(
1150 upstream_id = %self.id,
1151 target = %selection.address,
1152 "Creating peer for upstream (Pingora handles connection reuse)"
1153 );
1154 let peer = self.create_peer(&selection)?;
1155
1156 debug!(
1157 upstream_id = %self.id,
1158 target = %selection.address,
1159 attempt = attempts,
1160 metadata_keys = ?selection.metadata.keys().collect::<Vec<_>>(),
1161 "Selected upstream peer with metadata"
1162 );
1163
1164 self.stats.successes.fetch_add(1, Ordering::Relaxed);
1165 return Ok((peer, selection.metadata));
1166 }
1167
1168 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1169 error!(
1170 upstream_id = %self.id,
1171 attempts = attempts,
1172 max_attempts = max_attempts,
1173 "Failed to select upstream after max attempts"
1174 );
1175 Err(GrapsusError::upstream(
1176 self.id.to_string(),
1177 "Failed to select upstream after max attempts",
1178 ))
1179 }
1180
1181 pub async fn select_peer(&self, context: Option<&RequestContext>) -> GrapsusResult<HttpPeer> {
1183 self.select_peer_with_metadata(context)
1185 .await
1186 .map(|(peer, _)| peer)
1187 }
1188
1189 fn create_peer(&self, selection: &TargetSelection) -> GrapsusResult<HttpPeer> {
1195 let sni_hostname = self.tls_sni.clone().unwrap_or_else(|| {
1197 selection
1199 .address
1200 .split(':')
1201 .next()
1202 .unwrap_or(&selection.address)
1203 .to_string()
1204 });
1205
1206 let resolved_address = selection
1209 .address
1210 .to_socket_addrs()
1211 .map_err(|e| {
1212 error!(
1213 upstream = %self.id,
1214 address = %selection.address,
1215 error = %e,
1216 "Failed to resolve upstream address"
1217 );
1218 GrapsusError::Upstream {
1219 upstream: self.id.to_string(),
1220 message: format!("DNS resolution failed for {}: {}", selection.address, e),
1221 retryable: true,
1222 source: None,
1223 }
1224 })?
1225 .next()
1226 .ok_or_else(|| {
1227 error!(
1228 upstream = %self.id,
1229 address = %selection.address,
1230 "No addresses returned from DNS resolution"
1231 );
1232 GrapsusError::Upstream {
1233 upstream: self.id.to_string(),
1234 message: format!("No addresses for {}", selection.address),
1235 retryable: true,
1236 source: None,
1237 }
1238 })?;
1239
1240 let mut peer = HttpPeer::new(resolved_address, self.tls_enabled, sni_hostname.clone());
1242
1243 peer.options.idle_timeout = Some(self.pool_config.idle_timeout);
1247
1248 peer.options.connection_timeout = Some(self.pool_config.connection_timeout);
1250 peer.options.total_connection_timeout = Some(Duration::from_secs(10));
1251
1252 peer.options.read_timeout = Some(self.pool_config.read_timeout);
1254 peer.options.write_timeout = Some(self.pool_config.write_timeout);
1255
1256 peer.options.tcp_keepalive = Some(pingora::protocols::TcpKeepalive {
1258 idle: Duration::from_secs(60),
1259 interval: Duration::from_secs(10),
1260 count: 3,
1261 #[cfg(target_os = "linux")]
1263 user_timeout: Duration::from_secs(60),
1264 });
1265
1266 if self.tls_enabled {
1268 let alpn = match (self.http_version.min_version, self.http_version.max_version) {
1270 (2, _) => {
1271 pingora::upstreams::peer::ALPN::H2
1273 }
1274 (1, 2) | (_, 2) => {
1275 pingora::upstreams::peer::ALPN::H2H1
1277 }
1278 _ => {
1279 pingora::upstreams::peer::ALPN::H1
1281 }
1282 };
1283 peer.options.alpn = alpn;
1284
1285 if let Some(ref tls_config) = self.tls_config {
1287 if tls_config.insecure_skip_verify {
1289 peer.options.verify_cert = false;
1290 peer.options.verify_hostname = false;
1291 warn!(
1292 upstream_id = %self.id,
1293 target = %selection.address,
1294 "TLS certificate verification DISABLED (insecure_skip_verify=true)"
1295 );
1296 }
1297
1298 if let Some(ref sni) = tls_config.sni {
1300 peer.options.alternative_cn = Some(sni.clone());
1301 trace!(
1302 upstream_id = %self.id,
1303 target = %selection.address,
1304 alternative_cn = %sni,
1305 "Set alternative CN for TLS verification"
1306 );
1307 }
1308
1309 if let (Some(cert_path), Some(key_path)) =
1311 (&tls_config.client_cert, &tls_config.client_key)
1312 {
1313 match crate::tls::load_client_cert_key(cert_path, key_path) {
1314 Ok(cert_key) => {
1315 peer.client_cert_key = Some(cert_key);
1316 info!(
1317 upstream_id = %self.id,
1318 target = %selection.address,
1319 cert_path = ?cert_path,
1320 "mTLS client certificate configured"
1321 );
1322 }
1323 Err(e) => {
1324 error!(
1325 upstream_id = %self.id,
1326 target = %selection.address,
1327 error = %e,
1328 "Failed to load mTLS client certificate"
1329 );
1330 return Err(GrapsusError::Tls {
1331 message: format!("Failed to load client certificate: {}", e),
1332 source: None,
1333 });
1334 }
1335 }
1336 }
1337 }
1338
1339 trace!(
1340 upstream_id = %self.id,
1341 target = %selection.address,
1342 alpn = ?peer.options.alpn,
1343 min_version = self.http_version.min_version,
1344 max_version = self.http_version.max_version,
1345 verify_cert = peer.options.verify_cert,
1346 verify_hostname = peer.options.verify_hostname,
1347 "Configured ALPN and TLS options for HTTP version negotiation"
1348 );
1349 }
1350
1351 if self.http_version.max_version >= 2 {
1353 if !self.http_version.h2_ping_interval.is_zero() {
1355 peer.options.h2_ping_interval = Some(self.http_version.h2_ping_interval);
1356 trace!(
1357 upstream_id = %self.id,
1358 target = %selection.address,
1359 h2_ping_interval_secs = self.http_version.h2_ping_interval.as_secs(),
1360 "Configured H2 ping interval"
1361 );
1362 }
1363 }
1364
1365 trace!(
1366 upstream_id = %self.id,
1367 target = %selection.address,
1368 tls = self.tls_enabled,
1369 sni = %sni_hostname,
1370 idle_timeout_secs = self.pool_config.idle_timeout.as_secs(),
1371 http_max_version = self.http_version.max_version,
1372 "Created peer with Pingora connection pooling enabled"
1373 );
1374
1375 Ok(peer)
1376 }
1377
1378 pub async fn report_result(&self, target: &str, success: bool) {
1385 trace!(
1386 upstream_id = %self.id,
1387 target = %target,
1388 success = success,
1389 "Reporting connection result"
1390 );
1391
1392 if success {
1393 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1394 breaker.record_success();
1395 trace!(
1396 upstream_id = %self.id,
1397 target = %target,
1398 "Recorded success in circuit breaker"
1399 );
1400 }
1401 self.load_balancer.report_health(target, true).await;
1402 } else {
1403 let breaker_opened =
1404 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1405 let opened = breaker.record_failure();
1406 debug!(
1407 upstream_id = %self.id,
1408 target = %target,
1409 circuit_breaker_opened = opened,
1410 "Recorded failure in circuit breaker"
1411 );
1412 opened
1413 } else {
1414 false
1415 };
1416
1417 if breaker_opened {
1422 self.load_balancer.report_health(target, false).await;
1423 }
1424
1425 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1426 warn!(
1427 upstream_id = %self.id,
1428 target = %target,
1429 circuit_breaker_opened = breaker_opened,
1430 "Connection failure reported for target"
1431 );
1432 }
1433 }
1434
1435 pub async fn report_result_with_latency(
1443 &self,
1444 target: &str,
1445 success: bool,
1446 latency: Option<Duration>,
1447 ) {
1448 trace!(
1449 upstream_id = %self.id,
1450 target = %target,
1451 success = success,
1452 latency_ms = latency.map(|l| l.as_millis() as u64),
1453 "Reporting result with latency for adaptive LB"
1454 );
1455
1456 if success {
1458 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1459 breaker.record_success();
1460 }
1461 self.load_balancer
1463 .report_result_with_latency(target, true, latency)
1464 .await;
1465 } else {
1466 let breaker_opened =
1467 if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
1468 breaker.record_failure()
1469 } else {
1470 false
1471 };
1472 self.stats.failures.fetch_add(1, Ordering::Relaxed);
1473
1474 if breaker_opened {
1479 self.load_balancer
1480 .report_result_with_latency(target, false, latency)
1481 .await;
1482 }
1483 }
1484 }
1485
1486 pub fn stats(&self) -> &PoolStats {
1488 &self.stats
1489 }
1490
1491 pub fn id(&self) -> &UpstreamId {
1493 &self.id
1494 }
1495
1496 pub fn target_count(&self) -> usize {
1498 self.targets.len()
1499 }
1500
1501 pub fn pool_config(&self) -> PoolConfigSnapshot {
1503 PoolConfigSnapshot {
1504 max_connections: self.pool_config.max_connections,
1505 max_idle: self.pool_config.max_idle,
1506 idle_timeout_secs: self.pool_config.idle_timeout.as_secs(),
1507 max_lifetime_secs: self.pool_config.max_lifetime.map(|d| d.as_secs()),
1508 connection_timeout_secs: self.pool_config.connection_timeout.as_secs(),
1509 read_timeout_secs: self.pool_config.read_timeout.as_secs(),
1510 write_timeout_secs: self.pool_config.write_timeout.as_secs(),
1511 }
1512 }
1513
1514 pub async fn has_healthy_targets(&self) -> bool {
1518 let healthy = self.load_balancer.healthy_targets().await;
1519 !healthy.is_empty()
1520 }
1521
1522 pub async fn select_shadow_target(
1527 &self,
1528 context: Option<&RequestContext>,
1529 ) -> GrapsusResult<ShadowTarget> {
1530 let selection = self.load_balancer.select(context).await?;
1532
1533 let breakers = self.circuit_breakers.read().await;
1535 if let Some(breaker) = breakers.get(&selection.address) {
1536 if !breaker.is_closed() {
1537 return Err(GrapsusError::upstream(
1538 self.id.to_string(),
1539 "Circuit breaker is open for shadow target",
1540 ));
1541 }
1542 }
1543
1544 let (host, port) = if selection.address.contains(':') {
1546 let parts: Vec<&str> = selection.address.rsplitn(2, ':').collect();
1547 if parts.len() == 2 {
1548 (
1549 parts[1].to_string(),
1550 parts[0]
1551 .parse::<u16>()
1552 .unwrap_or(if self.tls_enabled { 443 } else { 80 }),
1553 )
1554 } else {
1555 (
1556 selection.address.clone(),
1557 if self.tls_enabled { 443 } else { 80 },
1558 )
1559 }
1560 } else {
1561 (
1562 selection.address.clone(),
1563 if self.tls_enabled { 443 } else { 80 },
1564 )
1565 };
1566
1567 Ok(ShadowTarget {
1568 scheme: if self.tls_enabled { "https" } else { "http" }.to_string(),
1569 host,
1570 port,
1571 sni: self.tls_sni.clone(),
1572 })
1573 }
1574
1575 pub fn is_tls_enabled(&self) -> bool {
1577 self.tls_enabled
1578 }
1579
1580 pub async fn shutdown(&self) {
1584 info!(
1585 upstream_id = %self.id,
1586 target_count = self.targets.len(),
1587 total_requests = self.stats.requests.load(Ordering::Relaxed),
1588 total_successes = self.stats.successes.load(Ordering::Relaxed),
1589 total_failures = self.stats.failures.load(Ordering::Relaxed),
1590 "Shutting down upstream pool"
1591 );
1592 debug!(upstream_id = %self.id, "Upstream pool shutdown complete");
1594 }
1595}