1use ahash::AHashMap;
16use arc_swap::ArcSwap;
17use async_trait::async_trait;
18use derive_more::Debug;
19use futures_util::FutureExt;
20use once_cell::sync::Lazy;
21use pingap_config::UpstreamConf;
22use pingap_core::{CommonServiceTask, ServiceTask};
23use pingap_core::{NotificationData, NotificationLevel, NotificationSender};
24use pingap_discovery::{
25 is_dns_discovery, is_docker_discovery, is_static_discovery,
26 new_dns_discover_backends, new_docker_discover_backends,
27 new_static_discovery, Discovery, TRANSPARENT_DISCOVERY,
28};
29use pingap_health::new_health_check;
30use pingora::lb::health_check::{HealthObserve, HealthObserveCallback};
31use pingora::lb::selection::{
32 BackendIter, BackendSelection, Consistent, RoundRobin,
33};
34use pingora::lb::Backend;
35use pingora::lb::{Backends, LoadBalancer};
36use pingora::protocols::l4::ext::TcpKeepalive;
37use pingora::protocols::ALPN;
38use pingora::proxy::Session;
39use pingora::upstreams::peer::{HttpPeer, Tracer, Tracing};
40use serde::{Deserialize, Serialize};
41use snafu::Snafu;
42use std::collections::HashMap;
43use std::sync::atomic::{AtomicI32, AtomicU32, Ordering};
44use std::sync::Arc;
45use std::time::{Duration, SystemTime};
46use tracing::{debug, error, info};
47
48const LOG_CATEGORY: &str = "upstream";
49
50#[derive(Debug, Snafu)]
51pub enum Error {
52 #[snafu(display("Common error, category: {category}, {message}"))]
53 Common { message: String, category: String },
54}
55type Result<T, E = Error> = std::result::Result<T, E>;
56
57pub struct BackendObserveNotification {
58 name: String,
59 sender: Arc<NotificationSender>,
60}
61
62#[async_trait]
63impl HealthObserve for BackendObserveNotification {
64 async fn observe(&self, backend: &Backend, healthy: bool) {
65 let addr = backend.addr.to_string();
66 let template = format!("upstream {}({addr}) becomes ", self.name);
67 let info = if healthy {
68 (NotificationLevel::Info, template + "healthy")
69 } else {
70 (NotificationLevel::Error, template + "unhealthy")
71 };
72
73 self.sender
74 .notify(NotificationData {
75 category: "backend_status".to_string(),
76 level: info.0,
77 title: "Upstream backend status changed".to_string(),
78 message: info.1,
79 })
80 .await;
81 }
82}
83
84fn new_observe(
85 name: &str,
86 sender: Option<Arc<NotificationSender>>,
87) -> Option<HealthObserveCallback> {
88 if let Some(sender) = sender {
89 Some(Box::new(BackendObserveNotification {
90 name: name.to_string(),
91 sender: sender.clone(),
92 }))
93 } else {
94 None
95 }
96}
97
98enum SelectionLb {
103 RoundRobin(Arc<LoadBalancer<RoundRobin>>),
104 Consistent(Arc<LoadBalancer<Consistent>>),
105 Transparent,
106}
107
108#[derive(Clone, Debug)]
110struct UpstreamPeerTracer {
111 name: String,
112 connected: Arc<AtomicI32>, }
114
115impl UpstreamPeerTracer {
116 fn new(name: &str) -> Self {
117 Self {
118 name: name.to_string(),
119 connected: Arc::new(AtomicI32::new(0)),
120 }
121 }
122}
123
124impl Tracing for UpstreamPeerTracer {
125 fn on_connected(&self) {
126 debug!(
127 category = LOG_CATEGORY,
128 name = self.name,
129 "upstream peer connected"
130 );
131 self.connected.fetch_add(1, Ordering::Relaxed);
132 }
133 fn on_disconnected(&self) {
134 debug!(
135 category = LOG_CATEGORY,
136 name = self.name,
137 "upstream peer disconnected"
138 );
139 self.connected.fetch_sub(1, Ordering::Relaxed);
140 }
141 fn boxed_clone(&self) -> Box<dyn Tracing> {
142 Box::new(self.clone())
143 }
144}
145
146#[derive(Debug)]
147pub struct Upstream {
149 pub name: String,
151
152 pub key: String,
154
155 hash: String,
162
163 hash_key: String,
168
169 tls: bool,
171
172 sni: String,
175
176 #[debug("lb")]
181 lb: SelectionLb,
182
183 connection_timeout: Option<Duration>,
185
186 total_connection_timeout: Option<Duration>,
188
189 read_timeout: Option<Duration>,
191
192 idle_timeout: Option<Duration>,
194
195 write_timeout: Option<Duration>,
197
198 verify_cert: Option<bool>,
200
201 alpn: ALPN,
203
204 tcp_keepalive: Option<TcpKeepalive>,
206
207 tcp_recv_buf: Option<usize>,
209
210 tcp_fast_open: Option<bool>,
212
213 peer_tracer: Option<UpstreamPeerTracer>,
215
216 tracer: Option<Tracer>,
218
219 processing: AtomicI32,
221}
222
223fn new_backends(
225 discovery_category: &str,
226 discovery: &Discovery,
227) -> Result<Backends> {
228 let (result, category) = match discovery_category {
229 d if is_dns_discovery(d) => {
230 (new_dns_discover_backends(discovery), "dns_discovery")
231 },
232 d if is_docker_discovery(d) => {
233 (new_docker_discover_backends(discovery), "docker_discovery")
234 },
235 _ => (new_static_discovery(discovery), "static_discovery"),
236 };
237 result.map_err(|e| Error::Common {
238 category: category.to_string(),
239 message: e.to_string(),
240 })
241}
242
243fn get_hash_value(
245 hash: &str, hash_key: &str, session: &Session, client_ip: &Option<String>, ) -> String {
250 match hash {
251 "url" => session.req_header().uri.to_string(),
252 "ip" => {
253 if let Some(client_ip) = client_ip {
254 client_ip.to_string()
255 } else {
256 pingap_core::get_client_ip(session)
257 }
258 },
259 "header" => {
260 if let Some(value) = session.get_header(hash_key) {
261 value.to_str().unwrap_or_default().to_string()
262 } else {
263 "".to_string()
264 }
265 },
266 "cookie" => {
267 pingap_core::get_cookie_value(session.req_header(), hash_key)
268 .unwrap_or_default()
269 .to_string()
270 },
271 "query" => pingap_core::get_query_value(session.req_header(), hash_key)
272 .unwrap_or_default()
273 .to_string(),
274 _ => session.req_header().uri.path().to_string(),
276 }
277}
278
279fn update_health_check_params<S>(
280 mut lb: LoadBalancer<S>,
281 name: &str,
282 conf: &UpstreamConf,
283 sender: Option<Arc<NotificationSender>>,
284) -> Result<LoadBalancer<S>>
285where
286 S: BackendSelection + 'static,
287 S::Iter: BackendIter,
288{
289 if is_static_discovery(&conf.guess_discovery()) {
291 lb.update()
292 .now_or_never()
293 .expect("static should not block")
294 .expect("static should not error");
295 }
296
297 let (health_check_conf, hc) = new_health_check(
299 name,
300 &conf.health_check.clone().unwrap_or_default(),
301 new_observe(name, sender),
302 )
303 .map_err(|e| Error::Common {
304 message: e.to_string(),
305 category: "health".to_string(),
306 })?;
307 lb.parallel_health_check = health_check_conf.parallel_check;
309 lb.set_health_check(hc);
310 lb.update_frequency = conf.update_frequency;
311 lb.health_check_frequency = Some(health_check_conf.check_frequency);
312 Ok(lb)
313}
314
315fn new_load_balancer(
324 name: &str,
325 conf: &UpstreamConf,
326 sender: Option<Arc<NotificationSender>>,
327) -> Result<(SelectionLb, String, String)> {
328 if conf.addrs.is_empty() {
330 return Err(Error::Common {
331 category: "new_upstream".to_string(),
332 message: "upstream addrs is empty".to_string(),
333 });
334 }
335
336 let discovery_category = conf.guess_discovery();
338 if discovery_category == TRANSPARENT_DISCOVERY {
340 return Ok((SelectionLb::Transparent, "".to_string(), "".to_string()));
341 }
342
343 let mut hash = "".to_string();
344 let tls = conf
346 .sni
347 .as_ref()
348 .map(|item| !item.is_empty())
349 .unwrap_or_default();
350
351 let discovery = Discovery::new(conf.addrs.clone())
353 .with_ipv4_only(conf.ipv4_only.unwrap_or_default())
354 .with_tls(tls)
355 .with_sender(sender.clone());
356 let backends = new_backends(&discovery_category, &discovery)?;
357
358 let algo_method = conf.algo.clone().unwrap_or_default();
361 let algo_params: Vec<&str> = algo_method.split(':').collect();
362 let mut hash_key = "".to_string();
363
364 let lb = match algo_params[0] {
366 "hash" => {
368 if algo_params.len() > 1 {
370 hash = algo_params[1].to_string();
371 if algo_params.len() > 2 {
372 hash_key = algo_params[2].to_string();
373 }
374 }
375 let lb = update_health_check_params(
376 LoadBalancer::<Consistent>::from_backends(backends),
377 name,
378 conf,
379 sender,
380 )?;
381
382 SelectionLb::Consistent(Arc::new(lb))
383 },
384 _ => {
386 let lb = update_health_check_params(
387 LoadBalancer::<RoundRobin>::from_backends(backends),
388 name,
389 conf,
390 sender,
391 )?;
392
393 SelectionLb::RoundRobin(Arc::new(lb))
394 },
395 };
396 Ok((lb, hash, hash_key))
397}
398
399impl Upstream {
400 pub fn new(
409 name: &str,
410 conf: &UpstreamConf,
411 sender: Option<Arc<NotificationSender>>,
412 ) -> Result<Self> {
413 let (lb, hash, hash_key) = new_load_balancer(name, conf, sender)?;
414 let key = conf.hash_key();
415 let sni = conf.sni.clone().unwrap_or_default();
416 let tls = !sni.is_empty();
417
418 let alpn = if let Some(alpn) = &conf.alpn {
419 match alpn.to_uppercase().as_str() {
420 "H2H1" => ALPN::H2H1,
421 "H2" => ALPN::H2,
422 _ => ALPN::H1,
423 }
424 } else {
425 ALPN::H1
426 };
427
428 let tcp_keepalive = if conf.tcp_idle.is_some()
429 && conf.tcp_probe_count.is_some()
430 && conf.tcp_interval.is_some()
431 {
432 Some(TcpKeepalive {
433 idle: conf.tcp_idle.unwrap_or_default(),
434 count: conf.tcp_probe_count.unwrap_or_default(),
435 interval: conf.tcp_interval.unwrap_or_default(),
436 #[cfg(target_os = "linux")]
437 user_timeout: Duration::from_secs(0),
438 })
439 } else {
440 None
441 };
442
443 let peer_tracer = if conf.enable_tracer.unwrap_or_default() {
444 Some(UpstreamPeerTracer::new(name))
445 } else {
446 None
447 };
448 let tracer = peer_tracer
449 .as_ref()
450 .map(|peer_tracer| Tracer(Box::new(peer_tracer.to_owned())));
451 let up = Self {
452 name: name.to_string(),
453 key,
454 tls,
455 sni,
456 hash,
457 hash_key,
458 lb,
459 alpn,
460 connection_timeout: conf.connection_timeout,
461 total_connection_timeout: conf.total_connection_timeout,
462 read_timeout: conf.read_timeout,
463 idle_timeout: conf.idle_timeout,
464 write_timeout: conf.write_timeout,
465 verify_cert: conf.verify_cert,
466 tcp_recv_buf: conf.tcp_recv_buf.map(|item| item.as_u64() as usize),
467 tcp_keepalive,
468 tcp_fast_open: conf.tcp_fast_open,
469 peer_tracer,
470 tracer,
471 processing: AtomicI32::new(0),
472 };
473 debug!(
474 category = LOG_CATEGORY,
475 name = up.name,
476 "new upstream: {up:?}"
477 );
478 Ok(up)
479 }
480
481 #[inline]
495 pub fn new_http_peer(
496 &self,
497 session: &Session,
498 client_ip: &Option<String>,
499 ) -> Option<HttpPeer> {
500 let upstream = match &self.lb {
502 SelectionLb::RoundRobin(lb) => lb.select(b"", 256),
504 SelectionLb::Consistent(lb) => {
506 let value = get_hash_value(
507 &self.hash,
508 &self.hash_key,
509 session,
510 client_ip,
511 );
512 lb.select(value.as_bytes(), 256)
513 },
514 SelectionLb::Transparent => None,
516 };
517 self.processing.fetch_add(1, Ordering::Relaxed);
519
520 let p = if matches!(self.lb, SelectionLb::Transparent) {
522 let host = pingap_core::get_host(session.req_header())?;
524 let sni = if self.sni == "$host" {
526 host.to_string()
527 } else {
528 self.sni.clone()
529 };
530 let port = if self.tls { 443 } else { 80 };
532 Some(HttpPeer::new(format!("{host}:{port}"), self.tls, sni))
534 } else {
535 upstream.map(|upstream| {
537 HttpPeer::new(upstream, self.tls, self.sni.clone())
538 })
539 };
540
541 p.map(|mut p| {
543 p.options.connection_timeout = self.connection_timeout;
545 p.options.total_connection_timeout = self.total_connection_timeout;
546 p.options.read_timeout = self.read_timeout;
547 p.options.idle_timeout = self.idle_timeout;
548 p.options.write_timeout = self.write_timeout;
549 if let Some(verify_cert) = self.verify_cert {
551 p.options.verify_cert = verify_cert;
552 }
553 p.options.alpn = self.alpn.clone();
555 p.options.tcp_keepalive.clone_from(&self.tcp_keepalive);
557 p.options.tcp_recv_buf = self.tcp_recv_buf;
558 if let Some(tcp_fast_open) = self.tcp_fast_open {
559 p.options.tcp_fast_open = tcp_fast_open;
560 }
561 p.options.tracer.clone_from(&self.tracer);
563 p
564 })
565 }
566
567 #[inline]
572 pub fn connected(&self) -> Option<i32> {
573 self.peer_tracer
574 .as_ref()
575 .map(|tracer| tracer.connected.load(Ordering::Relaxed))
576 }
577
578 #[inline]
583 pub fn as_round_robin(&self) -> Option<Arc<LoadBalancer<RoundRobin>>> {
584 match &self.lb {
585 SelectionLb::RoundRobin(lb) => Some(lb.clone()),
586 _ => None,
587 }
588 }
589
590 #[inline]
595 pub fn as_consistent(&self) -> Option<Arc<LoadBalancer<Consistent>>> {
596 match &self.lb {
597 SelectionLb::Consistent(lb) => Some(lb.clone()),
598 _ => None,
599 }
600 }
601
602 #[inline]
607 pub fn completed(&self) -> i32 {
608 self.processing.fetch_add(-1, Ordering::Relaxed)
609 }
610}
611
612type Upstreams = AHashMap<String, Arc<Upstream>>;
613static UPSTREAM_MAP: Lazy<ArcSwap<Upstreams>> =
614 Lazy::new(|| ArcSwap::from_pointee(AHashMap::new()));
615
616pub fn get_upstream(name: &str) -> Option<Arc<Upstream>> {
617 if name.is_empty() {
618 return None;
619 }
620 UPSTREAM_MAP.load().get(name).cloned()
621}
622
623#[derive(Debug, Clone, Serialize, Deserialize)]
624pub struct UpstreamHealthyStatus {
625 pub healthy: u32,
626 pub total: u32,
627 pub unhealthy_backends: Vec<String>,
628}
629
630pub fn get_upstream_healthy_status() -> HashMap<String, UpstreamHealthyStatus> {
637 let mut healthy_status = HashMap::new();
638 UPSTREAM_MAP.load().iter().for_each(|(k, v)| {
639 let mut total = 0;
640 let mut healthy = 0;
641 let mut unhealthy_backends = vec![];
642 if let Some(lb) = v.as_round_robin() {
643 let backends = lb.backends().get_backend();
644 total = backends.len();
645 backends.iter().for_each(|backend| {
646 if lb.backends().ready(backend) {
647 healthy += 1;
648 } else {
649 unhealthy_backends.push(backend.to_string());
650 }
651 });
652 } else if let Some(lb) = v.as_consistent() {
653 let backends = lb.backends().get_backend();
654 total = backends.len();
655 backends.iter().for_each(|backend| {
656 if lb.backends().ready(backend) {
657 healthy += 1;
658 } else {
659 unhealthy_backends.push(backend.to_string());
660 }
661 });
662 }
663 healthy_status.insert(
664 k.to_string(),
665 UpstreamHealthyStatus {
666 healthy,
667 total: total as u32,
668 unhealthy_backends,
669 },
670 );
671 });
672 healthy_status
673}
674
675pub fn get_upstreams_processing_connected(
680) -> HashMap<String, (i32, Option<i32>)> {
681 let mut processing_connected = HashMap::new();
682 UPSTREAM_MAP.load().iter().for_each(|(k, v)| {
683 let count = v.processing.load(Ordering::Relaxed);
684 let connected = v.connected();
685 processing_connected.insert(k.to_string(), (count, connected));
686 });
687 processing_connected
688}
689
690fn new_ahash_upstreams(
691 upstream_configs: &HashMap<String, UpstreamConf>,
692 sender: Option<Arc<NotificationSender>>,
693) -> Result<(Upstreams, Vec<String>)> {
694 let mut upstreams = AHashMap::new();
695 let mut updated_upstreams = vec![];
696 for (name, conf) in upstream_configs.iter() {
697 let key = conf.hash_key();
698 if let Some(found) = get_upstream(name) {
699 if found.key == key {
701 upstreams.insert(name.to_string(), found);
702 continue;
703 }
704 }
705 let up = Arc::new(Upstream::new(name, conf, sender.clone())?);
706 upstreams.insert(name.to_string(), up);
707 updated_upstreams.push(name.to_string());
708 }
709 Ok((upstreams, updated_upstreams))
710}
711
712pub fn try_init_upstreams(
720 upstream_configs: &HashMap<String, UpstreamConf>,
721 sender: Option<Arc<NotificationSender>>,
722) -> Result<()> {
723 let (upstreams, _) = new_ahash_upstreams(upstream_configs, sender)?;
724 UPSTREAM_MAP.store(Arc::new(upstreams));
725 Ok(())
726}
727
728async fn run_health_check(up: &Arc<Upstream>) -> Result<()> {
729 if let Some(lb) = up.as_round_robin() {
730 lb.update().await.map_err(|e| Error::Common {
731 category: "run_health_check".to_string(),
732 message: e.to_string(),
733 })?;
734 lb.backends()
735 .run_health_check(lb.parallel_health_check)
736 .await;
737 } else if let Some(lb) = up.as_consistent() {
738 lb.update().await.map_err(|e| Error::Common {
739 category: "run_health_check".to_string(),
740 message: e.to_string(),
741 })?;
742 lb.backends()
743 .run_health_check(lb.parallel_health_check)
744 .await;
745 }
746 Ok(())
747}
748
749pub async fn try_update_upstreams(
750 upstream_configs: &HashMap<String, UpstreamConf>,
751 sender: Option<Arc<NotificationSender>>,
752) -> Result<Vec<String>> {
753 let (upstreams, updated_upstreams) =
754 new_ahash_upstreams(upstream_configs, sender)?;
755 for (name, up) in upstreams.iter() {
756 if !updated_upstreams.contains(name) {
758 continue;
759 }
760 if let Err(e) = run_health_check(up).await {
761 error!(
762 category = LOG_CATEGORY,
763 error = %e,
764 upstream = name,
765 "update upstream health check fail"
766 );
767 }
768 }
769 UPSTREAM_MAP.store(Arc::new(upstreams));
770 Ok(updated_upstreams)
771}
772
773#[async_trait]
774impl ServiceTask for HealthCheckTask {
775 async fn run(&self) -> Option<bool> {
776 let check_count = self.count.fetch_add(1, Ordering::Relaxed);
777 let upstreams = {
779 let mut upstreams = vec![];
780 for (name, up) in UPSTREAM_MAP.load().iter() {
781 if matches!(up.lb, SelectionLb::Transparent) {
783 continue;
784 }
785 upstreams.push((name.to_string(), up.clone()));
786 }
787 upstreams
788 };
789 let interval = self.interval.as_secs();
790 let jobs = upstreams.into_iter().map(|(name, up)| {
792 let runtime = pingora_runtime::current_handle();
793 runtime.spawn(async move {
794 let check_frequency_matched = |frequency: u64| -> bool {
795 let mut count = (frequency / interval) as u32;
796 if frequency % interval != 0 {
797 count += 1;
798 }
799 check_count % count == 0
800 };
801
802 let (update_frequency, health_check_frequency) =
805 if let Some(lb) = up.as_round_robin() {
806 let update_frequency =
807 lb.update_frequency.unwrap_or_default().as_secs();
808 let health_check_frequency = lb
809 .health_check_frequency
810 .unwrap_or_default()
811 .as_secs();
812 (update_frequency, health_check_frequency)
813 } else if let Some(lb) = up.as_consistent() {
814 let update_frequency =
815 lb.update_frequency.unwrap_or_default().as_secs();
816 let health_check_frequency = lb
817 .health_check_frequency
818 .unwrap_or_default()
819 .as_secs();
820 (update_frequency, health_check_frequency)
821 } else {
822 (0, 0)
823 };
824
825 if check_count == 0
828 || (update_frequency > 0
829 && check_frequency_matched(update_frequency))
830 {
831 let result = if let Some(lb) = up.as_round_robin() {
832 lb.update().await
833 } else if let Some(lb) = up.as_consistent() {
834 lb.update().await
835 } else {
836 Ok(())
837 };
838 if let Err(e) = result {
839 error!(
840 category = LOG_CATEGORY,
841 error = %e,
842 name,
843 "update backends fail"
844 )
845 } else {
846 debug!(
847 category = LOG_CATEGORY,
848 name, "update backend success"
849 );
850 }
851 }
852
853 if !check_frequency_matched(health_check_frequency) {
855 return;
856 }
857 let health_check_start_time = SystemTime::now();
858 if let Some(lb) = up.as_round_robin() {
859 lb.backends()
860 .run_health_check(lb.parallel_health_check)
861 .await;
862 } else if let Some(lb) = up.as_consistent() {
863 lb.backends()
864 .run_health_check(lb.parallel_health_check)
865 .await;
866 }
867 info!(
868 category = LOG_CATEGORY,
869 name,
870 elapsed = format!(
871 "{}ms",
872 health_check_start_time
873 .elapsed()
874 .unwrap_or_default()
875 .as_millis()
876 ),
877 "health check is done"
878 );
879 })
880 });
881 futures::future::join_all(jobs).await;
882
883 if check_count % 10 == 1 {
885 let current_unhealthy_upstreams =
886 self.unhealthy_upstreams.load().clone();
887 let mut notify_healthy_upstreams = vec![];
888 let mut unhealthy_upstreams = vec![];
889 for (name, status) in get_upstream_healthy_status().iter() {
890 if status.healthy == 0 {
891 unhealthy_upstreams.push(name.to_string());
892 } else if current_unhealthy_upstreams.contains(name) {
893 notify_healthy_upstreams.push(name.to_string());
894 }
895 }
896 let mut notify_unhealthy_upstreams = vec![];
897 for name in unhealthy_upstreams.iter() {
898 if !current_unhealthy_upstreams.contains(name) {
899 notify_unhealthy_upstreams.push(name.to_string());
900 }
901 }
902 self.unhealthy_upstreams
903 .store(Arc::new(unhealthy_upstreams));
904 if let Some(sender) = &self.sender {
905 if !notify_unhealthy_upstreams.is_empty() {
906 let data = NotificationData {
907 category: "upstream_status".to_string(),
908 title: "Upstream unhealthy".to_string(),
909 message: notify_unhealthy_upstreams.join(", "),
910 level: NotificationLevel::Error,
911 };
912 sender.notify(data).await;
913 }
914 if !notify_healthy_upstreams.is_empty() {
915 let data = NotificationData {
916 category: "upstream_status".to_string(),
917 title: "Upstream healthy".to_string(),
918 message: notify_healthy_upstreams.join(", "),
919 ..Default::default()
920 };
921 sender.notify(data).await;
922 }
923 }
924 }
925 None
926 }
927 fn description(&self) -> String {
928 let count = UPSTREAM_MAP.load().len();
929 format!("upstream health check, upstream count: {count}")
930 }
931}
932
933struct HealthCheckTask {
934 interval: Duration,
935 count: AtomicU32,
936 sender: Option<Arc<NotificationSender>>,
937 unhealthy_upstreams: ArcSwap<Vec<String>>,
938}
939
940pub fn new_upstream_health_check_task(
941 interval: Duration,
942 sender: Option<Arc<NotificationSender>>,
943) -> CommonServiceTask {
944 let interval = interval.max(Duration::from_secs(10));
945 CommonServiceTask::new(
946 interval,
947 HealthCheckTask {
948 interval,
949 count: AtomicU32::new(0),
950 sender,
951 unhealthy_upstreams: ArcSwap::new(Arc::new(vec![])),
952 },
953 )
954}
955
956#[cfg(test)]
957mod tests {
958 use super::{
959 get_hash_value, new_backends, Upstream, UpstreamConf,
960 UpstreamPeerTracer,
961 };
962 use pingap_discovery::Discovery;
963 use pingora::protocols::ALPN;
964 use pingora::proxy::Session;
965 use pingora::upstreams::peer::Tracing;
966 use pretty_assertions::assert_eq;
967 use std::sync::atomic::Ordering;
968 use std::time::Duration;
969 use tokio_test::io::Builder;
970
971 #[test]
972 fn test_new_backends() {
973 let _ = new_backends(
974 "",
975 &Discovery::new(vec![
976 "192.168.1.1:8001 10".to_string(),
977 "192.168.1.2:8001".to_string(),
978 ]),
979 )
980 .unwrap();
981
982 let _ = new_backends(
983 "",
984 &Discovery::new(vec![
985 "192.168.1.1".to_string(),
986 "192.168.1.2:8001".to_string(),
987 ]),
988 )
989 .unwrap();
990
991 let _ = new_backends(
992 "dns",
993 &Discovery::new(vec!["github.com".to_string()]),
994 )
995 .unwrap();
996 }
997 #[test]
998 fn test_new_upstream() {
999 let result = Upstream::new(
1000 "charts",
1001 &UpstreamConf {
1002 ..Default::default()
1003 },
1004 None,
1005 );
1006 assert_eq!(
1007 "Common error, category: new_upstream, upstream addrs is empty",
1008 result.err().unwrap().to_string()
1009 );
1010
1011 let up = Upstream::new(
1012 "charts",
1013 &UpstreamConf {
1014 addrs: vec!["192.168.1.1".to_string()],
1015 algo: Some("hash:cookie:user-id".to_string()),
1016 alpn: Some("h2".to_string()),
1017 connection_timeout: Some(Duration::from_secs(5)),
1018 total_connection_timeout: Some(Duration::from_secs(10)),
1019 read_timeout: Some(Duration::from_secs(3)),
1020 idle_timeout: Some(Duration::from_secs(30)),
1021 write_timeout: Some(Duration::from_secs(5)),
1022 tcp_idle: Some(Duration::from_secs(60)),
1023 tcp_probe_count: Some(100),
1024 tcp_interval: Some(Duration::from_secs(60)),
1025 tcp_recv_buf: Some(bytesize::ByteSize(1024)),
1026 ..Default::default()
1027 },
1028 None,
1029 )
1030 .unwrap();
1031
1032 assert_eq!("cookie", up.hash);
1033 assert_eq!("user-id", up.hash_key);
1034 assert_eq!(ALPN::H2.to_string(), up.alpn.to_string());
1035 assert_eq!("Some(5s)", format!("{:?}", up.connection_timeout));
1036 assert_eq!("Some(10s)", format!("{:?}", up.total_connection_timeout));
1037 assert_eq!("Some(3s)", format!("{:?}", up.read_timeout));
1038 assert_eq!("Some(30s)", format!("{:?}", up.idle_timeout));
1039 assert_eq!("Some(5s)", format!("{:?}", up.write_timeout));
1040 #[cfg(target_os = "linux")]
1041 assert_eq!(
1042 "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100, user_timeout: 0ns })",
1043 format!("{:?}", up.tcp_keepalive)
1044 );
1045 #[cfg(not(target_os = "linux"))]
1046 assert_eq!(
1047 "Some(TcpKeepalive { idle: 60s, interval: 60s, count: 100 })",
1048 format!("{:?}", up.tcp_keepalive)
1049 );
1050 assert_eq!("Some(1024)", format!("{:?}", up.tcp_recv_buf));
1051 }
1052 #[tokio::test]
1053 async fn test_get_hash_key_value() {
1054 let headers = [
1055 "Host: github.com",
1056 "Referer: https://github.com/",
1057 "User-Agent: pingap/0.1.1",
1058 "Cookie: deviceId=abc",
1059 "Accept: application/json",
1060 "X-Forwarded-For: 1.1.1.1",
1061 ]
1062 .join("\r\n");
1063 let input_header = format!(
1064 "GET /vicanso/pingap?id=1234 HTTP/1.1\r\n{headers}\r\n\r\n"
1065 );
1066 let mock_io = Builder::new().read(input_header.as_bytes()).build();
1067
1068 let mut session = Session::new_h1(Box::new(mock_io));
1069 session.read_request().await.unwrap();
1070
1071 assert_eq!(
1072 "/vicanso/pingap?id=1234",
1073 get_hash_value("url", "", &session, &None)
1074 );
1075
1076 assert_eq!("1.1.1.1", get_hash_value("ip", "", &session, &None));
1077 assert_eq!(
1078 "2.2.2.2",
1079 get_hash_value("ip", "", &session, &Some("2.2.2.2".to_string()))
1080 );
1081
1082 assert_eq!(
1083 "pingap/0.1.1",
1084 get_hash_value("header", "User-Agent", &session, &None)
1085 );
1086
1087 assert_eq!(
1088 "abc",
1089 get_hash_value("cookie", "deviceId", &session, &None)
1090 );
1091 assert_eq!("1234", get_hash_value("query", "id", &session, &None));
1092 assert_eq!(
1093 "/vicanso/pingap",
1094 get_hash_value("path", "", &session, &None)
1095 );
1096 }
1097 #[tokio::test]
1098 async fn test_upstream() {
1099 let headers = [
1100 "Host: github.com",
1101 "Referer: https://github.com/",
1102 "User-Agent: pingap/0.1.1",
1103 "Cookie: deviceId=abc",
1104 "Accept: application/json",
1105 ]
1106 .join("\r\n");
1107 let input_header =
1108 format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
1109 let mock_io = Builder::new().read(input_header.as_bytes()).build();
1110
1111 let mut session = Session::new_h1(Box::new(mock_io));
1112 session.read_request().await.unwrap();
1113 let up = Upstream::new(
1114 "upstreamname",
1115 &UpstreamConf {
1116 addrs: vec!["192.168.1.1:8001".to_string()],
1117 ..Default::default()
1118 },
1119 None,
1120 )
1121 .unwrap();
1122 assert_eq!(true, up.new_http_peer(&session, &None,).is_some());
1123 assert_eq!(true, up.as_round_robin().is_some());
1124 }
1125 #[test]
1126 fn test_upstream_peer_tracer() {
1127 let tracer = UpstreamPeerTracer::new("upstreamname");
1128 tracer.on_connected();
1129 assert_eq!(1, tracer.connected.load(Ordering::Relaxed));
1130 tracer.on_disconnected();
1131 assert_eq!(0, tracer.connected.load(Ordering::Relaxed));
1132 }
1133}