cardinal_plugins/runner/
mod.rs1use crate::container::PluginContainer;
2use crate::request_context::RequestContext;
3use async_trait::async_trait;
4use cardinal_base::context::CardinalContext;
5use cardinal_errors::CardinalError;
6use pingora::http::ResponseHeader;
7use pingora::proxy::Session;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum MiddlewareResult {
13 Continue(HashMap<String, String>),
14 Responded,
15}
16
17#[async_trait]
18pub trait RequestMiddleware: Send + Sync + 'static {
19 async fn on_request(
20 &self,
21 session: &mut Session,
22 req_ctx: &mut RequestContext,
23 cardinal: Arc<CardinalContext>,
24 ) -> Result<MiddlewareResult, CardinalError>;
25}
26
27#[async_trait]
28pub trait ResponseMiddleware: Send + Sync + 'static {
29 async fn on_response(
30 &self,
31 session: &mut Session,
32 req_ctx: &mut RequestContext,
33 response: &mut ResponseHeader,
34 cardinal: Arc<CardinalContext>,
35 );
36}
37
38pub type DynRequestMiddleware = dyn RequestMiddleware + Send + Sync + 'static;
39pub type DynResponseMiddleware = dyn ResponseMiddleware + Send + Sync + 'static;
40
41#[derive(Clone)]
42pub struct PluginRunner {
43 context: Arc<CardinalContext>,
44 global_request: Arc<Vec<String>>,
45 global_response: Arc<Vec<String>>,
46}
47
48impl PluginRunner {
49 pub fn new(context: Arc<CardinalContext>) -> Self {
50 let global_request = context.config.server.global_request_middleware.clone();
51 let global_response = context.config.server.global_response_middleware.clone();
52
53 Self {
54 context,
55 global_request: Arc::new(global_request),
56 global_response: Arc::new(global_response),
57 }
58 }
59
60 fn global_request_filters(&self) -> &[String] {
61 &self.global_request
62 }
63
64 fn global_response_filters(&self) -> &[String] {
65 &self.global_response
66 }
67
68 pub async fn run_request_filters(
69 &self,
70 session: &mut Session,
71 req_ctx: &mut RequestContext,
72 ) -> Result<MiddlewareResult, CardinalError> {
73 let filter_container = self.context.get::<PluginContainer>().await?;
74 let mut resp_headers = HashMap::new();
75
76 for filter in self.global_request_filters() {
77 let run = filter_container
78 .run_request_filter(filter, session, req_ctx, self.context.clone())
79 .await?;
80
81 match run {
82 MiddlewareResult::Continue(middleware_resp_headers) => {
83 resp_headers.extend(middleware_resp_headers)
84 }
85 MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
86 }
87 }
88
89 let backend = req_ctx.backend.clone(); let inbound_middleware = backend.get_inbound_middleware();
91 for middleware in inbound_middleware {
92 let run = filter_container
93 .run_request_filter(&middleware.name, session, req_ctx, self.context.clone())
94 .await?;
95
96 match run {
97 MiddlewareResult::Continue(middleware_resp_headers) => {
98 resp_headers.extend(middleware_resp_headers)
99 }
100 MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
101 }
102 }
103
104 Ok(MiddlewareResult::Continue(resp_headers))
105 }
106
107 pub async fn run_response_filters(
108 &self,
109 session: &mut Session,
110 req_ctx: &mut RequestContext,
111 response: &mut ResponseHeader,
112 ) {
113 let filter_container = self.context.get::<PluginContainer>().await.unwrap();
114
115 for filter in self.global_response_filters() {
116 filter_container
117 .run_response_filter(filter, session, req_ctx, response, self.context.clone())
118 .await;
119 }
120
121 let backend = req_ctx.backend.clone(); let outbound_middleware = backend.get_outbound_middleware();
123 for middleware in outbound_middleware {
124 let middleware_name = &middleware.name;
125 filter_container
126 .run_response_filter(
127 middleware_name,
128 session,
129 req_ctx,
130 response,
131 self.context.clone(),
132 )
133 .await;
134 }
135 }
136}