pingap_upstream/
upstream.rs

1// Copyright 2024-2025 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ahash::AHashMap;
16use arc_swap::ArcSwap;
17use async_trait::async_trait;
18use derive_more::Debug;
19use futures_util::FutureExt;
20use once_cell::sync::Lazy;
21use pingap_config::UpstreamConf;
22use pingap_core::{CommonServiceTask, ServiceTask};
23use pingap_core::{NotificationData, NotificationLevel, NotificationSender};
24use pingap_discovery::{
25    is_dns_discovery, is_docker_discovery, is_static_discovery,
26    new_dns_discover_backends, new_docker_discover_backends,
27    new_static_discovery, Discovery, TRANSPARENT_DISCOVERY,
28};
29use pingap_health::new_health_check;
30use pingora::lb::health_check::{HealthObserve, HealthObserveCallback};
31use pingora::lb::selection::{
32    BackendIter, BackendSelection, Consistent, RoundRobin,
33};
34use pingora::lb::Backend;
35use pingora::lb::{Backends, LoadBalancer};
36use pingora::protocols::l4::ext::TcpKeepalive;
37use pingora::protocols::ALPN;
38use pingora::proxy::Session;
39use pingora::upstreams::peer::{HttpPeer, Tracer, Tracing};
40use serde::{Deserialize, Serialize};
41use snafu::Snafu;
42use std::collections::HashMap;
43use std::sync::atomic::{AtomicI32, AtomicU32, Ordering};
44use std::sync::Arc;
45use std::time::{Duration, SystemTime};
46use tracing::{debug, error, info};
47
48const LOG_CATEGORY: &str = "upstream";
49
50#[derive(Debug, Snafu)]
51pub enum Error {
52    #[snafu(display("Common error, category: {category}, {message}"))]
53    Common { message: String, category: String },
54}
55type Result<T, E = Error> = std::result::Result<T, E>;
56
57pub struct BackendObserveNotification {
58    name: String,
59    sender: Arc<NotificationSender>,
60}
61
62#[async_trait]
63impl HealthObserve for BackendObserveNotification {
64    async fn observe(&self, backend: &Backend, healthy: bool) {
65        let addr = backend.addr.to_string();
66        let template = format!("upstream {}({addr}) becomes ", self.name);
67        let info = if healthy {
68            (NotificationLevel::Info, template + "healthy")
69        } else {
70            (NotificationLevel::Error, template + "unhealthy")
71        };
72
73        self.sender
74            .notify(NotificationData {
75                category: "backend_status".to_string(),
76                level: info.0,
77                title: "Upstream backend status changed".to_string(),
78                message: info.1,
79            })
80            .await;
81    }
82}
83
84fn new_observe(
85    name: &str,
86    sender: Option<Arc<NotificationSender>>,
87) -> Option<HealthObserveCallback> {
88    if let Some(sender) = sender {
89        Some(Box::new(BackendObserveNotification {
90            name: name.to_string(),
91            sender: sender.clone(),
92        }))
93    } else {
94        None
95    }
96}
97
98// SelectionLb represents different load balancing strategies:
99// - RoundRobin: Distributes requests evenly across backends
100// - Consistent: Uses consistent hashing to map requests to backends
101// - Transparent: Passes requests through without load balancing
102enum SelectionLb {
103    RoundRobin(Arc<LoadBalancer<RoundRobin>>),
104    Consistent(Arc<LoadBalancer<Consistent>>),
105    Transparent,
106}
107
108// UpstreamPeerTracer tracks active connections to upstream servers
109#[derive(Clone, Debug)]
110struct UpstreamPeerTracer {
111    name: String,
112    connected: Arc<AtomicI32>, // Number of active connections
113}
114
115impl UpstreamPeerTracer {
116    fn new(name: &str) -> Self {
117        Self {
118            name: name.to_string(),
119            connected: Arc::new(AtomicI32::new(0)),
120        }
121    }
122}
123
124impl Tracing for UpstreamPeerTracer {
125    fn on_connected(&self) {
126        debug!(
127            category = LOG_CATEGORY,
128            name = self.name,
129            "upstream peer connected"
130        );
131        self.connected.fetch_add(1, Ordering::Relaxed);
132    }
133    fn on_disconnected(&self) {
134        debug!(
135            category = LOG_CATEGORY,
136            name = self.name,
137            "upstream peer disconnected"
138        );
139        self.connected.fetch_sub(1, Ordering::Relaxed);
140    }
141    fn boxed_clone(&self) -> Box<dyn Tracing> {
142        Box::new(self.clone())
143    }
144}
145
146#[derive(Debug)]
147/// Represents a group of backend servers and their configuration for load balancing and connection management
148pub struct Upstream {
149    /// Unique identifier for this upstream group
150    pub name: String,
151
152    /// Hash key used to detect configuration changes
153    pub key: String,
154
155    /// Load balancing hash strategy:
156    /// - "url": Hash based on request URL
157    /// - "ip": Hash based on client IP
158    /// - "header": Hash based on specific header value
159    /// - "cookie": Hash based on specific cookie value
160    /// - "query": Hash based on specific query parameter
161    hash: String,
162
163    /// Key to use with the hash strategy:
164    /// - For "header": Header name to use
165    /// - For "cookie": Cookie name to use
166    /// - For "query": Query parameter name to use
167    hash_key: String,
168
169    /// Whether to use TLS for connections to backend servers
170    tls: bool,
171
172    /// Server Name Indication value for TLS connections
173    /// Special value "$host" means use the request's Host header
174    sni: String,
175
176    /// Load balancing strategy implementation:
177    /// - RoundRobin: Distributes requests evenly
178    /// - Consistent: Uses consistent hashing
179    /// - Transparent: Direct passthrough
180    #[debug("lb")]
181    lb: SelectionLb,
182
183    /// Maximum time to wait for establishing a connection
184    connection_timeout: Option<Duration>,
185
186    /// Maximum time for the entire connection lifecycle
187    total_connection_timeout: Option<Duration>,
188
189    /// Maximum time to wait for reading data
190    read_timeout: Option<Duration>,
191
192    /// Maximum time a connection can be idle before being closed
193    idle_timeout: Option<Duration>,
194
195    /// Maximum time to wait for writing data
196    write_timeout: Option<Duration>,
197
198    /// Whether to verify TLS certificates from backend servers
199    verify_cert: Option<bool>,
200
201    /// Application Layer Protocol Negotiation settings (H1, H2, H2H1)
202    alpn: ALPN,
203
204    /// TCP keepalive configuration for maintaining persistent connections
205    tcp_keepalive: Option<TcpKeepalive>,
206
207    /// Size of TCP receive buffer in bytes
208    tcp_recv_buf: Option<usize>,
209
210    /// Whether to enable TCP Fast Open for reduced connection latency
211    tcp_fast_open: Option<bool>,
212
213    /// Tracer for monitoring active connections to this upstream
214    peer_tracer: Option<UpstreamPeerTracer>,
215
216    /// Generic tracer interface for connection monitoring
217    tracer: Option<Tracer>,
218
219    /// Counter for number of requests currently being processed by this upstream
220    processing: AtomicI32,
221}
222
223// Creates new backend servers based on discovery method (DNS/Docker/Static)
224fn new_backends(
225    discovery_category: &str,
226    discovery: &Discovery,
227) -> Result<Backends> {
228    let (result, category) = match discovery_category {
229        d if is_dns_discovery(d) => {
230            (new_dns_discover_backends(discovery), "dns_discovery")
231        },
232        d if is_docker_discovery(d) => {
233            (new_docker_discover_backends(discovery), "docker_discovery")
234        },
235        _ => (new_static_discovery(discovery), "static_discovery"),
236    };
237    result.map_err(|e| Error::Common {
238        category: category.to_string(),
239        message: e.to_string(),
240    })
241}
242
243// Gets the value to use for consistent hashing based on the hash strategy
244fn get_hash_value(
245    hash: &str,        // Hash strategy (url/ip/header/cookie/query)
246    hash_key: &str,    // Key to use for hash lookups
247    session: &Session, // Current request session
248    client_ip: &Option<String>, // Request context
249) -> String {
250    match hash {
251        "url" => session.req_header().uri.to_string(),
252        "ip" => {
253            if let Some(client_ip) = client_ip {
254                client_ip.to_string()
255            } else {
256                pingap_core::get_client_ip(session)
257            }
258        },
259        "header" => {
260            if let Some(value) = session.get_header(hash_key) {
261                value.to_str().unwrap_or_default().to_string()
262            } else {
263                "".to_string()
264            }
265        },
266        "cookie" => {
267            pingap_core::get_cookie_value(session.req_header(), hash_key)
268                .unwrap_or_default()
269                .to_string()
270        },
271        "query" => pingap_core::get_query_value(session.req_header(), hash_key)
272            .unwrap_or_default()
273            .to_string(),
274        // default: path
275        _ => session.req_header().uri.path().to_string(),
276    }
277}
278
279fn update_health_check_params<S>(
280    mut lb: LoadBalancer<S>,
281    name: &str,
282    conf: &UpstreamConf,
283    sender: Option<Arc<NotificationSender>>,
284) -> Result<LoadBalancer<S>>
285where
286    S: BackendSelection + 'static,
287    S::Iter: BackendIter,
288{
289    // For static discovery, perform immediate backend update
290    if is_static_discovery(&conf.guess_discovery()) {
291        lb.update()
292            .now_or_never()
293            .expect("static should not block")
294            .expect("static should not error");
295    }
296
297    // Set up health checking for the backends
298    let (health_check_conf, hc) = new_health_check(
299        name,
300        &conf.health_check.clone().unwrap_or_default(),
301        new_observe(name, sender),
302    )
303    .map_err(|e| Error::Common {
304        message: e.to_string(),
305        category: "health".to_string(),
306    })?;
307    // Configure health checking
308    lb.parallel_health_check = health_check_conf.parallel_check;
309    lb.set_health_check(hc);
310    lb.update_frequency = conf.update_frequency;
311    lb.health_check_frequency = Some(health_check_conf.check_frequency);
312    Ok(lb)
313}
314
315/// Creates a new load balancer instance based on the provided configuration
316///
317/// # Arguments
318/// * `name` - Name identifier for the upstream service
319/// * `conf` - Configuration for the upstream service
320///
321/// # Returns
322/// * `Result<(SelectionLb, String, String)>` - Returns the load balancer, hash strategy, and hash key
323fn new_load_balancer(
324    name: &str,
325    conf: &UpstreamConf,
326    sender: Option<Arc<NotificationSender>>,
327) -> Result<(SelectionLb, String, String)> {
328    // Validate that addresses are provided
329    if conf.addrs.is_empty() {
330        return Err(Error::Common {
331            category: "new_upstream".to_string(),
332            message: "upstream addrs is empty".to_string(),
333        });
334    }
335
336    // Determine the service discovery method
337    let discovery_category = conf.guess_discovery();
338    // For transparent discovery, return early with no load balancing
339    if discovery_category == TRANSPARENT_DISCOVERY {
340        return Ok((SelectionLb::Transparent, "".to_string(), "".to_string()));
341    }
342
343    let mut hash = "".to_string();
344    // Determine if TLS should be enabled based on SNI configuration
345    let tls = conf
346        .sni
347        .as_ref()
348        .map(|item| !item.is_empty())
349        .unwrap_or_default();
350
351    // Create backend servers using the configured addresses and discovery method
352    let discovery = Discovery::new(conf.addrs.clone())
353        .with_ipv4_only(conf.ipv4_only.unwrap_or_default())
354        .with_tls(tls)
355        .with_sender(sender.clone());
356    let backends = new_backends(&discovery_category, &discovery)?;
357
358    // Parse the load balancing algorithm configuration
359    // Format: "algo:hash_type:hash_key" (e.g. "hash:cookie:session_id")
360    let algo_method = conf.algo.clone().unwrap_or_default();
361    let algo_params: Vec<&str> = algo_method.split(':').collect();
362    let mut hash_key = "".to_string();
363
364    // Create the appropriate load balancer based on the algorithm
365    let lb = match algo_params[0] {
366        // Consistent hashing load balancer
367        "hash" => {
368            // Parse hash type and key if provided
369            if algo_params.len() > 1 {
370                hash = algo_params[1].to_string();
371                if algo_params.len() > 2 {
372                    hash_key = algo_params[2].to_string();
373                }
374            }
375            let lb = update_health_check_params(
376                LoadBalancer::<Consistent>::from_backends(backends),
377                name,
378                conf,
379                sender,
380            )?;
381
382            SelectionLb::Consistent(Arc::new(lb))
383        },
384        // Round robin load balancer (default)
385        _ => {
386            let lb = update_health_check_params(
387                LoadBalancer::<RoundRobin>::from_backends(backends),
388                name,
389                conf,
390                sender,
391            )?;
392
393            SelectionLb::RoundRobin(Arc::new(lb))
394        },
395    };
396    Ok((lb, hash, hash_key))
397}
398
399impl Upstream {
400    /// Creates a new Upstream instance from the provided configuration
401    ///
402    /// # Arguments
403    /// * `name` - Name identifier for the upstream service
404    /// * `conf` - Configuration parameters for the upstream service
405    ///
406    /// # Returns
407    /// * `Result<Self>` - New Upstream instance or error if creation fails
408    pub fn new(
409        name: &str,
410        conf: &UpstreamConf,
411        sender: Option<Arc<NotificationSender>>,
412    ) -> Result<Self> {
413        let (lb, hash, hash_key) = new_load_balancer(name, conf, sender)?;
414        let key = conf.hash_key();
415        let sni = conf.sni.clone().unwrap_or_default();
416        let tls = !sni.is_empty();
417
418        let alpn = if let Some(alpn) = &conf.alpn {
419            match alpn.to_uppercase().as_str() {
420                "H2H1" => ALPN::H2H1,
421                "H2" => ALPN::H2,
422                _ => ALPN::H1,
423            }
424        } else {
425            ALPN::H1
426        };
427
428        let tcp_keepalive = if conf.tcp_idle.is_some()
429            && conf.tcp_probe_count.is_some()
430            && conf.tcp_interval.is_some()
431        {
432            Some(TcpKeepalive {
433                idle: conf.tcp_idle.unwrap_or_default(),
434                count: conf.tcp_probe_count.unwrap_or_default(),
435                interval: conf.tcp_interval.unwrap_or_default(),
436                #[cfg(target_os = "linux")]
437                user_timeout: Duration::from_secs(0),
438            })
439        } else {
440            None
441        };
442
443        let peer_tracer = if conf.enable_tracer.unwrap_or_default() {
444            Some(UpstreamPeerTracer::new(name))
445        } else {
446            None
447        };
448        let tracer = peer_tracer
449            .as_ref()
450            .map(|peer_tracer| Tracer(Box::new(peer_tracer.to_owned())));
451        let up = Self {
452            name: name.to_string(),
453            key,
454            tls,
455            sni,
456            hash,
457            hash_key,
458            lb,
459            alpn,
460            connection_timeout: conf.connection_timeout,
461            total_connection_timeout: conf.total_connection_timeout,
462            read_timeout: conf.read_timeout,
463            idle_timeout: conf.idle_timeout,
464            write_timeout: conf.write_timeout,
465            verify_cert: conf.verify_cert,
466            tcp_recv_buf: conf.tcp_recv_buf.map(|item| item.as_u64() as usize),
467            tcp_keepalive,
468            tcp_fast_open: conf.tcp_fast_open,
469            peer_tracer,
470            tracer,
471            processing: AtomicI32::new(0),
472        };
473        debug!(
474            category = LOG_CATEGORY,
475            name = up.name,
476            "new upstream: {up:?}"
477        );
478        Ok(up)
479    }
480
481    /// Creates and configures a new HTTP peer for handling requests
482    ///
483    /// # Arguments
484    /// * `session` - Current HTTP session containing request details
485    /// * `ctx` - Request context state
486    ///
487    /// # Returns
488    /// * `Option<HttpPeer>` - Configured HTTP peer if a healthy backend is available, None otherwise
489    ///
490    /// This method:
491    /// 1. Selects an appropriate backend using the configured load balancing strategy
492    /// 2. Increments the processing counter
493    /// 3. Creates and configures an HttpPeer with the connection settings
494    #[inline]
495    pub fn new_http_peer(
496        &self,
497        session: &Session,
498        client_ip: &Option<String>,
499    ) -> Option<HttpPeer> {
500        // Select a backend based on the load balancing strategy
501        let upstream = match &self.lb {
502            // For round-robin, use empty key since selection is sequential
503            SelectionLb::RoundRobin(lb) => lb.select(b"", 256),
504            // For consistent hashing, generate hash value from request details
505            SelectionLb::Consistent(lb) => {
506                let value = get_hash_value(
507                    &self.hash,
508                    &self.hash_key,
509                    session,
510                    client_ip,
511                );
512                lb.select(value.as_bytes(), 256)
513            },
514            // For transparent mode, no backend selection needed
515            SelectionLb::Transparent => None,
516        };
517        // Increment counter for requests being processed
518        self.processing.fetch_add(1, Ordering::Relaxed);
519
520        // Create HTTP peer based on load balancing mode
521        let p = if matches!(self.lb, SelectionLb::Transparent) {
522            // In transparent mode, use the request's host header
523            let host = pingap_core::get_host(session.req_header())?;
524            // Set SNI: either use host header ($host) or configured value
525            let sni = if self.sni == "$host" {
526                host.to_string()
527            } else {
528                self.sni.clone()
529            };
530            // use default port for transparent http/https
531            let port = if self.tls { 443 } else { 80 };
532            // Create peer with host:port, TLS settings, and SNI
533            Some(HttpPeer::new(format!("{host}:{port}"), self.tls, sni))
534        } else {
535            // For load balanced modes, create peer from selected backend
536            upstream.map(|upstream| {
537                HttpPeer::new(upstream, self.tls, self.sni.clone())
538            })
539        };
540
541        // Configure connection options for the peer
542        p.map(|mut p| {
543            // Set various timeout values
544            p.options.connection_timeout = self.connection_timeout;
545            p.options.total_connection_timeout = self.total_connection_timeout;
546            p.options.read_timeout = self.read_timeout;
547            p.options.idle_timeout = self.idle_timeout;
548            p.options.write_timeout = self.write_timeout;
549            // Configure TLS certificate verification if specified
550            if let Some(verify_cert) = self.verify_cert {
551                p.options.verify_cert = verify_cert;
552            }
553            // Set protocol negotiation settings
554            p.options.alpn = self.alpn.clone();
555            // Configure TCP-specific options
556            p.options.tcp_keepalive.clone_from(&self.tcp_keepalive);
557            p.options.tcp_recv_buf = self.tcp_recv_buf;
558            if let Some(tcp_fast_open) = self.tcp_fast_open {
559                p.options.tcp_fast_open = tcp_fast_open;
560            }
561            // Set connection tracing if enabled
562            p.options.tracer.clone_from(&self.tracer);
563            p
564        })
565    }
566
567    /// Returns the current number of active connections to this upstream
568    ///
569    /// # Returns
570    /// * `Option<i32>` - Number of active connections if tracking is enabled, None otherwise
571    #[inline]
572    pub fn connected(&self) -> Option<i32> {
573        self.peer_tracer
574            .as_ref()
575            .map(|tracer| tracer.connected.load(Ordering::Relaxed))
576    }
577
578    /// Returns the round-robin load balancer if configured
579    ///
580    /// # Returns
581    /// * `Option<Arc<LoadBalancer<RoundRobin>>>` - Round-robin load balancer if used, None otherwise
582    #[inline]
583    pub fn as_round_robin(&self) -> Option<Arc<LoadBalancer<RoundRobin>>> {
584        match &self.lb {
585            SelectionLb::RoundRobin(lb) => Some(lb.clone()),
586            _ => None,
587        }
588    }
589
590    /// Returns the consistent hash load balancer if configured
591    ///
592    /// # Returns
593    /// * `Option<Arc<LoadBalancer<Consistent>>>` - Consistent hash load balancer if used, None otherwise
594    #[inline]
595    pub fn as_consistent(&self) -> Option<Arc<LoadBalancer<Consistent>>> {
596        match &self.lb {
597            SelectionLb::Consistent(lb) => Some(lb.clone()),
598            _ => None,
599        }
600    }
601
602    /// Decrements and returns the number of requests being processed
603    ///
604    /// # Returns
605    /// * `i32` - Previous count of requests being processed
606    #[inline]
607    pub fn completed(&self) -> i32 {
608        self.processing.fetch_add(-1, Ordering::Relaxed)
609    }
610}
611
612type Upstreams = AHashMap<String, Arc<Upstream>>;
613static UPSTREAM_MAP: Lazy<ArcSwap<Upstreams>> =
614    Lazy::new(|| ArcSwap::from_pointee(AHashMap::new()));
615
616pub fn get_upstream(name: &str) -> Option<Arc<Upstream>> {
617    if name.is_empty() {
618        return None;
619    }
620    UPSTREAM_MAP.load().get(name).cloned()
621}
622
623#[derive(Debug, Clone, Serialize, Deserialize)]
624pub struct UpstreamHealthyStatus {
625    pub healthy: u32,
626    pub total: u32,
627    pub unhealthy_backends: Vec<String>,
628}
629
630/// Get the healthy status of all upstreams
631///
632/// # Returns
633/// * `HashMap<String, UpstreamHealthyStatus>` - Healthy status of all upstreams
634///
635/// This function iterates through all upstreams and checks their health status.
636pub fn get_upstream_healthy_status() -> HashMap<String, UpstreamHealthyStatus> {
637    let mut healthy_status = HashMap::new();
638    UPSTREAM_MAP.load().iter().for_each(|(k, v)| {
639        let mut total = 0;
640        let mut healthy = 0;
641        let mut unhealthy_backends = vec![];
642        if let Some(lb) = v.as_round_robin() {
643            let backends = lb.backends().get_backend();
644            total = backends.len();
645            backends.iter().for_each(|backend| {
646                if lb.backends().ready(backend) {
647                    healthy += 1;
648                } else {
649                    unhealthy_backends.push(backend.to_string());
650                }
651            });
652        } else if let Some(lb) = v.as_consistent() {
653            let backends = lb.backends().get_backend();
654            total = backends.len();
655            backends.iter().for_each(|backend| {
656                if lb.backends().ready(backend) {
657                    healthy += 1;
658                } else {
659                    unhealthy_backends.push(backend.to_string());
660                }
661            });
662        }
663        healthy_status.insert(
664            k.to_string(),
665            UpstreamHealthyStatus {
666                healthy,
667                total: total as u32,
668                unhealthy_backends,
669            },
670        );
671    });
672    healthy_status
673}
674
675/// Get the processing and connected status of all upstreams
676///
677/// # Returns
678/// * `HashMap<String, (i32, Option<i32>)>` - Processing and connected status of all upstreams
679pub fn get_upstreams_processing_connected(
680) -> HashMap<String, (i32, Option<i32>)> {
681    let mut processing_connected = HashMap::new();
682    UPSTREAM_MAP.load().iter().for_each(|(k, v)| {
683        let count = v.processing.load(Ordering::Relaxed);
684        let connected = v.connected();
685        processing_connected.insert(k.to_string(), (count, connected));
686    });
687    processing_connected
688}
689
690fn new_ahash_upstreams(
691    upstream_configs: &HashMap<String, UpstreamConf>,
692    sender: Option<Arc<NotificationSender>>,
693) -> Result<(Upstreams, Vec<String>)> {
694    let mut upstreams = AHashMap::new();
695    let mut updated_upstreams = vec![];
696    for (name, conf) in upstream_configs.iter() {
697        let key = conf.hash_key();
698        if let Some(found) = get_upstream(name) {
699            // not modified
700            if found.key == key {
701                upstreams.insert(name.to_string(), found);
702                continue;
703            }
704        }
705        let up = Arc::new(Upstream::new(name, conf, sender.clone())?);
706        upstreams.insert(name.to_string(), up);
707        updated_upstreams.push(name.to_string());
708    }
709    Ok((upstreams, updated_upstreams))
710}
711
712/// Initialize the upstreams
713///
714/// # Arguments
715/// * `upstream_configs` - The upstream configurations
716/// * `sender` - The notification sender
717///
718/// # Returns
719pub fn try_init_upstreams(
720    upstream_configs: &HashMap<String, UpstreamConf>,
721    sender: Option<Arc<NotificationSender>>,
722) -> Result<()> {
723    let (upstreams, _) = new_ahash_upstreams(upstream_configs, sender)?;
724    UPSTREAM_MAP.store(Arc::new(upstreams));
725    Ok(())
726}
727
728async fn run_health_check(up: &Arc<Upstream>) -> Result<()> {
729    if let Some(lb) = up.as_round_robin() {
730        lb.update().await.map_err(|e| Error::Common {
731            category: "run_health_check".to_string(),
732            message: e.to_string(),
733        })?;
734        lb.backends()
735            .run_health_check(lb.parallel_health_check)
736            .await;
737    } else if let Some(lb) = up.as_consistent() {
738        lb.update().await.map_err(|e| Error::Common {
739            category: "run_health_check".to_string(),
740            message: e.to_string(),
741        })?;
742        lb.backends()
743            .run_health_check(lb.parallel_health_check)
744            .await;
745    }
746    Ok(())
747}
748
749pub async fn try_update_upstreams(
750    upstream_configs: &HashMap<String, UpstreamConf>,
751    sender: Option<Arc<NotificationSender>>,
752) -> Result<Vec<String>> {
753    let (upstreams, updated_upstreams) =
754        new_ahash_upstreams(upstream_configs, sender)?;
755    for (name, up) in upstreams.iter() {
756        // no need to run health check if not new upstream
757        if !updated_upstreams.contains(name) {
758            continue;
759        }
760        if let Err(e) = run_health_check(up).await {
761            error!(
762                category = LOG_CATEGORY,
763                error = %e,
764                upstream = name,
765                "update upstream health check fail"
766            );
767        }
768    }
769    UPSTREAM_MAP.store(Arc::new(upstreams));
770    Ok(updated_upstreams)
771}
772
773#[async_trait]
774impl ServiceTask for HealthCheckTask {
775    async fn run(&self) -> Option<bool> {
776        let check_count = self.count.fetch_add(1, Ordering::Relaxed);
777        // get upstream names
778        let upstreams = {
779            let mut upstreams = vec![];
780            for (name, up) in UPSTREAM_MAP.load().iter() {
781                // transparent ignore health check
782                if matches!(up.lb, SelectionLb::Transparent) {
783                    continue;
784                }
785                upstreams.push((name.to_string(), up.clone()));
786            }
787            upstreams
788        };
789        let interval = self.interval.as_secs();
790        // run health check for each upstream
791        let jobs = upstreams.into_iter().map(|(name, up)| {
792            let runtime = pingora_runtime::current_handle();
793            runtime.spawn(async move {
794                let check_frequency_matched = |frequency: u64| -> bool {
795                    let mut count = (frequency / interval) as u32;
796                    if frequency % interval != 0 {
797                        count += 1;
798                    }
799                    check_count % count == 0
800                };
801
802                // get update frequency(update service)
803                // and health check frequency
804                let (update_frequency, health_check_frequency) =
805                    if let Some(lb) = up.as_round_robin() {
806                        let update_frequency =
807                            lb.update_frequency.unwrap_or_default().as_secs();
808                        let health_check_frequency = lb
809                            .health_check_frequency
810                            .unwrap_or_default()
811                            .as_secs();
812                        (update_frequency, health_check_frequency)
813                    } else if let Some(lb) = up.as_consistent() {
814                        let update_frequency =
815                            lb.update_frequency.unwrap_or_default().as_secs();
816                        let health_check_frequency = lb
817                            .health_check_frequency
818                            .unwrap_or_default()
819                            .as_secs();
820                        (update_frequency, health_check_frequency)
821                    } else {
822                        (0, 0)
823                    };
824
825                // the first time should match
826                // update check
827                if check_count == 0
828                    || (update_frequency > 0
829                        && check_frequency_matched(update_frequency))
830                {
831                    let result = if let Some(lb) = up.as_round_robin() {
832                        lb.update().await
833                    } else if let Some(lb) = up.as_consistent() {
834                        lb.update().await
835                    } else {
836                        Ok(())
837                    };
838                    if let Err(e) = result {
839                        error!(
840                            category = LOG_CATEGORY,
841                            error = %e,
842                            name,
843                            "update backends fail"
844                        )
845                    } else {
846                        debug!(
847                            category = LOG_CATEGORY,
848                            name, "update backend success"
849                        );
850                    }
851                }
852
853                // health check
854                if !check_frequency_matched(health_check_frequency) {
855                    return;
856                }
857                let health_check_start_time = SystemTime::now();
858                if let Some(lb) = up.as_round_robin() {
859                    lb.backends()
860                        .run_health_check(lb.parallel_health_check)
861                        .await;
862                } else if let Some(lb) = up.as_consistent() {
863                    lb.backends()
864                        .run_health_check(lb.parallel_health_check)
865                        .await;
866                }
867                info!(
868                    category = LOG_CATEGORY,
869                    name,
870                    elapsed = format!(
871                        "{}ms",
872                        health_check_start_time
873                            .elapsed()
874                            .unwrap_or_default()
875                            .as_millis()
876                    ),
877                    "health check is done"
878                );
879            })
880        });
881        futures::future::join_all(jobs).await;
882
883        // each 10 times, check unhealthy upstreams
884        if check_count % 10 == 1 {
885            let current_unhealthy_upstreams =
886                self.unhealthy_upstreams.load().clone();
887            let mut notify_healthy_upstreams = vec![];
888            let mut unhealthy_upstreams = vec![];
889            for (name, status) in get_upstream_healthy_status().iter() {
890                if status.healthy == 0 {
891                    unhealthy_upstreams.push(name.to_string());
892                } else if current_unhealthy_upstreams.contains(name) {
893                    notify_healthy_upstreams.push(name.to_string());
894                }
895            }
896            let mut notify_unhealthy_upstreams = vec![];
897            for name in unhealthy_upstreams.iter() {
898                if !current_unhealthy_upstreams.contains(name) {
899                    notify_unhealthy_upstreams.push(name.to_string());
900                }
901            }
902            self.unhealthy_upstreams
903                .store(Arc::new(unhealthy_upstreams));
904            if let Some(sender) = &self.sender {
905                if !notify_unhealthy_upstreams.is_empty() {
906                    let data = NotificationData {
907                        category: "upstream_status".to_string(),
908                        title: "Upstream unhealthy".to_string(),
909                        message: notify_unhealthy_upstreams.join(", "),
910                        level: NotificationLevel::Error,
911                    };
912                    sender.notify(data).await;
913                }
914                if !notify_healthy_upstreams.is_empty() {
915                    let data = NotificationData {
916                        category: "upstream_status".to_string(),
917                        title: "Upstream healthy".to_string(),
918                        message: notify_healthy_upstreams.join(", "),
919                        ..Default::default()
920                    };
921                    sender.notify(data).await;
922                }
923            }
924        }
925        None
926    }
927    fn description(&self) -> String {
928        let count = UPSTREAM_MAP.load().len();
929        format!("upstream health check, upstream count: {count}")
930    }
931}
932
933struct HealthCheckTask {
934    interval: Duration,
935    count: AtomicU32,
936    sender: Option<Arc<NotificationSender>>,
937    unhealthy_upstreams: ArcSwap<Vec<String>>,
938}
939
940pub fn new_upstream_health_check_task(
941    interval: Duration,
942    sender: Option<Arc<NotificationSender>>,
943) -> CommonServiceTask {
944    let interval = interval.max(Duration::from_secs(10));
945    CommonServiceTask::new(
946        interval,
947        HealthCheckTask {
948            interval,
949            count: AtomicU32::new(0),
950            sender,
951            unhealthy_upstreams: ArcSwap::new(Arc::new(vec![])),
952        },
953    )
954}
955
956#[cfg(test)]
957mod tests {
958    use super::{
959        get_hash_value, new_backends, Upstream, UpstreamConf,
960        UpstreamPeerTracer,
961    };
962    use pingap_discovery::Discovery;
963    use pingora::protocols::ALPN;
964    use pingora::proxy::Session;
965    use pingora::upstreams::peer::Tracing;
966    use pretty_assertions::assert_eq;
967    use std::sync::atomic::Ordering;
968    use std::time::Duration;
969    use tokio_test::io::Builder;
970
971    #[test]
972    fn test_new_backends() {
973        let _ = new_backends(
974            "",
975            &Discovery::new(vec![
976                "192.168.1.1:8001 10".to_string(),
977                "192.168.1.2:8001".to_string(),
978            ]),
979        )
980        .unwrap();
981
982        let _ = new_backends(
983            "",
984            &Discovery::new(vec![
985                "192.168.1.1".to_string(),
986                "192.168.1.2:8001".to_string(),
987            ]),
988        )
989        .unwrap();
990
991        let _ = new_backends(
992            "dns",
993            &Discovery::new(vec!["github.com".to_string()]),
994        )
995        .unwrap();
996    }
997    #[test]
998    fn test_new_upstream() {
999        let result = Upstream::new(
1000            "charts",
1001            &UpstreamConf {
1002                ..Default::default()
1003            },
1004            None,
1005        );
1006        assert_eq!(
1007            "Common error, category: new_upstream, upstream addrs is empty",
1008            result.err().unwrap().to_string()
1009        );
1010
1011        let up = Upstream::new(
1012            "charts",
1013            &UpstreamConf {
1014                addrs: vec!["192.168.1.1".to_string()],
1015                algo: Some("hash:cookie:user-id".to_string()),
1016                alpn: Some("h2".to_string()),
1017                connection_timeout: Some(Duration::from_secs(5)),
1018                total_connection_timeout: Some(Duration::from_secs(10)),
1019                read_timeout: Some(Duration::from_secs(3)),
1020                idle_timeout: Some(Duration::from_secs(30)),
1021                write_timeout: Some(Duration::from_secs(5)),
1022                tcp_idle: Some(Duration::from_secs(60)),
1023                tcp_probe_count: Some(100),
1024                tcp_interval: Some(Duration::from_secs(60)),
1025                tcp_recv_buf: Some(bytesize::ByteSize(1024)),
1026                ..Default::default()
1027            },
1028            None,
1029        )
1030        .unwrap();
1031
1032        assert_eq!("cookie", up.hash);
1033        assert_eq!("user-id", up.hash_key);
1034        assert_eq!(ALPN::H2.to_string(), up.alpn.to_string());
1035        assert_eq!("Some(5s)", format!("{:?}", up.connection_timeout));
1036        assert_eq!("Some(10s)", format!("{:?}", up.total_connection_timeout));
1037        assert_eq!("Some(3s)", format!("{:?}", up.read_timeout));
1038        assert_eq!("Some(30s)", format!("{:?}", up.idle_timeout));
1039        assert_eq!("Some(5s)", format!("{:?}", up.write_timeout));
1040        #[cfg(target_os = "linux")]
1041        assert_eq!(
1042            "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100, user_timeout: 0ns })",
1043            format!("{:?}", up.tcp_keepalive)
1044        );
1045        #[cfg(not(target_os = "linux"))]
1046        assert_eq!(
1047            "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100 })",
1048            format!("{:?}", up.tcp_keepalive)
1049        );
1050        assert_eq!("Some(1024)", format!("{:?}", up.tcp_recv_buf));
1051    }
1052    #[tokio::test]
1053    async fn test_get_hash_key_value() {
1054        let headers = [
1055            "Host: github.com",
1056            "Referer: https://github.com/",
1057            "User-Agent: pingap/0.1.1",
1058            "Cookie: deviceId=abc",
1059            "Accept: application/json",
1060            "X-Forwarded-For: 1.1.1.1",
1061        ]
1062        .join("\r\n");
1063        let input_header = format!(
1064            "GET /vicanso/pingap?id=1234 HTTP/1.1\r\n{headers}\r\n\r\n"
1065        );
1066        let mock_io = Builder::new().read(input_header.as_bytes()).build();
1067
1068        let mut session = Session::new_h1(Box::new(mock_io));
1069        session.read_request().await.unwrap();
1070
1071        assert_eq!(
1072            "/vicanso/pingap?id=1234",
1073            get_hash_value("url", "", &session, &None)
1074        );
1075
1076        assert_eq!("1.1.1.1", get_hash_value("ip", "", &session, &None));
1077        assert_eq!(
1078            "2.2.2.2",
1079            get_hash_value("ip", "", &session, &Some("2.2.2.2".to_string()))
1080        );
1081
1082        assert_eq!(
1083            "pingap/0.1.1",
1084            get_hash_value("header", "User-Agent", &session, &None)
1085        );
1086
1087        assert_eq!(
1088            "abc",
1089            get_hash_value("cookie", "deviceId", &session, &None)
1090        );
1091        assert_eq!("1234", get_hash_value("query", "id", &session, &None));
1092        assert_eq!(
1093            "/vicanso/pingap",
1094            get_hash_value("path", "", &session, &None)
1095        );
1096    }
1097    #[tokio::test]
1098    async fn test_upstream() {
1099        let headers = [
1100            "Host: github.com",
1101            "Referer: https://github.com/",
1102            "User-Agent: pingap/0.1.1",
1103            "Cookie: deviceId=abc",
1104            "Accept: application/json",
1105        ]
1106        .join("\r\n");
1107        let input_header =
1108            format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
1109        let mock_io = Builder::new().read(input_header.as_bytes()).build();
1110
1111        let mut session = Session::new_h1(Box::new(mock_io));
1112        session.read_request().await.unwrap();
1113        let up = Upstream::new(
1114            "upstreamname",
1115            &UpstreamConf {
1116                addrs: vec!["192.168.1.1:8001".to_string()],
1117                ..Default::default()
1118            },
1119            None,
1120        )
1121        .unwrap();
1122        assert_eq!(true, up.new_http_peer(&session, &None,).is_some());
1123        assert_eq!(true, up.as_round_robin().is_some());
1124    }
1125    #[test]
1126    fn test_upstream_peer_tracer() {
1127        let tracer = UpstreamPeerTracer::new("upstreamname");
1128        tracer.on_connected();
1129        assert_eq!(1, tracer.connected.load(Ordering::Relaxed));
1130        tracer.on_disconnected();
1131        assert_eq!(0, tracer.connected.load(Ordering::Relaxed));
1132    }
1133}