pub mod compression;
pub mod grpc_web;
use async_trait::async_trait;
use bytes::Bytes;
use http::HeaderMap;
use once_cell::sync::OnceCell;
use pingora_error::Result;
use pingora_http::{RequestHeader, ResponseHeader};
use std::any::Any;
use std::any::TypeId;
use std::collections::HashMap;
use std::sync::Arc;
#[async_trait]
pub trait HttpModule {
async fn request_header_filter(&mut self, _req: &mut RequestHeader) -> Result<()> {
Ok(())
}
async fn request_body_filter(
&mut self,
_body: &mut Option<Bytes>,
_end_of_stream: bool,
) -> Result<()> {
Ok(())
}
async fn response_header_filter(
&mut self,
_resp: &mut ResponseHeader,
_end_of_stream: bool,
) -> Result<()> {
Ok(())
}
fn response_body_filter(
&mut self,
_body: &mut Option<Bytes>,
_end_of_stream: bool,
) -> Result<()> {
Ok(())
}
fn response_trailer_filter(
&mut self,
_trailers: &mut Option<Box<HeaderMap>>,
) -> Result<Option<Bytes>> {
Ok(None)
}
fn as_any(&self) -> &dyn Any;
fn as_any_mut(&mut self) -> &mut dyn Any;
}
pub type Module = Box<dyn HttpModule + 'static + Send + Sync>;
pub trait HttpModuleBuilder {
fn order(&self) -> i16 {
0
}
fn init(&self) -> Module;
}
pub type ModuleBuilder = Box<dyn HttpModuleBuilder + 'static + Send + Sync>;
pub struct HttpModules {
modules: Vec<ModuleBuilder>,
module_index: OnceCell<Arc<HashMap<TypeId, usize>>>,
}
impl HttpModules {
pub fn new() -> Self {
HttpModules {
modules: vec![],
module_index: OnceCell::new(),
}
}
pub fn add_module(&mut self, builder: ModuleBuilder) {
if self.module_index.get().is_some() {
panic!("cannot add module after ctx is already built")
}
self.modules.push(builder);
self.modules.sort_by_key(|m| -m.order());
}
pub fn build_ctx(&self) -> HttpModuleCtx {
let module_ctx: Vec<_> = self.modules.iter().map(|b| b.init()).collect();
let module_index = self
.module_index
.get_or_init(|| {
let mut module_index = HashMap::with_capacity(self.modules.len());
for (i, c) in module_ctx.iter().enumerate() {
let exist = module_index.insert(c.as_any().type_id(), i);
if exist.is_some() {
panic!("duplicated filters found")
}
}
Arc::new(module_index)
})
.clone();
HttpModuleCtx {
module_ctx,
module_index,
}
}
}
pub struct HttpModuleCtx {
module_ctx: Vec<Module>,
module_index: Arc<HashMap<TypeId, usize>>,
}
impl HttpModuleCtx {
pub fn empty() -> Self {
HttpModuleCtx {
module_ctx: vec![],
module_index: Arc::new(HashMap::new()),
}
}
pub fn get<T: 'static>(&self) -> Option<&T> {
let idx = self.module_index.get(&TypeId::of::<T>())?;
let ctx = &self.module_ctx[*idx];
Some(
ctx.as_any()
.downcast_ref::<T>()
.expect("type should always match"),
)
}
pub fn get_mut<T: 'static>(&mut self) -> Option<&mut T> {
let idx = self.module_index.get(&TypeId::of::<T>())?;
let ctx = &mut self.module_ctx[*idx];
Some(
ctx.as_any_mut()
.downcast_mut::<T>()
.expect("type should always match"),
)
}
pub async fn request_header_filter(&mut self, req: &mut RequestHeader) -> Result<()> {
for filter in self.module_ctx.iter_mut() {
filter.request_header_filter(req).await?;
}
Ok(())
}
pub async fn request_body_filter(
&mut self,
body: &mut Option<Bytes>,
end_of_stream: bool,
) -> Result<()> {
for filter in self.module_ctx.iter_mut() {
filter.request_body_filter(body, end_of_stream).await?;
}
Ok(())
}
pub async fn response_header_filter(
&mut self,
req: &mut ResponseHeader,
end_of_stream: bool,
) -> Result<()> {
for filter in self.module_ctx.iter_mut() {
filter.response_header_filter(req, end_of_stream).await?;
}
Ok(())
}
pub fn response_body_filter(
&mut self,
body: &mut Option<Bytes>,
end_of_stream: bool,
) -> Result<()> {
for filter in self.module_ctx.iter_mut() {
filter.response_body_filter(body, end_of_stream)?;
}
Ok(())
}
pub fn response_trailer_filter(
&mut self,
trailers: &mut Option<Box<HeaderMap>>,
) -> Result<Option<Bytes>> {
let mut encoded = None;
for filter in self.module_ctx.iter_mut() {
if let Some(buf) = filter.response_trailer_filter(trailers)? {
encoded = Some(buf);
}
}
Ok(encoded)
}
}
#[cfg(test)]
mod tests {
use super::*;
struct MyModule;
#[async_trait]
impl HttpModule for MyModule {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
async fn request_header_filter(&mut self, req: &mut RequestHeader) -> Result<()> {
req.insert_header("my-filter", "1")
}
}
struct MyModuleBuilder;
impl HttpModuleBuilder for MyModuleBuilder {
fn order(&self) -> i16 {
1
}
fn init(&self) -> Module {
Box::new(MyModule)
}
}
struct MyOtherModule;
#[async_trait]
impl HttpModule for MyOtherModule {
fn as_any(&self) -> &dyn Any {
self
}
fn as_any_mut(&mut self) -> &mut dyn Any {
self
}
async fn request_header_filter(&mut self, req: &mut RequestHeader) -> Result<()> {
if req.headers.get("my-filter").is_some() {
req.insert_header("my-filter", "2")
} else {
req.insert_header("my-other-filter", "1")
}
}
}
struct MyOtherModuleBuilder;
impl HttpModuleBuilder for MyOtherModuleBuilder {
fn order(&self) -> i16 {
-1
}
fn init(&self) -> Module {
Box::new(MyOtherModule)
}
}
#[test]
fn test_module_get() {
let mut http_module = HttpModules::new();
http_module.add_module(Box::new(MyModuleBuilder));
http_module.add_module(Box::new(MyOtherModuleBuilder));
let mut ctx = http_module.build_ctx();
assert!(ctx.get::<MyModule>().is_some());
assert!(ctx.get::<MyOtherModule>().is_some());
assert!(ctx.get::<usize>().is_none());
assert!(ctx.get_mut::<MyModule>().is_some());
assert!(ctx.get_mut::<MyOtherModule>().is_some());
assert!(ctx.get_mut::<usize>().is_none());
}
#[tokio::test]
async fn test_module_filter() {
let mut http_module = HttpModules::new();
http_module.add_module(Box::new(MyOtherModuleBuilder));
http_module.add_module(Box::new(MyModuleBuilder));
let mut ctx = http_module.build_ctx();
let mut req = RequestHeader::build("Get", b"/", None).unwrap();
ctx.request_header_filter(&mut req).await.unwrap();
assert_eq!(req.headers.get("my-filter").unwrap(), "2");
assert!(req.headers.get("my-other-filter").is_none());
}
}