pingap_upstream/
upstream.rs

1// Copyright 2024-2025 Tree xie.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use ahash::AHashMap;
16use arc_swap::ArcSwap;
17use async_trait::async_trait;
18use derive_more::Debug;
19use futures_util::FutureExt;
20use once_cell::sync::Lazy;
21use pingap_config::UpstreamConf;
22use pingap_core::{CommonServiceTask, ServiceTask};
23use pingap_core::{NotificationData, NotificationLevel, NotificationSender};
24use pingap_discovery::{
25    is_dns_discovery, is_docker_discovery, is_static_discovery,
26    new_dns_discover_backends, new_docker_discover_backends,
27    new_static_discovery, Discovery, TRANSPARENT_DISCOVERY,
28};
29use pingap_health::new_health_check;
30use pingora::lb::health_check::{HealthObserve, HealthObserveCallback};
31use pingora::lb::selection::{
32    BackendIter, BackendSelection, Consistent, RoundRobin,
33};
34use pingora::lb::Backend;
35use pingora::lb::{Backends, LoadBalancer};
36use pingora::protocols::l4::ext::TcpKeepalive;
37use pingora::protocols::ALPN;
38use pingora::proxy::Session;
39use pingora::upstreams::peer::{HttpPeer, Tracer, Tracing};
40use serde::{Deserialize, Serialize};
41use snafu::Snafu;
42use std::collections::HashMap;
43use std::sync::atomic::{AtomicI32, AtomicU32, Ordering};
44use std::sync::Arc;
45use std::time::{Duration, SystemTime};
46use tracing::{debug, error, info};
47
48const LOG_CATEGORY: &str = "upstream";
49
50#[derive(Debug, Snafu)]
51pub enum Error {
52    #[snafu(display("Common error, category: {category}, {message}"))]
53    Common { message: String, category: String },
54}
55type Result<T, E = Error> = std::result::Result<T, E>;
56
57pub struct BackendObserveNotification {
58    name: String,
59    sender: Arc<NotificationSender>,
60}
61
62#[async_trait]
63impl HealthObserve for BackendObserveNotification {
64    async fn observe(&self, backend: &Backend, healthy: bool) {
65        let addr = backend.addr.to_string();
66        let template = format!("upstream {}({addr}) becomes ", self.name);
67        let info = if healthy {
68            (NotificationLevel::Info, template + "healthy")
69        } else {
70            (NotificationLevel::Error, template + "unhealthy")
71        };
72
73        self.sender
74            .notify(NotificationData {
75                category: "backend_status".to_string(),
76                level: info.0,
77                title: "Upstream backend status changed".to_string(),
78                message: info.1,
79            })
80            .await;
81    }
82}
83
84fn new_observe(
85    name: &str,
86    sender: Option<Arc<NotificationSender>>,
87) -> Option<HealthObserveCallback> {
88    if let Some(sender) = sender {
89        Some(Box::new(BackendObserveNotification {
90            name: name.to_string(),
91            sender: sender.clone(),
92        }))
93    } else {
94        None
95    }
96}
97
98// SelectionLb represents different load balancing strategies:
99// - RoundRobin: Distributes requests evenly across backends
100// - Consistent: Uses consistent hashing to map requests to backends
101// - Transparent: Passes requests through without load balancing
102enum SelectionLb {
103    RoundRobin(Arc<LoadBalancer<RoundRobin>>),
104    Consistent(Arc<LoadBalancer<Consistent>>),
105    Transparent,
106}
107
108// UpstreamPeerTracer tracks active connections to upstream servers
109#[derive(Clone, Debug)]
110struct UpstreamPeerTracer {
111    name: String,
112    connected: Arc<AtomicI32>, // Number of active connections
113}
114
115impl UpstreamPeerTracer {
116    fn new(name: &str) -> Self {
117        Self {
118            name: name.to_string(),
119            connected: Arc::new(AtomicI32::new(0)),
120        }
121    }
122}
123
124impl Tracing for UpstreamPeerTracer {
125    fn on_connected(&self) {
126        debug!(
127            category = LOG_CATEGORY,
128            name = self.name,
129            "upstream peer connected"
130        );
131        self.connected.fetch_add(1, Ordering::Relaxed);
132    }
133    fn on_disconnected(&self) {
134        debug!(
135            category = LOG_CATEGORY,
136            name = self.name,
137            "upstream peer disconnected"
138        );
139        self.connected.fetch_sub(1, Ordering::Relaxed);
140    }
141    fn boxed_clone(&self) -> Box<dyn Tracing> {
142        Box::new(self.clone())
143    }
144}
145
146#[derive(Debug)]
147/// Represents a group of backend servers and their configuration for load balancing and connection management
148pub struct Upstream {
149    /// Unique identifier for this upstream group
150    pub name: String,
151
152    /// Hash key used to detect configuration changes
153    pub key: String,
154
155    /// Load balancing hash strategy:
156    /// - "url": Hash based on request URL
157    /// - "ip": Hash based on client IP
158    /// - "header": Hash based on specific header value
159    /// - "cookie": Hash based on specific cookie value
160    /// - "query": Hash based on specific query parameter
161    hash: String,
162
163    /// Key to use with the hash strategy:
164    /// - For "header": Header name to use
165    /// - For "cookie": Cookie name to use
166    /// - For "query": Query parameter name to use
167    hash_key: String,
168
169    /// Whether to use TLS for connections to backend servers
170    tls: bool,
171
172    /// Server Name Indication value for TLS connections
173    /// Special value "$host" means use the request's Host header
174    sni: String,
175
176    /// Load balancing strategy implementation:
177    /// - RoundRobin: Distributes requests evenly
178    /// - Consistent: Uses consistent hashing
179    /// - Transparent: Direct passthrough
180    #[debug("lb")]
181    lb: SelectionLb,
182
183    /// Maximum time to wait for establishing a connection
184    connection_timeout: Option<Duration>,
185
186    /// Maximum time for the entire connection lifecycle
187    total_connection_timeout: Option<Duration>,
188
189    /// Maximum time to wait for reading data
190    read_timeout: Option<Duration>,
191
192    /// Maximum time a connection can be idle before being closed
193    idle_timeout: Option<Duration>,
194
195    /// Maximum time to wait for writing data
196    write_timeout: Option<Duration>,
197
198    /// Whether to verify TLS certificates from backend servers
199    verify_cert: Option<bool>,
200
201    /// Application Layer Protocol Negotiation settings (H1, H2, H2H1)
202    alpn: ALPN,
203
204    /// TCP keepalive configuration for maintaining persistent connections
205    tcp_keepalive: Option<TcpKeepalive>,
206
207    /// Size of TCP receive buffer in bytes
208    tcp_recv_buf: Option<usize>,
209
210    /// Whether to enable TCP Fast Open for reduced connection latency
211    tcp_fast_open: Option<bool>,
212
213    /// Tracer for monitoring active connections to this upstream
214    peer_tracer: Option<UpstreamPeerTracer>,
215
216    /// Generic tracer interface for connection monitoring
217    tracer: Option<Tracer>,
218
219    /// Counter for number of requests currently being processed by this upstream
220    processing: AtomicI32,
221}
222
223// Creates new backend servers based on discovery method (DNS/Docker/Static)
224fn new_backends(
225    discovery_category: &str,
226    discovery: &Discovery,
227) -> Result<Backends> {
228    let (result, category) = match discovery_category {
229        d if is_dns_discovery(d) => {
230            (new_dns_discover_backends(discovery), "dns_discovery")
231        },
232        d if is_docker_discovery(d) => {
233            (new_docker_discover_backends(discovery), "docker_discovery")
234        },
235        _ => (new_static_discovery(discovery), "static_discovery"),
236    };
237    result.map_err(|e| Error::Common {
238        category: category.to_string(),
239        message: e.to_string(),
240    })
241}
242
243// Gets the value to use for consistent hashing based on the hash strategy
244fn get_hash_value(
245    hash: &str,        // Hash strategy (url/ip/header/cookie/query)
246    hash_key: &str,    // Key to use for hash lookups
247    session: &Session, // Current request session
248    client_ip: &Option<String>, // Request context
249) -> String {
250    match hash {
251        "url" => session.req_header().uri.to_string(),
252        "ip" => {
253            if let Some(client_ip) = client_ip {
254                client_ip.to_string()
255            } else {
256                pingap_core::get_client_ip(session)
257            }
258        },
259        "header" => {
260            if let Some(value) = session.get_header(hash_key) {
261                value.to_str().unwrap_or_default().to_string()
262            } else {
263                "".to_string()
264            }
265        },
266        "cookie" => {
267            pingap_core::get_cookie_value(session.req_header(), hash_key)
268                .unwrap_or_default()
269                .to_string()
270        },
271        "query" => pingap_core::get_query_value(session.req_header(), hash_key)
272            .unwrap_or_default()
273            .to_string(),
274        // default: path
275        _ => session.req_header().uri.path().to_string(),
276    }
277}
278
279fn update_health_check_params<S>(
280    mut lb: LoadBalancer<S>,
281    name: &str,
282    conf: &UpstreamConf,
283    sender: Option<Arc<NotificationSender>>,
284) -> Result<LoadBalancer<S>>
285where
286    S: BackendSelection + 'static,
287    S::Iter: BackendIter,
288{
289    // For static discovery, perform immediate backend update
290    if is_static_discovery(&conf.guess_discovery()) {
291        lb.update()
292            .now_or_never()
293            .expect("static should not block")
294            .expect("static should not error");
295    }
296
297    // Set up health checking for the backends
298    let (health_check_conf, hc) = new_health_check(
299        name,
300        &conf.health_check.clone().unwrap_or_default(),
301        new_observe(name, sender),
302    )
303    .map_err(|e| Error::Common {
304        message: e.to_string(),
305        category: "health".to_string(),
306    })?;
307    // Configure health checking
308    lb.parallel_health_check = health_check_conf.parallel_check;
309    lb.set_health_check(hc);
310    lb.update_frequency = conf.update_frequency;
311    lb.health_check_frequency = Some(health_check_conf.check_frequency);
312    Ok(lb)
313}
314
315/// Creates a new load balancer instance based on the provided configuration
316///
317/// # Arguments
318/// * `name` - Name identifier for the upstream service
319/// * `conf` - Configuration for the upstream service
320///
321/// # Returns
322/// * `Result<(SelectionLb, String, String)>` - Returns the load balancer, hash strategy, and hash key
323fn new_load_balancer(
324    name: &str,
325    conf: &UpstreamConf,
326    sender: Option<Arc<NotificationSender>>,
327) -> Result<(SelectionLb, String, String)> {
328    // Validate that addresses are provided
329    if conf.addrs.is_empty() {
330        return Err(Error::Common {
331            category: "new_upstream".to_string(),
332            message: "upstream addrs is empty".to_string(),
333        });
334    }
335
336    // Determine the service discovery method
337    let discovery_category = conf.guess_discovery();
338    // For transparent discovery, return early with no load balancing
339    if discovery_category == TRANSPARENT_DISCOVERY {
340        return Ok((SelectionLb::Transparent, "".to_string(), "".to_string()));
341    }
342
343    let mut hash = "".to_string();
344    // Determine if TLS should be enabled based on SNI configuration
345    let tls = conf
346        .sni
347        .as_ref()
348        .map(|item| !item.is_empty())
349        .unwrap_or_default();
350
351    // Create backend servers using the configured addresses and discovery method
352    let mut discovery = Discovery::new(conf.addrs.clone())
353        .with_ipv4_only(conf.ipv4_only.unwrap_or_default())
354        .with_tls(tls)
355        .with_sender(sender.clone());
356    if let Some(dns_server) = &conf.dns_server {
357        discovery = discovery.with_dns_server(dns_server.clone());
358    }
359    if let Some(dns_domain) = &conf.dns_domain {
360        discovery = discovery.with_domain(dns_domain.clone());
361    }
362    if let Some(dns_search) = &conf.dns_search {
363        discovery = discovery.with_search(dns_search.clone());
364    }
365    let backends = new_backends(&discovery_category, &discovery)?;
366
367    // Parse the load balancing algorithm configuration
368    // Format: "algo:hash_type:hash_key" (e.g. "hash:cookie:session_id")
369    let algo_method = conf.algo.clone().unwrap_or_default();
370    let algo_params: Vec<&str> = algo_method.split(':').collect();
371    let mut hash_key = "".to_string();
372
373    // Create the appropriate load balancer based on the algorithm
374    let lb = match algo_params[0] {
375        // Consistent hashing load balancer
376        "hash" => {
377            // Parse hash type and key if provided
378            if algo_params.len() > 1 {
379                hash = algo_params[1].to_string();
380                if algo_params.len() > 2 {
381                    hash_key = algo_params[2].to_string();
382                }
383            }
384            let lb = update_health_check_params(
385                LoadBalancer::<Consistent>::from_backends(backends),
386                name,
387                conf,
388                sender,
389            )?;
390
391            SelectionLb::Consistent(Arc::new(lb))
392        },
393        // Round robin load balancer (default)
394        _ => {
395            let lb = update_health_check_params(
396                LoadBalancer::<RoundRobin>::from_backends(backends),
397                name,
398                conf,
399                sender,
400            )?;
401
402            SelectionLb::RoundRobin(Arc::new(lb))
403        },
404    };
405    Ok((lb, hash, hash_key))
406}
407
408impl Upstream {
409    /// Creates a new Upstream instance from the provided configuration
410    ///
411    /// # Arguments
412    /// * `name` - Name identifier for the upstream service
413    /// * `conf` - Configuration parameters for the upstream service
414    ///
415    /// # Returns
416    /// * `Result<Self>` - New Upstream instance or error if creation fails
417    pub fn new(
418        name: &str,
419        conf: &UpstreamConf,
420        sender: Option<Arc<NotificationSender>>,
421    ) -> Result<Self> {
422        let (lb, hash, hash_key) = new_load_balancer(name, conf, sender)?;
423        let key = conf.hash_key();
424        let sni = conf.sni.clone().unwrap_or_default();
425        let tls = !sni.is_empty();
426
427        let alpn = if let Some(alpn) = &conf.alpn {
428            match alpn.to_uppercase().as_str() {
429                "H2H1" => ALPN::H2H1,
430                "H2" => ALPN::H2,
431                _ => ALPN::H1,
432            }
433        } else {
434            ALPN::H1
435        };
436
437        let tcp_keepalive = if (conf.tcp_idle.is_some()
438            && conf.tcp_probe_count.is_some()
439            && conf.tcp_interval.is_some())
440            || conf.tcp_user_timeout.is_some()
441        {
442            Some(TcpKeepalive {
443                idle: conf.tcp_idle.unwrap_or_default(),
444                count: conf.tcp_probe_count.unwrap_or_default(),
445                interval: conf.tcp_interval.unwrap_or_default(),
446                #[cfg(target_os = "linux")]
447                user_timeout: conf.tcp_user_timeout.unwrap_or_default(),
448            })
449        } else {
450            None
451        };
452
453        let peer_tracer = if conf.enable_tracer.unwrap_or_default() {
454            Some(UpstreamPeerTracer::new(name))
455        } else {
456            None
457        };
458        let tracer = peer_tracer
459            .as_ref()
460            .map(|peer_tracer| Tracer(Box::new(peer_tracer.to_owned())));
461        let up = Self {
462            name: name.to_string(),
463            key,
464            tls,
465            sni,
466            hash,
467            hash_key,
468            lb,
469            alpn,
470            connection_timeout: conf.connection_timeout,
471            total_connection_timeout: conf.total_connection_timeout,
472            read_timeout: conf.read_timeout,
473            idle_timeout: conf.idle_timeout,
474            write_timeout: conf.write_timeout,
475            verify_cert: conf.verify_cert,
476            tcp_recv_buf: conf.tcp_recv_buf.map(|item| item.as_u64() as usize),
477            tcp_keepalive,
478            tcp_fast_open: conf.tcp_fast_open,
479            peer_tracer,
480            tracer,
481            processing: AtomicI32::new(0),
482        };
483        debug!(
484            category = LOG_CATEGORY,
485            name = up.name,
486            "new upstream: {up:?}"
487        );
488        Ok(up)
489    }
490
491    /// Creates and configures a new HTTP peer for handling requests
492    ///
493    /// # Arguments
494    /// * `session` - Current HTTP session containing request details
495    /// * `ctx` - Request context state
496    ///
497    /// # Returns
498    /// * `Option<HttpPeer>` - Configured HTTP peer if a healthy backend is available, None otherwise
499    ///
500    /// This method:
501    /// 1. Selects an appropriate backend using the configured load balancing strategy
502    /// 2. Increments the processing counter
503    /// 3. Creates and configures an HttpPeer with the connection settings
504    #[inline]
505    pub fn new_http_peer(
506        &self,
507        session: &Session,
508        client_ip: &Option<String>,
509    ) -> Option<HttpPeer> {
510        // Select a backend based on the load balancing strategy
511        let upstream = match &self.lb {
512            // For round-robin, use empty key since selection is sequential
513            SelectionLb::RoundRobin(lb) => lb.select(b"", 256),
514            // For consistent hashing, generate hash value from request details
515            SelectionLb::Consistent(lb) => {
516                let value = get_hash_value(
517                    &self.hash,
518                    &self.hash_key,
519                    session,
520                    client_ip,
521                );
522                lb.select(value.as_bytes(), 256)
523            },
524            // For transparent mode, no backend selection needed
525            SelectionLb::Transparent => None,
526        };
527        // Increment counter for requests being processed
528        self.processing.fetch_add(1, Ordering::Relaxed);
529
530        // Create HTTP peer based on load balancing mode
531        let p = if matches!(self.lb, SelectionLb::Transparent) {
532            // In transparent mode, use the request's host header
533            let host = pingap_core::get_host(session.req_header())?;
534            // Set SNI: either use host header ($host) or configured value
535            let sni = if self.sni == "$host" {
536                host.to_string()
537            } else {
538                self.sni.clone()
539            };
540            // use default port for transparent http/https
541            let port = if self.tls { 443 } else { 80 };
542            // Create peer with host:port, TLS settings, and SNI
543            Some(HttpPeer::new(format!("{host}:{port}"), self.tls, sni))
544        } else {
545            // For load balanced modes, create peer from selected backend
546            upstream.map(|upstream| {
547                HttpPeer::new(upstream, self.tls, self.sni.clone())
548            })
549        };
550
551        // Configure connection options for the peer
552        p.map(|mut p| {
553            // Set various timeout values
554            p.options.connection_timeout = self.connection_timeout;
555            p.options.total_connection_timeout = self.total_connection_timeout;
556            p.options.read_timeout = self.read_timeout;
557            p.options.idle_timeout = self.idle_timeout;
558            p.options.write_timeout = self.write_timeout;
559            // Configure TLS certificate verification if specified
560            if let Some(verify_cert) = self.verify_cert {
561                p.options.verify_cert = verify_cert;
562            }
563            // Set protocol negotiation settings
564            p.options.alpn = self.alpn.clone();
565            // Configure TCP-specific options
566            p.options.tcp_keepalive.clone_from(&self.tcp_keepalive);
567            p.options.tcp_recv_buf = self.tcp_recv_buf;
568            if let Some(tcp_fast_open) = self.tcp_fast_open {
569                p.options.tcp_fast_open = tcp_fast_open;
570            }
571            // Set connection tracing if enabled
572            p.options.tracer.clone_from(&self.tracer);
573            p
574        })
575    }
576
577    /// Returns the current number of active connections to this upstream
578    ///
579    /// # Returns
580    /// * `Option<i32>` - Number of active connections if tracking is enabled, None otherwise
581    #[inline]
582    pub fn connected(&self) -> Option<i32> {
583        self.peer_tracer
584            .as_ref()
585            .map(|tracer| tracer.connected.load(Ordering::Relaxed))
586    }
587
588    /// Returns the round-robin load balancer if configured
589    ///
590    /// # Returns
591    /// * `Option<Arc<LoadBalancer<RoundRobin>>>` - Round-robin load balancer if used, None otherwise
592    #[inline]
593    pub fn as_round_robin(&self) -> Option<Arc<LoadBalancer<RoundRobin>>> {
594        match &self.lb {
595            SelectionLb::RoundRobin(lb) => Some(lb.clone()),
596            _ => None,
597        }
598    }
599
600    /// Returns the consistent hash load balancer if configured
601    ///
602    /// # Returns
603    /// * `Option<Arc<LoadBalancer<Consistent>>>` - Consistent hash load balancer if used, None otherwise
604    #[inline]
605    pub fn as_consistent(&self) -> Option<Arc<LoadBalancer<Consistent>>> {
606        match &self.lb {
607            SelectionLb::Consistent(lb) => Some(lb.clone()),
608            _ => None,
609        }
610    }
611
612    /// Decrements and returns the number of requests being processed
613    ///
614    /// # Returns
615    /// * `i32` - Previous count of requests being processed
616    #[inline]
617    pub fn completed(&self) -> i32 {
618        self.processing.fetch_add(-1, Ordering::Relaxed)
619    }
620}
621
622type Upstreams = AHashMap<String, Arc<Upstream>>;
623static UPSTREAM_MAP: Lazy<ArcSwap<Upstreams>> =
624    Lazy::new(|| ArcSwap::from_pointee(AHashMap::new()));
625
626pub fn get_upstream(name: &str) -> Option<Arc<Upstream>> {
627    if name.is_empty() {
628        return None;
629    }
630    UPSTREAM_MAP.load().get(name).cloned()
631}
632
633#[derive(Debug, Clone, Serialize, Deserialize)]
634pub struct UpstreamHealthyStatus {
635    pub healthy: u32,
636    pub total: u32,
637    pub unhealthy_backends: Vec<String>,
638}
639
640/// Get the healthy status of all upstreams
641///
642/// # Returns
643/// * `HashMap<String, UpstreamHealthyStatus>` - Healthy status of all upstreams
644///
645/// This function iterates through all upstreams and checks their health status.
646pub fn get_upstream_healthy_status() -> HashMap<String, UpstreamHealthyStatus> {
647    let mut healthy_status = HashMap::new();
648    UPSTREAM_MAP.load().iter().for_each(|(k, v)| {
649        let mut total = 0;
650        let mut healthy = 0;
651        let mut unhealthy_backends = vec![];
652        if let Some(lb) = v.as_round_robin() {
653            let backends = lb.backends().get_backend();
654            total = backends.len();
655            backends.iter().for_each(|backend| {
656                if lb.backends().ready(backend) {
657                    healthy += 1;
658                } else {
659                    unhealthy_backends.push(backend.to_string());
660                }
661            });
662        } else if let Some(lb) = v.as_consistent() {
663            let backends = lb.backends().get_backend();
664            total = backends.len();
665            backends.iter().for_each(|backend| {
666                if lb.backends().ready(backend) {
667                    healthy += 1;
668                } else {
669                    unhealthy_backends.push(backend.to_string());
670                }
671            });
672        }
673        healthy_status.insert(
674            k.to_string(),
675            UpstreamHealthyStatus {
676                healthy,
677                total: total as u32,
678                unhealthy_backends,
679            },
680        );
681    });
682    healthy_status
683}
684
685/// Get the processing and connected status of all upstreams
686///
687/// # Returns
688/// * `HashMap<String, (i32, Option<i32>)>` - Processing and connected status of all upstreams
689pub fn get_upstreams_processing_connected(
690) -> HashMap<String, (i32, Option<i32>)> {
691    let mut processing_connected = HashMap::new();
692    UPSTREAM_MAP.load().iter().for_each(|(k, v)| {
693        let count = v.processing.load(Ordering::Relaxed);
694        let connected = v.connected();
695        processing_connected.insert(k.to_string(), (count, connected));
696    });
697    processing_connected
698}
699
700fn new_ahash_upstreams(
701    upstream_configs: &HashMap<String, UpstreamConf>,
702    sender: Option<Arc<NotificationSender>>,
703) -> Result<(Upstreams, Vec<String>)> {
704    let mut upstreams = AHashMap::new();
705    let mut updated_upstreams = vec![];
706    for (name, conf) in upstream_configs.iter() {
707        let key = conf.hash_key();
708        if let Some(found) = get_upstream(name) {
709            // not modified
710            if found.key == key {
711                upstreams.insert(name.to_string(), found);
712                continue;
713            }
714        }
715        let up = Arc::new(Upstream::new(name, conf, sender.clone())?);
716        upstreams.insert(name.to_string(), up);
717        updated_upstreams.push(name.to_string());
718    }
719    Ok((upstreams, updated_upstreams))
720}
721
722/// Initialize the upstreams
723///
724/// # Arguments
725/// * `upstream_configs` - The upstream configurations
726/// * `sender` - The notification sender
727///
728/// # Returns
729pub fn try_init_upstreams(
730    upstream_configs: &HashMap<String, UpstreamConf>,
731    sender: Option<Arc<NotificationSender>>,
732) -> Result<()> {
733    let (upstreams, _) = new_ahash_upstreams(upstream_configs, sender)?;
734    UPSTREAM_MAP.store(Arc::new(upstreams));
735    Ok(())
736}
737
738async fn run_health_check(up: &Arc<Upstream>) -> Result<()> {
739    if let Some(lb) = up.as_round_robin() {
740        lb.update().await.map_err(|e| Error::Common {
741            category: "run_health_check".to_string(),
742            message: e.to_string(),
743        })?;
744        lb.backends()
745            .run_health_check(lb.parallel_health_check)
746            .await;
747    } else if let Some(lb) = up.as_consistent() {
748        lb.update().await.map_err(|e| Error::Common {
749            category: "run_health_check".to_string(),
750            message: e.to_string(),
751        })?;
752        lb.backends()
753            .run_health_check(lb.parallel_health_check)
754            .await;
755    }
756    Ok(())
757}
758
759pub async fn try_update_upstreams(
760    upstream_configs: &HashMap<String, UpstreamConf>,
761    sender: Option<Arc<NotificationSender>>,
762) -> Result<Vec<String>> {
763    let (upstreams, updated_upstreams) =
764        new_ahash_upstreams(upstream_configs, sender)?;
765    for (name, up) in upstreams.iter() {
766        // no need to run health check if not new upstream
767        if !updated_upstreams.contains(name) {
768            continue;
769        }
770        // run health check before switch to new upstream
771        if let Err(e) = run_health_check(up).await {
772            error!(
773                category = LOG_CATEGORY,
774                error = %e,
775                upstream = name,
776                "update upstream health check fail"
777            );
778        }
779    }
780    UPSTREAM_MAP.store(Arc::new(upstreams));
781    Ok(updated_upstreams)
782}
783
784#[async_trait]
785impl ServiceTask for HealthCheckTask {
786    async fn run(&self) -> Option<bool> {
787        let check_count = self.count.fetch_add(1, Ordering::Relaxed);
788        // get upstream names
789        let upstreams = {
790            let mut upstreams = vec![];
791            for (name, up) in UPSTREAM_MAP.load().iter() {
792                // transparent ignore health check
793                if matches!(up.lb, SelectionLb::Transparent) {
794                    continue;
795                }
796                upstreams.push((name.to_string(), up.clone()));
797            }
798            upstreams
799        };
800        let interval = self.interval.as_secs();
801        // run health check for each upstream
802        let jobs = upstreams.into_iter().map(|(name, up)| {
803            let runtime = pingora_runtime::current_handle();
804            runtime.spawn(async move {
805                let check_frequency_matched = |frequency: u64| -> bool {
806                    let mut count = (frequency / interval) as u32;
807                    if frequency % interval != 0 {
808                        count += 1;
809                    }
810                    check_count % count == 0
811                };
812
813                // get update frequency(update service)
814                // and health check frequency
815                let (update_frequency, health_check_frequency) =
816                    if let Some(lb) = up.as_round_robin() {
817                        let update_frequency =
818                            lb.update_frequency.unwrap_or_default().as_secs();
819                        let health_check_frequency = lb
820                            .health_check_frequency
821                            .unwrap_or_default()
822                            .as_secs();
823                        (update_frequency, health_check_frequency)
824                    } else if let Some(lb) = up.as_consistent() {
825                        let update_frequency =
826                            lb.update_frequency.unwrap_or_default().as_secs();
827                        let health_check_frequency = lb
828                            .health_check_frequency
829                            .unwrap_or_default()
830                            .as_secs();
831                        (update_frequency, health_check_frequency)
832                    } else {
833                        (0, 0)
834                    };
835
836                // the first time should match
837                // update check
838                if check_count == 0
839                    || (update_frequency > 0
840                        && check_frequency_matched(update_frequency))
841                {
842                    let result = if let Some(lb) = up.as_round_robin() {
843                        lb.update().await
844                    } else if let Some(lb) = up.as_consistent() {
845                        lb.update().await
846                    } else {
847                        Ok(())
848                    };
849                    if let Err(e) = result {
850                        error!(
851                            category = LOG_CATEGORY,
852                            error = %e,
853                            name,
854                            "update backends fail"
855                        )
856                    } else {
857                        debug!(
858                            category = LOG_CATEGORY,
859                            name, "update backend success"
860                        );
861                    }
862                }
863
864                // health check
865                if !check_frequency_matched(health_check_frequency) {
866                    return;
867                }
868                let health_check_start_time = SystemTime::now();
869                if let Some(lb) = up.as_round_robin() {
870                    lb.backends()
871                        .run_health_check(lb.parallel_health_check)
872                        .await;
873                } else if let Some(lb) = up.as_consistent() {
874                    lb.backends()
875                        .run_health_check(lb.parallel_health_check)
876                        .await;
877                }
878                info!(
879                    category = LOG_CATEGORY,
880                    name,
881                    elapsed = format!(
882                        "{}ms",
883                        health_check_start_time
884                            .elapsed()
885                            .unwrap_or_default()
886                            .as_millis()
887                    ),
888                    "health check is done"
889                );
890            })
891        });
892        futures::future::join_all(jobs).await;
893
894        // each 10 times, check unhealthy upstreams
895        if check_count % 10 == 1 {
896            let current_unhealthy_upstreams =
897                self.unhealthy_upstreams.load().clone();
898            let mut notify_healthy_upstreams = vec![];
899            let mut unhealthy_upstreams = vec![];
900            for (name, status) in get_upstream_healthy_status().iter() {
901                if status.healthy == 0 {
902                    unhealthy_upstreams.push(name.to_string());
903                } else if current_unhealthy_upstreams.contains(name) {
904                    notify_healthy_upstreams.push(name.to_string());
905                }
906            }
907            let mut notify_unhealthy_upstreams = vec![];
908            for name in unhealthy_upstreams.iter() {
909                if !current_unhealthy_upstreams.contains(name) {
910                    notify_unhealthy_upstreams.push(name.to_string());
911                }
912            }
913            self.unhealthy_upstreams
914                .store(Arc::new(unhealthy_upstreams));
915            if let Some(sender) = &self.sender {
916                if !notify_unhealthy_upstreams.is_empty() {
917                    let data = NotificationData {
918                        category: "upstream_status".to_string(),
919                        title: "Upstream unhealthy".to_string(),
920                        message: notify_unhealthy_upstreams.join(", "),
921                        level: NotificationLevel::Error,
922                    };
923                    sender.notify(data).await;
924                }
925                if !notify_healthy_upstreams.is_empty() {
926                    let data = NotificationData {
927                        category: "upstream_status".to_string(),
928                        title: "Upstream healthy".to_string(),
929                        message: notify_healthy_upstreams.join(", "),
930                        ..Default::default()
931                    };
932                    sender.notify(data).await;
933                }
934            }
935        }
936        None
937    }
938    fn description(&self) -> String {
939        let count = UPSTREAM_MAP.load().len();
940        format!("upstream health check, upstream count: {count}")
941    }
942}
943
944struct HealthCheckTask {
945    interval: Duration,
946    count: AtomicU32,
947    sender: Option<Arc<NotificationSender>>,
948    unhealthy_upstreams: ArcSwap<Vec<String>>,
949}
950
951pub fn new_upstream_health_check_task(
952    interval: Duration,
953    sender: Option<Arc<NotificationSender>>,
954) -> CommonServiceTask {
955    let interval = interval.max(Duration::from_secs(10));
956    CommonServiceTask::new(
957        interval,
958        HealthCheckTask {
959            interval,
960            count: AtomicU32::new(0),
961            sender,
962            unhealthy_upstreams: ArcSwap::new(Arc::new(vec![])),
963        },
964    )
965}
966
967#[cfg(test)]
968mod tests {
969    use super::{
970        get_hash_value, new_backends, Upstream, UpstreamConf,
971        UpstreamPeerTracer,
972    };
973    use pingap_discovery::Discovery;
974    use pingora::protocols::ALPN;
975    use pingora::proxy::Session;
976    use pingora::upstreams::peer::Tracing;
977    use pretty_assertions::assert_eq;
978    use std::sync::atomic::Ordering;
979    use std::time::Duration;
980    use tokio_test::io::Builder;
981
982    #[test]
983    fn test_new_backends() {
984        let _ = new_backends(
985            "",
986            &Discovery::new(vec![
987                "192.168.1.1:8001 10".to_string(),
988                "192.168.1.2:8001".to_string(),
989            ]),
990        )
991        .unwrap();
992
993        let _ = new_backends(
994            "",
995            &Discovery::new(vec![
996                "192.168.1.1".to_string(),
997                "192.168.1.2:8001".to_string(),
998            ]),
999        )
1000        .unwrap();
1001
1002        let _ = new_backends(
1003            "dns",
1004            &Discovery::new(vec!["github.com".to_string()]),
1005        )
1006        .unwrap();
1007    }
1008    #[test]
1009    fn test_new_upstream() {
1010        let result = Upstream::new(
1011            "charts",
1012            &UpstreamConf {
1013                ..Default::default()
1014            },
1015            None,
1016        );
1017        assert_eq!(
1018            "Common error, category: new_upstream, upstream addrs is empty",
1019            result.err().unwrap().to_string()
1020        );
1021
1022        let up = Upstream::new(
1023            "charts",
1024            &UpstreamConf {
1025                addrs: vec!["192.168.1.1".to_string()],
1026                algo: Some("hash:cookie:user-id".to_string()),
1027                alpn: Some("h2".to_string()),
1028                connection_timeout: Some(Duration::from_secs(5)),
1029                total_connection_timeout: Some(Duration::from_secs(10)),
1030                read_timeout: Some(Duration::from_secs(3)),
1031                idle_timeout: Some(Duration::from_secs(30)),
1032                write_timeout: Some(Duration::from_secs(5)),
1033                tcp_idle: Some(Duration::from_secs(60)),
1034                tcp_probe_count: Some(100),
1035                tcp_interval: Some(Duration::from_secs(60)),
1036                tcp_recv_buf: Some(bytesize::ByteSize(1024)),
1037                ..Default::default()
1038            },
1039            None,
1040        )
1041        .unwrap();
1042
1043        assert_eq!("cookie", up.hash);
1044        assert_eq!("user-id", up.hash_key);
1045        assert_eq!(ALPN::H2.to_string(), up.alpn.to_string());
1046        assert_eq!("Some(5s)", format!("{:?}", up.connection_timeout));
1047        assert_eq!("Some(10s)", format!("{:?}", up.total_connection_timeout));
1048        assert_eq!("Some(3s)", format!("{:?}", up.read_timeout));
1049        assert_eq!("Some(30s)", format!("{:?}", up.idle_timeout));
1050        assert_eq!("Some(5s)", format!("{:?}", up.write_timeout));
1051        #[cfg(target_os = "linux")]
1052        assert_eq!(
1053            "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100, user_timeout: 0ns })",
1054            format!("{:?}", up.tcp_keepalive)
1055        );
1056        #[cfg(not(target_os = "linux"))]
1057        assert_eq!(
1058            "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100 })",
1059            format!("{:?}", up.tcp_keepalive)
1060        );
1061        assert_eq!("Some(1024)", format!("{:?}", up.tcp_recv_buf));
1062    }
1063    #[tokio::test]
1064    async fn test_get_hash_key_value() {
1065        let headers = [
1066            "Host: github.com",
1067            "Referer: https://github.com/",
1068            "User-Agent: pingap/0.1.1",
1069            "Cookie: deviceId=abc",
1070            "Accept: application/json",
1071            "X-Forwarded-For: 1.1.1.1",
1072        ]
1073        .join("\r\n");
1074        let input_header = format!(
1075            "GET /vicanso/pingap?id=1234 HTTP/1.1\r\n{headers}\r\n\r\n"
1076        );
1077        let mock_io = Builder::new().read(input_header.as_bytes()).build();
1078
1079        let mut session = Session::new_h1(Box::new(mock_io));
1080        session.read_request().await.unwrap();
1081
1082        assert_eq!(
1083            "/vicanso/pingap?id=1234",
1084            get_hash_value("url", "", &session, &None)
1085        );
1086
1087        assert_eq!("1.1.1.1", get_hash_value("ip", "", &session, &None));
1088        assert_eq!(
1089            "2.2.2.2",
1090            get_hash_value("ip", "", &session, &Some("2.2.2.2".to_string()))
1091        );
1092
1093        assert_eq!(
1094            "pingap/0.1.1",
1095            get_hash_value("header", "User-Agent", &session, &None)
1096        );
1097
1098        assert_eq!(
1099            "abc",
1100            get_hash_value("cookie", "deviceId", &session, &None)
1101        );
1102        assert_eq!("1234", get_hash_value("query", "id", &session, &None));
1103        assert_eq!(
1104            "/vicanso/pingap",
1105            get_hash_value("path", "", &session, &None)
1106        );
1107    }
1108    #[tokio::test]
1109    async fn test_upstream() {
1110        let headers = [
1111            "Host: github.com",
1112            "Referer: https://github.com/",
1113            "User-Agent: pingap/0.1.1",
1114            "Cookie: deviceId=abc",
1115            "Accept: application/json",
1116        ]
1117        .join("\r\n");
1118        let input_header =
1119            format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
1120        let mock_io = Builder::new().read(input_header.as_bytes()).build();
1121
1122        let mut session = Session::new_h1(Box::new(mock_io));
1123        session.read_request().await.unwrap();
1124        let up = Upstream::new(
1125            "upstreamname",
1126            &UpstreamConf {
1127                addrs: vec!["192.168.1.1:8001".to_string()],
1128                ..Default::default()
1129            },
1130            None,
1131        )
1132        .unwrap();
1133        assert_eq!(true, up.new_http_peer(&session, &None,).is_some());
1134        assert_eq!(true, up.as_round_robin().is_some());
1135    }
1136    #[test]
1137    fn test_upstream_peer_tracer() {
1138        let tracer = UpstreamPeerTracer::new("upstreamname");
1139        tracer.on_connected();
1140        assert_eq!(1, tracer.connected.load(Ordering::Relaxed));
1141        tracer.on_disconnected();
1142        assert_eq!(0, tracer.connected.load(Ordering::Relaxed));
1143    }
1144}