sentinel_proxy/upstream/
mod.rs

1//! Upstream pool management module for Sentinel proxy
2//!
3//! This module handles upstream server pools, load balancing, health checking,
4//! connection pooling, and retry logic with circuit breakers.
5
6use async_trait::async_trait;
7use pingora::upstreams::peer::HttpPeer;
8use std::collections::HashMap;
9use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12use tokio::sync::RwLock;
13use tracing::{debug, info};
14
15use sentinel_common::{
16    errors::{SentinelError, SentinelResult},
17    types::{CircuitBreakerConfig, LoadBalancingAlgorithm, RetryPolicy},
18    CircuitBreaker, UpstreamId,
19};
20use sentinel_config::{HealthCheck as HealthCheckConfig, UpstreamConfig};
21
22// ============================================================================
23// Internal Upstream Target Type
24// ============================================================================
25
26/// Internal upstream target representation for load balancers
27///
28/// This is a simplified representation used internally by load balancers,
29/// separate from the user-facing config UpstreamTarget.
30#[derive(Debug, Clone)]
31pub struct UpstreamTarget {
32    /// Target IP address or hostname
33    pub address: String,
34    /// Target port
35    pub port: u16,
36    /// Weight for weighted load balancing
37    pub weight: u32,
38}
39
40impl UpstreamTarget {
41    /// Create a new upstream target
42    pub fn new(address: impl Into<String>, port: u16, weight: u32) -> Self {
43        Self {
44            address: address.into(),
45            port,
46            weight,
47        }
48    }
49
50    /// Create from a "host:port" string with default weight
51    pub fn from_address(addr: &str) -> Option<Self> {
52        let parts: Vec<&str> = addr.rsplitn(2, ':').collect();
53        if parts.len() == 2 {
54            let port = parts[0].parse().ok()?;
55            let address = parts[1].to_string();
56            Some(Self {
57                address,
58                port,
59                weight: 100,
60            })
61        } else {
62            None
63        }
64    }
65
66    /// Convert from config UpstreamTarget
67    pub fn from_config(config: &sentinel_config::UpstreamTarget) -> Option<Self> {
68        Self::from_address(&config.address).map(|mut t| {
69            t.weight = config.weight;
70            t
71        })
72    }
73
74    /// Get the full address string
75    pub fn full_address(&self) -> String {
76        format!("{}:{}", self.address, self.port)
77    }
78}
79
80// ============================================================================
81// Load Balancing
82// ============================================================================
83
84// Load balancing algorithm implementations
85pub mod adaptive;
86pub mod consistent_hash;
87pub mod p2c;
88
89// Re-export commonly used types from sub-modules
90pub use adaptive::{AdaptiveBalancer, AdaptiveConfig};
91pub use consistent_hash::{
92    ConsistentHashBalancer, ConsistentHashConfig,
93};
94pub use p2c::{P2cBalancer, P2cConfig};
95
96/// Request context for load balancer decisions
97#[derive(Debug, Clone)]
98pub struct RequestContext {
99    pub client_ip: Option<std::net::SocketAddr>,
100    pub headers: HashMap<String, String>,
101    pub path: String,
102    pub method: String,
103}
104
105/// Load balancer trait for different algorithms
106#[async_trait]
107pub trait LoadBalancer: Send + Sync {
108    /// Select next upstream target
109    async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection>;
110
111    /// Report target health status
112    async fn report_health(&self, address: &str, healthy: bool);
113
114    /// Get all healthy targets
115    async fn healthy_targets(&self) -> Vec<String>;
116
117    /// Release connection (for connection tracking)
118    async fn release(&self, _selection: &TargetSelection) {
119        // Default implementation - no-op
120    }
121
122    /// Report request result (for adaptive algorithms)
123    async fn report_result(
124        &self,
125        _selection: &TargetSelection,
126        _success: bool,
127        _latency: Option<Duration>,
128    ) {
129        // Default implementation - no-op
130    }
131}
132
133/// Selected upstream target
134#[derive(Debug, Clone)]
135pub struct TargetSelection {
136    /// Target address
137    pub address: String,
138    /// Target weight
139    pub weight: u32,
140    /// Target metadata
141    pub metadata: HashMap<String, String>,
142}
143
144/// Upstream pool managing multiple backend servers
145pub struct UpstreamPool {
146    /// Pool identifier
147    id: UpstreamId,
148    /// Configured targets
149    targets: Vec<UpstreamTarget>,
150    /// Load balancer implementation
151    load_balancer: Arc<dyn LoadBalancer>,
152    /// Health checker
153    health_checker: Option<Arc<UpstreamHealthChecker>>,
154    /// Connection pool
155    connection_pool: Arc<ConnectionPool>,
156    /// Circuit breakers per target
157    circuit_breakers: Arc<RwLock<HashMap<String, CircuitBreaker>>>,
158    /// Retry policy
159    retry_policy: Option<RetryPolicy>,
160    /// Pool statistics
161    stats: Arc<PoolStats>,
162}
163
164/// Health checker for upstream targets
165///
166/// Performs active health checking on upstream targets to determine
167/// their availability for load balancing.
168pub struct UpstreamHealthChecker {
169    /// Check configuration
170    config: HealthCheckConfig,
171    /// Health status per target
172    health_status: Arc<RwLock<HashMap<String, TargetHealthStatus>>>,
173    /// Check tasks handles
174    check_handles: Arc<RwLock<Vec<tokio::task::JoinHandle<()>>>>,
175}
176
177impl UpstreamHealthChecker {
178    /// Create a new health checker
179    pub fn new(config: HealthCheckConfig) -> Self {
180        Self {
181            config,
182            health_status: Arc::new(RwLock::new(HashMap::new())),
183            check_handles: Arc::new(RwLock::new(Vec::new())),
184        }
185    }
186}
187
188/// Health status for an upstream target
189#[derive(Debug, Clone)]
190struct TargetHealthStatus {
191    /// Is target healthy
192    healthy: bool,
193    /// Consecutive successes
194    consecutive_successes: u32,
195    /// Consecutive failures
196    consecutive_failures: u32,
197    /// Last check time
198    last_check: Instant,
199    /// Last successful check
200    last_success: Option<Instant>,
201    /// Last error message
202    last_error: Option<String>,
203}
204
205/// Connection pool for upstream connections
206pub struct ConnectionPool {
207    /// Pool configuration
208    max_connections: usize,
209    max_idle: usize,
210    idle_timeout: Duration,
211    max_lifetime: Option<Duration>,
212    /// Active connections per target
213    connections: Arc<RwLock<HashMap<String, Vec<PooledConnection>>>>,
214    /// Connection statistics
215    stats: Arc<ConnectionPoolStats>,
216}
217
218impl ConnectionPool {
219    /// Create a new connection pool
220    pub fn new(
221        max_connections: usize,
222        max_idle: usize,
223        idle_timeout: Duration,
224        max_lifetime: Option<Duration>,
225    ) -> Self {
226        Self {
227            max_connections,
228            max_idle,
229            idle_timeout,
230            max_lifetime,
231            connections: Arc::new(RwLock::new(HashMap::new())),
232            stats: Arc::new(ConnectionPoolStats::default()),
233        }
234    }
235
236    /// Acquire a connection from the pool
237    pub async fn acquire(&self, _address: &str) -> SentinelResult<Option<HttpPeer>> {
238        // TODO: Implement actual connection pooling logic
239        // For now, return None to always create new connections
240        Ok(None)
241    }
242
243    /// Close all connections in the pool
244    pub async fn close_all(&self) {
245        let mut connections = self.connections.write().await;
246        connections.clear();
247    }
248}
249
250/// Pooled connection wrapper
251struct PooledConnection {
252    /// The actual connection/peer
253    peer: HttpPeer,
254    /// Creation time
255    created: Instant,
256    /// Last used time
257    last_used: Instant,
258    /// Is currently in use
259    in_use: bool,
260}
261
262/// Connection pool statistics
263#[derive(Default)]
264struct ConnectionPoolStats {
265    /// Total connections created
266    created: AtomicU64,
267    /// Total connections reused
268    reused: AtomicU64,
269    /// Total connections closed
270    closed: AtomicU64,
271    /// Current active connections
272    active: AtomicU64,
273    /// Current idle connections
274    idle: AtomicU64,
275}
276
277// CircuitBreaker is imported from sentinel_common
278
279/// Pool statistics
280#[derive(Default)]
281pub struct PoolStats {
282    /// Total requests
283    pub requests: AtomicU64,
284    /// Successful requests
285    pub successes: AtomicU64,
286    /// Failed requests
287    pub failures: AtomicU64,
288    /// Retried requests
289    pub retries: AtomicU64,
290    /// Circuit breaker trips
291    pub circuit_breaker_trips: AtomicU64,
292}
293
294/// Round-robin load balancer
295struct RoundRobinBalancer {
296    targets: Vec<UpstreamTarget>,
297    current: AtomicUsize,
298    health_status: Arc<RwLock<HashMap<String, bool>>>,
299}
300
301impl RoundRobinBalancer {
302    fn new(targets: Vec<UpstreamTarget>) -> Self {
303        let mut health_status = HashMap::new();
304        for target in &targets {
305            health_status.insert(target.full_address(), true);
306        }
307
308        Self {
309            targets,
310            current: AtomicUsize::new(0),
311            health_status: Arc::new(RwLock::new(health_status)),
312        }
313    }
314}
315
316#[async_trait]
317impl LoadBalancer for RoundRobinBalancer {
318    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
319        let health = self.health_status.read().await;
320        let healthy_targets: Vec<_> = self
321            .targets
322            .iter()
323            .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
324            .collect();
325
326        if healthy_targets.is_empty() {
327            return Err(SentinelError::NoHealthyUpstream);
328        }
329
330        let index = self.current.fetch_add(1, Ordering::Relaxed) % healthy_targets.len();
331        let target = healthy_targets[index];
332
333        Ok(TargetSelection {
334            address: target.full_address(),
335            weight: target.weight,
336            metadata: HashMap::new(),
337        })
338    }
339
340    async fn report_health(&self, address: &str, healthy: bool) {
341        self.health_status
342            .write()
343            .await
344            .insert(address.to_string(), healthy);
345    }
346
347    async fn healthy_targets(&self) -> Vec<String> {
348        self.health_status
349            .read()
350            .await
351            .iter()
352            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
353            .collect()
354    }
355}
356
357/// Least connections load balancer
358struct LeastConnectionsBalancer {
359    targets: Vec<UpstreamTarget>,
360    connections: Arc<RwLock<HashMap<String, usize>>>,
361    health_status: Arc<RwLock<HashMap<String, bool>>>,
362}
363
364impl LeastConnectionsBalancer {
365    fn new(targets: Vec<UpstreamTarget>) -> Self {
366        let mut health_status = HashMap::new();
367        let mut connections = HashMap::new();
368
369        for target in &targets {
370            let addr = target.full_address();
371            health_status.insert(addr.clone(), true);
372            connections.insert(addr, 0);
373        }
374
375        Self {
376            targets,
377            connections: Arc::new(RwLock::new(connections)),
378            health_status: Arc::new(RwLock::new(health_status)),
379        }
380    }
381}
382
383#[async_trait]
384impl LoadBalancer for LeastConnectionsBalancer {
385    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
386        let health = self.health_status.read().await;
387        let conns = self.connections.read().await;
388
389        let mut best_target = None;
390        let mut min_connections = usize::MAX;
391
392        for target in &self.targets {
393            let addr = target.full_address();
394            if !*health.get(&addr).unwrap_or(&true) {
395                continue;
396            }
397
398            let conn_count = *conns.get(&addr).unwrap_or(&0);
399            if conn_count < min_connections {
400                min_connections = conn_count;
401                best_target = Some(target);
402            }
403        }
404
405        best_target
406            .map(|target| TargetSelection {
407                address: target.full_address(),
408                weight: target.weight,
409                metadata: HashMap::new(),
410            })
411            .ok_or(SentinelError::NoHealthyUpstream)
412    }
413
414    async fn report_health(&self, address: &str, healthy: bool) {
415        self.health_status
416            .write()
417            .await
418            .insert(address.to_string(), healthy);
419    }
420
421    async fn healthy_targets(&self) -> Vec<String> {
422        self.health_status
423            .read()
424            .await
425            .iter()
426            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
427            .collect()
428    }
429}
430
431/// Weighted load balancer
432struct WeightedBalancer {
433    targets: Vec<UpstreamTarget>,
434    weights: Vec<u32>,
435    current_index: AtomicUsize,
436    health_status: Arc<RwLock<HashMap<String, bool>>>,
437}
438
439#[async_trait]
440impl LoadBalancer for WeightedBalancer {
441    async fn select(&self, _context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
442        let health = self.health_status.read().await;
443        let healthy_indices: Vec<_> = self
444            .targets
445            .iter()
446            .enumerate()
447            .filter(|(_, t)| *health.get(&t.full_address()).unwrap_or(&true))
448            .map(|(i, _)| i)
449            .collect();
450
451        if healthy_indices.is_empty() {
452            return Err(SentinelError::NoHealthyUpstream);
453        }
454
455        let idx = self.current_index.fetch_add(1, Ordering::Relaxed) % healthy_indices.len();
456        let target_idx = healthy_indices[idx];
457        let target = &self.targets[target_idx];
458
459        Ok(TargetSelection {
460            address: target.full_address(),
461            weight: self.weights.get(target_idx).copied().unwrap_or(1),
462            metadata: HashMap::new(),
463        })
464    }
465
466    async fn report_health(&self, address: &str, healthy: bool) {
467        self.health_status
468            .write()
469            .await
470            .insert(address.to_string(), healthy);
471    }
472
473    async fn healthy_targets(&self) -> Vec<String> {
474        self.health_status
475            .read()
476            .await
477            .iter()
478            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
479            .collect()
480    }
481}
482
483/// IP hash load balancer
484struct IpHashBalancer {
485    targets: Vec<UpstreamTarget>,
486    health_status: Arc<RwLock<HashMap<String, bool>>>,
487}
488
489#[async_trait]
490impl LoadBalancer for IpHashBalancer {
491    async fn select(&self, context: Option<&RequestContext>) -> SentinelResult<TargetSelection> {
492        let health = self.health_status.read().await;
493        let healthy_targets: Vec<_> = self
494            .targets
495            .iter()
496            .filter(|t| *health.get(&t.full_address()).unwrap_or(&true))
497            .collect();
498
499        if healthy_targets.is_empty() {
500            return Err(SentinelError::NoHealthyUpstream);
501        }
502
503        // Hash the client IP to select a target
504        let hash = if let Some(ctx) = context {
505            if let Some(ip) = &ctx.client_ip {
506                use std::hash::{Hash, Hasher};
507                let mut hasher = std::collections::hash_map::DefaultHasher::new();
508                ip.hash(&mut hasher);
509                hasher.finish()
510            } else {
511                0
512            }
513        } else {
514            0
515        };
516
517        let idx = (hash as usize) % healthy_targets.len();
518        let target = healthy_targets[idx];
519
520        Ok(TargetSelection {
521            address: target.full_address(),
522            weight: target.weight,
523            metadata: HashMap::new(),
524        })
525    }
526
527    async fn report_health(&self, address: &str, healthy: bool) {
528        self.health_status
529            .write()
530            .await
531            .insert(address.to_string(), healthy);
532    }
533
534    async fn healthy_targets(&self) -> Vec<String> {
535        self.health_status
536            .read()
537            .await
538            .iter()
539            .filter_map(|(addr, &healthy)| if healthy { Some(addr.clone()) } else { None })
540            .collect()
541    }
542}
543
544impl UpstreamPool {
545    /// Create new upstream pool from configuration
546    pub async fn new(config: UpstreamConfig) -> SentinelResult<Self> {
547        let id = UpstreamId::new(&config.id);
548
549        // Convert config targets to internal targets
550        let targets: Vec<UpstreamTarget> = config
551            .targets
552            .iter()
553            .filter_map(|t| UpstreamTarget::from_config(t))
554            .collect();
555
556        if targets.is_empty() {
557            return Err(SentinelError::Config {
558                message: "No valid upstream targets".to_string(),
559                source: None,
560            });
561        }
562
563        // Create load balancer
564        let load_balancer = Self::create_load_balancer(&config.load_balancing, &targets)?;
565
566        // Create health checker if configured
567        let health_checker = config
568            .health_check
569            .as_ref()
570            .map(|hc_config| Arc::new(UpstreamHealthChecker::new(hc_config.clone())));
571
572        // Create connection pool
573        let connection_pool = Arc::new(ConnectionPool::new(
574            config.connection_pool.max_connections,
575            config.connection_pool.max_idle,
576            Duration::from_secs(config.connection_pool.idle_timeout_secs),
577            config
578                .connection_pool
579                .max_lifetime_secs
580                .map(Duration::from_secs),
581        ));
582
583        // Initialize circuit breakers for each target
584        let mut circuit_breakers = HashMap::new();
585        for target in &targets {
586            circuit_breakers.insert(
587                target.full_address(),
588                CircuitBreaker::new(CircuitBreakerConfig::default()),
589            );
590        }
591
592        let pool = Self {
593            id,
594            targets,
595            load_balancer,
596            health_checker,
597            connection_pool,
598            circuit_breakers: Arc::new(RwLock::new(circuit_breakers)),
599            retry_policy: None,
600            stats: Arc::new(PoolStats::default()),
601        };
602
603        Ok(pool)
604    }
605
606    /// Create load balancer based on algorithm
607    fn create_load_balancer(
608        algorithm: &LoadBalancingAlgorithm,
609        targets: &[UpstreamTarget],
610    ) -> SentinelResult<Arc<dyn LoadBalancer>> {
611        let balancer: Arc<dyn LoadBalancer> = match algorithm {
612            LoadBalancingAlgorithm::RoundRobin => {
613                Arc::new(RoundRobinBalancer::new(targets.to_vec()))
614            }
615            LoadBalancingAlgorithm::LeastConnections => {
616                Arc::new(LeastConnectionsBalancer::new(targets.to_vec()))
617            }
618            LoadBalancingAlgorithm::Weighted => {
619                let weights: Vec<u32> = targets.iter().map(|t| t.weight).collect();
620                Arc::new(WeightedBalancer {
621                    targets: targets.to_vec(),
622                    weights,
623                    current_index: AtomicUsize::new(0),
624                    health_status: Arc::new(RwLock::new(HashMap::new())),
625                })
626            }
627            LoadBalancingAlgorithm::IpHash => Arc::new(IpHashBalancer {
628                targets: targets.to_vec(),
629                health_status: Arc::new(RwLock::new(HashMap::new())),
630            }),
631            LoadBalancingAlgorithm::Random => {
632                Arc::new(RoundRobinBalancer::new(targets.to_vec()))
633            }
634            LoadBalancingAlgorithm::ConsistentHash => Arc::new(ConsistentHashBalancer::new(
635                targets.to_vec(),
636                ConsistentHashConfig::default(),
637            )),
638            LoadBalancingAlgorithm::PowerOfTwoChoices => {
639                Arc::new(P2cBalancer::new(targets.to_vec(), P2cConfig::default()))
640            }
641            LoadBalancingAlgorithm::Adaptive => Arc::new(AdaptiveBalancer::new(
642                targets.to_vec(),
643                AdaptiveConfig::default(),
644            )),
645        };
646        Ok(balancer)
647    }
648
649    /// Select next upstream peer
650    pub async fn select_peer(&self, context: Option<&RequestContext>) -> SentinelResult<HttpPeer> {
651        self.stats.requests.fetch_add(1, Ordering::Relaxed);
652
653        let mut attempts = 0;
654        let max_attempts = self.targets.len() * 2;
655
656        while attempts < max_attempts {
657            attempts += 1;
658
659            let selection = self.load_balancer.select(context).await?;
660
661            // Check circuit breaker
662            let breakers = self.circuit_breakers.read().await;
663            if let Some(breaker) = breakers.get(&selection.address) {
664                if !breaker.is_closed().await {
665                    debug!(
666                        target = %selection.address,
667                        "Circuit breaker is open, skipping target"
668                    );
669                    continue;
670                }
671            }
672
673            // Try to get connection from pool
674            if let Some(peer) = self.connection_pool.acquire(&selection.address).await? {
675                debug!(target = %selection.address, "Reusing pooled connection");
676                return Ok(peer);
677            }
678
679            // Create new connection
680            debug!(target = %selection.address, "Creating new connection");
681            let peer = self.create_peer(&selection)?;
682
683            self.stats.successes.fetch_add(1, Ordering::Relaxed);
684            return Ok(peer);
685        }
686
687        self.stats.failures.fetch_add(1, Ordering::Relaxed);
688        Err(SentinelError::upstream(
689            &self.id.to_string(),
690            "Failed to select upstream after max attempts",
691        ))
692    }
693
694    /// Create new peer connection
695    fn create_peer(&self, selection: &TargetSelection) -> SentinelResult<HttpPeer> {
696        let peer = HttpPeer::new(
697            &selection.address,
698            false,
699            String::new(),
700        );
701        Ok(peer)
702    }
703
704    /// Report connection result for a target
705    pub async fn report_result(&self, target: &str, success: bool) {
706        if success {
707            if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
708                breaker.record_success().await;
709            }
710            self.load_balancer.report_health(target, true).await;
711        } else {
712            if let Some(breaker) = self.circuit_breakers.read().await.get(target) {
713                breaker.record_failure().await;
714            }
715            self.load_balancer.report_health(target, false).await;
716            self.stats.failures.fetch_add(1, Ordering::Relaxed);
717        }
718    }
719
720    /// Get pool statistics
721    pub fn stats(&self) -> &PoolStats {
722        &self.stats
723    }
724
725    /// Shutdown the pool
726    pub async fn shutdown(&self) {
727        info!("Shutting down upstream pool: {}", self.id);
728        self.connection_pool.close_all().await;
729    }
730}