use super::regex::RegexCapture;
use ahash::AHashMap;
use http::HeaderName;
use http::HeaderValue;
use pingap_config::Hashable;
use pingap_config::LocationConf;
use pingap_core::LocationInstance;
use pingap_core::new_internal_error;
use pingap_core::{HttpHeader, convert_headers};
use pingora::http::RequestHeader;
use regex::Regex;
use snafu::{ResultExt, Snafu};
use std::sync::Arc;
use std::sync::LazyLock;
use std::sync::atomic::{AtomicI32, AtomicU64, Ordering};
use std::time::Duration;
use tracing::{debug, error};
const LOG_CATEGORY: &str = "location";
pub type Locations = AHashMap<String, Arc<Location>>;
#[derive(Debug, Snafu)]
pub enum Error {
#[snafu(display("Invalid error {message}"))]
Invalid { message: String },
#[snafu(display("Regex value: {value}, {source}"))]
Regex { value: String, source: regex::Error },
#[snafu(display("Too Many Requests, max:{max}"))]
TooManyRequest { max: i32 },
#[snafu(display("Request Entity Too Large, max:{max}"))]
BodyTooLarge { max: usize },
}
type Result<T, E = Error> = std::result::Result<T, E>;
pub struct LocationStats {
pub processing: i32,
pub accepted: u64,
}
#[derive(Debug)]
enum PathSelector {
Regex(RegexCapture),
Prefix(String),
Equal(String),
Any,
}
impl PathSelector {
fn new(path: &str) -> Result<Self> {
let path = path.trim();
if path.is_empty() {
return Ok(PathSelector::Any);
}
if let Some(re_path) = path.strip_prefix('~') {
let re = RegexCapture::new(re_path.trim()).context(RegexSnafu {
value: re_path.trim(),
})?;
Ok(PathSelector::Regex(re))
} else if let Some(eq_path) = path.strip_prefix('=') {
Ok(PathSelector::Equal(eq_path.trim().to_string()))
} else {
Ok(PathSelector::Prefix(path.to_string()))
}
}
#[inline]
fn is_match(&self, path: &str) -> (bool, Option<AHashMap<String, String>>) {
match self {
PathSelector::Equal(value) => (value == path, None),
PathSelector::Regex(value) => value.captures(path),
PathSelector::Prefix(value) => (path.starts_with(value), None),
PathSelector::Any => (true, None),
}
}
}
#[derive(Debug)]
enum HostSelector {
Regex(RegexCapture),
Equal(String),
}
impl HostSelector {
fn new(host: &str) -> Result<Self> {
let host = host.trim();
if let Some(re_host) = host.strip_prefix('~') {
let re = RegexCapture::new(re_host.trim()).context(RegexSnafu {
value: re_host.trim(),
})?;
Ok(HostSelector::Regex(re))
} else {
Ok(HostSelector::Equal(host.to_string()))
}
}
#[inline]
fn is_match(&self, host: &str) -> (bool, Option<AHashMap<String, String>>) {
match self {
HostSelector::Equal(value) => (value == host, None),
HostSelector::Regex(value) => value.captures(host),
}
}
}
static DEFAULT_PROXY_SET_HEADERS: LazyLock<Vec<HttpHeader>> =
LazyLock::new(|| {
convert_headers(&[
"x-real-ip:$remote_addr".to_string(),
"x-forwarded-for:$proxy_add_x_forwarded_for".to_string(),
"x-forwarded-proto:$scheme".to_string(),
"x-forwarded-host:$host".to_string(),
"x-forwarded-port:$server_port".to_string(),
])
.expect("Failed to convert default proxy set headers")
});
#[derive(Debug)]
pub struct Location {
pub name: Arc<str>,
pub key: String,
upstream: String,
path: String,
path_selector: PathSelector,
hosts: Vec<HostSelector>,
reg_rewrite: Option<(Regex, String)>,
pub headers: Option<Vec<(HeaderName, HeaderValue, bool)>>,
pub plugins: Option<Vec<String>>,
accepted: AtomicU64,
processing: AtomicI32,
max_processing: i32,
grpc_web: bool,
client_max_body_size: usize,
pub max_retries: Option<u8>,
pub max_retry_window: Option<Duration>,
}
fn format_headers(
values: &Option<Vec<String>>,
) -> Result<Option<Vec<HttpHeader>>> {
if let Some(header_values) = values {
let arr =
convert_headers(header_values).map_err(|err| Error::Invalid {
message: err.to_string(),
})?;
Ok(Some(arr))
} else {
Ok(None)
}
}
fn get_content_length(header: &RequestHeader) -> Option<usize> {
if let Some(content_length) =
header.headers.get(http::header::CONTENT_LENGTH)
&& let Ok(size) =
content_length.to_str().unwrap_or_default().parse::<usize>()
{
return Some(size);
}
None
}
impl Location {
pub fn new(name: &str, conf: &LocationConf) -> Result<Location> {
if name.is_empty() {
return Err(Error::Invalid {
message: "Name is required".to_string(),
});
}
let key = conf.hash_key();
let upstream = conf.upstream.clone().unwrap_or_default();
let mut reg_rewrite = None;
if let Some(value) = &conf.rewrite {
let mut arr: Vec<&str> = value.split(' ').collect();
if arr.len() == 1 && arr[0].contains("$") {
arr.push(arr[0]);
arr[0] = ".*";
}
let value = if arr.len() == 2 { arr[1] } else { "" };
if let Ok(re) = Regex::new(arr[0]) {
reg_rewrite = Some((re, value.to_string()));
}
}
let hosts = conf
.host
.as_deref()
.unwrap_or("")
.split(',')
.map(str::trim)
.filter(|s| !s.is_empty())
.map(HostSelector::new)
.collect::<Result<Vec<_>>>()?;
let path = conf.path.clone().unwrap_or_default();
let mut headers: Vec<(HeaderName, HeaderValue, bool)> = vec![];
if conf.enable_reverse_proxy_headers.unwrap_or_default() {
for (name, value) in DEFAULT_PROXY_SET_HEADERS.iter() {
headers.push((name.clone(), value.clone(), false));
}
}
if let Some(proxy_set_headers) =
format_headers(&conf.proxy_set_headers)?
{
for (name, value) in proxy_set_headers.iter() {
headers.push((name.clone(), value.clone(), false));
}
}
if let Some(proxy_add_headers) =
format_headers(&conf.proxy_add_headers)?
{
for (name, value) in proxy_add_headers.iter() {
headers.push((name.clone(), value.clone(), true));
}
}
let location = Location {
name: name.into(),
key,
path_selector: PathSelector::new(&path)?,
path,
hosts,
upstream,
reg_rewrite,
plugins: conf.plugins.clone(),
accepted: AtomicU64::new(0),
processing: AtomicI32::new(0),
max_processing: conf.max_processing.unwrap_or_default(),
grpc_web: conf.grpc_web.unwrap_or_default(),
headers: if headers.is_empty() {
None
} else {
Some(headers)
},
client_max_body_size: conf
.client_max_body_size
.unwrap_or_default()
.as_u64() as usize,
max_retries: conf.max_retries,
max_retry_window: conf.max_retry_window,
};
debug!(
category = LOG_CATEGORY,
location = format!("{location:?}"),
"create a new location"
);
Ok(location)
}
#[inline]
pub fn support_grpc_web(&self) -> bool {
self.grpc_web
}
#[inline]
pub fn validate_content_length(
&self,
header: &RequestHeader,
) -> Result<()> {
if self.client_max_body_size == 0 {
return Ok(());
}
if get_content_length(header).unwrap_or_default()
> self.client_max_body_size
{
return Err(Error::BodyTooLarge {
max: self.client_max_body_size,
});
}
Ok(())
}
#[inline]
pub fn match_host_path(
&self,
host: &str,
path: &str,
) -> (bool, Option<AHashMap<String, String>>) {
let mut capture_values = None;
if !self.path.is_empty() {
let (matched, captures) = self.path_selector.is_match(path);
if !matched {
return (false, None);
}
capture_values = captures;
}
if self.hosts.is_empty() {
return (true, capture_values);
}
let matched = self.hosts.iter().any(|host_selector| {
let (matched, captures) = host_selector.is_match(host);
if let Some(captures) = captures {
if let Some(values) = capture_values.as_mut() {
values.extend(captures);
} else {
capture_values = Some(captures);
}
}
matched
});
(matched, capture_values)
}
pub fn stats(&self) -> LocationStats {
LocationStats {
processing: self.processing.load(Ordering::Relaxed),
accepted: self.accepted.load(Ordering::Relaxed),
}
}
}
impl LocationInstance for Location {
fn name(&self) -> &str {
self.name.as_ref()
}
fn headers(&self) -> Option<&Vec<(HeaderName, HeaderValue, bool)>> {
self.headers.as_ref()
}
fn client_body_size_limit(&self) -> usize {
self.client_max_body_size
}
fn upstream(&self) -> &str {
self.upstream.as_ref()
}
fn on_response(&self) {
self.processing.fetch_sub(1, Ordering::Relaxed);
}
fn on_request(&self) -> pingora::Result<(u64, i32)> {
let accepted = self.accepted.fetch_add(1, Ordering::Relaxed) + 1;
let processing = self.processing.fetch_add(1, Ordering::Relaxed) + 1;
if self.max_processing != 0 && processing > self.max_processing {
let err = Error::TooManyRequest {
max: self.max_processing,
};
return Err(new_internal_error(429, err));
}
Ok((accepted, processing))
}
#[inline]
fn rewrite(
&self,
header: &mut RequestHeader,
mut variables: Option<AHashMap<String, String>>,
) -> (bool, Option<AHashMap<String, String>>) {
let Some((re, value)) = &self.reg_rewrite else {
return (false, variables);
};
let mut replace_value = value.to_string();
if let Some(vars) = &variables {
for (k, v) in vars.iter() {
replace_value = replace_value.replace(k, v);
}
}
let path = header.uri.path();
let mut new_path = if re.to_string() == ".*" {
replace_value
} else {
re.replace(path, replace_value).to_string()
};
if path == new_path {
return (false, variables);
}
if let Some(captures) = re.captures(path) {
for name in re.capture_names().flatten() {
if let Some(match_value) = captures.name(name) {
let values = variables.get_or_insert_with(AHashMap::new);
values.insert(
name.to_string(),
match_value.as_str().to_string(),
);
}
}
}
if let Some(query) = header.uri.query() {
new_path = format!("{new_path}?{query}");
}
debug!(category = LOG_CATEGORY, new_path, "rewrite path");
if let Err(e) =
new_path.parse::<http::Uri>().map(|uri| header.set_uri(uri))
{
error!(category = LOG_CATEGORY, error = %e, location = self.name.as_ref(), "new path parse fail");
}
(true, variables)
}
}
#[cfg(test)]
mod tests {
use super::*;
use bytesize::ByteSize;
use pingap_config::LocationConf;
use pingora::http::RequestHeader;
use pingora::proxy::Session;
use pretty_assertions::assert_eq;
use tokio_test::io::Builder;
#[test]
fn test_format_headers() {
let headers = format_headers(&Some(vec![
"Content-Type: application/json".to_string(),
]))
.unwrap();
assert_eq!(
r###"Some([("content-type", "application/json")])"###,
format!("{headers:?}")
);
}
#[test]
fn test_new_path_selector() {
let selector = PathSelector::new("").unwrap();
assert_eq!(true, matches!(selector, PathSelector::Any));
let selector = PathSelector::new("~/api").unwrap();
assert_eq!(true, matches!(selector, PathSelector::Regex(_)));
let selector = PathSelector::new("=/api").unwrap();
assert_eq!(true, matches!(selector, PathSelector::Equal(_)));
let selector = PathSelector::new("/api").unwrap();
assert_eq!(true, matches!(selector, PathSelector::Prefix(_)));
}
#[test]
fn test_path_host_select_location() {
let upstream_name = "charts";
let lo = Location::new(
"lo",
&LocationConf {
upstream: Some(upstream_name.to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(true, lo.match_host_path("pingap", "/api").0);
assert_eq!(true, lo.match_host_path("", "").0);
let lo = Location::new(
"lo",
&LocationConf {
upstream: Some(upstream_name.to_string()),
host: Some("test.com,pingap".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(true, lo.match_host_path("pingap", "/api").0);
assert_eq!(true, lo.match_host_path("pingap", "").0);
assert_eq!(false, lo.match_host_path("", "/api").0);
let lo = Location::new(
"lo",
&LocationConf {
upstream: Some(upstream_name.to_string()),
path: Some("~/users".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(true, lo.match_host_path("", "/api/users").0);
assert_eq!(true, lo.match_host_path("", "/users").0);
assert_eq!(false, lo.match_host_path("", "/api").0);
let lo = Location::new(
"lo",
&LocationConf {
upstream: Some(upstream_name.to_string()),
path: Some("~^/api".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(true, lo.match_host_path("", "/api/users").0);
assert_eq!(false, lo.match_host_path("", "/users").0);
assert_eq!(true, lo.match_host_path("", "/api").0);
let lo = Location::new(
"lo",
&LocationConf {
upstream: Some(upstream_name.to_string()),
path: Some("/api".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(true, lo.match_host_path("", "/api/users").0);
assert_eq!(false, lo.match_host_path("", "/users").0);
assert_eq!(true, lo.match_host_path("", "/api").0);
let lo = Location::new(
"lo",
&LocationConf {
upstream: Some(upstream_name.to_string()),
path: Some("=/api".to_string()),
..Default::default()
},
)
.unwrap();
assert_eq!(false, lo.match_host_path("", "/api/users").0);
assert_eq!(false, lo.match_host_path("", "/users").0);
assert_eq!(true, lo.match_host_path("", "/api").0);
}
#[test]
fn test_match_host_path_variables() {
let lo = Location::new(
"lo",
&LocationConf {
upstream: Some("charts".to_string()),
host: Some("~(?<name>.+).npmtrend.com".to_string()),
path: Some("~/(?<route>.+)/(.*)".to_string()),
..Default::default()
},
)
.unwrap();
let (matched, variables) =
lo.match_host_path("charts.npmtrend.com", "/users/123");
assert_eq!(true, matched);
let variables = variables.unwrap();
assert_eq!("users", variables.get("route").unwrap());
assert_eq!("charts", variables.get("name").unwrap());
}
#[test]
fn test_rewrite_path() {
let upstream_name = "charts";
let lo = Location::new(
"lo",
&LocationConf {
upstream: Some(upstream_name.to_string()),
rewrite: Some("^/users/(?<upstream>.*?)/(.*)$ /$2".to_string()),
..Default::default()
},
)
.unwrap();
let mut req_header =
RequestHeader::build("GET", b"/users/rest/me?abc=1", None).unwrap();
let (matched, variables) = lo.rewrite(&mut req_header, None);
assert_eq!(true, matched);
assert_eq!(r#"Some({"upstream": "rest"})"#, format!("{:?}", variables));
assert_eq!("/me?abc=1", req_header.uri.to_string());
let mut req_header =
RequestHeader::build("GET", b"/api/me?abc=1", None).unwrap();
let (matched, variables) = lo.rewrite(&mut req_header, None);
assert_eq!(false, matched);
assert_eq!(None, variables);
assert_eq!("/api/me?abc=1", req_header.uri.to_string());
}
#[tokio::test]
async fn test_get_content_length() {
let headers = ["Content-Length: 123"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
assert_eq!(get_content_length(session.req_header()), Some(123));
}
#[test]
fn test_validate_content_length() {
let lo = Location::new(
"lo",
&LocationConf {
client_max_body_size: Some(ByteSize(10)),
..Default::default()
},
)
.unwrap();
let mut req_header =
RequestHeader::build("GET", b"/users/me?abc=1", None).unwrap();
assert_eq!(true, lo.validate_content_length(&req_header).is_ok());
req_header
.append_header(
http::header::CONTENT_LENGTH,
http::HeaderValue::from_str("20").unwrap(),
)
.unwrap();
assert_eq!(
"Request Entity Too Large, max:10",
lo.validate_content_length(&req_header)
.err()
.unwrap()
.to_string()
);
}
}