cardinal_plugins/
container.rs

1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
3use crate::utils::parse_query_string_multi;
4use cardinal_base::context::CardinalContext;
5use cardinal_base::destinations::container::DestinationWrapper;
6use cardinal_base::provider::Provider;
7use cardinal_config::Plugin;
8use cardinal_errors::CardinalError;
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{ExecutionType, WasmRunner};
11use cardinal_wasm_plugins::{ExecutionContext, ExecutionRequest, ExecutionResponse};
12use pingora::prelude::Session;
13use std::collections::HashMap;
14use std::sync::Arc;
15use tracing::{error, warn};
16
17pub enum PluginBuiltInType {
18    Inbound(Arc<DynRequestMiddleware>),
19    Outbound(Arc<DynResponseMiddleware>),
20}
21
22pub enum PluginHandler {
23    Builtin(PluginBuiltInType),
24    Wasm(Arc<WasmPlugin>),
25}
26
27pub struct PluginContainer {
28    plugins: HashMap<String, Arc<PluginHandler>>,
29}
30
31impl PluginContainer {
32    pub fn new() -> Self {
33        Self {
34            plugins: HashMap::from_iter(Self::builtin_plugins()),
35        }
36    }
37
38    pub fn new_empty() -> Self {
39        Self {
40            plugins: HashMap::new(),
41        }
42    }
43
44    pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
45        vec![(
46            "RestrictedRouteMiddleware".to_string(),
47            Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
48                Arc::new(RestrictedRouteMiddleware),
49            ))),
50        )]
51    }
52
53    pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
54        self.plugins.insert(name, Arc::new(plugin));
55    }
56
57    pub fn remove_plugin(&mut self, name: &str) {
58        self.plugins.remove(name);
59    }
60
61    pub async fn run_request_filter(
62        &self,
63        name: &str,
64        session: &mut Session,
65        backend: Arc<DestinationWrapper>,
66        ctx: Arc<CardinalContext>,
67    ) -> Result<MiddlewareResult, CardinalError> {
68        let plugin = self
69            .plugins
70            .get(name)
71            .ok_or_else(|| CardinalError::Other(format!("Plugin {} does not exist", name)))?;
72
73        match plugin.as_ref() {
74            PluginHandler::Builtin(builtin) => match builtin {
75                PluginBuiltInType::Inbound(filter) => {
76                    filter.on_request(session, backend, ctx).await
77                }
78                PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
79                    "The filter {} is not a request filter",
80                    name
81                ))),
82            },
83            PluginHandler::Wasm(wasm) => {
84                let get_req_headers = session
85                    .req_header()
86                    .headers
87                    .iter()
88                    .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
89                    .collect();
90
91                let query =
92                    parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
93
94                {
95                    let runner = WasmRunner::new(wasm, ExecutionType::Inbound);
96
97                    let inbound_ctx = ExecutionRequest {
98                        query,
99                        memory: None,
100                        req_headers: get_req_headers,
101                        body: None,
102                    };
103
104                    let exec = runner.run(ExecutionContext::Inbound(inbound_ctx))?;
105
106                    if exec.should_continue {
107                        Ok(MiddlewareResult::Continue)
108                    } else {
109                        let _ = session.respond_error(403).await;
110                        Ok(MiddlewareResult::Responded)
111                    }
112                }
113            }
114        }
115    }
116
117    pub async fn run_response_filter(
118        &self,
119        name: &str,
120        session: &mut Session,
121        backend: Arc<DestinationWrapper>,
122        response: &mut pingora::http::ResponseHeader,
123        ctx: Arc<CardinalContext>,
124    ) {
125        let plugin = self
126            .plugins
127            .get(name)
128            .ok_or_else(|| CardinalError::Other(format!("Plugin {} does not exist", name)));
129
130        if let Ok(plugin) = plugin {
131            match plugin.as_ref() {
132                PluginHandler::Builtin(builtin) => match builtin {
133                    PluginBuiltInType::Inbound(_) => {
134                        error!("The filter {} is not a response filter", name);
135                    }
136                    PluginBuiltInType::Outbound(filter) => {
137                        filter.on_response(session, backend, response, ctx).await
138                    }
139                },
140                PluginHandler::Wasm(wasm) => {
141                    let get_req_headers = session
142                        .req_header()
143                        .headers
144                        .iter()
145                        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
146                        .collect();
147                    let query =
148                        parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
149
150                    {
151                        let runner = WasmRunner::new(wasm, ExecutionType::Outbound);
152
153                        let outbound_ctx = ExecutionResponse {
154                            memory: None,
155                            req_headers: get_req_headers,
156                            query,
157                            resp_headers: Default::default(),
158                            status: 0,
159                            body: None,
160                        };
161
162                        let exec = runner.run(ExecutionContext::Outbound(outbound_ctx));
163
164                        match &exec {
165                            Ok(ex) => {
166                                let outbound_resp = ex.execution_context.as_outbound().unwrap();
167                                for (key, val) in &outbound_resp.resp_headers {
168                                    let _ =
169                                        response.insert_header(key.to_string(), val.to_string());
170                                }
171
172                                response.set_status(outbound_resp.status as u16).unwrap();
173                            }
174                            Err(e) => {
175                                error!("Failed to run plugin {}: {}", name, e);
176                            }
177                        }
178                    }
179                }
180            }
181        }
182    }
183}
184
185impl Default for PluginContainer {
186    fn default() -> Self {
187        Self::new()
188    }
189}
190
191#[async_trait::async_trait]
192impl Provider for PluginContainer {
193    async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
194        let preloaded_plugins = ctx.config.plugins.clone();
195        let mut plugin_container = PluginContainer::new();
196
197        for plugin in preloaded_plugins {
198            let plugin_name = plugin.name();
199            let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
200
201            if plugin_exists {
202                warn!("Plugin {} already exists, skipping", plugin_name);
203                continue;
204            }
205
206            match plugin {
207                Plugin::Builtin(_) => continue,
208                Plugin::Wasm(wasm_config) => {
209                    let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
210                        CardinalError::Other(format!(
211                            "Failed to load plugin {}: {}",
212                            wasm_config.name, e
213                        ))
214                    })?;
215                    plugin_container.plugins.insert(
216                        wasm_config.name.clone(),
217                        Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
218                    );
219                }
220            }
221        }
222
223        Ok(plugin_container)
224    }
225}