use super::{
Error, get_hash_key, get_int_conf, get_int_conf_or_default,
get_plugin_factory, get_step_conf, get_str_conf,
};
use async_trait::async_trait;
use ctor::ctor;
use http::StatusCode;
use humantime::parse_duration;
use pingap_config::{PluginCategory, PluginConf};
use pingap_core::{
Ctx, HttpResponse, Inflight, Plugin, PluginStep, Rate, RequestPluginResult,
};
use pingap_core::{
get_client_ip, get_cookie_value, get_query_value, get_req_header_value,
};
use pingora::proxy::Session;
use std::borrow::Cow;
use std::sync::Arc;
use std::time::Duration;
use tracing::debug;
type Result<T, E = Error> = std::result::Result<T, E>;
#[derive(PartialEq, Debug)]
pub enum LimitTag {
Ip, RequestHeader, Cookie, Query, }
pub struct Limiter {
tag: LimitTag,
max: f64,
key: String,
inflight: Option<Inflight>,
rate: Option<Rate>,
plugin_step: PluginStep,
hash_value: String,
weight: f64,
}
impl TryFrom<&PluginConf> for Limiter {
type Error = Error;
fn try_from(value: &PluginConf) -> Result<Self> {
let hash_value = get_hash_key(value);
let step = get_step_conf(value, PluginStep::Request);
let tag = match get_str_conf(value, "tag").as_str() {
"cookie" => LimitTag::Cookie,
"header" => LimitTag::RequestHeader,
"query" => LimitTag::Query,
_ => LimitTag::Ip,
};
let interval = get_str_conf(value, "interval");
let interval = if !interval.is_empty() {
parse_duration(&interval).map_err(|e| Error::Invalid {
category: PluginCategory::Limit.to_string(),
message: e.to_string(),
})?
} else {
Duration::from_secs(10)
};
let mut inflight = None;
let mut rate = None;
let mut max = get_int_conf(value, "max") as f64;
if get_str_conf(value, "type") == "inflight" {
inflight = Some(Inflight::new());
} else {
max /= interval.as_secs_f64().max(1.0);
rate = Some(Rate::new(interval));
}
let weight = get_int_conf_or_default(value, "weight", 50).clamp(0, 100)
as f64
/ 100.0;
let params = Self {
hash_value,
tag,
key: get_str_conf(value, "key"),
max,
inflight,
rate,
plugin_step: step,
weight,
};
if ![PluginStep::Request, PluginStep::ProxyUpstream]
.contains(¶ms.plugin_step)
{
return Err(Error::Invalid {
category: PluginCategory::Limit.to_string(),
message: "Limit plugin should be executed at request or proxy upstream step".to_string(),
});
}
Ok(params)
}
}
impl Limiter {
pub fn new(params: &PluginConf) -> Result<Self> {
debug!(params = params.to_string(), "new limit plugin");
Self::try_from(params)
}
pub fn incr(&self, session: &Session, ctx: &mut Ctx) -> Result<()> {
let key = match self.tag {
LimitTag::Query => {
get_query_value(session.req_header(), &self.key)
.unwrap_or_default()
.to_string()
},
LimitTag::RequestHeader => {
get_req_header_value(session.req_header(), &self.key)
.unwrap_or_default()
.to_string()
},
LimitTag::Cookie => {
get_cookie_value(session.req_header(), &self.key)
.unwrap_or_default()
.to_string()
},
_ => {
let ip = ctx
.conn
.client_ip
.get_or_insert_with(|| get_client_ip(session));
ip.to_string()
},
};
if key.is_empty() {
return Ok(());
}
let value = if let Some(rate) = &self.rate {
rate.observe(&key, 1); if self.weight > 0.0 {
rate.rate_with(&key, |rate_info| {
let prev =
rate_info.prev_samples as f64 * (1. - self.weight);
let curr = rate_info.curr_samples as f64 * self.weight;
(prev + curr) / rate_info.interval.as_secs_f64()
})
} else {
rate.rate(&key) }
} else if let Some(inflight) = &self.inflight {
let (guard, value) = inflight.incr(&key, 1);
ctx.state.guard = Some(guard);
value as f64
} else {
0.0
};
if value > self.max {
return Err(Error::Exceed {
category: PluginCategory::Limit.to_string(),
max: self.max,
value,
});
}
Ok(())
}
}
#[async_trait]
impl Plugin for Limiter {
#[inline]
fn config_key(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.hash_value)
}
#[inline]
async fn handle_request(
&self,
step: PluginStep,
session: &mut Session,
ctx: &mut Ctx,
) -> pingora::Result<RequestPluginResult> {
if step != self.plugin_step {
return Ok(RequestPluginResult::Skipped);
}
if let Err(e) = self.incr(session, ctx) {
return Ok(RequestPluginResult::Respond(HttpResponse {
status: StatusCode::TOO_MANY_REQUESTS,
body: e.to_string().into(),
..Default::default()
}));
}
Ok(RequestPluginResult::Continue)
}
}
#[ctor]
fn init() {
get_plugin_factory()
.register("limit", |params| Ok(Arc::new(Limiter::new(params)?)));
}
#[cfg(test)]
mod tests {
use super::*;
use http::StatusCode;
use pingap_config::PluginConf;
use pingap_core::{Ctx, PluginStep};
use pingora::proxy::Session;
use pretty_assertions::assert_eq;
use std::time::Duration;
use tokio_test::io::Builder;
async fn new_session() -> Session {
let headers = [
"Host: github.com",
"Referer: https://github.com/",
"User-Agent: pingap/0.1.1",
"Cookie: deviceId=abc",
"Accept: application/json",
"X-Uuid: 138q71",
"X-Forwarded-For: 1.1.1.1, 192.168.1.2",
]
.join("\r\n");
let input_header =
format!("GET /vicanso/pingap?key=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
session
}
#[test]
fn test_limit_params() {
let params = Limiter::try_from(
&toml::from_str::<PluginConf>(
r###"
type = "inflight"
tag = "cookie"
key = "deviceId"
max = 10
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!("request", params.plugin_step.to_string());
assert_eq!(true, params.inflight.is_some());
assert_eq!(LimitTag::Cookie, params.tag);
assert_eq!("deviceId", params.key);
let result = Limiter::try_from(
&toml::from_str::<PluginConf>(
r###"
step = "response"
type = "inflight"
tag = "cookie"
key = "deviceId"
max = 10
"###,
)
.unwrap(),
);
assert_eq!(
"Plugin limit invalid, message: Limit plugin should be executed at request or proxy upstream step",
result.err().unwrap().to_string()
);
}
#[tokio::test]
async fn test_new_cookie_limiter() {
let limiter = Limiter::new(
&toml::from_str::<PluginConf>(
r###"
type = "inflight"
tag = "cookie"
key = "deviceId"
max = 10
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!(LimitTag::Cookie, limiter.tag);
let mut ctx = Ctx {
..Default::default()
};
let session = new_session().await;
limiter.incr(&session, &mut ctx).unwrap();
assert_eq!(true, ctx.state.guard.is_some());
}
#[tokio::test]
async fn test_new_req_header_limiter() {
let limiter = Limiter::new(
&toml::from_str::<PluginConf>(
r###"
type = "inflight"
tag = "header"
key = "X-Uuid"
max = 10
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!(LimitTag::RequestHeader, limiter.tag);
let mut ctx = Ctx {
..Default::default()
};
let session = new_session().await;
limiter.incr(&session, &mut ctx).unwrap();
assert_eq!(true, ctx.state.guard.is_some());
}
#[tokio::test]
async fn test_new_query_limiter() {
let limiter = Limiter::new(
&toml::from_str::<PluginConf>(
r###"
type = "inflight"
tag = "query"
key = "key"
max = 10
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!(LimitTag::Query, limiter.tag);
let mut ctx = Ctx {
..Default::default()
};
let session = new_session().await;
limiter.incr(&session, &mut ctx).unwrap();
assert_eq!(true, ctx.state.guard.is_some());
}
#[tokio::test]
async fn test_new_ip_limiter() {
let limiter = Limiter::new(
&toml::from_str::<PluginConf>(
r###"
type = "inflight"
max = 10
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!(LimitTag::Ip, limiter.tag);
let mut ctx = Ctx {
..Default::default()
};
let session = new_session().await;
limiter.incr(&session, &mut ctx).unwrap();
assert_eq!(true, ctx.state.guard.is_some());
}
#[tokio::test]
async fn test_inflight_limit() {
let limiter = Limiter::new(
&toml::from_str::<PluginConf>(
r###"
type = "inflight"
max = 0
"###,
)
.unwrap(),
)
.unwrap();
let headers = ["X-Forwarded-For: 1.1.1.1"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = limiter
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(StatusCode::TOO_MANY_REQUESTS, resp.status);
let limiter = Limiter::new(
&toml::from_str::<PluginConf>(
r###"
type = "inflight"
max = 1
"###,
)
.unwrap(),
)
.unwrap();
let result = limiter
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(true, result == RequestPluginResult::Continue);
}
#[tokio::test]
async fn test_rate_limit() {
let limiter = Limiter::new(
&toml::from_str::<PluginConf>(
r###"
type = "rate"
max = 1
interval = "1s"
"###,
)
.unwrap(),
)
.unwrap();
let headers = ["X-Forwarded-For: 1.1.1.1"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = limiter
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(true, result == RequestPluginResult::Continue);
let _ = limiter
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
tokio::time::sleep(Duration::from_secs(1)).await;
let result = limiter
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(StatusCode::TOO_MANY_REQUESTS, resp.status);
tokio::time::sleep(Duration::from_secs(1)).await;
let result = limiter
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(true, result == RequestPluginResult::Continue);
}
}