use super::{
Error, get_hash_key, get_plugin_factory, get_str_conf, get_str_slice_conf,
};
use async_trait::async_trait;
use bytes::Bytes;
use ctor::ctor;
use http::StatusCode;
use pingap_config::PluginConf;
use pingap_core::{Ctx, HttpResponse, Plugin, PluginStep, RequestPluginResult};
use pingora::proxy::Session;
use regex::Regex;
use std::borrow::Cow;
use std::sync::Arc;
use tracing::debug;
type Result<T, E = Error> = std::result::Result<T, E>;
pub struct UaRestriction {
plugin_step: PluginStep, ua_list: Vec<Regex>, restriction_category: String, forbidden_resp: HttpResponse, hash_value: String, }
impl TryFrom<&PluginConf> for UaRestriction {
type Error = Error;
fn try_from(value: &PluginConf) -> Result<Self> {
let hash_value = get_hash_key(value);
let mut ua_list = vec![];
for item in get_str_slice_conf(value, "ua_list").iter() {
let reg = Regex::new(item).map_err(|e| Error::Invalid {
category: "regex".to_string(),
message: e.to_string(),
})?;
ua_list.push(reg);
}
let mut message = get_str_conf(value, "message");
if message.is_empty() {
message = "Request is forbidden".to_string();
}
let params = Self {
hash_value,
plugin_step: PluginStep::Request,
ua_list,
restriction_category: get_str_conf(value, "type"),
forbidden_resp: HttpResponse {
status: StatusCode::FORBIDDEN,
body: Bytes::from(message),
..Default::default()
},
};
Ok(params)
}
}
impl UaRestriction {
pub fn new(params: &PluginConf) -> Result<Self> {
debug!(
params = params.to_string(),
"new user agent restriction plugin"
);
Self::try_from(params)
}
}
#[async_trait]
impl Plugin for UaRestriction {
#[inline]
fn config_key(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.hash_value)
}
#[inline]
async fn handle_request(
&self,
step: PluginStep,
session: &mut Session,
_ctx: &mut Ctx,
) -> pingora::Result<RequestPluginResult> {
if step != self.plugin_step {
return Ok(RequestPluginResult::Skipped);
}
let mut found = false;
if let Some(value) = session.get_header(http::header::USER_AGENT) {
let ua = value.to_str().unwrap_or_default();
for item in self.ua_list.iter() {
if !found && item.is_match(ua) {
found = true;
}
}
}
let allow = if self.restriction_category == "deny" {
!found } else {
found };
if !allow {
return Ok(RequestPluginResult::Respond(
self.forbidden_resp.clone(),
));
}
Ok(RequestPluginResult::Continue) }
}
#[ctor]
fn init() {
get_plugin_factory().register("ua_restriction", |params| {
Ok(Arc::new(UaRestriction::new(params)?))
});
}
#[cfg(test)]
mod tests {
use super::*;
use http::StatusCode;
use pingap_config::PluginConf;
use pingap_core::{ConnectionInfo, Ctx, PluginStep};
use pingora::proxy::Session;
use pretty_assertions::assert_eq;
use tokio_test::io::Builder;
#[test]
fn test_ua_restriction_params() {
let params = UaRestriction::try_from(
&toml::from_str::<PluginConf>(
r###"
ua_list = [
"go-http-client/1.1", # Blocks/allows exact UA string
"(Twitterspider)/(\\d+)\\.(\\d+)" # Blocks/allows Twitterspider with version numbers
]
type = "deny" # This config will block these user agents
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!("request", params.plugin_step.to_string());
assert_eq!(
r#"go-http-client/1.1,(Twitterspider)/(\d+)\.(\d+)"#,
params
.ua_list
.iter()
.map(|item| item.to_string())
.collect::<Vec<String>>()
.join(",")
);
assert_eq!("deny", params.restriction_category);
}
#[tokio::test]
async fn test_ua_restriction() {
let deny = UaRestriction::new(
&toml::from_str::<PluginConf>(
r###"
ua_list = [
"go-http-client/1.1",
"(Twitterspider)/(\\d+)\\.(\\d+)"
]
type = "deny"
"###,
)
.unwrap(),
)
.unwrap();
let headers = ["User-Agent: pingap/1.0"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = deny
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(true, result == RequestPluginResult::Continue);
let headers = ["User-Agent: go-http-client/1.1"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = deny
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(StatusCode::FORBIDDEN, resp.status);
let headers = ["User-Agent: Twitterspider/1.1"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1(Box::new(mock_io));
session.read_request().await.unwrap();
let result = deny
.handle_request(
PluginStep::Request,
&mut session,
&mut Ctx {
conn: ConnectionInfo {
client_ip: Some("1.1.1.2".to_string()),
..Default::default()
},
..Default::default()
},
)
.await
.unwrap();
let RequestPluginResult::Respond(resp) = result else {
panic!("result is not Respond");
};
assert_eq!(StatusCode::FORBIDDEN, resp.status);
}
}