pub mod config;
pub mod error;
pub mod grpc;
pub mod prelude;
pub mod tls;
pub mod websocket;
use arc_swap::ArcSwap;
use bytes::Bytes;
use dashmap::DashMap;
use pingora::Error;
use pingora::http::ResponseHeader;
use pingora::proxy::Session;
use pingora::server::Server;
use std::any::{Any, TypeId};
use std::collections::HashMap;
use std::sync::Arc;
pub trait Context: Send + Sync {
fn insert<T: Any + Send + Sync>(&self, value: T);
fn get<T: Any + Send + Sync>(&self) -> Option<Arc<T>>;
fn remove<T: Any + Send + Sync>(&self) -> Option<Arc<T>>;
}
#[derive(Clone)]
pub struct AppContext {
data: Arc<ArcSwap<HashMap<TypeId, Arc<dyn Any + Send + Sync>>>>,
}
impl AppContext {
pub fn new() -> Self {
Self {
data: Arc::new(ArcSwap::from_pointee(HashMap::new())),
}
}
}
impl Default for AppContext {
fn default() -> Self {
Self::new()
}
}
impl Context for AppContext {
fn insert<T: Any + Send + Sync>(&self, value: T) {
let value = Arc::new(value) as Arc<dyn Any + Send + Sync>;
self.data.rcu(move |old| {
let mut next = (**old).clone();
next.insert(TypeId::of::<T>(), value.clone());
next
});
}
fn get<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
let data = self.data.load();
let value = data.get(&TypeId::of::<T>()).cloned()?;
Arc::downcast::<T>(value).ok()
}
fn remove<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
let mut removed: Option<Arc<dyn Any + Send + Sync>> = None;
self.data.rcu(|old| {
let mut next = (**old).clone();
removed = next.remove(&TypeId::of::<T>());
next
});
removed.and_then(|value| Arc::downcast::<T>(value).ok())
}
}
#[derive(Clone)]
pub struct RequestContext {
data: Arc<DashMap<TypeId, Arc<dyn Any + Send + Sync>>>,
}
impl RequestContext {
pub fn new() -> Self {
Self {
data: Arc::new(DashMap::new()),
}
}
}
impl Default for RequestContext {
fn default() -> Self {
Self::new()
}
}
impl Context for RequestContext {
fn insert<T: Any + Send + Sync>(&self, value: T) {
let value = Arc::new(value) as Arc<dyn Any + Send + Sync>;
self.data.insert(TypeId::of::<T>(), value);
}
fn get<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
let value = self.data.get(&TypeId::of::<T>())?.clone();
Arc::downcast::<T>(value).ok()
}
fn remove<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
let (_, value) = self.data.remove(&TypeId::of::<T>())?;
Arc::downcast::<T>(value).ok()
}
}
#[async_trait::async_trait]
pub trait JokowayMiddleware: Send + Sync {
type CTX: Send + Sync + 'static;
fn name(&self) -> &'static str;
fn new_ctx(&self) -> Self::CTX;
fn order(&self) -> i16 {
0
}
async fn request_filter(
&self,
_session: &mut Session,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> Result<bool, Box<Error>> {
Ok(false)
}
async fn upstream_response_filter(
&self,
_session: &mut Session,
_upstream_response: &mut ResponseHeader,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> Result<(), Box<Error>> {
Ok(())
}
fn response_body_filter(
&self,
_session: &mut Session,
_body: &mut Option<Bytes>,
_end_of_stream: bool,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> Result<Option<std::time::Duration>, Box<Error>> {
Ok(None)
}
async fn request_body_filter(
&self,
_session: &mut Session,
_body: &mut Option<Bytes>,
_end_of_stream: bool,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> Result<(), Box<Error>> {
Ok(())
}
fn on_websocket_message(
&self,
_direction: crate::websocket::WebsocketDirection,
frame: crate::websocket::WsFrame,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> crate::websocket::WebsocketMessageAction {
crate::websocket::WebsocketMessageAction::Forward(frame)
}
fn on_websocket_error(
&self,
_direction: crate::websocket::WebsocketDirection,
_error: crate::websocket::WebsocketError,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> crate::websocket::WebsocketErrorAction {
crate::websocket::WebsocketErrorAction::PassThrough
}
fn on_grpc_message(
&self,
_direction: crate::grpc::GrpcDirection,
message: crate::grpc::GrpcMessage,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> crate::grpc::GrpcMessageAction {
crate::grpc::GrpcMessageAction::Forward(message)
}
}
#[async_trait::async_trait]
pub trait JokowayMiddlewareDyn: Send + Sync {
fn name(&self) -> &'static str;
fn order(&self) -> i16 {
0
}
fn new_ctx_dyn(&self) -> Box<dyn Any + Send + Sync>;
async fn request_filter_dyn(
&self,
session: &mut Session,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<bool, Box<Error>>;
async fn upstream_response_filter_dyn(
&self,
session: &mut Session,
upstream_response: &mut ResponseHeader,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<(), Box<Error>>;
fn response_body_filter_dyn(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<Option<std::time::Duration>, Box<Error>>;
async fn request_body_filter_dyn(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<(), Box<Error>>;
fn on_websocket_message_dyn(
&self,
direction: crate::websocket::WebsocketDirection,
frame: crate::websocket::WsFrame,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> crate::websocket::WebsocketMessageAction;
fn on_websocket_error_dyn(
&self,
direction: crate::websocket::WebsocketDirection,
error: crate::websocket::WebsocketError,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> crate::websocket::WebsocketErrorAction;
fn on_grpc_message_dyn(
&self,
direction: crate::grpc::GrpcDirection,
message: crate::grpc::GrpcMessage,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> crate::grpc::GrpcMessageAction;
}
#[async_trait::async_trait]
impl<T: JokowayMiddleware> JokowayMiddlewareDyn for T {
fn name(&self) -> &'static str {
JokowayMiddleware::name(self)
}
fn order(&self) -> i16 {
JokowayMiddleware::order(self)
}
fn new_ctx_dyn(&self) -> Box<dyn Any + Send + Sync> {
Box::new(self.new_ctx())
}
async fn request_filter_dyn(
&self,
session: &mut Session,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<bool, Box<Error>> {
let ctx = ctx.downcast_mut::<T::CTX>().ok_or_else(|| {
Error::explain(pingora::ErrorType::InternalError, "Invalid context type")
})?;
self.request_filter(session, ctx, app_ctx, request_ctx)
.await
}
async fn upstream_response_filter_dyn(
&self,
session: &mut Session,
upstream_response: &mut ResponseHeader,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<(), Box<Error>> {
let ctx = ctx.downcast_mut::<T::CTX>().ok_or_else(|| {
Error::explain(pingora::ErrorType::InternalError, "Invalid context type")
})?;
self.upstream_response_filter(session, upstream_response, ctx, app_ctx, request_ctx)
.await
}
fn response_body_filter_dyn(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<Option<std::time::Duration>, Box<Error>> {
let ctx = ctx.downcast_mut::<T::CTX>().ok_or_else(|| {
Error::explain(pingora::ErrorType::InternalError, "Invalid context type")
})?;
self.response_body_filter(session, body, end_of_stream, ctx, app_ctx, request_ctx)
}
async fn request_body_filter_dyn(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<(), Box<Error>> {
let ctx = ctx.downcast_mut::<T::CTX>().ok_or_else(|| {
Error::explain(pingora::ErrorType::InternalError, "Invalid context type")
})?;
self.request_body_filter(session, body, end_of_stream, ctx, app_ctx, request_ctx)
.await
}
fn on_websocket_message_dyn(
&self,
direction: crate::websocket::WebsocketDirection,
frame: crate::websocket::WsFrame,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> crate::websocket::WebsocketMessageAction {
let ctx = ctx
.downcast_mut::<T::CTX>()
.expect("Invalid context type for JokowayMiddleware");
self.on_websocket_message(direction, frame, ctx, app_ctx, request_ctx)
}
fn on_websocket_error_dyn(
&self,
direction: crate::websocket::WebsocketDirection,
error: crate::websocket::WebsocketError,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> crate::websocket::WebsocketErrorAction {
let ctx = ctx
.downcast_mut::<T::CTX>()
.expect("Invalid context type for JokowayMiddleware");
self.on_websocket_error(direction, error, ctx, app_ctx, request_ctx)
}
fn on_grpc_message_dyn(
&self,
direction: crate::grpc::GrpcDirection,
message: crate::grpc::GrpcMessage,
ctx: &mut (dyn Any + Send + Sync),
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> crate::grpc::GrpcMessageAction {
let ctx = ctx
.downcast_mut::<T::CTX>()
.expect("Invalid context type for JokowayMiddleware");
self.on_grpc_message(direction, message, ctx, app_ctx, request_ctx)
}
}
pub trait JokowayExtension: Send + Sync {
fn order(&self) -> i16 {
0
}
fn init(
&self,
_server: &mut Server,
_app_ctx: &mut AppContext,
_middlewares: &mut Vec<Arc<dyn JokowayMiddlewareDyn>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
Ok(())
}
}