1use super::Error;
16use crate::backend_circuit_state::{
17 BackendCircuitStates, CircuitBreakerConfig,
18};
19use crate::backend_stats::{BackendStats, WindowStats};
20use crate::hash_strategy::HashStrategy;
21use crate::peer_tracer::UpstreamPeerTracer;
22use crate::{LOG_TARGET, UpstreamProvider, Upstreams};
23use ahash::AHashMap;
24use arc_swap::ArcSwap;
25use async_trait::async_trait;
26use derive_more::Debug;
27use futures_util::FutureExt;
28use http::StatusCode;
29use pingap_config::Hashable;
30use pingap_config::UpstreamConf;
31use pingap_core::UpstreamInstance;
32use pingap_core::{
33 BackgroundTask, BackgroundTaskService, Error as ServiceError,
34};
35use pingap_core::{NotificationData, NotificationLevel, NotificationSender};
36use pingap_discovery::{
37 Discovery, TRANSPARENT_DISCOVERY, is_dns_discovery, is_docker_discovery,
38 is_static_discovery, new_dns_discover_backends,
39 new_docker_discover_backends, new_static_discovery,
40};
41use pingap_health::new_health_check;
42use pingora::lb::Backend;
43use pingora::lb::health_check::{HealthObserve, HealthObserveCallback};
44use pingora::lb::selection::{
45 BackendIter, BackendSelection, Consistent, RoundRobin,
46};
47use pingora::lb::{Backends, LoadBalancer};
48use pingora::protocols::ALPN;
49use pingora::protocols::l4::ext::TcpKeepalive;
50use pingora::proxy::Session;
51use pingora::upstreams::peer::{HttpPeer, Tracer};
52use serde::{Deserialize, Serialize};
53use std::collections::HashMap;
54use std::sync::Arc;
55use std::sync::atomic::{AtomicI32, Ordering};
56use std::time::{Duration, Instant};
57use tracing::{debug, error, info};
58
59type Result<T, E = Error> = std::result::Result<T, E>;
60
61pub struct BackendObserveNotification {
62 name: String,
63 sender: Arc<NotificationSender>,
64}
65
66impl BackendObserveNotification {
67 pub fn new(name: String, sender: Arc<NotificationSender>) -> Self {
68 Self { name, sender }
69 }
70}
71
72#[async_trait]
73impl HealthObserve for BackendObserveNotification {
74 async fn observe(&self, backend: &Backend, healthy: bool) {
75 let addr = backend.addr.to_string();
76 let template = format!("upstream {}({addr}) becomes ", self.name);
77 let info = if healthy {
78 (NotificationLevel::Info, template + "healthy")
79 } else {
80 (NotificationLevel::Error, template + "unhealthy")
81 };
82
83 self.sender
84 .notify(NotificationData {
85 category: "backend_status".to_string(),
86 level: info.0,
87 title: "Upstream backend status changed".to_string(),
88 message: info.1,
89 })
90 .await;
91 }
92}
93
94enum SelectionLb {
99 RoundRobin(LoadBalancer<RoundRobin>),
100 Consistent {
101 lb: LoadBalancer<Consistent>,
102 hash: HashStrategy,
103 },
104 Transparent,
105}
106
107impl SelectionLb {
108 fn get_health_frequency(&self) -> (u64, u64) {
109 match self {
110 SelectionLb::RoundRobin(lb) => (
111 lb.update_frequency.unwrap_or_default().as_secs(),
112 lb.health_check_frequency.unwrap_or_default().as_secs(),
113 ),
114 SelectionLb::Consistent { lb, .. } => (
115 lb.update_frequency.unwrap_or_default().as_secs(),
116 lb.health_check_frequency.unwrap_or_default().as_secs(),
117 ),
118 SelectionLb::Transparent => (0, 0),
119 }
120 }
121 async fn update(&self) -> pingora::Result<()> {
122 match self {
123 SelectionLb::RoundRobin(lb) => lb.update().await,
124 SelectionLb::Consistent { lb, .. } => lb.update().await,
125 SelectionLb::Transparent => Ok(()),
126 }
127 }
128 async fn run_health_check(&self) {
129 match self {
130 SelectionLb::RoundRobin(lb) => {
131 lb.backends()
132 .run_health_check(lb.parallel_health_check)
133 .await
134 },
135 SelectionLb::Consistent { lb, .. } => {
136 lb.backends()
137 .run_health_check(lb.parallel_health_check)
138 .await
139 },
140 SelectionLb::Transparent => (),
141 }
142 }
143}
144
145#[derive(Debug)]
146pub struct Upstream {
148 pub name: Arc<str>,
150
151 pub key: String,
153
154 tls: bool,
156
157 sni: String,
160
161 #[debug("lb")]
166 lb: SelectionLb,
167
168 connection_timeout: Option<Duration>,
170
171 total_connection_timeout: Option<Duration>,
173
174 read_timeout: Option<Duration>,
176
177 idle_timeout: Option<Duration>,
179
180 write_timeout: Option<Duration>,
182
183 verify_cert: Option<bool>,
185
186 alpn: ALPN,
188
189 tcp_keepalive: Option<TcpKeepalive>,
191
192 tcp_recv_buf: Option<usize>,
194
195 tcp_fast_open: Option<bool>,
197
198 peer_tracer: Option<UpstreamPeerTracer>,
200
201 tracer: Option<Tracer>,
203
204 processing: AtomicI32,
206
207 #[debug("backend_stats")]
209 backend_stats: Option<BackendStats>,
210
211 #[debug("circuit_breaker_states")]
213 circuit_breaker_states: Option<BackendCircuitStates>,
214}
215
216fn new_backends(
218 discovery_category: &str,
219 discovery: &Discovery,
220) -> Result<Backends> {
221 let (result, category) = match discovery_category {
222 d if is_dns_discovery(d) => {
223 (new_dns_discover_backends(discovery), "dns_discovery")
224 },
225 d if is_docker_discovery(d) => {
226 (new_docker_discover_backends(discovery), "docker_discovery")
227 },
228 _ => (new_static_discovery(discovery), "static_discovery"),
229 };
230 result.map_err(|e| Error::Common {
231 category: category.to_string(),
232 message: e.to_string(),
233 })
234}
235
236fn update_health_check_params<S>(
237 mut lb: LoadBalancer<S>,
238 name: &str,
239 conf: &UpstreamConf,
240 sender: Option<Arc<NotificationSender>>,
241) -> Result<LoadBalancer<S>>
242where
243 S: BackendSelection + 'static,
244 S::Iter: BackendIter,
245{
246 let mut update_frequency = if let Some(value) = conf.update_frequency {
247 Some(value)
248 } else {
249 Some(Duration::from_secs(60))
250 };
251 if is_static_discovery(&conf.guess_discovery()) {
253 update_frequency = None;
254 lb.update()
255 .now_or_never()
256 .expect("static should not block")
257 .expect("static should not error");
258 }
259
260 let observe: Option<HealthObserveCallback> = if let Some(sender) = sender {
261 Some(Box::new(BackendObserveNotification::new(
262 name.to_string(),
263 sender.clone(),
264 )))
265 } else {
266 None
267 };
268
269 let (health_check_conf, hc) = new_health_check(
271 name,
272 &conf.health_check.clone().unwrap_or_default(),
273 observe,
274 )
275 .map_err(|e| Error::Common {
276 message: e.to_string(),
277 category: "health".to_string(),
278 })?;
279 lb.parallel_health_check = health_check_conf.parallel_check;
281 lb.set_health_check(hc);
282 lb.update_frequency = update_frequency;
283 lb.health_check_frequency = Some(health_check_conf.check_frequency);
284 Ok(lb)
285}
286
287fn new_load_balancer(
296 name: &str,
297 conf: &UpstreamConf,
298 sender: Option<Arc<NotificationSender>>,
299) -> Result<SelectionLb> {
300 if conf.addrs.is_empty() {
302 return Err(Error::Common {
303 category: "new_upstream".to_string(),
304 message: "upstream addrs is empty".to_string(),
305 });
306 }
307
308 let discovery_category = conf.guess_discovery();
310 if discovery_category == TRANSPARENT_DISCOVERY {
312 return Ok(SelectionLb::Transparent);
313 }
314
315 let tls = conf
317 .sni
318 .as_ref()
319 .map(|item| !item.is_empty())
320 .unwrap_or_default();
321
322 let mut discovery = Discovery::new(conf.addrs.clone())
324 .with_ipv4_only(conf.ipv4_only.unwrap_or_default())
325 .with_tls(tls)
326 .with_sender(sender.clone());
327 if let Some(dns_server) = &conf.dns_server {
328 discovery = discovery.with_dns_server(dns_server.clone());
329 }
330 if let Some(dns_domain) = &conf.dns_domain {
331 discovery = discovery.with_domain(dns_domain.clone());
332 }
333 if let Some(dns_search) = &conf.dns_search {
334 discovery = discovery.with_search(dns_search.clone());
335 }
336 let backends = new_backends(&discovery_category, &discovery)?;
337
338 let algo_method = conf.algo.as_deref().unwrap_or("round_robin");
342
343 let parts: Vec<&str> = algo_method.split(':').collect();
344 if parts.first() == Some(&"hash") && parts.len() >= 2 {
345 let hash_type = parts[1];
346 let hash_key = parts.get(2).copied().unwrap_or_default();
347 let lb = update_health_check_params(
348 LoadBalancer::<Consistent>::from_backends(backends),
349 name,
350 conf,
351 sender,
352 )?;
353 Ok(SelectionLb::Consistent {
354 lb,
355 hash: HashStrategy::from((hash_type, hash_key)),
356 })
357 } else {
358 let lb = update_health_check_params(
360 LoadBalancer::<RoundRobin>::from_backends(backends),
361 name,
362 conf,
363 sender,
364 )?;
365 Ok(SelectionLb::RoundRobin(lb))
366 }
367}
368
369#[derive(Debug, Clone, Default)]
370pub struct UpstreamStats {
371 pub processing: i32,
372 pub connected: Option<i32>,
373 pub backend_stats: HashMap<String, WindowStats>,
374}
375
376impl Upstream {
377 pub fn new(
386 name: &str,
387 conf: &UpstreamConf,
388 sender: Option<Arc<NotificationSender>>,
389 ) -> Result<Self> {
390 let lb = new_load_balancer(name, conf, sender)?;
391 let key = conf.hash_key();
392 let sni = conf.sni.clone().unwrap_or_default();
393 let tls = !sni.is_empty();
394
395 let alpn = if let Some(alpn) = &conf.alpn {
396 match alpn.to_uppercase().as_str() {
397 "H2H1" => ALPN::H2H1,
398 "H2" => ALPN::H2,
399 _ => ALPN::H1,
400 }
401 } else {
402 ALPN::H1
403 };
404
405 let tcp_keepalive = if (conf.tcp_idle.is_some()
406 && conf.tcp_probe_count.is_some()
407 && conf.tcp_interval.is_some())
408 || conf.tcp_user_timeout.is_some()
409 {
410 Some(TcpKeepalive {
411 idle: conf.tcp_idle.unwrap_or_default(),
412 count: conf.tcp_probe_count.unwrap_or_default(),
413 interval: conf.tcp_interval.unwrap_or_default(),
414 #[cfg(target_os = "linux")]
415 user_timeout: conf.tcp_user_timeout.unwrap_or_default(),
416 })
417 } else {
418 None
419 };
420
421 let peer_tracer = if conf.enable_tracer.unwrap_or_default() {
422 Some(UpstreamPeerTracer::new(name))
423 } else {
424 None
425 };
426 let failure_status_codes = conf
427 .backend_failure_status_code
428 .clone()
429 .unwrap_or_default()
430 .split(",")
431 .flat_map(|code| code.trim().parse::<u16>().ok())
432 .collect::<Vec<u16>>();
433 let tracer = peer_tracer
434 .as_ref()
435 .map(|peer_tracer| Tracer(Box::new(peer_tracer.to_owned())));
436 let circuit_break_max_consecutive_failures = conf
437 .circuit_break_max_consecutive_failures
438 .unwrap_or_default();
439 let circuit_break_max_failure_percent =
440 conf.circuit_break_max_failure_percent.unwrap_or_default();
441 let circuit_breaker_states = if circuit_break_max_consecutive_failures
442 > 0
443 || circuit_break_max_failure_percent > 0
444 {
445 Some(BackendCircuitStates::new(CircuitBreakerConfig {
446 max_consecutive_failures:
447 circuit_break_max_consecutive_failures,
448 max_failure_percent: circuit_break_max_failure_percent as f64,
449 min_requests_threshold: conf
450 .circuit_break_min_requests_threshold
451 .unwrap_or(10),
452 half_open_consecutive_success_threshold: conf
453 .circuit_break_half_open_consecutive_success_threshold
454 .unwrap_or(5),
455 open_duration: conf
456 .circuit_break_open_duration
457 .unwrap_or(Duration::from_secs(10)),
458 }))
459 } else {
460 None
461 };
462
463 let up = Self {
464 name: name.into(),
465 key,
466 tls,
467 sni,
468 lb,
469 alpn,
470 connection_timeout: conf.connection_timeout,
471 total_connection_timeout: conf.total_connection_timeout,
472 read_timeout: conf.read_timeout,
473 idle_timeout: conf.idle_timeout.or(Some(Duration::from_secs(60))),
474 write_timeout: conf.write_timeout,
475 verify_cert: conf.verify_cert,
476 tcp_recv_buf: conf.tcp_recv_buf.map(|item| item.as_u64() as usize),
477 tcp_keepalive,
478 tcp_fast_open: conf.tcp_fast_open,
479 peer_tracer,
480 tracer,
481 processing: AtomicI32::new(0),
482 backend_stats: if conf.enable_backend_stats.unwrap_or_default() {
483 Some(BackendStats::new(
484 conf.backend_stats_interval
485 .unwrap_or_else(|| Duration::from_secs(60)),
486 failure_status_codes,
487 ))
488 } else {
489 None
490 },
491 circuit_breaker_states,
492 };
493 debug!(
494 target: LOG_TARGET,
495 name = up.name.as_ref(),
496 "new upstream: {up:?}"
497 );
498 Ok(up)
499 }
500
501 #[inline]
502 fn accept_backend(&self, backend: &Backend, healthy: bool) -> bool {
503 if !healthy {
505 return false;
506 }
507 let Some(states) = &self.circuit_breaker_states else {
509 return true;
510 };
511 states.is_backend_acceptable(&backend.addr.to_string())
512 }
513
514 #[inline]
528 pub fn new_http_peer(
529 &self,
530 session: &Session,
531 client_ip: &Option<String>,
532 ) -> Option<HttpPeer> {
533 let upstream = match &self.lb {
535 SelectionLb::RoundRobin(lb) => {
537 lb.select_with(b"", 4, |backend, healthy| {
538 self.accept_backend(backend, healthy)
539 })
540 },
541 SelectionLb::Consistent { lb, hash } => {
543 let value = hash.get_value(session, client_ip);
544 lb.select_with(value.as_bytes(), 4, |backend, healthy| {
545 self.accept_backend(backend, healthy)
546 })
547 },
548 SelectionLb::Transparent => None,
550 };
551 self.processing.fetch_add(1, Ordering::Relaxed);
553
554 let p = if matches!(self.lb, SelectionLb::Transparent) {
556 let host = pingap_core::get_host(session.req_header())?;
558 let sni = if self.sni == "$host" {
560 host.to_string()
561 } else {
562 self.sni.clone()
563 };
564 let port = if self.tls { 443 } else { 80 };
566 Some(HttpPeer::new(format!("{host}:{port}"), self.tls, sni))
568 } else {
569 upstream.map(|upstream| {
571 HttpPeer::new(upstream, self.tls, self.sni.clone())
572 })
573 };
574
575 p.map(|mut p| {
577 p.options.connection_timeout = self.connection_timeout;
579 p.options.total_connection_timeout = self.total_connection_timeout;
580 p.options.read_timeout = self.read_timeout;
581 p.options.idle_timeout = self.idle_timeout;
582 p.options.write_timeout = self.write_timeout;
583 if let Some(verify_cert) = self.verify_cert {
585 p.options.verify_cert = verify_cert;
586 }
587 p.options.alpn = self.alpn.clone();
589 p.options.tcp_keepalive.clone_from(&self.tcp_keepalive);
591 p.options.tcp_recv_buf = self.tcp_recv_buf;
592 if let Some(tcp_fast_open) = self.tcp_fast_open {
593 p.options.tcp_fast_open = tcp_fast_open;
594 }
595 p.options.tracer.clone_from(&self.tracer);
597 p
598 })
599 }
600
601 #[inline]
606 pub fn get_backends(&self) -> Option<&Backends> {
607 match &self.lb {
608 SelectionLb::RoundRobin(lb) => Some(lb.backends()),
609 SelectionLb::Consistent { lb, .. } => Some(lb.backends()),
610 SelectionLb::Transparent => None,
611 }
612 }
613
614 pub async fn run_health_check(&self) -> Result<()> {
615 self.lb.update().await.map_err(|e| Error::Common {
616 category: "run_health_check".to_string(),
617 message: e.to_string(),
618 })?;
619 self.lb.run_health_check().await;
620
621 Ok(())
622 }
623 pub fn is_transparent(&self) -> bool {
624 matches!(self.lb, SelectionLb::Transparent)
625 }
626 pub fn connected(&self) -> Option<i32> {
627 self.peer_tracer.as_ref().map(|tracer| tracer.connected())
628 }
629
630 pub fn stats(&self) -> UpstreamStats {
631 let Some(backends) = self.get_backends() else {
632 return UpstreamStats::default();
633 };
634 UpstreamStats {
635 processing: self.processing.load(Ordering::Relaxed),
636 connected: self
637 .peer_tracer
638 .as_ref()
639 .map(|tracer| tracer.connected()),
640 backend_stats: self
641 .backend_stats
642 .as_ref()
643 .map(|backend_stats| backend_stats.get_all_stats(backends))
644 .unwrap_or_default(),
645 }
646 }
647}
648
649impl UpstreamInstance for Upstream {
650 fn completed(&self) -> i32 {
655 self.processing.fetch_add(-1, Ordering::Relaxed)
656 }
657 fn on_transport_failure(&self, address: &str) {
658 let Some(backend_stats) = &self.backend_stats else {
659 return;
660 };
661 debug!(target: LOG_TARGET, address, "on_transport_failure");
662 backend_stats.on_transport_failure(address);
663 if let Some(circuit_breaker_states) = &self.circuit_breaker_states {
664 circuit_breaker_states.update_state_after_request(
665 address,
666 true,
667 backend_stats,
668 );
669 }
670 }
671 fn on_response(&self, address: &str, status: StatusCode) {
672 let Some(backend_stats) = &self.backend_stats else {
673 return;
674 };
675 debug!(target: LOG_TARGET, address, status = status.to_string(), "on_response");
676 let is_request_failure = backend_stats.on_response(address, status);
677 if let Some(circuit_breaker_states) = &self.circuit_breaker_states {
678 circuit_breaker_states.update_state_after_request(
679 address,
680 is_request_failure,
681 backend_stats,
682 );
683 }
684 }
685}
686
687#[derive(Debug, Clone, Serialize, Deserialize)]
688pub struct UpstreamHealthyStatus {
689 pub healthy: u32,
690 pub total: u32,
691 pub unhealthy_backends: Vec<String>,
692}
693
694pub fn new_ahash_upstreams(
695 upstream_configs: &HashMap<String, UpstreamConf>,
696 upstream_provider: Arc<dyn UpstreamProvider>,
697 sender: Option<Arc<NotificationSender>>,
698) -> Result<(Upstreams, Vec<String>)> {
699 let mut upstreams = AHashMap::new();
700 let mut updated_upstreams = vec![];
701 for (name, conf) in upstream_configs.iter() {
702 let key = conf.hash_key();
703 if let Some(found) = upstream_provider.get(name) {
704 if found.key == key {
706 upstreams.insert(name.to_string(), found);
707 continue;
708 }
709 }
710 let up = Arc::new(Upstream::new(name, conf, sender.clone())?);
711 upstreams.insert(name.to_string(), up);
712 updated_upstreams.push(name.to_string());
713 }
714 Ok((upstreams, updated_upstreams))
715}
716
717#[async_trait]
718impl BackgroundTask for HealthCheckTask {
719 async fn execute(&self, check_count: u32) -> Result<bool, ServiceError> {
720 let mut upstreams = self.upstream_provider.list();
722 upstreams.retain(|(_, up)| !up.is_transparent());
723 let interval = self.interval.as_secs();
724 let jobs = upstreams.into_iter().map(|(name, up)| {
726 let runtime = pingora_runtime::current_handle();
727 runtime.spawn(async move {
728 let check_frequency_matched = |frequency: u64| -> bool {
729 let mut count = (frequency / interval) as u32;
730 if !frequency.is_multiple_of(interval) {
731 count += 1;
732 }
733 check_count.is_multiple_of(count)
734 };
735
736 let (update_frequency, health_check_frequency) =
739 up.lb.get_health_frequency();
740
741 if check_count == 0
744 || (update_frequency > 0
745 && check_frequency_matched(update_frequency))
746 {
747 let update_backend_start_time = Instant::now();
748 let result = up.lb.update().await;
749 if let Err(e) = result {
750 error!(
751 target: LOG_TARGET,
752 error = %e,
753 name,
754 "update backends fail"
755 )
756 } else {
757 info!(
758 target: LOG_TARGET,
759 name,
760 elapsed = format!(
761 "{}ms",
762 update_backend_start_time.elapsed().as_millis()
763 ),
764 "update backend success"
765 );
766 }
767 }
768
769 if !check_frequency_matched(health_check_frequency) {
771 return;
772 }
773 let health_check_start_time = Instant::now();
774 up.lb.run_health_check().await;
775 info!(
776 target: LOG_TARGET,
777 name,
778 elapsed = format!(
779 "{}ms",
780 health_check_start_time.elapsed().as_millis()
781 ),
782 "health check is done"
783 );
784 })
785 });
786 futures::future::join_all(jobs).await;
787
788 if check_count % 10 == 1 {
790 let current_unhealthy_upstreams =
791 self.unhealthy_upstreams.load().clone();
792 let mut notify_healthy_upstreams = vec![];
793 let mut unhealthy_upstreams = vec![];
794 for (name, status) in self.upstream_provider.healthy_status().iter()
795 {
796 if status.healthy == 0 {
797 unhealthy_upstreams.push(name.to_string());
798 } else if current_unhealthy_upstreams.contains(name) {
799 notify_healthy_upstreams.push(name.to_string());
800 }
801 }
802 let mut notify_unhealthy_upstreams = vec![];
803 for name in unhealthy_upstreams.iter() {
804 if !current_unhealthy_upstreams.contains(name) {
805 notify_unhealthy_upstreams.push(name.to_string());
806 }
807 }
808 self.unhealthy_upstreams
809 .store(Arc::new(unhealthy_upstreams));
810 if let Some(sender) = &self.sender {
811 if !notify_unhealthy_upstreams.is_empty() {
812 let data = NotificationData {
813 category: "upstream_status".to_string(),
814 title: "Upstream unhealthy".to_string(),
815 message: notify_unhealthy_upstreams.join(", "),
816 level: NotificationLevel::Error,
817 };
818 sender.notify(data).await;
819 }
820 if !notify_healthy_upstreams.is_empty() {
821 let data = NotificationData {
822 category: "upstream_status".to_string(),
823 title: "Upstream healthy".to_string(),
824 message: notify_healthy_upstreams.join(", "),
825 ..Default::default()
826 };
827 sender.notify(data).await;
828 }
829 }
830 }
831 Ok(true)
832 }
833}
834
835struct HealthCheckTask {
836 interval: Duration,
837 sender: Option<Arc<NotificationSender>>,
838 unhealthy_upstreams: ArcSwap<Vec<String>>,
839 upstream_provider: Arc<dyn UpstreamProvider>,
840}
841
842pub fn new_upstream_health_check_task(
843 upstream_provider: Arc<dyn UpstreamProvider>,
844 interval: Duration,
845 sender: Option<Arc<NotificationSender>>,
846) -> BackgroundTaskService {
847 let task = Box::new(HealthCheckTask {
848 interval,
849 sender,
850 unhealthy_upstreams: ArcSwap::new(Arc::new(vec![])),
851 upstream_provider,
852 });
853 let name = "upstream_health_check";
854 let mut service =
855 BackgroundTaskService::new_single(name, interval, name, task);
856 service.set_immediately(true);
857 service
858}
859
860#[cfg(test)]
861mod tests {
862 use super::{
863 Upstream, UpstreamConf, UpstreamProvider, new_backends,
864 new_load_balancer,
865 };
866 use crate::new_ahash_upstreams;
867 use pingap_core::UpstreamInstance;
868 use pingap_discovery::Discovery;
869 use pingora::protocols::ALPN;
870 use pingora::proxy::Session;
871 use pretty_assertions::assert_eq;
872 use std::collections::HashMap;
873 use std::sync::Arc;
874 use std::sync::atomic::{AtomicI32, Ordering};
875 use std::time::Duration;
876 use tokio_test::io::Builder;
877
878 struct TmpProvider {
879 upstream: Arc<Upstream>,
880 }
881
882 impl UpstreamProvider for TmpProvider {
883 fn get(&self, name: &str) -> Option<Arc<Upstream>> {
884 if name == self.upstream.name.as_ref() {
885 return Some(self.upstream.clone());
886 }
887 None
888 }
889 fn list(&self) -> Vec<(String, Arc<Upstream>)> {
890 vec![(
891 self.upstream.name.as_ref().to_string(),
892 self.upstream.clone(),
893 )]
894 }
895 }
896
897 #[test]
898 fn test_new_backends() {
899 let _ = new_backends(
900 "",
901 &Discovery::new(vec![
902 "192.168.1.1:8001 10".to_string(),
903 "192.168.1.2:8001".to_string(),
904 ]),
905 )
906 .unwrap();
907
908 let _ = new_backends(
909 "",
910 &Discovery::new(vec![
911 "192.168.1.1".to_string(),
912 "192.168.1.2:8001".to_string(),
913 ]),
914 )
915 .unwrap();
916
917 let _ = new_backends(
918 "dns",
919 &Discovery::new(vec!["github.com".to_string()]),
920 )
921 .unwrap();
922 }
923 #[test]
924 fn test_new_upstream() {
925 let result = Upstream::new(
926 "charts",
927 &UpstreamConf {
928 ..Default::default()
929 },
930 None,
931 );
932 assert_eq!(
933 "Common error, category: new_upstream, upstream addrs is empty",
934 result.err().unwrap().to_string()
935 );
936
937 let up = Upstream::new(
938 "charts",
939 &UpstreamConf {
940 addrs: vec!["192.168.1.1".to_string()],
941 algo: Some("hash:cookie:user-id".to_string()),
942 alpn: Some("h2".to_string()),
943 connection_timeout: Some(Duration::from_secs(5)),
944 total_connection_timeout: Some(Duration::from_secs(10)),
945 read_timeout: Some(Duration::from_secs(3)),
946 idle_timeout: Some(Duration::from_secs(30)),
947 write_timeout: Some(Duration::from_secs(5)),
948 tcp_idle: Some(Duration::from_secs(60)),
949 tcp_probe_count: Some(100),
950 tcp_interval: Some(Duration::from_secs(60)),
951 tcp_recv_buf: Some(bytesize::ByteSize(1024)),
952 ..Default::default()
953 },
954 None,
955 )
956 .unwrap();
957
958 assert_eq!(ALPN::H2.to_string(), up.alpn.to_string());
959 assert_eq!("Some(5s)", format!("{:?}", up.connection_timeout));
960 assert_eq!("Some(10s)", format!("{:?}", up.total_connection_timeout));
961 assert_eq!("Some(3s)", format!("{:?}", up.read_timeout));
962 assert_eq!("Some(30s)", format!("{:?}", up.idle_timeout));
963 assert_eq!("Some(5s)", format!("{:?}", up.write_timeout));
964 #[cfg(target_os = "linux")]
965 assert_eq!(
966 "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100, user_timeout: 0ns })",
967 format!("{:?}", up.tcp_keepalive)
968 );
969 #[cfg(not(target_os = "linux"))]
970 assert_eq!(
971 "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100 })",
972 format!("{:?}", up.tcp_keepalive)
973 );
974 assert_eq!("Some(1024)", format!("{:?}", up.tcp_recv_buf));
975 }
976
977 #[tokio::test]
978 async fn test_upstream() {
979 let headers = [
980 "Host: github.com",
981 "Referer: https://github.com/",
982 "User-Agent: pingap/0.1.1",
983 "Cookie: deviceId=abc",
984 "Accept: application/json",
985 ]
986 .join("\r\n");
987 let input_header =
988 format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
989 let mock_io = Builder::new().read(input_header.as_bytes()).build();
990
991 let mut session = Session::new_h1(Box::new(mock_io));
992 session.read_request().await.unwrap();
993 let up = Upstream::new(
994 "upstreamname",
995 &UpstreamConf {
996 addrs: vec!["192.168.1.1:8001".to_string()],
997 ..Default::default()
998 },
999 None,
1000 )
1001 .unwrap();
1002 up.processing.fetch_add(10, Ordering::Relaxed);
1003 let value = up.processing.load(Ordering::Relaxed);
1004 assert_eq!(value, up.completed());
1005 assert_eq!(value - 1, up.processing.load(Ordering::Relaxed));
1006 assert_eq!(true, up.new_http_peer(&session, &None,).is_some());
1007 }
1008
1009 #[test]
1010 fn test_get_upstreams_processing_connected() {
1011 let mut tmp_upstream = Upstream::new(
1012 "test",
1013 &UpstreamConf {
1014 addrs: vec!["127.0.0.1:5001".to_string()],
1015 ..Default::default()
1016 },
1017 None,
1018 )
1019 .unwrap();
1020 tmp_upstream.processing = AtomicI32::new(10);
1021 let upstream = Arc::new(tmp_upstream);
1022
1023 let upstream_provider = Arc::new(TmpProvider { upstream });
1024
1025 let stat = upstream_provider.get_all_stats();
1026
1027 assert_eq!(1, stat.len());
1028 assert_eq!(10, stat.get("test").unwrap().processing);
1029 }
1030
1031 #[test]
1032 fn test_get_upstream_healthy_status() {
1033 let tmp_upstream = Upstream::new(
1034 "test",
1035 &UpstreamConf {
1036 addrs: vec!["127.0.0.1:5001".to_string()],
1037 ..Default::default()
1038 },
1039 None,
1040 )
1041 .unwrap();
1042 let upstream_provider = Arc::new(TmpProvider {
1043 upstream: Arc::new(tmp_upstream),
1044 });
1045 let status = upstream_provider.healthy_status();
1046 assert_eq!(1, status.len());
1047 assert_eq!(1, status.get("test").unwrap().healthy);
1049
1050 let tmp_upstream = Upstream::new(
1051 "ip",
1052 &UpstreamConf {
1053 addrs: vec!["127.0.0.1:5001".to_string()],
1054 algo: Some("hash:ip".to_string()),
1055 ..Default::default()
1056 },
1057 None,
1058 )
1059 .unwrap();
1060 let upstream_provider = Arc::new(TmpProvider {
1061 upstream: Arc::new(tmp_upstream),
1062 });
1063 let status = upstream_provider.healthy_status();
1064 assert_eq!(1, status.len());
1065 assert_eq!(1, status.get("ip").unwrap().healthy);
1067 }
1068
1069 #[test]
1070 fn test_new_ahash_upstreams() {
1071 let mut tmp_upstream = Upstream::new(
1072 "test",
1073 &UpstreamConf {
1074 addrs: vec!["127.0.0.1:5001".to_string()],
1075 ..Default::default()
1076 },
1077 None,
1078 )
1079 .unwrap();
1080 tmp_upstream.processing = AtomicI32::new(10);
1081 let upstream = Arc::new(tmp_upstream);
1082 let upstream_provider = Arc::new(TmpProvider { upstream });
1083
1084 let mut upstream_configs = HashMap::new();
1085 upstream_configs.insert(
1086 "test".to_string(),
1087 UpstreamConf {
1088 addrs: vec!["127.0.0.1:5001".to_string()],
1089 ..Default::default()
1090 },
1091 );
1092 let (upstreams, updated_upstreams) = new_ahash_upstreams(
1093 &upstream_configs,
1094 upstream_provider.clone(),
1095 None,
1096 )
1097 .unwrap();
1098 assert_eq!(0, updated_upstreams.len());
1099 assert_eq!(1, upstreams.len());
1100 assert_eq!(true, upstreams.contains_key("test"));
1101
1102 let mut upstream_configs = HashMap::new();
1104 upstream_configs.insert(
1105 "test".to_string(),
1106 UpstreamConf {
1107 addrs: vec!["127.0.0.1:5002".to_string()],
1108 ..Default::default()
1109 },
1110 );
1111 let (upstreams, updated_upstreams) = new_ahash_upstreams(
1112 &upstream_configs,
1113 upstream_provider.clone(),
1114 None,
1115 )
1116 .unwrap();
1117 assert_eq!(1, updated_upstreams.len());
1118 assert_eq!(1, upstreams.len());
1119 assert_eq!(true, upstreams.contains_key("test"));
1120
1121 let mut upstream_configs = HashMap::new();
1123 upstream_configs.insert(
1124 "test1".to_string(),
1125 UpstreamConf {
1126 addrs: vec!["127.0.0.1:5001".to_string()],
1127 ..Default::default()
1128 },
1129 );
1130 let (upstreams, updated_upstreams) =
1131 new_ahash_upstreams(&upstream_configs, upstream_provider, None)
1132 .unwrap();
1133
1134 assert_eq!(1, updated_upstreams.len());
1135 assert_eq!(1, upstreams.len());
1136 assert_eq!(false, upstreams.contains_key("test"));
1137 assert_eq!(true, upstreams.contains_key("test1"));
1138 }
1139
1140 #[test]
1141 fn test_selection_load_balancer() {
1142 let round_robin = new_load_balancer(
1143 "test",
1144 &UpstreamConf {
1145 discovery: Some("dns".to_string()),
1146 addrs: vec!["127.0.0.1:3000".to_string()],
1147 update_frequency: Some(Duration::from_secs(5)),
1148 health_check: Some(
1149 "http://127.0.0.1:3000/?check_frequency=3s".to_string(),
1150 ),
1151 ..Default::default()
1152 },
1153 None,
1154 )
1155 .unwrap();
1156 let (update_frequency, health_check_frequency) =
1157 round_robin.get_health_frequency();
1158 assert_eq!(5, update_frequency);
1159 assert_eq!(3, health_check_frequency);
1160
1161 let consistent = new_load_balancer(
1162 "test",
1163 &UpstreamConf {
1164 discovery: Some("dns".to_string()),
1165 algo: Some("hash:ip".to_string()),
1166 addrs: vec!["127.0.0.1:3000".to_string()],
1167 update_frequency: Some(Duration::from_secs(10)),
1168 health_check: Some(
1169 "http://127.0.0.1:3000/?check_frequency=2s".to_string(),
1170 ),
1171 ..Default::default()
1172 },
1173 None,
1174 )
1175 .unwrap();
1176 let (update_frequency, health_check_frequency) =
1177 consistent.get_health_frequency();
1178 assert_eq!(10, update_frequency);
1179 assert_eq!(2, health_check_frequency);
1180 }
1181}