cardinal_plugins/runner/
mod.rs

1use crate::plugin_executor::CardinalPluginExecutor;
2use crate::request_context::RequestContext;
3use async_trait::async_trait;
4use cardinal_base::context::CardinalContext;
5use cardinal_errors::internal::CardinalInternalError;
6use cardinal_errors::CardinalError;
7use pingora::http::ResponseHeader;
8use pingora::proxy::Session;
9use std::collections::HashMap;
10use std::sync::Arc;
11
12#[derive(Debug, Clone, PartialEq, Eq)]
13pub enum MiddlewareResult {
14    Continue(HashMap<String, String>),
15    Responded,
16}
17
18#[async_trait]
19pub trait RequestMiddleware: Send + Sync + 'static {
20    async fn on_request(
21        &self,
22        session: &mut Session,
23        req_ctx: &mut RequestContext,
24        cardinal: Arc<CardinalContext>,
25    ) -> Result<MiddlewareResult, CardinalError>;
26}
27
28#[async_trait]
29pub trait ResponseMiddleware: Send + Sync + 'static {
30    async fn on_response(
31        &self,
32        session: &mut Session,
33        req_ctx: &mut RequestContext,
34        response: &mut ResponseHeader,
35        cardinal: Arc<CardinalContext>,
36    );
37}
38
39pub type DynRequestMiddleware = dyn RequestMiddleware + Send + Sync + 'static;
40pub type DynResponseMiddleware = dyn ResponseMiddleware + Send + Sync + 'static;
41
42#[derive(Clone)]
43pub struct PluginRunner {
44    global_request: Arc<Vec<String>>,
45    global_response: Arc<Vec<String>>,
46    plugin_executor: Arc<dyn CardinalPluginExecutor>,
47}
48
49impl PluginRunner {
50    pub fn new(
51        context: Arc<CardinalContext>,
52        plugin_executor: Arc<dyn CardinalPluginExecutor>,
53    ) -> Self {
54        let global_request = context.config.server.global_request_middleware.clone();
55        let global_response = context.config.server.global_response_middleware.clone();
56
57        Self {
58            global_request: Arc::new(global_request),
59            global_response: Arc::new(global_response),
60            plugin_executor,
61        }
62    }
63
64    fn global_request_filters(&self) -> &[String] {
65        &self.global_request
66    }
67
68    fn global_response_filters(&self) -> &[String] {
69        &self.global_response
70    }
71
72    pub async fn can_run(
73        &self,
74        filter: &str,
75        session: &mut Session,
76        req_ctx: &mut RequestContext,
77    ) -> Result<bool, CardinalError> {
78        self.plugin_executor
79            .can_run_plugin(filter, session, req_ctx)
80            .await
81            .map_err(|e| CardinalInternalError::RequestPluginError(format!("{e:?}")).into())
82    }
83
84    pub async fn run_request_filters(
85        &self,
86        session: &mut Session,
87        req_ctx: &mut RequestContext,
88    ) -> Result<MiddlewareResult, CardinalError> {
89        let mut resp_headers = HashMap::new();
90
91        for filter in self.global_request_filters() {
92            let can_run = self.can_run(filter, session, req_ctx).await?;
93
94            if !can_run {
95                continue;
96            }
97
98            let run = self
99                .plugin_executor
100                .run_request_filter(filter, session, req_ctx)
101                .await?;
102
103            match run {
104                MiddlewareResult::Continue(middleware_resp_headers) => {
105                    resp_headers.extend(middleware_resp_headers)
106                }
107                MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
108            }
109        }
110
111        let backend = req_ctx.backend.clone(); // Cheap clone
112        let inbound_middleware = backend.get_inbound_middleware();
113        for middleware in inbound_middleware {
114            let middleware_name = &middleware.name;
115            let can_run = self.can_run(middleware_name, session, req_ctx).await?;
116
117            if !can_run {
118                continue;
119            }
120
121            let run = self
122                .plugin_executor
123                .run_request_filter(middleware_name, session, req_ctx)
124                .await?;
125
126            match run {
127                MiddlewareResult::Continue(middleware_resp_headers) => {
128                    resp_headers.extend(middleware_resp_headers)
129                }
130                MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
131            }
132        }
133
134        Ok(MiddlewareResult::Continue(resp_headers))
135    }
136
137    pub async fn run_response_filters(
138        &self,
139        session: &mut Session,
140        req_ctx: &mut RequestContext,
141        response: &mut ResponseHeader,
142    ) {
143        for filter in self.global_response_filters() {
144            let can_run = self
145                .can_run(filter, session, req_ctx)
146                .await
147                .unwrap_or(false);
148
149            if !can_run {
150                continue;
151            }
152
153            let _ = self
154                .plugin_executor
155                .run_response_filter(filter, session, req_ctx, response)
156                .await;
157        }
158
159        let backend = req_ctx.backend.clone(); // Cheap clone
160        let outbound_middleware = backend.get_outbound_middleware();
161        for middleware in outbound_middleware {
162            let middleware_name = &middleware.name;
163
164            let can_run = self
165                .can_run(middleware_name, session, req_ctx)
166                .await
167                .unwrap_or(false);
168
169            if !can_run {
170                continue;
171            }
172
173            let _ = self
174                .plugin_executor
175                .run_response_filter(middleware_name, session, req_ctx, response)
176                .await;
177        }
178    }
179}