cardinal_plugins/runner/
mod.rs1use crate::container::PluginContainer;
2use async_trait::async_trait;
3use cardinal_base::context::CardinalContext;
4use cardinal_base::destinations::container::DestinationWrapper;
5use cardinal_errors::CardinalError;
6use pingora::http::ResponseHeader;
7use pingora::proxy::Session;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum MiddlewareResult {
13 Continue(HashMap<String, String>),
14 Responded,
15}
16
17#[async_trait]
18pub trait RequestMiddleware: Send + Sync + 'static {
19 async fn on_request(
20 &self,
21 session: &mut Session,
22 backend: Arc<DestinationWrapper>,
23 cardinal: Arc<CardinalContext>,
24 ) -> Result<MiddlewareResult, CardinalError>;
25}
26
27#[async_trait]
28pub trait ResponseMiddleware: Send + Sync + 'static {
29 async fn on_response(
30 &self,
31 session: &mut Session,
32 backend: Arc<DestinationWrapper>,
33 response: &mut ResponseHeader,
34 cardinal: Arc<CardinalContext>,
35 );
36}
37
38pub type DynRequestMiddleware = dyn RequestMiddleware + Send + Sync + 'static;
39pub type DynResponseMiddleware = dyn ResponseMiddleware + Send + Sync + 'static;
40
41#[derive(Clone)]
42pub struct PluginRunner {
43 context: Arc<CardinalContext>,
44 global_request: Arc<Vec<String>>,
45 global_response: Arc<Vec<String>>,
46}
47
48impl PluginRunner {
49 pub fn new(context: Arc<CardinalContext>) -> Self {
50 let global_request = context.config.server.global_request_middleware.clone();
51 let global_response = context.config.server.global_response_middleware.clone();
52
53 Self {
54 context,
55 global_request: Arc::new(global_request),
56 global_response: Arc::new(global_response),
57 }
58 }
59
60 fn global_request_filters(&self) -> &[String] {
61 &self.global_request
62 }
63
64 fn global_response_filters(&self) -> &[String] {
65 &self.global_response
66 }
67
68 pub async fn run_request_filters(
69 &self,
70 session: &mut Session,
71 backend: Arc<DestinationWrapper>,
72 ) -> Result<MiddlewareResult, CardinalError> {
73 let filter_container = self.context.get::<PluginContainer>().await?;
74 let mut resp_headers = HashMap::new();
75
76 for filter in self.global_request_filters() {
77 let run = filter_container
78 .run_request_filter(filter, session, backend.clone(), self.context.clone())
79 .await?;
80
81 match run {
82 MiddlewareResult::Continue(middleware_resp_headers) => {
83 resp_headers.extend(middleware_resp_headers)
84 }
85 MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
86 }
87 }
88
89 let inbound_middleware = backend.get_inbound_middleware();
90 for middleware in inbound_middleware {
91 let run = filter_container
92 .run_request_filter(
93 &middleware.name,
94 session,
95 backend.clone(),
96 self.context.clone(),
97 )
98 .await?;
99
100 match run {
101 MiddlewareResult::Continue(middleware_resp_headers) => {
102 resp_headers.extend(middleware_resp_headers)
103 }
104 MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
105 }
106 }
107
108 Ok(MiddlewareResult::Continue(resp_headers))
109 }
110
111 pub async fn run_response_filters(
112 &self,
113 session: &mut Session,
114 backend: Arc<DestinationWrapper>,
115 response: &mut ResponseHeader,
116 ) {
117 let filter_container = self.context.get::<PluginContainer>().await.unwrap();
118
119 for filter in self.global_response_filters() {
120 filter_container
121 .run_response_filter(
122 filter,
123 session,
124 backend.clone(),
125 response,
126 self.context.clone(),
127 )
128 .await;
129 }
130
131 let outbound_middleware = backend.get_outbound_middleware();
132 for middleware in outbound_middleware {
133 let middleware_name = &middleware.name;
134 filter_container
135 .run_response_filter(
136 middleware_name,
137 session,
138 backend.clone(),
139 response,
140 self.context.clone(),
141 )
142 .await;
143 }
144 }
145}