1use super::regex::RegexCapture;
16use ahash::AHashMap;
17use http::HeaderName;
18use http::HeaderValue;
19use pingap_config::Hashable;
20use pingap_config::LocationConf;
21use pingap_core::LocationInstance;
22use pingap_core::new_internal_error;
23use pingap_core::{HttpHeader, convert_headers};
24use pingora::http::RequestHeader;
25use regex::Regex;
26use snafu::{ResultExt, Snafu};
27use std::sync::Arc;
28use std::sync::LazyLock;
29use std::sync::atomic::{AtomicI32, AtomicU64, Ordering};
30use std::time::Duration;
31use tracing::{debug, error};
32
33const LOG_CATEGORY: &str = "location";
34
35pub type Locations = AHashMap<String, Arc<Location>>;
36
37#[derive(Debug, Snafu)]
39pub enum Error {
40 #[snafu(display("Invalid error {message}"))]
41 Invalid { message: String },
42 #[snafu(display("Regex value: {value}, {source}"))]
43 Regex { value: String, source: regex::Error },
44 #[snafu(display("Too Many Requests, max:{max}"))]
45 TooManyRequest { max: i32 },
46 #[snafu(display("Request Entity Too Large, max:{max}"))]
47 BodyTooLarge { max: usize },
48}
49type Result<T, E = Error> = std::result::Result<T, E>;
50
51pub struct LocationStats {
52 pub processing: i32,
53 pub accepted: u64,
54}
55
56#[derive(Debug)]
62enum PathSelector {
63 Regex(RegexCapture),
64 Prefix(String),
65 Equal(String),
66 Any,
67}
68impl PathSelector {
69 fn new(path: &str) -> Result<Self> {
83 let path = path.trim();
84 if path.is_empty() {
85 return Ok(PathSelector::Any);
86 }
87
88 if let Some(re_path) = path.strip_prefix('~') {
89 let re = RegexCapture::new(re_path.trim()).context(RegexSnafu {
90 value: re_path.trim(),
91 })?;
92 Ok(PathSelector::Regex(re))
93 } else if let Some(eq_path) = path.strip_prefix('=') {
94 Ok(PathSelector::Equal(eq_path.trim().to_string()))
95 } else {
96 Ok(PathSelector::Prefix(path.to_string()))
97 }
98 }
99 #[inline]
100 fn is_match(&self, path: &str) -> (bool, Option<AHashMap<String, String>>) {
101 match self {
102 PathSelector::Equal(value) => (value == path, None),
104 PathSelector::Regex(value) => value.captures(path),
106 PathSelector::Prefix(value) => (path.starts_with(value), None),
108 PathSelector::Any => (true, None),
110 }
111 }
112}
113
114#[derive(Debug)]
118enum HostSelector {
119 Regex(RegexCapture),
120 Equal(String),
121}
122impl HostSelector {
123 fn new(host: &str) -> Result<Self> {
136 let host = host.trim();
137 if let Some(re_host) = host.strip_prefix('~') {
138 let re = RegexCapture::new(re_host.trim()).context(RegexSnafu {
139 value: re_host.trim(),
140 })?;
141 Ok(HostSelector::Regex(re))
142 } else {
143 Ok(HostSelector::Equal(host.to_string()))
144 }
145 }
146 #[inline]
147 fn is_match(&self, host: &str) -> (bool, Option<AHashMap<String, String>>) {
148 match self {
149 HostSelector::Equal(value) => (value == host, None),
150 HostSelector::Regex(value) => value.captures(host),
151 }
152 }
153}
154
155static DEFAULT_PROXY_SET_HEADERS: LazyLock<Vec<HttpHeader>> =
161 LazyLock::new(|| {
162 convert_headers(&[
163 "x-real-ip:$remote_addr".to_string(),
164 "x-forwarded-for:$proxy_add_x_forwarded_for".to_string(),
165 "x-forwarded-proto:$scheme".to_string(),
166 "x-forwarded-host:$host".to_string(),
167 "x-forwarded-port:$server_port".to_string(),
168 ])
169 .expect("Failed to convert default proxy set headers")
170 });
171
172#[derive(Debug)]
176pub struct Location {
177 pub name: Arc<str>,
179
180 pub key: String,
182
183 upstream: String,
185
186 path: String,
188
189 path_selector: PathSelector,
191
192 hosts: Vec<HostSelector>,
195
196 reg_rewrite: Option<(Regex, String)>,
200
201 pub headers: Option<Vec<(HeaderName, HeaderValue, bool)>>,
203
204 pub plugins: Option<Vec<String>>,
214
215 accepted: AtomicU64,
218
219 processing: AtomicI32,
222
223 max_processing: i32,
226
227 grpc_web: bool,
230
231 client_max_body_size: usize,
234
235 pub max_retries: Option<u8>,
241
242 pub max_retry_window: Option<Duration>,
244}
245
246fn format_headers(
254 values: &Option<Vec<String>>,
255) -> Result<Option<Vec<HttpHeader>>> {
256 if let Some(header_values) = values {
257 let arr =
258 convert_headers(header_values).map_err(|err| Error::Invalid {
259 message: err.to_string(),
260 })?;
261 Ok(Some(arr))
262 } else {
263 Ok(None)
264 }
265}
266
267fn get_content_length(header: &RequestHeader) -> Option<usize> {
269 if let Some(content_length) =
270 header.headers.get(http::header::CONTENT_LENGTH)
271 {
272 if let Ok(size) =
273 content_length.to_str().unwrap_or_default().parse::<usize>()
274 {
275 return Some(size);
276 }
277 }
278 None
279}
280
281impl Location {
282 pub fn new(name: &str, conf: &LocationConf) -> Result<Location> {
285 if name.is_empty() {
286 return Err(Error::Invalid {
287 message: "Name is required".to_string(),
288 });
289 }
290 let key = conf.hash_key();
291 let upstream = conf.upstream.clone().unwrap_or_default();
292 let mut reg_rewrite = None;
293 if let Some(value) = &conf.rewrite {
295 let mut arr: Vec<&str> = value.split(' ').collect();
296 if arr.len() == 1 && arr[0].contains("$") {
297 arr.push(arr[0]);
298 arr[0] = ".*";
299 }
300
301 let value = if arr.len() == 2 { arr[1] } else { "" };
302 if let Ok(re) = Regex::new(arr[0]) {
303 reg_rewrite = Some((re, value.to_string()));
304 }
305 }
306
307 let hosts = conf
308 .host
309 .as_deref()
310 .unwrap_or("")
311 .split(',')
312 .map(str::trim)
313 .filter(|s| !s.is_empty())
314 .map(HostSelector::new)
315 .collect::<Result<Vec<_>>>()?;
316
317 let path = conf.path.clone().unwrap_or_default();
318 let mut headers: Vec<(HeaderName, HeaderValue, bool)> = vec![];
319 if conf.enable_reverse_proxy_headers.unwrap_or_default() {
320 for (name, value) in DEFAULT_PROXY_SET_HEADERS.iter() {
321 headers.push((name.clone(), value.clone(), false));
322 }
323 }
324 if let Some(proxy_set_headers) =
325 format_headers(&conf.proxy_set_headers)?
326 {
327 for (name, value) in proxy_set_headers.iter() {
328 headers.push((name.clone(), value.clone(), false));
329 }
330 }
331 if let Some(proxy_add_headers) =
332 format_headers(&conf.proxy_add_headers)?
333 {
334 for (name, value) in proxy_add_headers.iter() {
335 headers.push((name.clone(), value.clone(), true));
336 }
337 }
338
339 let location = Location {
340 name: name.into(),
341 key,
342 path_selector: PathSelector::new(&path)?,
343 path,
344 hosts,
345 upstream,
346 reg_rewrite,
347 plugins: conf.plugins.clone(),
348 accepted: AtomicU64::new(0),
349 processing: AtomicI32::new(0),
350 max_processing: conf.max_processing.unwrap_or_default(),
351 grpc_web: conf.grpc_web.unwrap_or_default(),
352 headers: if headers.is_empty() {
353 None
354 } else {
355 Some(headers)
356 },
357 client_max_body_size: conf
360 .client_max_body_size
361 .unwrap_or_default()
362 .as_u64() as usize,
363 max_retries: conf.max_retries,
367 max_retry_window: conf.max_retry_window,
368 };
369 debug!(
370 category = LOG_CATEGORY,
371 location = format!("{location:?}"),
372 "create a new location"
373 );
374
375 Ok(location)
376 }
377
378 #[inline]
381 pub fn support_grpc_web(&self) -> bool {
382 self.grpc_web
383 }
384
385 #[inline]
397 pub fn validate_content_length(
398 &self,
399 header: &RequestHeader,
400 ) -> Result<()> {
401 if self.client_max_body_size == 0 {
402 return Ok(());
403 }
404 if get_content_length(header).unwrap_or_default()
405 > self.client_max_body_size
406 {
407 return Err(Error::BodyTooLarge {
408 max: self.client_max_body_size,
409 });
410 }
411
412 Ok(())
413 }
414
415 #[inline]
420 pub fn match_host_path(
421 &self,
422 host: &str,
423 path: &str,
424 ) -> (bool, Option<AHashMap<String, String>>) {
425 let mut capture_values = None;
430 if !self.path.is_empty() {
431 let (matched, captures) = self.path_selector.is_match(path);
432 if !matched {
433 return (false, None);
434 }
435 capture_values = captures;
436 }
437
438 if self.hosts.is_empty() {
440 return (true, capture_values);
441 }
442
443 let matched = self.hosts.iter().any(|host_selector| {
444 let (matched, captures) = host_selector.is_match(host);
445 if let Some(captures) = captures {
446 if let Some(values) = capture_values.as_mut() {
447 values.extend(captures);
448 } else {
449 capture_values = Some(captures);
450 }
451 }
452 matched
453 });
454
455 (matched, capture_values)
456 }
457
458 pub fn stats(&self) -> LocationStats {
459 LocationStats {
460 processing: self.processing.load(Ordering::Relaxed),
461 accepted: self.accepted.load(Ordering::Relaxed),
462 }
463 }
464}
465
466impl LocationInstance for Location {
467 fn name(&self) -> &str {
468 self.name.as_ref()
469 }
470 fn headers(&self) -> Option<&Vec<(HeaderName, HeaderValue, bool)>> {
471 self.headers.as_ref()
472 }
473 fn client_body_size_limit(&self) -> usize {
474 self.client_max_body_size
475 }
476 fn upstream(&self) -> &str {
477 self.upstream.as_ref()
478 }
479 fn on_response(&self) {
480 self.processing.fetch_sub(1, Ordering::Relaxed);
481 }
482 fn on_request(&self) -> pingora::Result<(u64, i32)> {
498 let accepted = self.accepted.fetch_add(1, Ordering::Relaxed) + 1;
499 let processing = self.processing.fetch_add(1, Ordering::Relaxed) + 1;
500 if self.max_processing != 0 && processing > self.max_processing {
501 let err = Error::TooManyRequest {
502 max: self.max_processing,
503 };
504 return Err(new_internal_error(429, err));
505 }
506 Ok((accepted, processing))
507 }
508 #[inline]
533 fn rewrite(
534 &self,
535 header: &mut RequestHeader,
536 mut variables: Option<AHashMap<String, String>>,
537 ) -> (bool, Option<AHashMap<String, String>>) {
538 let Some((re, value)) = &self.reg_rewrite else {
539 return (false, variables);
540 };
541
542 let mut replace_value = value.to_string();
543
544 if let Some(vars) = &variables {
545 for (k, v) in vars.iter() {
546 replace_value = replace_value.replace(k, v);
547 }
548 }
549
550 let path = header.uri.path();
551
552 let mut new_path = if re.to_string() == ".*" {
553 replace_value
554 } else {
555 re.replace(path, replace_value).to_string()
556 };
557
558 if path == new_path {
559 return (false, variables);
560 }
561
562 if let Some(captures) = re.captures(path) {
563 for name in re.capture_names().flatten() {
564 if let Some(match_value) = captures.name(name) {
565 let values = variables.get_or_insert_with(AHashMap::new);
566 values.insert(
567 name.to_string(),
568 match_value.as_str().to_string(),
569 );
570 }
571 }
572 }
573
574 if let Some(query) = header.uri.query() {
576 new_path = format!("{new_path}?{query}");
577 }
578 debug!(category = LOG_CATEGORY, new_path, "rewrite path");
579
580 if let Err(e) =
582 new_path.parse::<http::Uri>().map(|uri| header.set_uri(uri))
583 {
584 error!(category = LOG_CATEGORY, error = %e, location = self.name.as_ref(), "new path parse fail");
585 }
586
587 (true, variables)
588 }
589}
590
591#[cfg(test)]
592mod tests {
593 use super::*;
594 use bytesize::ByteSize;
595 use pingap_config::LocationConf;
596 use pingora::http::RequestHeader;
597 use pingora::proxy::Session;
598 use pretty_assertions::assert_eq;
599 use tokio_test::io::Builder;
600
601 #[test]
602 fn test_format_headers() {
603 let headers = format_headers(&Some(vec![
604 "Content-Type: application/json".to_string(),
605 ]))
606 .unwrap();
607 assert_eq!(
608 r###"Some([("content-type", "application/json")])"###,
609 format!("{headers:?}")
610 );
611 }
612 #[test]
613 fn test_new_path_selector() {
614 let selector = PathSelector::new("").unwrap();
615 assert_eq!(true, matches!(selector, PathSelector::Any));
616
617 let selector = PathSelector::new("~/api").unwrap();
618 assert_eq!(true, matches!(selector, PathSelector::Regex(_)));
619
620 let selector = PathSelector::new("=/api").unwrap();
621 assert_eq!(true, matches!(selector, PathSelector::Equal(_)));
622
623 let selector = PathSelector::new("/api").unwrap();
624 assert_eq!(true, matches!(selector, PathSelector::Prefix(_)));
625 }
626 #[test]
627 fn test_path_host_select_location() {
628 let upstream_name = "charts";
629
630 let lo = Location::new(
632 "lo",
633 &LocationConf {
634 upstream: Some(upstream_name.to_string()),
635 ..Default::default()
636 },
637 )
638 .unwrap();
639 assert_eq!(true, lo.match_host_path("pingap", "/api").0);
640 assert_eq!(true, lo.match_host_path("", "").0);
641
642 let lo = Location::new(
644 "lo",
645 &LocationConf {
646 upstream: Some(upstream_name.to_string()),
647 host: Some("test.com,pingap".to_string()),
648 ..Default::default()
649 },
650 )
651 .unwrap();
652 assert_eq!(true, lo.match_host_path("pingap", "/api").0);
653 assert_eq!(true, lo.match_host_path("pingap", "").0);
654 assert_eq!(false, lo.match_host_path("", "/api").0);
655
656 let lo = Location::new(
658 "lo",
659 &LocationConf {
660 upstream: Some(upstream_name.to_string()),
661 path: Some("~/users".to_string()),
662 ..Default::default()
663 },
664 )
665 .unwrap();
666 assert_eq!(true, lo.match_host_path("", "/api/users").0);
667 assert_eq!(true, lo.match_host_path("", "/users").0);
668 assert_eq!(false, lo.match_host_path("", "/api").0);
669
670 let lo = Location::new(
672 "lo",
673 &LocationConf {
674 upstream: Some(upstream_name.to_string()),
675 path: Some("~^/api".to_string()),
676 ..Default::default()
677 },
678 )
679 .unwrap();
680 assert_eq!(true, lo.match_host_path("", "/api/users").0);
681 assert_eq!(false, lo.match_host_path("", "/users").0);
682 assert_eq!(true, lo.match_host_path("", "/api").0);
683
684 let lo = Location::new(
686 "lo",
687 &LocationConf {
688 upstream: Some(upstream_name.to_string()),
689 path: Some("/api".to_string()),
690 ..Default::default()
691 },
692 )
693 .unwrap();
694 assert_eq!(true, lo.match_host_path("", "/api/users").0);
695 assert_eq!(false, lo.match_host_path("", "/users").0);
696 assert_eq!(true, lo.match_host_path("", "/api").0);
697
698 let lo = Location::new(
700 "lo",
701 &LocationConf {
702 upstream: Some(upstream_name.to_string()),
703 path: Some("=/api".to_string()),
704 ..Default::default()
705 },
706 )
707 .unwrap();
708 assert_eq!(false, lo.match_host_path("", "/api/users").0);
709 assert_eq!(false, lo.match_host_path("", "/users").0);
710 assert_eq!(true, lo.match_host_path("", "/api").0);
711 }
712
713 #[test]
714 fn test_match_host_path_variables() {
715 let lo = Location::new(
716 "lo",
717 &LocationConf {
718 upstream: Some("charts".to_string()),
719 host: Some("~(?<name>.+).npmtrend.com".to_string()),
720 path: Some("~/(?<route>.+)/(.*)".to_string()),
721 ..Default::default()
722 },
723 )
724 .unwrap();
725 let (matched, variables) =
726 lo.match_host_path("charts.npmtrend.com", "/users/123");
727 assert_eq!(true, matched);
728 let variables = variables.unwrap();
729 assert_eq!("users", variables.get("route").unwrap());
730 assert_eq!("charts", variables.get("name").unwrap());
731 }
732
733 #[test]
734 fn test_rewrite_path() {
735 let upstream_name = "charts";
736
737 let lo = Location::new(
738 "lo",
739 &LocationConf {
740 upstream: Some(upstream_name.to_string()),
741 rewrite: Some("^/users/(?<upstream>.*?)/(.*)$ /$2".to_string()),
742 ..Default::default()
743 },
744 )
745 .unwrap();
746 let mut req_header =
747 RequestHeader::build("GET", b"/users/rest/me?abc=1", None).unwrap();
748 let (matched, variables) = lo.rewrite(&mut req_header, None);
749 assert_eq!(true, matched);
750 assert_eq!(r#"Some({"upstream": "rest"})"#, format!("{:?}", variables));
751 assert_eq!("/me?abc=1", req_header.uri.to_string());
752
753 let mut req_header =
754 RequestHeader::build("GET", b"/api/me?abc=1", None).unwrap();
755 let (matched, variables) = lo.rewrite(&mut req_header, None);
756 assert_eq!(false, matched);
757 assert_eq!(None, variables);
758 assert_eq!("/api/me?abc=1", req_header.uri.to_string());
759 }
760
761 #[tokio::test]
762 async fn test_get_content_length() {
763 let headers = ["Content-Length: 123"].join("\r\n");
764 let input_header =
765 format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
766 let mock_io = Builder::new().read(input_header.as_bytes()).build();
767 let mut session = Session::new_h1(Box::new(mock_io));
768 session.read_request().await.unwrap();
769 assert_eq!(get_content_length(session.req_header()), Some(123));
770 }
771
772 #[test]
773 fn test_validate_content_length() {
774 let lo = Location::new(
775 "lo",
776 &LocationConf {
777 client_max_body_size: Some(ByteSize(10)),
778 ..Default::default()
779 },
780 )
781 .unwrap();
782 let mut req_header =
783 RequestHeader::build("GET", b"/users/me?abc=1", None).unwrap();
784 assert_eq!(true, lo.validate_content_length(&req_header).is_ok());
785
786 req_header
787 .append_header(
788 http::header::CONTENT_LENGTH,
789 http::HeaderValue::from_str("20").unwrap(),
790 )
791 .unwrap();
792 assert_eq!(
793 "Request Entity Too Large, max:10",
794 lo.validate_content_length(&req_header)
795 .err()
796 .unwrap()
797 .to_string()
798 );
799 }
800}