cardinal_plugins/runner/
mod.rs

1use crate::container::PluginContainer;
2use async_trait::async_trait;
3use cardinal_base::context::CardinalContext;
4use cardinal_base::destinations::container::DestinationWrapper;
5use cardinal_errors::CardinalError;
6use pingora::http::ResponseHeader;
7use pingora::proxy::Session;
8use std::sync::Arc;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum MiddlewareResult {
12    Continue,
13    Responded,
14}
15
16#[async_trait]
17pub trait RequestMiddleware: Send + Sync + 'static {
18    async fn on_request(
19        &self,
20        session: &mut Session,
21        backend: Arc<DestinationWrapper>,
22        cardinal: Arc<CardinalContext>,
23    ) -> Result<MiddlewareResult, CardinalError>;
24}
25
26#[async_trait]
27pub trait ResponseMiddleware: Send + Sync + 'static {
28    async fn on_response(
29        &self,
30        session: &mut Session,
31        backend: Arc<DestinationWrapper>,
32        response: &mut ResponseHeader,
33        cardinal: Arc<CardinalContext>,
34    );
35}
36
37pub type DynRequestMiddleware = dyn RequestMiddleware + Send + Sync + 'static;
38pub type DynResponseMiddleware = dyn ResponseMiddleware + Send + Sync + 'static;
39
40#[derive(Clone)]
41pub struct PluginRunner {
42    context: Arc<CardinalContext>,
43    global_request: Arc<Vec<String>>,
44    global_response: Arc<Vec<String>>,
45}
46
47impl PluginRunner {
48    pub fn new(context: Arc<CardinalContext>) -> Self {
49        let global_request = context.config.server.global_request_middleware.clone();
50        let global_response = context.config.server.global_response_middleware.clone();
51
52        Self {
53            context,
54            global_request: Arc::new(global_request),
55            global_response: Arc::new(global_response),
56        }
57    }
58
59    fn global_request_filters(&self) -> &[String] {
60        &self.global_request
61    }
62
63    fn global_response_filters(&self) -> &[String] {
64        &self.global_response
65    }
66
67    pub async fn run_request_filters(
68        &self,
69        session: &mut Session,
70        backend: Arc<DestinationWrapper>,
71    ) -> Result<MiddlewareResult, CardinalError> {
72        let filter_container = self.context.get::<PluginContainer>().await?;
73
74        for filter in self.global_request_filters() {
75            let run = filter_container
76                .run_request_filter(filter, session, backend.clone(), self.context.clone())
77                .await?;
78            if let MiddlewareResult::Responded = run {
79                return Ok(MiddlewareResult::Responded);
80            }
81        }
82
83        let inbound_middleware = backend.get_inbound_middleware();
84        for middleware in inbound_middleware {
85            let run = filter_container
86                .run_request_filter(
87                    &middleware.name,
88                    session,
89                    backend.clone(),
90                    self.context.clone(),
91                )
92                .await?;
93            if let MiddlewareResult::Responded = run {
94                return Ok(MiddlewareResult::Responded);
95            }
96        }
97
98        Ok(MiddlewareResult::Continue)
99    }
100
101    pub async fn run_response_filters(
102        &self,
103        session: &mut Session,
104        backend: Arc<DestinationWrapper>,
105        response: &mut ResponseHeader,
106    ) {
107        let filter_container = self.context.get::<PluginContainer>().await.unwrap();
108
109        for filter in self.global_response_filters() {
110            filter_container
111                .run_response_filter(
112                    filter,
113                    session,
114                    backend.clone(),
115                    response,
116                    self.context.clone(),
117                )
118                .await;
119        }
120
121        let outbound_middleware = backend.get_outbound_middleware();
122        for middleware in outbound_middleware {
123            let middleware_name = &middleware.name;
124            filter_container
125                .run_response_filter(
126                    middleware_name,
127                    session,
128                    backend.clone(),
129                    response,
130                    self.context.clone(),
131                )
132                .await;
133        }
134    }
135}