use super::{
Error, get_bool_conf, get_hash_key, get_plugin_factory, get_str_conf,
};
use async_trait::async_trait;
use ctor::ctor;
use pingap_config::PluginConf;
use pingap_core::{Ctx, Plugin, PluginStep, RequestPluginResult};
use pingora::proxy::Session;
use smallvec::SmallVec;
use std::borrow::Cow;
use std::sync::Arc;
use tracing::debug;
type Result<T, E = Error> = std::result::Result<T, E>;
pub struct AcceptEncoding {
encodings: Vec<String>,
only_one_encoding: bool,
hash_value: String,
plugin_step: PluginStep,
}
impl TryFrom<&PluginConf> for AcceptEncoding {
type Error = Error;
fn try_from(value: &PluginConf) -> Result<Self> {
let hash_value = get_hash_key(value);
let only_one_encoding = get_bool_conf(value, "only_one_encoding");
let mut encodings = vec![];
for encoding in get_str_conf(value, "encodings").split(",") {
let v = encoding.trim();
if !v.is_empty() {
encodings.push(v.to_string());
}
}
Ok(Self {
encodings,
only_one_encoding,
hash_value,
plugin_step: PluginStep::EarlyRequest,
})
}
}
impl AcceptEncoding {
pub fn new(params: &PluginConf) -> Result<Self> {
debug!(params = params.to_string(), "new accept encoding plugin");
Self::try_from(params)
}
}
#[async_trait]
impl Plugin for AcceptEncoding {
#[inline]
fn config_key(&self) -> Cow<'_, str> {
Cow::Borrowed(&self.hash_value)
}
#[inline]
async fn handle_request(
&self,
step: PluginStep,
session: &mut Session,
_ctx: &mut Ctx,
) -> pingora::Result<RequestPluginResult> {
if step != self.plugin_step {
return Ok(RequestPluginResult::Skipped);
}
let header = session.req_header_mut();
let Some(accept_encoding) =
header.headers.get(http::header::ACCEPT_ENCODING)
else {
return Ok(RequestPluginResult::Skipped);
};
let accept_encoding = accept_encoding.to_str().unwrap_or_default();
let mut new_accept_encodings: SmallVec<[String; 3]> =
SmallVec::with_capacity(self.encodings.len());
for encoding in self.encodings.iter() {
if self.only_one_encoding && !new_accept_encodings.is_empty() {
break;
}
if accept_encoding.contains(encoding) {
new_accept_encodings.push(encoding.to_string());
}
}
if new_accept_encodings.is_empty() {
header.remove_header(&http::header::ACCEPT_ENCODING);
} else {
let _ = header.insert_header(
http::header::ACCEPT_ENCODING,
new_accept_encodings.join(", "),
);
}
Ok(RequestPluginResult::Continue)
}
}
#[ctor]
fn init() {
let factory = get_plugin_factory();
factory.register("accept_encoding", |params| {
Ok(Arc::new(AcceptEncoding::new(params)?))
});
}
#[cfg(test)]
mod tests {
use super::{AcceptEncoding, Plugin};
use pingap_config::PluginConf;
use pingap_core::{Ctx, PluginStep};
use pingora::modules::http::HttpModules;
use pingora::proxy::Session;
use pretty_assertions::assert_eq;
use tokio_test::io::Builder;
#[test]
fn test_accept_encoding_params() {
let params = AcceptEncoding::try_from(
&toml::from_str::<PluginConf>(
r###"
encodings = "zstd, br, gzip"
only_one_encoding = true
"###,
)
.unwrap(),
)
.unwrap();
assert_eq!("zstd,br,gzip", params.encodings.join(","));
assert_eq!(true, params.only_one_encoding);
}
#[tokio::test]
async fn test_accept_conding() {
let accept_encoding = AcceptEncoding::try_from(
&toml::from_str::<PluginConf>(
r###"
encodings = "zstd, br, gzip"
only_one_encoding = true
"###,
)
.unwrap(),
)
.unwrap();
let headers = ["Accept-Encoding: gzip, zstd"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1_with_modules(
Box::new(mock_io),
&HttpModules::new(),
);
session.read_request().await.unwrap();
let _ = accept_encoding
.handle_request(
PluginStep::EarlyRequest,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(
"zstd",
session
.req_header()
.headers
.get("Accept-Encoding")
.unwrap()
.to_str()
.unwrap()
);
let accept_encoding = AcceptEncoding::try_from(
&toml::from_str::<PluginConf>(
r###"
encodings = "zstd, br, gzip"
"###,
)
.unwrap(),
)
.unwrap();
let headers = ["Accept-Encoding: gzip, br, zstd"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1_with_modules(
Box::new(mock_io),
&HttpModules::new(),
);
session.read_request().await.unwrap();
let _ = accept_encoding
.handle_request(
PluginStep::EarlyRequest,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(
"zstd, br, gzip",
session
.req_header()
.headers
.get("Accept-Encoding")
.unwrap()
.to_str()
.unwrap()
);
let headers = ["Accept-Encoding: snappy"].join("\r\n");
let input_header =
format!("GET /vicanso/pingap?size=1 HTTP/1.1\r\n{headers}\r\n\r\n");
let mock_io = Builder::new().read(input_header.as_bytes()).build();
let mut session = Session::new_h1_with_modules(
Box::new(mock_io),
&HttpModules::new(),
);
session.read_request().await.unwrap();
let _ = accept_encoding
.handle_request(
PluginStep::EarlyRequest,
&mut session,
&mut Ctx::default(),
)
.await
.unwrap();
assert_eq!(
true,
session
.req_header()
.headers
.get("Accept-Encoding")
.is_none()
);
}
}