1use super::regex::RegexCapture;
16use ahash::AHashMap;
17use http::HeaderName;
18use http::HeaderValue;
19use pingap_config::Hashable;
20use pingap_config::LocationConf;
21use pingap_core::LocationInstance;
22use pingap_core::new_internal_error;
23use pingap_core::{HttpHeader, convert_headers};
24use pingora::http::RequestHeader;
25use regex::Regex;
26use snafu::{ResultExt, Snafu};
27use std::sync::Arc;
28use std::sync::LazyLock;
29use std::sync::atomic::{AtomicI32, AtomicU64, Ordering};
30use std::time::Duration;
31use tracing::{debug, error};
32
33const LOG_CATEGORY: &str = "location";
34
35pub type Locations = AHashMap<String, Arc<Location>>;
36
37#[derive(Debug, Snafu)]
39pub enum Error {
40 #[snafu(display("Invalid error {message}"))]
41 Invalid { message: String },
42 #[snafu(display("Regex value: {value}, {source}"))]
43 Regex { value: String, source: regex::Error },
44 #[snafu(display("Too Many Requests, max:{max}"))]
45 TooManyRequest { max: i32 },
46 #[snafu(display("Request Entity Too Large, max:{max}"))]
47 BodyTooLarge { max: usize },
48}
49type Result<T, E = Error> = std::result::Result<T, E>;
50
51pub struct LocationStats {
52 pub processing: i32,
53 pub accepted: u64,
54}
55
56#[derive(Debug)]
62enum PathSelector {
63 Regex(RegexCapture),
64 Prefix(String),
65 Equal(String),
66 Any,
67}
68impl PathSelector {
69 fn new(path: &str) -> Result<Self> {
83 let path = path.trim();
84 if path.is_empty() {
85 return Ok(PathSelector::Any);
86 }
87
88 if let Some(re_path) = path.strip_prefix('~') {
89 let re = RegexCapture::new(re_path.trim()).context(RegexSnafu {
90 value: re_path.trim(),
91 })?;
92 Ok(PathSelector::Regex(re))
93 } else if let Some(eq_path) = path.strip_prefix('=') {
94 Ok(PathSelector::Equal(eq_path.trim().to_string()))
95 } else {
96 Ok(PathSelector::Prefix(path.to_string()))
97 }
98 }
99 #[inline]
100 fn is_match(&self, path: &str) -> (bool, Option<AHashMap<String, String>>) {
101 match self {
102 PathSelector::Equal(value) => (value == path, None),
104 PathSelector::Regex(value) => value.captures(path),
106 PathSelector::Prefix(value) => (path.starts_with(value), None),
108 PathSelector::Any => (true, None),
110 }
111 }
112}
113
114#[derive(Debug)]
118enum HostSelector {
119 Regex(RegexCapture),
120 Equal(String),
121}
122impl HostSelector {
123 fn new(host: &str) -> Result<Self> {
136 let host = host.trim();
137 if let Some(re_host) = host.strip_prefix('~') {
138 let re = RegexCapture::new(re_host.trim()).context(RegexSnafu {
139 value: re_host.trim(),
140 })?;
141 Ok(HostSelector::Regex(re))
142 } else {
143 Ok(HostSelector::Equal(host.to_string()))
144 }
145 }
146 #[inline]
147 fn is_match(&self, host: &str) -> (bool, Option<AHashMap<String, String>>) {
148 match self {
149 HostSelector::Equal(value) => (value == host, None),
150 HostSelector::Regex(value) => value.captures(host),
151 }
152 }
153}
154
155static DEFAULT_PROXY_SET_HEADERS: LazyLock<Vec<HttpHeader>> =
161 LazyLock::new(|| {
162 convert_headers(&[
163 "x-real-ip:$remote_addr".to_string(),
164 "x-forwarded-for:$proxy_add_x_forwarded_for".to_string(),
165 "x-forwarded-proto:$scheme".to_string(),
166 "x-forwarded-host:$host".to_string(),
167 "x-forwarded-port:$server_port".to_string(),
168 ])
169 .expect("Failed to convert default proxy set headers")
170 });
171
172#[derive(Debug)]
176pub struct Location {
177 pub name: Arc<str>,
179
180 pub key: String,
182
183 upstream: String,
185
186 path: String,
188
189 path_selector: PathSelector,
191
192 hosts: Vec<HostSelector>,
195
196 reg_rewrite: Option<(Regex, String)>,
200
201 pub headers: Option<Vec<(HeaderName, HeaderValue, bool)>>,
203
204 pub plugins: Option<Vec<String>>,
214
215 accepted: AtomicU64,
218
219 processing: AtomicI32,
222
223 max_processing: i32,
226
227 grpc_web: bool,
230
231 client_max_body_size: usize,
234
235 pub max_retries: Option<u8>,
241
242 pub max_retry_window: Option<Duration>,
244}
245
246fn format_headers(
254 values: &Option<Vec<String>>,
255) -> Result<Option<Vec<HttpHeader>>> {
256 if let Some(header_values) = values {
257 let arr =
258 convert_headers(header_values).map_err(|err| Error::Invalid {
259 message: err.to_string(),
260 })?;
261 Ok(Some(arr))
262 } else {
263 Ok(None)
264 }
265}
266
267fn get_content_length(header: &RequestHeader) -> Option<usize> {
269 if let Some(content_length) =
270 header.headers.get(http::header::CONTENT_LENGTH)
271 && let Ok(size) =
272 content_length.to_str().unwrap_or_default().parse::<usize>()
273 {
274 return Some(size);
275 }
276 None
277}
278
279impl Location {
280 pub fn new(name: &str, conf: &LocationConf) -> Result<Location> {
283 if name.is_empty() {
284 return Err(Error::Invalid {
285 message: "Name is required".to_string(),
286 });
287 }
288 let key = conf.hash_key();
289 let upstream = conf.upstream.clone().unwrap_or_default();
290 let mut reg_rewrite = None;
291 if let Some(value) = &conf.rewrite {
293 let mut arr: Vec<&str> = value.split(' ').collect();
294 if arr.len() == 1 && arr[0].contains("$") {
295 arr.push(arr[0]);
296 arr[0] = ".*";
297 }
298
299 let value = if arr.len() == 2 { arr[1] } else { "" };
300 if let Ok(re) = Regex::new(arr[0]) {
301 reg_rewrite = Some((re, value.to_string()));
302 }
303 }
304
305 let hosts = conf
306 .host
307 .as_deref()
308 .unwrap_or("")
309 .split(',')
310 .map(str::trim)
311 .filter(|s| !s.is_empty())
312 .map(HostSelector::new)
313 .collect::<Result<Vec<_>>>()?;
314
315 let path = conf.path.clone().unwrap_or_default();
316 let mut headers: Vec<(HeaderName, HeaderValue, bool)> = vec![];
317 if conf.enable_reverse_proxy_headers.unwrap_or_default() {
318 for (name, value) in DEFAULT_PROXY_SET_HEADERS.iter() {
319 headers.push((name.clone(), value.clone(), false));
320 }
321 }
322 if let Some(proxy_set_headers) =
323 format_headers(&conf.proxy_set_headers)?
324 {
325 for (name, value) in proxy_set_headers.iter() {
326 headers.push((name.clone(), value.clone(), false));
327 }
328 }
329 if let Some(proxy_add_headers) =
330 format_headers(&conf.proxy_add_headers)?
331 {
332 for (name, value) in proxy_add_headers.iter() {
333 headers.push((name.clone(), value.clone(), true));
334 }
335 }
336
337 let location = Location {
338 name: name.into(),
339 key,
340 path_selector: PathSelector::new(&path)?,
341 path,
342 hosts,
343 upstream,
344 reg_rewrite,
345 plugins: conf.plugins.clone(),
346 accepted: AtomicU64::new(0),
347 processing: AtomicI32::new(0),
348 max_processing: conf.max_processing.unwrap_or_default(),
349 grpc_web: conf.grpc_web.unwrap_or_default(),
350 headers: if headers.is_empty() {
351 None
352 } else {
353 Some(headers)
354 },
355 client_max_body_size: conf
358 .client_max_body_size
359 .unwrap_or_default()
360 .as_u64() as usize,
361 max_retries: conf.max_retries,
365 max_retry_window: conf.max_retry_window,
366 };
367 debug!(
368 category = LOG_CATEGORY,
369 location = format!("{location:?}"),
370 "create a new location"
371 );
372
373 Ok(location)
374 }
375
376 #[inline]
379 pub fn support_grpc_web(&self) -> bool {
380 self.grpc_web
381 }
382
383 #[inline]
395 pub fn validate_content_length(
396 &self,
397 header: &RequestHeader,
398 ) -> Result<()> {
399 if self.client_max_body_size == 0 {
400 return Ok(());
401 }
402 if get_content_length(header).unwrap_or_default()
403 > self.client_max_body_size
404 {
405 return Err(Error::BodyTooLarge {
406 max: self.client_max_body_size,
407 });
408 }
409
410 Ok(())
411 }
412
413 #[inline]
418 pub fn match_host_path(
419 &self,
420 host: &str,
421 path: &str,
422 ) -> (bool, Option<AHashMap<String, String>>) {
423 let mut capture_values = None;
428 if !self.path.is_empty() {
429 let (matched, captures) = self.path_selector.is_match(path);
430 if !matched {
431 return (false, None);
432 }
433 capture_values = captures;
434 }
435
436 if self.hosts.is_empty() {
438 return (true, capture_values);
439 }
440
441 let matched = self.hosts.iter().any(|host_selector| {
442 let (matched, captures) = host_selector.is_match(host);
443 if let Some(captures) = captures {
444 if let Some(values) = capture_values.as_mut() {
445 values.extend(captures);
446 } else {
447 capture_values = Some(captures);
448 }
449 }
450 matched
451 });
452
453 (matched, capture_values)
454 }
455
456 pub fn stats(&self) -> LocationStats {
457 LocationStats {
458 processing: self.processing.load(Ordering::Relaxed),
459 accepted: self.accepted.load(Ordering::Relaxed),
460 }
461 }
462}
463
464impl LocationInstance for Location {
465 fn name(&self) -> &str {
466 self.name.as_ref()
467 }
468 fn headers(&self) -> Option<&Vec<(HeaderName, HeaderValue, bool)>> {
469 self.headers.as_ref()
470 }
471 fn client_body_size_limit(&self) -> usize {
472 self.client_max_body_size
473 }
474 fn upstream(&self) -> &str {
475 self.upstream.as_ref()
476 }
477 fn on_response(&self) {
478 self.processing.fetch_sub(1, Ordering::Relaxed);
479 }
480 fn on_request(&self) -> pingora::Result<(u64, i32)> {
496 let accepted = self.accepted.fetch_add(1, Ordering::Relaxed) + 1;
497 let processing = self.processing.fetch_add(1, Ordering::Relaxed) + 1;
498 if self.max_processing != 0 && processing > self.max_processing {
499 let err = Error::TooManyRequest {
500 max: self.max_processing,
501 };
502 return Err(new_internal_error(429, err));
503 }
504 Ok((accepted, processing))
505 }
506 #[inline]
531 fn rewrite(
532 &self,
533 header: &mut RequestHeader,
534 mut variables: Option<AHashMap<String, String>>,
535 ) -> (bool, Option<AHashMap<String, String>>) {
536 let Some((re, value)) = &self.reg_rewrite else {
537 return (false, variables);
538 };
539
540 let mut replace_value = value.to_string();
541
542 if let Some(vars) = &variables {
543 for (k, v) in vars.iter() {
544 replace_value = replace_value.replace(k, v);
545 }
546 }
547
548 let path = header.uri.path();
549
550 let mut new_path = if re.to_string() == ".*" {
551 replace_value
552 } else {
553 re.replace(path, replace_value).to_string()
554 };
555
556 if path == new_path {
557 return (false, variables);
558 }
559
560 if let Some(captures) = re.captures(path) {
561 for name in re.capture_names().flatten() {
562 if let Some(match_value) = captures.name(name) {
563 let values = variables.get_or_insert_with(AHashMap::new);
564 values.insert(
565 name.to_string(),
566 match_value.as_str().to_string(),
567 );
568 }
569 }
570 }
571
572 if let Some(query) = header.uri.query() {
574 new_path = format!("{new_path}?{query}");
575 }
576 debug!(category = LOG_CATEGORY, new_path, "rewrite path");
577
578 if let Err(e) =
580 new_path.parse::<http::Uri>().map(|uri| header.set_uri(uri))
581 {
582 error!(category = LOG_CATEGORY, error = %e, location = self.name.as_ref(), "new path parse fail");
583 }
584
585 (true, variables)
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592 use bytesize::ByteSize;
593 use pingap_config::LocationConf;
594 use pingora::http::RequestHeader;
595 use pingora::proxy::Session;
596 use pretty_assertions::assert_eq;
597 use tokio_test::io::Builder;
598
599 #[test]
600 fn test_format_headers() {
601 let headers = format_headers(&Some(vec![
602 "Content-Type: application/json".to_string(),
603 ]))
604 .unwrap();
605 assert_eq!(
606 r###"Some([("content-type", "application/json")])"###,
607 format!("{headers:?}")
608 );
609 }
610 #[test]
611 fn test_new_path_selector() {
612 let selector = PathSelector::new("").unwrap();
613 assert_eq!(true, matches!(selector, PathSelector::Any));
614
615 let selector = PathSelector::new("~/api").unwrap();
616 assert_eq!(true, matches!(selector, PathSelector::Regex(_)));
617
618 let selector = PathSelector::new("=/api").unwrap();
619 assert_eq!(true, matches!(selector, PathSelector::Equal(_)));
620
621 let selector = PathSelector::new("/api").unwrap();
622 assert_eq!(true, matches!(selector, PathSelector::Prefix(_)));
623 }
624 #[test]
625 fn test_path_host_select_location() {
626 let upstream_name = "charts";
627
628 let lo = Location::new(
630 "lo",
631 &LocationConf {
632 upstream: Some(upstream_name.to_string()),
633 ..Default::default()
634 },
635 )
636 .unwrap();
637 assert_eq!(true, lo.match_host_path("pingap", "/api").0);
638 assert_eq!(true, lo.match_host_path("", "").0);
639
640 let lo = Location::new(
642 "lo",
643 &LocationConf {
644 upstream: Some(upstream_name.to_string()),
645 host: Some("test.com,pingap".to_string()),
646 ..Default::default()
647 },
648 )
649 .unwrap();
650 assert_eq!(true, lo.match_host_path("pingap", "/api").0);
651 assert_eq!(true, lo.match_host_path("pingap", "").0);
652 assert_eq!(false, lo.match_host_path("", "/api").0);
653
654 let lo = Location::new(
656 "lo",
657 &LocationConf {
658 upstream: Some(upstream_name.to_string()),
659 path: Some("~/users".to_string()),
660 ..Default::default()
661 },
662 )
663 .unwrap();
664 assert_eq!(true, lo.match_host_path("", "/api/users").0);
665 assert_eq!(true, lo.match_host_path("", "/users").0);
666 assert_eq!(false, lo.match_host_path("", "/api").0);
667
668 let lo = Location::new(
670 "lo",
671 &LocationConf {
672 upstream: Some(upstream_name.to_string()),
673 path: Some("~^/api".to_string()),
674 ..Default::default()
675 },
676 )
677 .unwrap();
678 assert_eq!(true, lo.match_host_path("", "/api/users").0);
679 assert_eq!(false, lo.match_host_path("", "/users").0);
680 assert_eq!(true, lo.match_host_path("", "/api").0);
681
682 let lo = Location::new(
684 "lo",
685 &LocationConf {
686 upstream: Some(upstream_name.to_string()),
687 path: Some("/api".to_string()),
688 ..Default::default()
689 },
690 )
691 .unwrap();
692 assert_eq!(true, lo.match_host_path("", "/api/users").0);
693 assert_eq!(false, lo.match_host_path("", "/users").0);
694 assert_eq!(true, lo.match_host_path("", "/api").0);
695
696 let lo = Location::new(
698 "lo",
699 &LocationConf {
700 upstream: Some(upstream_name.to_string()),
701 path: Some("=/api".to_string()),
702 ..Default::default()
703 },
704 )
705 .unwrap();
706 assert_eq!(false, lo.match_host_path("", "/api/users").0);
707 assert_eq!(false, lo.match_host_path("", "/users").0);
708 assert_eq!(true, lo.match_host_path("", "/api").0);
709 }
710
711 #[test]
712 fn test_match_host_path_variables() {
713 let lo = Location::new(
714 "lo",
715 &LocationConf {
716 upstream: Some("charts".to_string()),
717 host: Some("~(?<name>.+).npmtrend.com".to_string()),
718 path: Some("~/(?<route>.+)/(.*)".to_string()),
719 ..Default::default()
720 },
721 )
722 .unwrap();
723 let (matched, variables) =
724 lo.match_host_path("charts.npmtrend.com", "/users/123");
725 assert_eq!(true, matched);
726 let variables = variables.unwrap();
727 assert_eq!("users", variables.get("route").unwrap());
728 assert_eq!("charts", variables.get("name").unwrap());
729 }
730
731 #[test]
732 fn test_rewrite_path() {
733 let upstream_name = "charts";
734
735 let lo = Location::new(
736 "lo",
737 &LocationConf {
738 upstream: Some(upstream_name.to_string()),
739 rewrite: Some("^/users/(?<upstream>.*?)/(.*)$ /$2".to_string()),
740 ..Default::default()
741 },
742 )
743 .unwrap();
744 let mut req_header =
745 RequestHeader::build("GET", b"/users/rest/me?abc=1", None).unwrap();
746 let (matched, variables) = lo.rewrite(&mut req_header, None);
747 assert_eq!(true, matched);
748 assert_eq!(r#"Some({"upstream": "rest"})"#, format!("{:?}", variables));
749 assert_eq!("/me?abc=1", req_header.uri.to_string());
750
751 let mut req_header =
752 RequestHeader::build("GET", b"/api/me?abc=1", None).unwrap();
753 let (matched, variables) = lo.rewrite(&mut req_header, None);
754 assert_eq!(false, matched);
755 assert_eq!(None, variables);
756 assert_eq!("/api/me?abc=1", req_header.uri.to_string());
757 }
758
759 #[tokio::test]
760 async fn test_get_content_length() {
761 let headers = ["Content-Length: 123"].join("\r\n");
762 let input_header =
763 format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
764 let mock_io = Builder::new().read(input_header.as_bytes()).build();
765 let mut session = Session::new_h1(Box::new(mock_io));
766 session.read_request().await.unwrap();
767 assert_eq!(get_content_length(session.req_header()), Some(123));
768 }
769
770 #[test]
771 fn test_validate_content_length() {
772 let lo = Location::new(
773 "lo",
774 &LocationConf {
775 client_max_body_size: Some(ByteSize(10)),
776 ..Default::default()
777 },
778 )
779 .unwrap();
780 let mut req_header =
781 RequestHeader::build("GET", b"/users/me?abc=1", None).unwrap();
782 assert_eq!(true, lo.validate_content_length(&req_header).is_ok());
783
784 req_header
785 .append_header(
786 http::header::CONTENT_LENGTH,
787 http::HeaderValue::from_str("20").unwrap(),
788 )
789 .unwrap();
790 assert_eq!(
791 "Request Entity Too Large, max:10",
792 lo.validate_content_length(&req_header)
793 .err()
794 .unwrap()
795 .to_string()
796 );
797 }
798}