cardinal_plugins/runner/
mod.rs

1use crate::container::PluginContainer;
2use async_trait::async_trait;
3use cardinal_base::context::CardinalContext;
4use cardinal_base::destinations::container::DestinationWrapper;
5use cardinal_errors::CardinalError;
6use pingora::http::ResponseHeader;
7use pingora::proxy::Session;
8use std::collections::HashMap;
9use std::sync::Arc;
10
11#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum MiddlewareResult {
13    Continue(HashMap<String, String>),
14    Responded,
15}
16
17#[async_trait]
18pub trait RequestMiddleware: Send + Sync + 'static {
19    async fn on_request(
20        &self,
21        session: &mut Session,
22        backend: Arc<DestinationWrapper>,
23        cardinal: Arc<CardinalContext>,
24    ) -> Result<MiddlewareResult, CardinalError>;
25}
26
27#[async_trait]
28pub trait ResponseMiddleware: Send + Sync + 'static {
29    async fn on_response(
30        &self,
31        session: &mut Session,
32        backend: Arc<DestinationWrapper>,
33        response: &mut ResponseHeader,
34        cardinal: Arc<CardinalContext>,
35    );
36}
37
38pub type DynRequestMiddleware = dyn RequestMiddleware + Send + Sync + 'static;
39pub type DynResponseMiddleware = dyn ResponseMiddleware + Send + Sync + 'static;
40
41#[derive(Clone)]
42pub struct PluginRunner {
43    context: Arc<CardinalContext>,
44    global_request: Arc<Vec<String>>,
45    global_response: Arc<Vec<String>>,
46}
47
48impl PluginRunner {
49    pub fn new(context: Arc<CardinalContext>) -> Self {
50        let global_request = context.config.server.global_request_middleware.clone();
51        let global_response = context.config.server.global_response_middleware.clone();
52
53        Self {
54            context,
55            global_request: Arc::new(global_request),
56            global_response: Arc::new(global_response),
57        }
58    }
59
60    fn global_request_filters(&self) -> &[String] {
61        &self.global_request
62    }
63
64    fn global_response_filters(&self) -> &[String] {
65        &self.global_response
66    }
67
68    pub async fn run_request_filters(
69        &self,
70        session: &mut Session,
71        backend: Arc<DestinationWrapper>,
72    ) -> Result<MiddlewareResult, CardinalError> {
73        let filter_container = self.context.get::<PluginContainer>().await?;
74        let mut resp_headers = HashMap::new();
75
76        for filter in self.global_request_filters() {
77            let run = filter_container
78                .run_request_filter(filter, session, backend.clone(), self.context.clone())
79                .await?;
80
81            match run {
82                MiddlewareResult::Continue(middleware_resp_headers) => {
83                    resp_headers.extend(middleware_resp_headers)
84                }
85                MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
86            }
87        }
88
89        let inbound_middleware = backend.get_inbound_middleware();
90        for middleware in inbound_middleware {
91            let run = filter_container
92                .run_request_filter(
93                    &middleware.name,
94                    session,
95                    backend.clone(),
96                    self.context.clone(),
97                )
98                .await?;
99
100            match run {
101                MiddlewareResult::Continue(middleware_resp_headers) => {
102                    resp_headers.extend(middleware_resp_headers)
103                }
104                MiddlewareResult::Responded => return Ok(MiddlewareResult::Responded),
105            }
106        }
107
108        Ok(MiddlewareResult::Continue(resp_headers))
109    }
110
111    pub async fn run_response_filters(
112        &self,
113        session: &mut Session,
114        backend: Arc<DestinationWrapper>,
115        response: &mut ResponseHeader,
116    ) {
117        let filter_container = self.context.get::<PluginContainer>().await.unwrap();
118
119        for filter in self.global_response_filters() {
120            filter_container
121                .run_response_filter(
122                    filter,
123                    session,
124                    backend.clone(),
125                    response,
126                    self.context.clone(),
127                )
128                .await;
129        }
130
131        let outbound_middleware = backend.get_outbound_middleware();
132        for middleware in outbound_middleware {
133            let middleware_name = &middleware.name;
134            filter_container
135                .run_response_filter(
136                    middleware_name,
137                    session,
138                    backend.clone(),
139                    response,
140                    self.context.clone(),
141                )
142                .await;
143        }
144    }
145}