1use super::regex::RegexCapture;
16use ahash::AHashMap;
17use arc_swap::ArcSwap;
18use once_cell::sync::Lazy;
19use pingap_config::LocationConf;
20use pingap_core::{convert_headers, HttpHeader};
21use pingora::http::RequestHeader;
22use regex::Regex;
23use snafu::{ResultExt, Snafu};
24use std::collections::HashMap;
25use std::sync::atomic::{AtomicI32, AtomicU64, Ordering};
26use std::sync::Arc;
27use substring::Substring;
28use tracing::{debug, error};
29
30const LOG_CATEGORY: &str = "location";
31
32#[derive(Debug, Snafu)]
34pub enum Error {
35 #[snafu(display("Invalid error {message}"))]
36 Invalid { message: String },
37 #[snafu(display("Regex value: {value}, {source}"))]
38 Regex { value: String, source: regex::Error },
39 #[snafu(display("Too Many Requests, max:{max}"))]
40 TooManyRequest { max: i32 },
41 #[snafu(display("Request Entity Too Large, max:{max}"))]
42 BodyTooLarge { max: usize },
43}
44type Result<T, E = Error> = std::result::Result<T, E>;
45
46#[derive(Debug)]
47struct RegexPath {
48 value: RegexCapture,
49}
50
51#[derive(Debug)]
52struct PrefixPath {
53 value: String,
54}
55
56#[derive(Debug)]
57struct EqualPath {
58 value: String,
59}
60
61#[derive(Debug)]
67enum PathSelector {
68 RegexPath(RegexPath),
69 PrefixPath(PrefixPath),
70 EqualPath(EqualPath),
71 Empty,
72}
73fn new_path_selector(path: &str) -> Result<PathSelector> {
87 let path = path.trim();
88 if path.is_empty() {
89 return Ok(PathSelector::Empty);
90 }
91 let first = path.chars().next().unwrap_or_default();
92 let last = path.substring(1, path.len()).trim();
93 let se = match first {
94 '~' => {
95 let re = RegexCapture::new(last).context(RegexSnafu {
96 value: last.to_string(),
97 })?;
98 PathSelector::RegexPath(RegexPath { value: re })
99 },
100 '=' => PathSelector::EqualPath(EqualPath {
101 value: last.to_string(),
102 }),
103 _ => {
104 PathSelector::PrefixPath(PrefixPath {
106 value: path.to_string(),
107 })
108 },
109 };
110
111 Ok(se)
112}
113
114#[derive(Debug)]
115struct RegexHost {
116 value: RegexCapture,
117}
118
119#[derive(Debug)]
120struct EqualHost {
121 value: String,
122}
123
124#[derive(Debug)]
128enum HostSelector {
129 RegexHost(RegexHost),
130 EqualHost(EqualHost),
131}
132
133fn new_host_selector(host: &str) -> Result<HostSelector> {
146 let host = host.trim();
147 if host.is_empty() {
148 return Ok(HostSelector::EqualHost(EqualHost {
149 value: host.to_string(),
150 }));
151 }
152 let first = host.chars().next().unwrap_or_default();
153 let last = host.substring(1, host.len()).trim();
154 let se = match first {
155 '~' => {
156 let re = RegexCapture::new(last).context(RegexSnafu {
157 value: last.to_string(),
158 })?;
159 HostSelector::RegexHost(RegexHost { value: re })
160 },
161 _ => {
162 HostSelector::EqualHost(EqualHost {
164 value: host.to_string(),
165 })
166 },
167 };
168
169 Ok(se)
170}
171
172#[derive(Debug)]
176pub struct Location {
177 pub name: String,
179
180 pub key: String,
182
183 pub upstream: String,
185
186 path: String,
188
189 path_selector: PathSelector,
191
192 hosts: Vec<HostSelector>,
195
196 reg_rewrite: Option<(Regex, String)>,
200
201 pub proxy_add_headers: Option<Vec<HttpHeader>>,
204
205 pub proxy_set_headers: Option<Vec<HttpHeader>>,
208
209 pub plugins: Option<Vec<String>>,
211
212 accepted: AtomicU64,
215
216 processing: AtomicI32,
219
220 max_processing: i32,
223
224 grpc_web: bool,
227
228 client_max_body_size: usize,
231
232 pub enable_reverse_proxy_headers: bool,
235}
236
237fn format_headers(
245 values: &Option<Vec<String>>,
246) -> Result<Option<Vec<HttpHeader>>> {
247 if let Some(header_values) = values {
248 let arr =
249 convert_headers(header_values).map_err(|err| Error::Invalid {
250 message: err.to_string(),
251 })?;
252 Ok(Some(arr))
253 } else {
254 Ok(None)
255 }
256}
257
258fn get_content_length(header: &RequestHeader) -> Option<usize> {
260 if let Some(content_length) =
261 header.headers.get(http::header::CONTENT_LENGTH)
262 {
263 if let Ok(size) =
264 content_length.to_str().unwrap_or_default().parse::<usize>()
265 {
266 return Some(size);
267 }
268 }
269 None
270}
271
272impl Location {
273 pub fn new(name: &str, conf: &LocationConf) -> Result<Location> {
276 if name.is_empty() {
277 return Err(Error::Invalid {
278 message: "Name is required".to_string(),
279 });
280 }
281 let key = conf.hash_key();
282 let upstream = conf.upstream.clone().unwrap_or_default();
283 let mut reg_rewrite = None;
284 if let Some(value) = &conf.rewrite {
286 let mut arr: Vec<&str> = value.split(' ').collect();
287 if arr.len() == 1 && arr[0].contains("$") {
288 arr.push(arr[0]);
289 arr[0] = ".*";
290 }
291
292 let value = if arr.len() == 2 { arr[1] } else { "" };
293 if let Ok(re) = Regex::new(arr[0]) {
294 reg_rewrite = Some((re, value.to_string()));
295 }
296 }
297 let mut hosts = vec![];
298 for item in conf.host.clone().unwrap_or_default().split(',') {
299 let host = item.trim().to_string();
300 if host.is_empty() {
301 continue;
302 }
303 hosts.push(new_host_selector(&host)?);
304 }
305
306 let path = conf.path.clone().unwrap_or_default();
307
308 let location = Location {
309 name: name.to_string(),
310 key,
311 path_selector: new_path_selector(&path)?,
312 path,
313 hosts,
314 upstream,
315 reg_rewrite,
316 plugins: conf.plugins.clone(),
317 accepted: AtomicU64::new(0),
318 processing: AtomicI32::new(0),
319 max_processing: conf.max_processing.unwrap_or_default(),
320 grpc_web: conf.grpc_web.unwrap_or_default(),
321 proxy_add_headers: format_headers(&conf.proxy_add_headers)?,
322 proxy_set_headers: format_headers(&conf.proxy_set_headers)?,
323 client_max_body_size: conf
324 .client_max_body_size
325 .unwrap_or_default()
326 .as_u64() as usize,
327 enable_reverse_proxy_headers: conf
328 .enable_reverse_proxy_headers
329 .unwrap_or_default(),
330 };
331 debug!(
332 category = LOG_CATEGORY,
333 location = format!("{location:?}"),
334 "create a new location"
335 );
336
337 Ok(location)
338 }
339
340 #[inline]
343 pub fn support_grpc_web(&self) -> bool {
344 self.grpc_web
345 }
346
347 #[inline]
359 pub fn validate_content_length(
360 &self,
361 header: &RequestHeader,
362 ) -> Result<()> {
363 if self.client_max_body_size == 0 {
364 return Ok(());
365 }
366 if get_content_length(header).unwrap_or_default()
367 > self.client_max_body_size
368 {
369 return Err(Error::BodyTooLarge {
370 max: self.client_max_body_size,
371 });
372 }
373
374 Ok(())
375 }
376
377 #[inline]
381 pub fn client_body_size_limit(&self, payload_size: usize) -> Result<()> {
382 if self.client_max_body_size == 0 {
383 return Ok(());
384 }
385 if payload_size > self.client_max_body_size {
386 return Err(Error::BodyTooLarge {
387 max: self.client_max_body_size,
388 });
389 }
390 Ok(())
391 }
392
393 #[inline]
409 pub fn add_processing(&self) -> Result<(u64, i32)> {
410 let accepted = self.accepted.fetch_add(1, Ordering::Relaxed) + 1;
411 let processing = self.processing.fetch_add(1, Ordering::Relaxed) + 1;
412 if self.max_processing != 0 && processing > self.max_processing {
413 return Err(Error::TooManyRequest {
414 max: self.max_processing,
415 });
416 }
417 Ok((accepted, processing))
418 }
419
420 #[inline]
425 pub fn sub_processing(&self) {
426 self.processing.fetch_sub(1, Ordering::Relaxed);
427 }
428
429 #[inline]
434 pub fn match_host_path(
435 &self,
436 host: &str,
437 path: &str,
438 ) -> (bool, Option<Vec<(String, String)>>) {
439 let mut variables: Vec<(String, String)> = vec![];
441
442 if !self.path.is_empty() {
444 let matched = match &self.path_selector {
445 PathSelector::EqualPath(EqualPath { value }) => value == path,
447 PathSelector::RegexPath(RegexPath { value }) => {
449 let (matched, value) = value.captures(path);
450 if let (true, Some(vars)) = (matched, value) {
451 variables.extend(vars);
452 }
453 matched
454 },
455 PathSelector::PrefixPath(PrefixPath { value }) => {
457 path.starts_with(value)
458 },
459 PathSelector::Empty => true,
461 };
462 if !matched {
464 return (false, None);
465 }
466 }
467
468 if self.hosts.is_empty() {
470 return (true, None);
471 }
472
473 let matched = self.hosts.iter().any(|item| match item {
474 HostSelector::RegexHost(RegexHost { value }) => {
478 let (matched, value) = value.captures(host);
479 if let (true, Some(vars)) = (matched, value) {
480 variables.extend(vars);
481 }
482 matched
483 },
484 HostSelector::EqualHost(EqualHost { value }) => {
488 if value.is_empty() {
489 return true;
490 }
491 value == host
492 },
493 });
494 if variables.is_empty() {
495 return (matched, None);
496 }
497
498 (matched, Some(variables))
500 }
501
502 #[inline]
527 pub fn rewrite(
528 &self,
529 header: &mut RequestHeader,
530 variables: Option<&AHashMap<String, String>>,
531 ) -> bool {
532 if let Some((re, value)) = &self.reg_rewrite {
533 let mut replace_value = value.to_string();
534 if let Some(variables) = variables {
536 for (k, v) in variables.iter() {
537 replace_value = replace_value.replace(k, v);
538 }
539 }
540 let path = header.uri.path();
541 let mut new_path = if re.to_string() == ".*" {
542 replace_value
543 } else {
544 re.replace(path, replace_value).to_string()
545 };
546 if path == new_path {
547 return false;
548 }
549 if let Some(query) = header.uri.query() {
551 new_path = format!("{new_path}?{query}");
552 }
553 debug!(category = LOG_CATEGORY, new_path, "rewrite path");
554 if let Err(e) =
556 new_path.parse::<http::Uri>().map(|uri| header.set_uri(uri))
557 {
558 error!(category = LOG_CATEGORY, error = %e, location = self.name, "new path parse fail");
559 }
560 return true;
561 }
562 false
563 }
564}
565
566type Locations = AHashMap<String, Arc<Location>>;
567static LOCATION_MAP: Lazy<ArcSwap<Locations>> =
568 Lazy::new(|| ArcSwap::from_pointee(AHashMap::new()));
569
570pub fn get_location(name: &str) -> Option<Arc<Location>> {
578 if name.is_empty() {
579 return None;
580 }
581 LOCATION_MAP.load().get(name).cloned()
582}
583
584pub fn get_locations_processing() -> HashMap<String, i32> {
589 let mut processing = HashMap::new();
590 LOCATION_MAP.load().iter().for_each(|(k, v)| {
591 processing.insert(k.to_string(), v.processing.load(Ordering::Relaxed));
592 });
593 processing
594}
595
596pub fn try_init_locations(
599 location_configs: &HashMap<String, LocationConf>,
600) -> Result<Vec<String>> {
601 let mut locations = AHashMap::new();
602 let mut updated_locations = vec![];
603 for (name, conf) in location_configs.iter() {
604 if let Some(found) = get_location(name) {
605 if found.key == conf.hash_key() {
606 locations.insert(name.to_string(), found);
607 continue;
608 }
609 }
610 updated_locations.push(name.clone());
611 let lo = Location::new(name, conf)?;
612 locations.insert(name.to_string(), Arc::new(lo));
613 }
614 LOCATION_MAP.store(Arc::new(locations));
615 Ok(updated_locations)
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621 use bytesize::ByteSize;
622 use pingap_config::LocationConf;
623 use pingora::http::RequestHeader;
624 use pingora::proxy::Session;
625 use pretty_assertions::assert_eq;
626 use tokio_test::io::Builder;
627
628 #[test]
629 fn test_format_headers() {
630 let headers = format_headers(&Some(vec![
631 "Content-Type: application/json".to_string(),
632 ]))
633 .unwrap();
634 assert_eq!(
635 r###"Some([("content-type", "application/json")])"###,
636 format!("{headers:?}")
637 );
638 }
639 #[test]
640 fn test_new_path_selector() {
641 let selector = new_path_selector("").unwrap();
642 assert_eq!(true, matches!(selector, PathSelector::Empty));
643
644 let selector = new_path_selector("~/api").unwrap();
645 assert_eq!(true, matches!(selector, PathSelector::RegexPath(_)));
646
647 let selector = new_path_selector("=/api").unwrap();
648 assert_eq!(true, matches!(selector, PathSelector::EqualPath(_)));
649
650 let selector = new_path_selector("/api").unwrap();
651 assert_eq!(true, matches!(selector, PathSelector::PrefixPath(_)));
652 }
653 #[test]
654 fn test_path_host_select_location() {
655 let upstream_name = "charts";
656
657 let lo = Location::new(
659 "lo",
660 &LocationConf {
661 upstream: Some(upstream_name.to_string()),
662 ..Default::default()
663 },
664 )
665 .unwrap();
666 assert_eq!(true, lo.match_host_path("pingap", "/api").0);
667 assert_eq!(true, lo.match_host_path("", "").0);
668
669 let lo = Location::new(
671 "lo",
672 &LocationConf {
673 upstream: Some(upstream_name.to_string()),
674 host: Some("test.com,pingap".to_string()),
675 ..Default::default()
676 },
677 )
678 .unwrap();
679 assert_eq!(true, lo.match_host_path("pingap", "/api").0);
680 assert_eq!(true, lo.match_host_path("pingap", "").0);
681 assert_eq!(false, lo.match_host_path("", "/api").0);
682
683 let lo = Location::new(
685 "lo",
686 &LocationConf {
687 upstream: Some(upstream_name.to_string()),
688 path: Some("~/users".to_string()),
689 ..Default::default()
690 },
691 )
692 .unwrap();
693 assert_eq!(true, lo.match_host_path("", "/api/users").0);
694 assert_eq!(true, lo.match_host_path("", "/users").0);
695 assert_eq!(false, lo.match_host_path("", "/api").0);
696
697 let lo = Location::new(
699 "lo",
700 &LocationConf {
701 upstream: Some(upstream_name.to_string()),
702 path: Some("~^/api".to_string()),
703 ..Default::default()
704 },
705 )
706 .unwrap();
707 assert_eq!(true, lo.match_host_path("", "/api/users").0);
708 assert_eq!(false, lo.match_host_path("", "/users").0);
709 assert_eq!(true, lo.match_host_path("", "/api").0);
710
711 let lo = Location::new(
713 "lo",
714 &LocationConf {
715 upstream: Some(upstream_name.to_string()),
716 path: Some("/api".to_string()),
717 ..Default::default()
718 },
719 )
720 .unwrap();
721 assert_eq!(true, lo.match_host_path("", "/api/users").0);
722 assert_eq!(false, lo.match_host_path("", "/users").0);
723 assert_eq!(true, lo.match_host_path("", "/api").0);
724
725 let lo = Location::new(
727 "lo",
728 &LocationConf {
729 upstream: Some(upstream_name.to_string()),
730 path: Some("=/api".to_string()),
731 ..Default::default()
732 },
733 )
734 .unwrap();
735 assert_eq!(false, lo.match_host_path("", "/api/users").0);
736 assert_eq!(false, lo.match_host_path("", "/users").0);
737 assert_eq!(true, lo.match_host_path("", "/api").0);
738 }
739
740 #[test]
741 fn test_match_host_path_variables() {
742 let lo = Location::new(
743 "lo",
744 &LocationConf {
745 upstream: Some("charts".to_string()),
746 host: Some("~(?<name>.+).npmtrend.com".to_string()),
747 path: Some("~/(?<route>.+)/(.*)".to_string()),
748 ..Default::default()
749 },
750 )
751 .unwrap();
752 let (matched, variables) =
753 lo.match_host_path("charts.npmtrend.com", "/users/123");
754 assert_eq!(true, matched);
755 assert_eq!(
756 Some(vec![
757 ("route".to_string(), "users".to_string()),
758 ("name".to_string(), "charts".to_string()),
759 ]),
760 variables
761 );
762 }
763
764 #[test]
765 fn test_rewrite_path() {
766 let upstream_name = "charts";
767
768 let lo = Location::new(
769 "lo",
770 &LocationConf {
771 upstream: Some(upstream_name.to_string()),
772 rewrite: Some("^/users/(.*)$ /$1".to_string()),
773 ..Default::default()
774 },
775 )
776 .unwrap();
777 let mut req_header =
778 RequestHeader::build("GET", b"/users/me?abc=1", None).unwrap();
779 assert_eq!(true, lo.rewrite(&mut req_header, None));
780 assert_eq!("/me?abc=1", req_header.uri.to_string());
781
782 let mut req_header =
783 RequestHeader::build("GET", b"/api/me?abc=1", None).unwrap();
784 assert_eq!(false, lo.rewrite(&mut req_header, None));
785 assert_eq!("/api/me?abc=1", req_header.uri.to_string());
786 }
787
788 #[test]
789 fn test_client_body_size_limit() {
790 let upstream_name = "charts";
791
792 let lo = Location::new(
793 "lo",
794 &LocationConf {
795 upstream: Some(upstream_name.to_string()),
796 rewrite: Some("^/users/(.*)$ /$1".to_string()),
797 plugins: Some(vec!["test:mock".to_string()]),
798 client_max_body_size: Some(ByteSize(10)),
799 ..Default::default()
800 },
801 )
802 .unwrap();
803
804 let result = lo.client_body_size_limit(2);
805 assert_eq!(true, result.is_ok());
806
807 let result = lo.client_body_size_limit(20);
808 assert_eq!(
809 "Request Entity Too Large, max:10",
810 result.err().unwrap().to_string()
811 );
812 }
813
814 #[tokio::test]
815 async fn test_get_content_length() {
816 let headers = ["Content-Length: 123"].join("\r\n");
817 let input_header =
818 format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
819 let mock_io = Builder::new().read(input_header.as_bytes()).build();
820 let mut session = Session::new_h1(Box::new(mock_io));
821 session.read_request().await.unwrap();
822 assert_eq!(get_content_length(session.req_header()), Some(123));
823 }
824
825 #[test]
826 fn test_location_processing() {
827 let lo = Location::new(
828 "lo",
829 &LocationConf {
830 ..Default::default()
831 },
832 )
833 .unwrap();
834 let value = lo.add_processing().unwrap();
835 assert_eq!(1, value.0);
836 assert_eq!(1, value.1);
837
838 lo.sub_processing();
839 assert_eq!(1, lo.accepted.load(Ordering::Relaxed));
840 assert_eq!(0, lo.processing.load(Ordering::Relaxed));
841 }
842
843 #[test]
844 fn test_validate_content_length() {
845 let lo = Location::new(
846 "lo",
847 &LocationConf {
848 client_max_body_size: Some(ByteSize(10)),
849 ..Default::default()
850 },
851 )
852 .unwrap();
853 let mut req_header =
854 RequestHeader::build("GET", b"/users/me?abc=1", None).unwrap();
855 assert_eq!(true, lo.validate_content_length(&req_header).is_ok());
856
857 req_header
858 .append_header(
859 http::header::CONTENT_LENGTH,
860 http::HeaderValue::from_str("20").unwrap(),
861 )
862 .unwrap();
863 assert_eq!(
864 "Request Entity Too Large, max:10",
865 lo.validate_content_length(&req_header)
866 .err()
867 .unwrap()
868 .to_string()
869 );
870 }
871}