Skip to main content

pingap_upstream/
upstream.rs

1// Copyright 2024-2025 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::Error;
16use crate::backend_circuit_state::{
17    BackendCircuitStates, CircuitBreakerConfig,
18};
19use crate::backend_stats::{BackendStats, WindowStats};
20use crate::hash_strategy::HashStrategy;
21use crate::peer_tracer::UpstreamPeerTracer;
22use crate::{LOG_TARGET, UpstreamProvider, Upstreams};
23use ahash::AHashMap;
24use arc_swap::ArcSwap;
25use async_trait::async_trait;
26use derive_more::Debug;
27use futures_util::FutureExt;
28use http::StatusCode;
29use pingap_config::Hashable;
30use pingap_config::UpstreamConf;
31use pingap_core::UpstreamInstance;
32use pingap_core::{
33    BackgroundTask, BackgroundTaskService, Error as ServiceError,
34};
35use pingap_core::{NotificationData, NotificationLevel, NotificationSender};
36use pingap_discovery::{
37    Discovery, TRANSPARENT_DISCOVERY, is_dns_discovery, is_docker_discovery,
38    is_static_discovery, new_dns_discover_backends,
39    new_docker_discover_backends, new_static_discovery,
40};
41use pingap_health::new_health_check;
42use pingora::lb::Backend;
43use pingora::lb::health_check::{HealthObserve, HealthObserveCallback};
44use pingora::lb::selection::{
45    BackendIter, BackendSelection, Consistent, RoundRobin,
46};
47use pingora::lb::{Backends, LoadBalancer};
48use pingora::protocols::ALPN;
49use pingora::protocols::l4::ext::TcpKeepalive;
50use pingora::proxy::Session;
51use pingora::upstreams::peer::{HttpPeer, Tracer};
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use std::sync::Arc;
55use std::sync::atomic::{AtomicI32, Ordering};
56use std::time::{Duration, Instant};
57use tracing::{debug, error, info};
58
59type Result<T, E = Error> = std::result::Result<T, E>;
60
61pub struct BackendObserveNotification {
62    name: String,
63    sender: Arc<NotificationSender>,
64}
65
66impl BackendObserveNotification {
67    pub fn new(name: String, sender: Arc<NotificationSender>) -> Self {
68        Self { name, sender }
69    }
70}
71
72#[async_trait]
73impl HealthObserve for BackendObserveNotification {
74    async fn observe(&self, backend: &Backend, healthy: bool) {
75        let addr = backend.addr.to_string();
76        let template = format!("upstream {}({addr}) becomes ", self.name);
77        let info = if healthy {
78            (NotificationLevel::Info, template + "healthy")
79        } else {
80            (NotificationLevel::Error, template + "unhealthy")
81        };
82
83        self.sender
84            .notify(NotificationData {
85                category: "backend_status".to_string(),
86                level: info.0,
87                title: "Upstream backend status changed".to_string(),
88                message: info.1,
89            })
90            .await;
91    }
92}
93
94// SelectionLb represents different load balancing strategies:
95// - RoundRobin: Distributes requests evenly across backends
96// - Consistent: Uses consistent hashing to map requests to backends
97// - Transparent: Passes requests through without load balancing
98enum SelectionLb {
99    RoundRobin(LoadBalancer<RoundRobin>),
100    Consistent {
101        lb: LoadBalancer<Consistent>,
102        hash: HashStrategy,
103    },
104    Transparent,
105}
106
107impl SelectionLb {
108    fn get_health_frequency(&self) -> (u64, u64) {
109        match self {
110            SelectionLb::RoundRobin(lb) => (
111                lb.update_frequency.unwrap_or_default().as_secs(),
112                lb.health_check_frequency.unwrap_or_default().as_secs(),
113            ),
114            SelectionLb::Consistent { lb, .. } => (
115                lb.update_frequency.unwrap_or_default().as_secs(),
116                lb.health_check_frequency.unwrap_or_default().as_secs(),
117            ),
118            SelectionLb::Transparent => (0, 0),
119        }
120    }
121    async fn update(&self) -> pingora::Result<()> {
122        match self {
123            SelectionLb::RoundRobin(lb) => lb.update().await,
124            SelectionLb::Consistent { lb, .. } => lb.update().await,
125            SelectionLb::Transparent => Ok(()),
126        }
127    }
128    async fn run_health_check(&self) {
129        match self {
130            SelectionLb::RoundRobin(lb) => {
131                lb.backends()
132                    .run_health_check(lb.parallel_health_check)
133                    .await
134            },
135            SelectionLb::Consistent { lb, .. } => {
136                lb.backends()
137                    .run_health_check(lb.parallel_health_check)
138                    .await
139            },
140            SelectionLb::Transparent => (),
141        }
142    }
143}
144
145#[derive(Debug)]
146/// Represents a group of backend servers and their configuration for load balancing and connection management
147pub struct Upstream {
148    /// Unique identifier for this upstream group
149    pub name: Arc<str>,
150
151    /// Hash key used to detect configuration changes
152    pub key: String,
153
154    /// Whether to use TLS for connections to backend servers
155    tls: bool,
156
157    /// Server Name Indication value for TLS connections
158    /// Special value "$host" means use the request's Host header
159    sni: String,
160
161    /// Load balancing strategy implementation:
162    /// - RoundRobin: Distributes requests evenly
163    /// - Consistent: Uses consistent hashing
164    /// - Transparent: Direct passthrough
165    #[debug("lb")]
166    lb: SelectionLb,
167
168    /// Maximum time to wait for establishing a connection
169    connection_timeout: Option<Duration>,
170
171    /// Maximum time for the entire connection lifecycle
172    total_connection_timeout: Option<Duration>,
173
174    /// Maximum time to wait for reading data
175    read_timeout: Option<Duration>,
176
177    /// Maximum time a connection can be idle before being closed
178    idle_timeout: Option<Duration>,
179
180    /// Maximum time to wait for writing data
181    write_timeout: Option<Duration>,
182
183    /// Whether to verify TLS certificates from backend servers
184    verify_cert: Option<bool>,
185
186    /// Application Layer Protocol Negotiation settings (H1, H2, H2H1)
187    alpn: ALPN,
188
189    /// TCP keepalive configuration for maintaining persistent connections
190    tcp_keepalive: Option<TcpKeepalive>,
191
192    /// Size of TCP receive buffer in bytes
193    tcp_recv_buf: Option<usize>,
194
195    /// Whether to enable TCP Fast Open for reduced connection latency
196    tcp_fast_open: Option<bool>,
197
198    /// Tracer for monitoring active connections to this upstream
199    peer_tracer: Option<UpstreamPeerTracer>,
200
201    /// Generic tracer interface for connection monitoring
202    tracer: Option<Tracer>,
203
204    /// Counter for number of requests currently being processed by this upstream
205    processing: AtomicI32,
206
207    /// Backend stats, success and fail count
208    #[debug("backend_stats")]
209    backend_stats: Option<BackendStats>,
210
211    /// Circuit breaker states
212    #[debug("circuit_breaker_states")]
213    circuit_breaker_states: Option<BackendCircuitStates>,
214}
215
216// Creates new backend servers based on discovery method (DNS/Docker/Static)
217fn new_backends(
218    discovery_category: &str,
219    discovery: &Discovery,
220) -> Result<Backends> {
221    let (result, category) = match discovery_category {
222        d if is_dns_discovery(d) => {
223            (new_dns_discover_backends(discovery), "dns_discovery")
224        },
225        d if is_docker_discovery(d) => {
226            (new_docker_discover_backends(discovery), "docker_discovery")
227        },
228        _ => (new_static_discovery(discovery), "static_discovery"),
229    };
230    result.map_err(|e| Error::Common {
231        category: category.to_string(),
232        message: e.to_string(),
233    })
234}
235
236fn update_health_check_params<S>(
237    mut lb: LoadBalancer<S>,
238    name: &str,
239    conf: &UpstreamConf,
240    sender: Option<Arc<NotificationSender>>,
241) -> Result<LoadBalancer<S>>
242where
243    S: BackendSelection + 'static,
244    S::Iter: BackendIter,
245{
246    let mut update_frequency = if let Some(value) = conf.update_frequency {
247        Some(value)
248    } else {
249        Some(Duration::from_secs(60))
250    };
251    // For static discovery, perform immediate backend update
252    if is_static_discovery(&conf.guess_discovery()) {
253        update_frequency = None;
254        lb.update()
255            .now_or_never()
256            .expect("static should not block")
257            .expect("static should not error");
258    }
259
260    let observe: Option<HealthObserveCallback> = if let Some(sender) = sender {
261        Some(Box::new(BackendObserveNotification::new(
262            name.to_string(),
263            sender.clone(),
264        )))
265    } else {
266        None
267    };
268
269    // Set up health checking for the backends
270    let (health_check_conf, hc) = new_health_check(
271        name,
272        &conf.health_check.clone().unwrap_or_default(),
273        observe,
274    )
275    .map_err(|e| Error::Common {
276        message: e.to_string(),
277        category: "health".to_string(),
278    })?;
279    // Configure health checking
280    lb.parallel_health_check = health_check_conf.parallel_check;
281    lb.set_health_check(hc);
282    lb.update_frequency = update_frequency;
283    lb.health_check_frequency = Some(health_check_conf.check_frequency);
284    Ok(lb)
285}
286
287/// Creates a new load balancer instance based on the provided configuration
288///
289/// # Arguments
290/// * `name` - Name identifier for the upstream service
291/// * `conf` - Configuration for the upstream service
292///
293/// # Returns
294/// * `Result<SelectionLb>` - Returns the load balancer
295fn new_load_balancer(
296    name: &str,
297    conf: &UpstreamConf,
298    sender: Option<Arc<NotificationSender>>,
299) -> Result<SelectionLb> {
300    // Validate that addresses are provided
301    if conf.addrs.is_empty() {
302        return Err(Error::Common {
303            category: "new_upstream".to_string(),
304            message: "upstream addrs is empty".to_string(),
305        });
306    }
307
308    // Determine the service discovery method
309    let discovery_category = conf.guess_discovery();
310    // For transparent discovery, return early with no load balancing
311    if discovery_category == TRANSPARENT_DISCOVERY {
312        return Ok(SelectionLb::Transparent);
313    }
314
315    // Determine if TLS should be enabled based on SNI configuration
316    let tls = conf
317        .sni
318        .as_ref()
319        .map(|item| !item.is_empty())
320        .unwrap_or_default();
321
322    // Create backend servers using the configured addresses and discovery method
323    let mut discovery = Discovery::new(conf.addrs.clone())
324        .with_ipv4_only(conf.ipv4_only.unwrap_or_default())
325        .with_tls(tls)
326        .with_sender(sender.clone());
327    if let Some(dns_server) = &conf.dns_server {
328        discovery = discovery.with_dns_server(dns_server.clone());
329    }
330    if let Some(dns_domain) = &conf.dns_domain {
331        discovery = discovery.with_domain(dns_domain.clone());
332    }
333    if let Some(dns_search) = &conf.dns_search {
334        discovery = discovery.with_search(dns_search.clone());
335    }
336    let backends = new_backends(&discovery_category, &discovery)?;
337
338    // Parse the load balancing algorithm configuration
339    // Format: "algo:hash_type:hash_key" (e.g. "hash:cookie:session_id")
340    // let algo_method = conf.algo.clone().unwrap_or_default();
341    let algo_method = conf.algo.as_deref().unwrap_or("round_robin");
342
343    let parts: Vec<&str> = algo_method.split(':').collect();
344    if parts.first() == Some(&"hash") && parts.len() >= 2 {
345        let hash_type = parts[1];
346        let hash_key = parts.get(2).copied().unwrap_or_default();
347        let lb = update_health_check_params(
348            LoadBalancer::<Consistent>::from_backends(backends),
349            name,
350            conf,
351            sender,
352        )?;
353        Ok(SelectionLb::Consistent {
354            lb,
355            hash: HashStrategy::from((hash_type, hash_key)),
356        })
357    } else {
358        // Default to RoundRobin
359        let lb = update_health_check_params(
360            LoadBalancer::<RoundRobin>::from_backends(backends),
361            name,
362            conf,
363            sender,
364        )?;
365        Ok(SelectionLb::RoundRobin(lb))
366    }
367}
368
369#[derive(Debug, Clone, Default)]
370pub struct UpstreamStats {
371    pub processing: i32,
372    pub connected: Option<i32>,
373    pub backend_stats: HashMap<String, WindowStats>,
374}
375
376impl Upstream {
377    /// Creates a new Upstream instance from the provided configuration
378    ///
379    /// # Arguments
380    /// * `name` - Name identifier for the upstream service
381    /// * `conf` - Configuration parameters for the upstream service
382    ///
383    /// # Returns
384    /// * `Result<Self>` - New Upstream instance or error if creation fails
385    pub fn new(
386        name: &str,
387        conf: &UpstreamConf,
388        sender: Option<Arc<NotificationSender>>,
389    ) -> Result<Self> {
390        let lb = new_load_balancer(name, conf, sender)?;
391        let key = conf.hash_key();
392        let sni = conf.sni.clone().unwrap_or_default();
393        let tls = !sni.is_empty();
394
395        let alpn = if let Some(alpn) = &conf.alpn {
396            match alpn.to_uppercase().as_str() {
397                "H2H1" => ALPN::H2H1,
398                "H2" => ALPN::H2,
399                _ => ALPN::H1,
400            }
401        } else {
402            ALPN::H1
403        };
404
405        let tcp_keepalive = if (conf.tcp_idle.is_some()
406            && conf.tcp_probe_count.is_some()
407            && conf.tcp_interval.is_some())
408            || conf.tcp_user_timeout.is_some()
409        {
410            Some(TcpKeepalive {
411                idle: conf.tcp_idle.unwrap_or_default(),
412                count: conf.tcp_probe_count.unwrap_or_default(),
413                interval: conf.tcp_interval.unwrap_or_default(),
414                #[cfg(target_os = "linux")]
415                user_timeout: conf.tcp_user_timeout.unwrap_or_default(),
416            })
417        } else {
418            None
419        };
420
421        let peer_tracer = if conf.enable_tracer.unwrap_or_default() {
422            Some(UpstreamPeerTracer::new(name))
423        } else {
424            None
425        };
426        let failure_status_codes = conf
427            .backend_failure_status_code
428            .clone()
429            .unwrap_or_default()
430            .split(",")
431            .flat_map(|code| code.trim().parse::<u16>().ok())
432            .collect::<Vec<u16>>();
433        let tracer = peer_tracer
434            .as_ref()
435            .map(|peer_tracer| Tracer(Box::new(peer_tracer.to_owned())));
436        let circuit_break_max_consecutive_failures = conf
437            .circuit_break_max_consecutive_failures
438            .unwrap_or_default();
439        let circuit_break_max_failure_percent =
440            conf.circuit_break_max_failure_percent.unwrap_or_default();
441        let circuit_breaker_states = if circuit_break_max_consecutive_failures
442            > 0
443            || circuit_break_max_failure_percent > 0
444        {
445            Some(BackendCircuitStates::new(CircuitBreakerConfig {
446                max_consecutive_failures:
447                    circuit_break_max_consecutive_failures,
448                max_failure_percent: circuit_break_max_failure_percent as f64,
449                min_requests_threshold: conf
450                    .circuit_break_min_requests_threshold
451                    .unwrap_or(10),
452                half_open_consecutive_success_threshold: conf
453                    .circuit_break_half_open_consecutive_success_threshold
454                    .unwrap_or(5),
455                open_duration: conf
456                    .circuit_break_open_duration
457                    .unwrap_or(Duration::from_secs(10)),
458            }))
459        } else {
460            None
461        };
462
463        let up = Self {
464            name: name.into(),
465            key,
466            tls,
467            sni,
468            lb,
469            alpn,
470            connection_timeout: conf.connection_timeout,
471            total_connection_timeout: conf.total_connection_timeout,
472            read_timeout: conf.read_timeout,
473            idle_timeout: conf.idle_timeout.or(Some(Duration::from_secs(60))),
474            write_timeout: conf.write_timeout,
475            verify_cert: conf.verify_cert,
476            tcp_recv_buf: conf.tcp_recv_buf.map(|item| item.as_u64() as usize),
477            tcp_keepalive,
478            tcp_fast_open: conf.tcp_fast_open,
479            peer_tracer,
480            tracer,
481            processing: AtomicI32::new(0),
482            backend_stats: if conf.enable_backend_stats.unwrap_or_default() {
483                Some(BackendStats::new(
484                    conf.backend_stats_interval
485                        .unwrap_or_else(|| Duration::from_secs(60)),
486                    failure_status_codes,
487                ))
488            } else {
489                None
490            },
491            circuit_breaker_states,
492        };
493        debug!(
494            target: LOG_TARGET,
495            name = up.name.as_ref(),
496            "new upstream: {up:?}"
497        );
498        Ok(up)
499    }
500
501    #[inline]
502    fn accept_backend(&self, backend: &Backend, healthy: bool) -> bool {
503        // 1. If the backend is marked unhealthy by the health check, always reject it.
504        if !healthy {
505            return false;
506        }
507        // if circuit breaking is not configured, accept any healthy backend.
508        let Some(states) = &self.circuit_breaker_states else {
509            return true;
510        };
511        states.is_backend_acceptable(&backend.addr.to_string())
512    }
513
514    /// Creates and configures a new HTTP peer for handling requests
515    ///
516    /// # Arguments
517    /// * `session` - Current HTTP session containing request details
518    /// * `ctx` - Request context state
519    ///
520    /// # Returns
521    /// * `Option<HttpPeer>` - Configured HTTP peer if a healthy backend is available, None otherwise
522    ///
523    /// This method:
524    /// 1. Selects an appropriate backend using the configured load balancing strategy
525    /// 2. Increments the processing counter
526    /// 3. Creates and configures an HttpPeer with the connection settings
527    #[inline]
528    pub fn new_http_peer(
529        &self,
530        session: &Session,
531        client_ip: &Option<String>,
532    ) -> Option<HttpPeer> {
533        // Select a backend based on the load balancing strategy
534        let upstream = match &self.lb {
535            // For round-robin, use empty key since selection is sequential
536            SelectionLb::RoundRobin(lb) => {
537                lb.select_with(b"", 4, |backend, healthy| {
538                    self.accept_backend(backend, healthy)
539                })
540            },
541            // For consistent hashing, generate hash value from request details
542            SelectionLb::Consistent { lb, hash } => {
543                let value = hash.get_value(session, client_ip);
544                lb.select_with(value.as_bytes(), 4, |backend, healthy| {
545                    self.accept_backend(backend, healthy)
546                })
547            },
548            // For transparent mode, no backend selection needed
549            SelectionLb::Transparent => None,
550        };
551        // Increment counter for requests being processed
552        self.processing.fetch_add(1, Ordering::Relaxed);
553
554        // Create HTTP peer based on load balancing mode
555        let p = if matches!(self.lb, SelectionLb::Transparent) {
556            // In transparent mode, use the request's host header
557            let host = pingap_core::get_host(session.req_header())?;
558            // Set SNI: either use host header ($host) or configured value
559            let sni = if self.sni == "$host" {
560                host.to_string()
561            } else {
562                self.sni.clone()
563            };
564            // use default port for transparent http/https
565            let port = if self.tls { 443 } else { 80 };
566            // Create peer with host:port, TLS settings, and SNI
567            Some(HttpPeer::new(format!("{host}:{port}"), self.tls, sni))
568        } else {
569            // For load balanced modes, create peer from selected backend
570            upstream.map(|upstream| {
571                HttpPeer::new(upstream, self.tls, self.sni.clone())
572            })
573        };
574
575        // Configure connection options for the peer
576        p.map(|mut p| {
577            // Set various timeout values
578            p.options.connection_timeout = self.connection_timeout;
579            p.options.total_connection_timeout = self.total_connection_timeout;
580            p.options.read_timeout = self.read_timeout;
581            p.options.idle_timeout = self.idle_timeout;
582            p.options.write_timeout = self.write_timeout;
583            // Configure TLS certificate verification if specified
584            if let Some(verify_cert) = self.verify_cert {
585                p.options.verify_cert = verify_cert;
586            }
587            // Set protocol negotiation settings
588            p.options.alpn = self.alpn.clone();
589            // Configure TCP-specific options
590            p.options.tcp_keepalive.clone_from(&self.tcp_keepalive);
591            p.options.tcp_recv_buf = self.tcp_recv_buf;
592            if let Some(tcp_fast_open) = self.tcp_fast_open {
593                p.options.tcp_fast_open = tcp_fast_open;
594            }
595            // Set connection tracing if enabled
596            p.options.tracer.clone_from(&self.tracer);
597            p
598        })
599    }
600
601    /// Returns the backends of the upstream
602    ///
603    /// # Returns
604    /// * `Option<&Backends>` - The backends of the upstream if available, None otherwise
605    #[inline]
606    pub fn get_backends(&self) -> Option<&Backends> {
607        match &self.lb {
608            SelectionLb::RoundRobin(lb) => Some(lb.backends()),
609            SelectionLb::Consistent { lb, .. } => Some(lb.backends()),
610            SelectionLb::Transparent => None,
611        }
612    }
613
614    pub async fn run_health_check(&self) -> Result<()> {
615        self.lb.update().await.map_err(|e| Error::Common {
616            category: "run_health_check".to_string(),
617            message: e.to_string(),
618        })?;
619        self.lb.run_health_check().await;
620
621        Ok(())
622    }
623    pub fn is_transparent(&self) -> bool {
624        matches!(self.lb, SelectionLb::Transparent)
625    }
626    pub fn connected(&self) -> Option<i32> {
627        self.peer_tracer.as_ref().map(|tracer| tracer.connected())
628    }
629
630    pub fn stats(&self) -> UpstreamStats {
631        let Some(backends) = self.get_backends() else {
632            return UpstreamStats::default();
633        };
634        UpstreamStats {
635            processing: self.processing.load(Ordering::Relaxed),
636            connected: self
637                .peer_tracer
638                .as_ref()
639                .map(|tracer| tracer.connected()),
640            backend_stats: self
641                .backend_stats
642                .as_ref()
643                .map(|backend_stats| backend_stats.get_all_stats(backends))
644                .unwrap_or_default(),
645        }
646    }
647}
648
649impl UpstreamInstance for Upstream {
650    /// Decrements and returns the number of requests being processed
651    ///
652    /// # Returns
653    /// * `i32` - Previous count of requests being processed
654    fn completed(&self) -> i32 {
655        self.processing.fetch_add(-1, Ordering::Relaxed)
656    }
657    fn on_transport_failure(&self, address: &str) {
658        let Some(backend_stats) = &self.backend_stats else {
659            return;
660        };
661        debug!(target: LOG_TARGET, address, "on_transport_failure");
662        backend_stats.on_transport_failure(address);
663        if let Some(circuit_breaker_states) = &self.circuit_breaker_states {
664            circuit_breaker_states.update_state_after_request(
665                address,
666                true,
667                backend_stats,
668            );
669        }
670    }
671    fn on_response(&self, address: &str, status: StatusCode) {
672        let Some(backend_stats) = &self.backend_stats else {
673            return;
674        };
675        debug!(target: LOG_TARGET, address, status = status.to_string(), "on_response");
676        let is_request_failure = backend_stats.on_response(address, status);
677        if let Some(circuit_breaker_states) = &self.circuit_breaker_states {
678            circuit_breaker_states.update_state_after_request(
679                address,
680                is_request_failure,
681                backend_stats,
682            );
683        }
684    }
685}
686
687#[derive(Debug, Clone, Serialize, Deserialize)]
688pub struct UpstreamHealthyStatus {
689    pub healthy: u32,
690    pub total: u32,
691    pub unhealthy_backends: Vec<String>,
692}
693
694pub fn new_ahash_upstreams(
695    upstream_configs: &HashMap<String, UpstreamConf>,
696    upstream_provider: Arc<dyn UpstreamProvider>,
697    sender: Option<Arc<NotificationSender>>,
698) -> Result<(Upstreams, Vec<String>)> {
699    let mut upstreams = AHashMap::new();
700    let mut updated_upstreams = vec![];
701    for (name, conf) in upstream_configs.iter() {
702        let key = conf.hash_key();
703        if let Some(found) = upstream_provider.get(name) {
704            // not modified
705            if found.key == key {
706                upstreams.insert(name.to_string(), found);
707                continue;
708            }
709        }
710        let up = Arc::new(Upstream::new(name, conf, sender.clone())?);
711        upstreams.insert(name.to_string(), up);
712        updated_upstreams.push(name.to_string());
713    }
714    Ok((upstreams, updated_upstreams))
715}
716
717#[async_trait]
718impl BackgroundTask for HealthCheckTask {
719    async fn execute(&self, check_count: u32) -> Result<bool, ServiceError> {
720        // get upstream names
721        let mut upstreams = self.upstream_provider.list();
722        upstreams.retain(|(_, up)| !up.is_transparent());
723        let interval = self.interval.as_secs();
724        // run health check for each upstream
725        let jobs = upstreams.into_iter().map(|(name, up)| {
726            let runtime = pingora_runtime::current_handle();
727            runtime.spawn(async move {
728                let check_frequency_matched = |frequency: u64| -> bool {
729                    let mut count = (frequency / interval) as u32;
730                    if !frequency.is_multiple_of(interval) {
731                        count += 1;
732                    }
733                    check_count.is_multiple_of(count)
734                };
735
736                // get update frequency(update service)
737                // and health check frequency
738                let (update_frequency, health_check_frequency) =
739                    up.lb.get_health_frequency();
740
741                // the first time should match
742                // update check
743                if check_count == 0
744                    || (update_frequency > 0
745                        && check_frequency_matched(update_frequency))
746                {
747                    let update_backend_start_time = Instant::now();
748                    let result = up.lb.update().await;
749                    if let Err(e) = result {
750                        error!(
751                            target: LOG_TARGET,
752                            error = %e,
753                            name,
754                            "update backends fail"
755                        )
756                    } else {
757                        info!(
758                            target: LOG_TARGET,
759                            name,
760                            elapsed = format!(
761                                "{}ms",
762                                update_backend_start_time.elapsed().as_millis()
763                            ),
764                            "update backend success"
765                        );
766                    }
767                }
768
769                // health check
770                if !check_frequency_matched(health_check_frequency) {
771                    return;
772                }
773                let health_check_start_time = Instant::now();
774                up.lb.run_health_check().await;
775                info!(
776                    target: LOG_TARGET,
777                    name,
778                    elapsed = format!(
779                        "{}ms",
780                        health_check_start_time.elapsed().as_millis()
781                    ),
782                    "health check is done"
783                );
784            })
785        });
786        futures::future::join_all(jobs).await;
787
788        // each 10 times, check unhealthy upstreams
789        if check_count % 10 == 1 {
790            let current_unhealthy_upstreams =
791                self.unhealthy_upstreams.load().clone();
792            let mut notify_healthy_upstreams = vec![];
793            let mut unhealthy_upstreams = vec![];
794            for (name, status) in self.upstream_provider.healthy_status().iter()
795            {
796                if status.healthy == 0 {
797                    unhealthy_upstreams.push(name.to_string());
798                } else if current_unhealthy_upstreams.contains(name) {
799                    notify_healthy_upstreams.push(name.to_string());
800                }
801            }
802            let mut notify_unhealthy_upstreams = vec![];
803            for name in unhealthy_upstreams.iter() {
804                if !current_unhealthy_upstreams.contains(name) {
805                    notify_unhealthy_upstreams.push(name.to_string());
806                }
807            }
808            self.unhealthy_upstreams
809                .store(Arc::new(unhealthy_upstreams));
810            if let Some(sender) = &self.sender {
811                if !notify_unhealthy_upstreams.is_empty() {
812                    let data = NotificationData {
813                        category: "upstream_status".to_string(),
814                        title: "Upstream unhealthy".to_string(),
815                        message: notify_unhealthy_upstreams.join(", "),
816                        level: NotificationLevel::Error,
817                    };
818                    sender.notify(data).await;
819                }
820                if !notify_healthy_upstreams.is_empty() {
821                    let data = NotificationData {
822                        category: "upstream_status".to_string(),
823                        title: "Upstream healthy".to_string(),
824                        message: notify_healthy_upstreams.join(", "),
825                        ..Default::default()
826                    };
827                    sender.notify(data).await;
828                }
829            }
830        }
831        Ok(true)
832    }
833}
834
835struct HealthCheckTask {
836    interval: Duration,
837    sender: Option<Arc<NotificationSender>>,
838    unhealthy_upstreams: ArcSwap<Vec<String>>,
839    upstream_provider: Arc<dyn UpstreamProvider>,
840}
841
842pub fn new_upstream_health_check_task(
843    upstream_provider: Arc<dyn UpstreamProvider>,
844    interval: Duration,
845    sender: Option<Arc<NotificationSender>>,
846) -> BackgroundTaskService {
847    let task = Box::new(HealthCheckTask {
848        interval,
849        sender,
850        unhealthy_upstreams: ArcSwap::new(Arc::new(vec![])),
851        upstream_provider,
852    });
853    let name = "upstream_health_check";
854    let mut service =
855        BackgroundTaskService::new_single(name, interval, name, task);
856    service.set_immediately(true);
857    service
858}
859
860#[cfg(test)]
861mod tests {
862    use super::{
863        Upstream, UpstreamConf, UpstreamProvider, new_backends,
864        new_load_balancer,
865    };
866    use crate::new_ahash_upstreams;
867    use pingap_core::UpstreamInstance;
868    use pingap_discovery::Discovery;
869    use pingora::protocols::ALPN;
870    use pingora::proxy::Session;
871    use pretty_assertions::assert_eq;
872    use std::collections::HashMap;
873    use std::sync::Arc;
874    use std::sync::atomic::{AtomicI32, Ordering};
875    use std::time::Duration;
876    use tokio_test::io::Builder;
877
878    struct TmpProvider {
879        upstream: Arc<Upstream>,
880    }
881
882    impl UpstreamProvider for TmpProvider {
883        fn get(&self, name: &str) -> Option<Arc<Upstream>> {
884            if name == self.upstream.name.as_ref() {
885                return Some(self.upstream.clone());
886            }
887            None
888        }
889        fn list(&self) -> Vec<(String, Arc<Upstream>)> {
890            vec![(
891                self.upstream.name.as_ref().to_string(),
892                self.upstream.clone(),
893            )]
894        }
895    }
896
897    #[test]
898    fn test_new_backends() {
899        let _ = new_backends(
900            "",
901            &Discovery::new(vec![
902                "192.168.1.1:8001 10".to_string(),
903                "192.168.1.2:8001".to_string(),
904            ]),
905        )
906        .unwrap();
907
908        let _ = new_backends(
909            "",
910            &Discovery::new(vec![
911                "192.168.1.1".to_string(),
912                "192.168.1.2:8001".to_string(),
913            ]),
914        )
915        .unwrap();
916
917        let _ = new_backends(
918            "dns",
919            &Discovery::new(vec!["github.com".to_string()]),
920        )
921        .unwrap();
922    }
923    #[test]
924    fn test_new_upstream() {
925        let result = Upstream::new(
926            "charts",
927            &UpstreamConf {
928                ..Default::default()
929            },
930            None,
931        );
932        assert_eq!(
933            "Common error, category: new_upstream, upstream addrs is empty",
934            result.err().unwrap().to_string()
935        );
936
937        let up = Upstream::new(
938            "charts",
939            &UpstreamConf {
940                addrs: vec!["192.168.1.1".to_string()],
941                algo: Some("hash:cookie:user-id".to_string()),
942                alpn: Some("h2".to_string()),
943                connection_timeout: Some(Duration::from_secs(5)),
944                total_connection_timeout: Some(Duration::from_secs(10)),
945                read_timeout: Some(Duration::from_secs(3)),
946                idle_timeout: Some(Duration::from_secs(30)),
947                write_timeout: Some(Duration::from_secs(5)),
948                tcp_idle: Some(Duration::from_secs(60)),
949                tcp_probe_count: Some(100),
950                tcp_interval: Some(Duration::from_secs(60)),
951                tcp_recv_buf: Some(bytesize::ByteSize(1024)),
952                ..Default::default()
953            },
954            None,
955        )
956        .unwrap();
957
958        assert_eq!(ALPN::H2.to_string(), up.alpn.to_string());
959        assert_eq!("Some(5s)", format!("{:?}", up.connection_timeout));
960        assert_eq!("Some(10s)", format!("{:?}", up.total_connection_timeout));
961        assert_eq!("Some(3s)", format!("{:?}", up.read_timeout));
962        assert_eq!("Some(30s)", format!("{:?}", up.idle_timeout));
963        assert_eq!("Some(5s)", format!("{:?}", up.write_timeout));
964        #[cfg(target_os = "linux")]
965        assert_eq!(
966            "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100, user_timeout: 0ns })",
967            format!("{:?}", up.tcp_keepalive)
968        );
969        #[cfg(not(target_os = "linux"))]
970        assert_eq!(
971            "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100 })",
972            format!("{:?}", up.tcp_keepalive)
973        );
974        assert_eq!("Some(1024)", format!("{:?}", up.tcp_recv_buf));
975    }
976
977    #[tokio::test]
978    async fn test_upstream() {
979        let headers = [
980            "Host: github.com",
981            "Referer: https://github.com/",
982            "User-Agent: pingap/0.1.1",
983            "Cookie: deviceId=abc",
984            "Accept: application/json",
985        ]
986        .join("\r\n");
987        let input_header =
988            format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
989        let mock_io = Builder::new().read(input_header.as_bytes()).build();
990
991        let mut session = Session::new_h1(Box::new(mock_io));
992        session.read_request().await.unwrap();
993        let up = Upstream::new(
994            "upstreamname",
995            &UpstreamConf {
996                addrs: vec!["192.168.1.1:8001".to_string()],
997                ..Default::default()
998            },
999            None,
1000        )
1001        .unwrap();
1002        up.processing.fetch_add(10, Ordering::Relaxed);
1003        let value = up.processing.load(Ordering::Relaxed);
1004        assert_eq!(value, up.completed());
1005        assert_eq!(value - 1, up.processing.load(Ordering::Relaxed));
1006        assert_eq!(true, up.new_http_peer(&session, &None,).is_some());
1007    }
1008
1009    #[test]
1010    fn test_get_upstreams_processing_connected() {
1011        let mut tmp_upstream = Upstream::new(
1012            "test",
1013            &UpstreamConf {
1014                addrs: vec!["127.0.0.1:5001".to_string()],
1015                ..Default::default()
1016            },
1017            None,
1018        )
1019        .unwrap();
1020        tmp_upstream.processing = AtomicI32::new(10);
1021        let upstream = Arc::new(tmp_upstream);
1022
1023        let upstream_provider = Arc::new(TmpProvider { upstream });
1024
1025        let stat = upstream_provider.get_all_stats();
1026
1027        assert_eq!(1, stat.len());
1028        assert_eq!(10, stat.get("test").unwrap().processing);
1029    }
1030
1031    #[test]
1032    fn test_get_upstream_healthy_status() {
1033        let tmp_upstream = Upstream::new(
1034            "test",
1035            &UpstreamConf {
1036                addrs: vec!["127.0.0.1:5001".to_string()],
1037                ..Default::default()
1038            },
1039            None,
1040        )
1041        .unwrap();
1042        let upstream_provider = Arc::new(TmpProvider {
1043            upstream: Arc::new(tmp_upstream),
1044        });
1045        let status = upstream_provider.healthy_status();
1046        assert_eq!(1, status.len());
1047        // no health check, so is healthy
1048        assert_eq!(1, status.get("test").unwrap().healthy);
1049
1050        let tmp_upstream = Upstream::new(
1051            "ip",
1052            &UpstreamConf {
1053                addrs: vec!["127.0.0.1:5001".to_string()],
1054                algo: Some("hash:ip".to_string()),
1055                ..Default::default()
1056            },
1057            None,
1058        )
1059        .unwrap();
1060        let upstream_provider = Arc::new(TmpProvider {
1061            upstream: Arc::new(tmp_upstream),
1062        });
1063        let status = upstream_provider.healthy_status();
1064        assert_eq!(1, status.len());
1065        // no health check, so is healthy
1066        assert_eq!(1, status.get("ip").unwrap().healthy);
1067    }
1068
1069    #[test]
1070    fn test_new_ahash_upstreams() {
1071        let mut tmp_upstream = Upstream::new(
1072            "test",
1073            &UpstreamConf {
1074                addrs: vec!["127.0.0.1:5001".to_string()],
1075                ..Default::default()
1076            },
1077            None,
1078        )
1079        .unwrap();
1080        tmp_upstream.processing = AtomicI32::new(10);
1081        let upstream = Arc::new(tmp_upstream);
1082        let upstream_provider = Arc::new(TmpProvider { upstream });
1083
1084        let mut upstream_configs = HashMap::new();
1085        upstream_configs.insert(
1086            "test".to_string(),
1087            UpstreamConf {
1088                addrs: vec!["127.0.0.1:5001".to_string()],
1089                ..Default::default()
1090            },
1091        );
1092        let (upstreams, updated_upstreams) = new_ahash_upstreams(
1093            &upstream_configs,
1094            upstream_provider.clone(),
1095            None,
1096        )
1097        .unwrap();
1098        assert_eq!(0, updated_upstreams.len());
1099        assert_eq!(1, upstreams.len());
1100        assert_eq!(true, upstreams.contains_key("test"));
1101
1102        // new upstream address
1103        let mut upstream_configs = HashMap::new();
1104        upstream_configs.insert(
1105            "test".to_string(),
1106            UpstreamConf {
1107                addrs: vec!["127.0.0.1:5002".to_string()],
1108                ..Default::default()
1109            },
1110        );
1111        let (upstreams, updated_upstreams) = new_ahash_upstreams(
1112            &upstream_configs,
1113            upstream_provider.clone(),
1114            None,
1115        )
1116        .unwrap();
1117        assert_eq!(1, updated_upstreams.len());
1118        assert_eq!(1, upstreams.len());
1119        assert_eq!(true, upstreams.contains_key("test"));
1120
1121        // new upstream, remove old upstream
1122        let mut upstream_configs = HashMap::new();
1123        upstream_configs.insert(
1124            "test1".to_string(),
1125            UpstreamConf {
1126                addrs: vec!["127.0.0.1:5001".to_string()],
1127                ..Default::default()
1128            },
1129        );
1130        let (upstreams, updated_upstreams) =
1131            new_ahash_upstreams(&upstream_configs, upstream_provider, None)
1132                .unwrap();
1133
1134        assert_eq!(1, updated_upstreams.len());
1135        assert_eq!(1, upstreams.len());
1136        assert_eq!(false, upstreams.contains_key("test"));
1137        assert_eq!(true, upstreams.contains_key("test1"));
1138    }
1139
1140    #[test]
1141    fn test_selection_load_balancer() {
1142        let round_robin = new_load_balancer(
1143            "test",
1144            &UpstreamConf {
1145                discovery: Some("dns".to_string()),
1146                addrs: vec!["127.0.0.1:3000".to_string()],
1147                update_frequency: Some(Duration::from_secs(5)),
1148                health_check: Some(
1149                    "http://127.0.0.1:3000/?check_frequency=3s".to_string(),
1150                ),
1151                ..Default::default()
1152            },
1153            None,
1154        )
1155        .unwrap();
1156        let (update_frequency, health_check_frequency) =
1157            round_robin.get_health_frequency();
1158        assert_eq!(5, update_frequency);
1159        assert_eq!(3, health_check_frequency);
1160
1161        let consistent = new_load_balancer(
1162            "test",
1163            &UpstreamConf {
1164                discovery: Some("dns".to_string()),
1165                algo: Some("hash:ip".to_string()),
1166                addrs: vec!["127.0.0.1:3000".to_string()],
1167                update_frequency: Some(Duration::from_secs(10)),
1168                health_check: Some(
1169                    "http://127.0.0.1:3000/?check_frequency=2s".to_string(),
1170                ),
1171                ..Default::default()
1172            },
1173            None,
1174        )
1175        .unwrap();
1176        let (update_frequency, health_check_frequency) =
1177            consistent.get_health_frequency();
1178        assert_eq!(10, update_frequency);
1179        assert_eq!(2, health_check_frequency);
1180    }
1181}