cardinal_plugins/runner/
mod.rs

1use crate::container::PluginContainer;
2use crate::request_context::RequestContext;
3use async_trait::async_trait;
4use cardinal_base::context::CardinalContext;
5use cardinal_errors::CardinalError;
6use pingora::http::ResponseHeader;
7use pingora::proxy::Session;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum MiddlewareResult {
13    Continue(HashMap<String, String>),
14    Responded,
15}
16
17#[async_trait]
18pub trait RequestMiddleware: Send + Sync + 'static {
19    async fn on_request(
20        &self,
21        session: &mut Session,
22        req_ctx: &mut RequestContext,
23        cardinal: Arc<CardinalContext>,
24    ) -> Result<MiddlewareResult, CardinalError>;
25}
26
27#[async_trait]
28pub trait ResponseMiddleware: Send + Sync + 'static {
29    async fn on_response(
30        &self,
31        session: &mut Session,
32        req_ctx: &mut RequestContext,
33        response: &mut ResponseHeader,
34        cardinal: Arc<CardinalContext>,
35    );
36}
37
38pub type DynRequestMiddleware = dyn RequestMiddleware + Send + Sync + 'static;
39pub type DynResponseMiddleware = dyn ResponseMiddleware + Send + Sync + 'static;
40
41#[derive(Clone)]
42pub struct PluginRunner {
43    context: Arc<CardinalContext>,
44    global_request: Arc<Vec<String>>,
45    global_response: Arc<Vec<String>>,
46}
47
48impl PluginRunner {
49    pub fn new(context: Arc<CardinalContext>) -> Self {
50        let global_request = context.config.server.global_request_middleware.clone();
51        let global_response = context.config.server.global_response_middleware.clone();
52
53        Self {
54            context,
55            global_request: Arc::new(global_request),
56            global_response: Arc::new(global_response),
57        }
58    }
59
60    fn global_request_filters(&self) -> &[String] {
61        &self.global_request
62    }
63
64    fn global_response_filters(&self) -> &[String] {
65        &self.global_response
66    }
67
68    pub async fn run_request_filters(
69        &self,
70        session: &mut Session,
71        req_ctx: &mut RequestContext,
72    ) -> Result<MiddlewareResult, CardinalError> {
73        let filter_container = self.context.get::<PluginContainer>().await?;
74        let mut resp_headers = HashMap::new();
75
76        for filter in self.global_request_filters() {
77            let run = filter_container
78                .run_request_filter(filter, session, req_ctx, self.context.clone())
79                .await?;
80
81            match run {
82                MiddlewareResult::Continue(middleware_resp_headers) => {
83                    resp_headers.extend(middleware_resp_headers)
84                }
85                MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
86            }
87        }
88
89        let backend = req_ctx.backend.clone(); // Cheap clone
90        let inbound_middleware = backend.get_inbound_middleware();
91        for middleware in inbound_middleware {
92            let run = filter_container
93                .run_request_filter(&middleware.name, session, req_ctx, self.context.clone())
94                .await?;
95
96            match run {
97                MiddlewareResult::Continue(middleware_resp_headers) => {
98                    resp_headers.extend(middleware_resp_headers)
99                }
100                MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
101            }
102        }
103
104        Ok(MiddlewareResult::Continue(resp_headers))
105    }
106
107    pub async fn run_response_filters(
108        &self,
109        session: &mut Session,
110        req_ctx: &mut RequestContext,
111        response: &mut ResponseHeader,
112    ) {
113        let filter_container = self.context.get::<PluginContainer>().await.unwrap();
114
115        for filter in self.global_response_filters() {
116            filter_container
117                .run_response_filter(filter, session, req_ctx, response, self.context.clone())
118                .await;
119        }
120
121        let backend = req_ctx.backend.clone(); // Cheap clone
122        let outbound_middleware = backend.get_outbound_middleware();
123        for middleware in outbound_middleware {
124            let middleware_name = &middleware.name;
125            filter_container
126                .run_response_filter(
127                    middleware_name,
128                    session,
129                    req_ctx,
130                    response,
131                    self.context.clone(),
132                )
133                .await;
134        }
135    }
136}