cardinal_plugins/
container.rs

1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
3use crate::utils::parse_query_string_multi;
4use cardinal_base::context::CardinalContext;
5use cardinal_base::destinations::container::DestinationWrapper;
6use cardinal_base::provider::Provider;
7use cardinal_config::Plugin;
8use cardinal_errors::CardinalError;
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{HostFunctionBuilder, HostFunctionMap, WasmRunner};
11use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
12use cardinal_wasm_plugins::{ExecutionContext, ResponseState};
13use pingora::http::ResponseHeader;
14use pingora::prelude::Session;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{error, warn};
18
19pub enum PluginBuiltInType {
20    Inbound(Arc<DynRequestMiddleware>),
21    Outbound(Arc<DynResponseMiddleware>),
22}
23
24pub enum PluginHandler {
25    Builtin(PluginBuiltInType),
26    Wasm(Arc<WasmPlugin>),
27}
28
29pub struct PluginContainer {
30    plugins: HashMap<String, Arc<PluginHandler>>,
31    host_imports: HostFunctionMap,
32}
33
34impl PluginContainer {
35    pub fn new() -> Self {
36        Self {
37            plugins: HashMap::from_iter(Self::builtin_plugins()),
38            host_imports: HashMap::new(),
39        }
40    }
41
42    pub fn new_empty() -> Self {
43        Self {
44            plugins: HashMap::new(),
45            host_imports: HashMap::new(),
46        }
47    }
48
49    pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
50        vec![(
51            "RestrictedRouteMiddleware".to_string(),
52            Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
53                Arc::new(RestrictedRouteMiddleware),
54            ))),
55        )]
56    }
57
58    pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
59        self.plugins.insert(name, Arc::new(plugin));
60    }
61
62    pub fn remove_plugin(&mut self, name: &str) {
63        self.plugins.remove(name);
64    }
65
66    pub fn add_host_function<F>(
67        &mut self,
68        namespace: impl Into<String>,
69        name: impl Into<String>,
70        builder: F,
71    ) where
72        F: Fn(&mut Store, &FunctionEnv<ExecutionContext>) -> Function + Send + Sync + 'static,
73    {
74        let ns = namespace.into();
75        let host_entry = self.host_imports.entry(ns).or_default();
76        let builder: HostFunctionBuilder = Arc::new(builder);
77        host_entry.push((name.into(), builder));
78    }
79
80    pub fn extend_host_functions<I, S>(&mut self, namespace: S, functions: I)
81    where
82        I: IntoIterator<Item = (String, HostFunctionBuilder)>,
83        S: Into<String>,
84    {
85        let ns = namespace.into();
86        let host_entry = self.host_imports.entry(ns).or_default();
87        host_entry.extend(functions);
88    }
89
90    fn host_imports(&self) -> Option<&HostFunctionMap> {
91        if self.host_imports.is_empty() {
92            None
93        } else {
94            Some(&self.host_imports)
95        }
96    }
97
98    pub async fn run_request_filter(
99        &self,
100        name: &str,
101        session: &mut Session,
102        backend: Arc<DestinationWrapper>,
103        ctx: Arc<CardinalContext>,
104    ) -> Result<MiddlewareResult, CardinalError> {
105        let plugin = self
106            .plugins
107            .get(name)
108            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
109
110        match plugin.as_ref() {
111            PluginHandler::Builtin(builtin) => match builtin {
112                PluginBuiltInType::Inbound(filter) => {
113                    filter.on_request(session, backend, ctx).await
114                }
115                PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
116                    "The filter {name} is not a request filter"
117                ))),
118            },
119            PluginHandler::Wasm(wasm) => {
120                let get_req_headers = session
121                    .req_header()
122                    .headers
123                    .iter()
124                    .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
125                    .collect();
126
127                let query =
128                    parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
129
130                {
131                    let runner = WasmRunner::new(wasm, self.host_imports());
132
133                    let inbound_ctx = ExecutionContext::from_parts(
134                        get_req_headers,
135                        query,
136                        None,
137                        ResponseState::with_default_status(200),
138                    );
139
140                    let mut exec = runner.run(inbound_ctx)?;
141
142                    if !exec.execution_context.req_headers().is_empty() {
143                        for (key, val) in exec.execution_context.req_headers() {
144                            let _ = session.req_header_mut().insert_header(key.to_string(), val);
145                        }
146                    }
147
148                    if exec
149                        .execution_context
150                        .response()
151                        .status_override()
152                        .is_some()
153                    {
154                        exec.should_continue = false;
155                    }
156
157                    if exec.should_continue {
158                        Ok(MiddlewareResult::Continue(
159                            exec.execution_context.response().headers().clone(),
160                        ))
161                    } else {
162                        let response_state = exec.execution_context.response();
163                        let header_response = Self::build_response_header(response_state);
164                        let _ = session
165                            .write_response_header(Box::new(header_response), false)
166                            .await;
167                        let _ = session.respond_error(response_state.status()).await;
168                        Ok(MiddlewareResult::Responded)
169                    }
170                }
171            }
172        }
173    }
174
175    pub async fn run_response_filter(
176        &self,
177        name: &str,
178        session: &mut Session,
179        backend: Arc<DestinationWrapper>,
180        response: &mut pingora::http::ResponseHeader,
181        ctx: Arc<CardinalContext>,
182    ) {
183        let plugin = self
184            .plugins
185            .get(name)
186            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
187
188        if let Ok(plugin) = plugin {
189            match plugin.as_ref() {
190                PluginHandler::Builtin(builtin) => match builtin {
191                    PluginBuiltInType::Inbound(_) => {
192                        error!("The filter {name} is not a response filter");
193                    }
194                    PluginBuiltInType::Outbound(filter) => {
195                        filter.on_response(session, backend, response, ctx).await
196                    }
197                },
198                PluginHandler::Wasm(wasm) => {
199                    let get_req_headers = session
200                        .req_header()
201                        .headers
202                        .iter()
203                        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
204                        .collect();
205                    let query =
206                        parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
207
208                    {
209                        let runner = WasmRunner::new(wasm, self.host_imports());
210
211                        let outbound_ctx = ExecutionContext::from_parts(
212                            get_req_headers,
213                            query,
214                            None,
215                            ResponseState::default(),
216                        );
217
218                        let exec = runner.run(outbound_ctx);
219
220                        match &exec {
221                            Ok(ex) => {
222                                let response_state = ex.execution_context.response();
223
224                                for (key, val) in response_state.headers() {
225                                    let _ =
226                                        response.insert_header(key.to_string(), val.to_string());
227                                }
228
229                                if let Some(status) = response_state.status_override() {
230                                    let _ = response.set_status(status);
231                                }
232                            }
233                            Err(e) => {
234                                error!("Failed to run plugin {}: {}", name, e);
235                            }
236                        }
237                    }
238                }
239            }
240        }
241    }
242
243    fn build_response_header(response: &ResponseState) -> ResponseHeader {
244        let mut header = ResponseHeader::build(response.status(), None)
245            .expect("failed to build response header");
246
247        for (key, value) in response.headers() {
248            let _ = header.insert_header(key.to_string(), value.to_string());
249        }
250
251        header
252    }
253}
254
255impl Default for PluginContainer {
256    fn default() -> Self {
257        Self::new()
258    }
259}
260
261#[async_trait::async_trait]
262impl Provider for PluginContainer {
263    async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
264        let preloaded_plugins = ctx.config.plugins.clone();
265        let mut plugin_container = PluginContainer::new();
266
267        for plugin in preloaded_plugins {
268            let plugin_name = plugin.name();
269            let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
270
271            if plugin_exists {
272                warn!("Plugin {} already exists, skipping", plugin_name);
273                continue;
274            }
275
276            match plugin {
277                Plugin::Builtin(_) => continue,
278                Plugin::Wasm(wasm_config) => {
279                    let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
280                        CardinalError::Other(format!(
281                            "Failed to load plugin {}: {}",
282                            wasm_config.name, e
283                        ))
284                    })?;
285                    plugin_container.plugins.insert(
286                        wasm_config.name.clone(),
287                        Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
288                    );
289                }
290            }
291        }
292
293        Ok(plugin_container)
294    }
295}