cardinal_plugins/runner/
mod.rs

1use crate::plugin_executor::CardinalPluginExecutor;
2use crate::request_context::RequestContext;
3use async_trait::async_trait;
4use cardinal_base::context::CardinalContext;
5use cardinal_errors::CardinalError;
6use pingora::http::ResponseHeader;
7use pingora::proxy::Session;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum MiddlewareResult {
13    Continue(HashMap<String, String>),
14    Responded,
15}
16
17#[async_trait]
18pub trait RequestMiddleware: Send + Sync + 'static {
19    async fn on_request(
20        &self,
21        session: &mut Session,
22        req_ctx: &mut RequestContext,
23        cardinal: Arc<CardinalContext>,
24    ) -> Result<MiddlewareResult, CardinalError>;
25}
26
27#[async_trait]
28pub trait ResponseMiddleware: Send + Sync + 'static {
29    async fn on_response(
30        &self,
31        session: &mut Session,
32        req_ctx: &mut RequestContext,
33        response: &mut ResponseHeader,
34        cardinal: Arc<CardinalContext>,
35    );
36}
37
38pub type DynRequestMiddleware = dyn RequestMiddleware + Send + Sync + 'static;
39pub type DynResponseMiddleware = dyn ResponseMiddleware + Send + Sync + 'static;
40
41#[derive(Clone)]
42pub struct PluginRunner {
43    global_request: Arc<Vec<String>>,
44    global_response: Arc<Vec<String>>,
45    plugin_executor: Arc<dyn CardinalPluginExecutor>,
46}
47
48impl PluginRunner {
49    pub fn new(
50        context: Arc<CardinalContext>,
51        plugin_executor: Arc<dyn CardinalPluginExecutor>,
52    ) -> Self {
53        let global_request = context.config.server.global_request_middleware.clone();
54        let global_response = context.config.server.global_response_middleware.clone();
55
56        Self {
57            global_request: Arc::new(global_request),
58            global_response: Arc::new(global_response),
59            plugin_executor,
60        }
61    }
62
63    fn global_request_filters(&self) -> &[String] {
64        &self.global_request
65    }
66
67    fn global_response_filters(&self) -> &[String] {
68        &self.global_response
69    }
70
71    pub async fn run_request_filters(
72        &self,
73        session: &mut Session,
74        req_ctx: &mut RequestContext,
75    ) -> Result<MiddlewareResult, CardinalError> {
76        let mut resp_headers = HashMap::new();
77
78        for filter in self.global_request_filters() {
79            let run = self
80                .plugin_executor
81                .run_request_filter(filter, session, req_ctx)
82                .await?;
83
84            match run {
85                MiddlewareResult::Continue(middleware_resp_headers) => {
86                    resp_headers.extend(middleware_resp_headers)
87                }
88                MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
89            }
90        }
91
92        let backend = req_ctx.backend.clone(); // Cheap clone
93        let inbound_middleware = backend.get_inbound_middleware();
94        for middleware in inbound_middleware {
95            let run = self
96                .plugin_executor
97                .run_request_filter(&middleware.name, session, req_ctx)
98                .await?;
99
100            match run {
101                MiddlewareResult::Continue(middleware_resp_headers) => {
102                    resp_headers.extend(middleware_resp_headers)
103                }
104                MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
105            }
106        }
107
108        Ok(MiddlewareResult::Continue(resp_headers))
109    }
110
111    pub async fn run_response_filters(
112        &self,
113        session: &mut Session,
114        req_ctx: &mut RequestContext,
115        response: &mut ResponseHeader,
116    ) {
117        for filter in self.global_response_filters() {
118            let _ = self
119                .plugin_executor
120                .run_response_filter(filter, session, req_ctx, response)
121                .await;
122        }
123
124        let backend = req_ctx.backend.clone(); // Cheap clone
125        let outbound_middleware = backend.get_outbound_middleware();
126        for middleware in outbound_middleware {
127            let middleware_name = &middleware.name;
128            let _ = self
129                .plugin_executor
130                .run_response_filter(middleware_name, session, req_ctx, response)
131                .await;
132        }
133    }
134}