cardinal_plugins/
container.rs

1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::request_context::RequestContext;
3use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
4use cardinal_base::context::CardinalContext;
5use cardinal_base::provider::Provider;
6use cardinal_config::Plugin;
7use cardinal_errors::CardinalError;
8use cardinal_wasm_plugins::host::{HostFunctionBuilder, HostImportHandle};
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{host_import_from_builder, ExecutionPhase, WasmRunner};
11use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
12use cardinal_wasm_plugins::{ResponseState, SharedExecutionContext};
13use pingora::http::ResponseHeader;
14use pingora::prelude::Session;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{error, warn};
18
19pub enum PluginBuiltInType {
20    Inbound(Arc<DynRequestMiddleware>),
21    Outbound(Arc<DynResponseMiddleware>),
22}
23
24pub enum PluginHandler {
25    Builtin(PluginBuiltInType),
26    Wasm(Arc<WasmPlugin>),
27}
28
29pub struct PluginContainer {
30    plugins: HashMap<String, Arc<PluginHandler>>,
31    host_imports: Vec<HostImportHandle>,
32}
33
34impl PluginContainer {
35    pub fn new() -> Self {
36        Self {
37            plugins: HashMap::from_iter(Self::builtin_plugins()),
38            host_imports: Vec::new(),
39        }
40    }
41
42    pub fn new_empty() -> Self {
43        Self {
44            plugins: HashMap::new(),
45            host_imports: Vec::new(),
46        }
47    }
48
49    pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
50        vec![(
51            "RestrictedRouteMiddleware".to_string(),
52            Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
53                Arc::new(RestrictedRouteMiddleware),
54            ))),
55        )]
56    }
57
58    pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
59        self.plugins.insert(name, Arc::new(plugin));
60    }
61
62    pub fn remove_plugin(&mut self, name: &str) {
63        self.plugins.remove(name);
64    }
65
66    pub fn add_host_function<F>(
67        &mut self,
68        namespace: impl Into<String>,
69        name: impl Into<String>,
70        builder: F,
71    ) where
72        F: Fn(&mut Store, &FunctionEnv<SharedExecutionContext>) -> Function + Send + Sync + 'static,
73    {
74        let builder: HostFunctionBuilder = Arc::new(builder);
75        let import = host_import_from_builder(namespace, name, builder);
76        self.host_imports.push(import);
77    }
78
79    pub fn extend_host_functions<I>(&mut self, functions: I)
80    where
81        I: IntoIterator<Item = HostImportHandle>,
82    {
83        self.host_imports.extend(functions);
84    }
85
86    fn host_imports(&self) -> Option<&[HostImportHandle]> {
87        if self.host_imports.is_empty() {
88            None
89        } else {
90            Some(&self.host_imports)
91        }
92    }
93
94    pub async fn run_request_filter(
95        &self,
96        name: &str,
97        session: &mut Session,
98        req_ctx: &mut RequestContext,
99    ) -> Result<MiddlewareResult, CardinalError> {
100        let plugin = self
101            .plugins
102            .get(name)
103            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
104
105        match plugin.as_ref() {
106            PluginHandler::Builtin(builtin) => match builtin {
107                PluginBuiltInType::Inbound(filter) => {
108                    filter
109                        .on_request(session, req_ctx, req_ctx.cardinal_context.clone())
110                        .await
111                }
112                PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
113                    "The filter {name} is not a request filter"
114                ))),
115            },
116            PluginHandler::Wasm(wasm) => {
117                let runner = WasmRunner::new(wasm, ExecutionPhase::Inbound, self.host_imports());
118
119                let exec = runner.run(req_ctx.shared_context())?;
120                let should_continue = exec.should_continue;
121
122                let (header_updates, response_snapshot) = {
123                    let guard = exec.execution_context.read();
124                    let request_headers: Vec<(String, String)> = guard
125                        .request()
126                        .headers()
127                        .iter()
128                        .filter_map(|(key, value)| {
129                            value
130                                .to_str()
131                                .ok()
132                                .map(|v| (key.as_str().to_string(), v.to_string()))
133                        })
134                        .collect();
135
136                    let response_state = guard.response().clone();
137                    (request_headers, response_state)
138                };
139
140                if !header_updates.is_empty() {
141                    for (key, val) in header_updates {
142                        let _ = session.req_header_mut().insert_header(key, val);
143                    }
144                }
145
146                if !should_continue || response_snapshot.status_override().is_some() {
147                    let header_response = Self::build_response_header(&response_snapshot);
148                    let _ = session
149                        .write_response_header(Box::new(header_response), false)
150                        .await;
151                    let _ = session.respond_error(response_snapshot.status()).await;
152                    Ok(MiddlewareResult::Responded)
153                } else {
154                    let headers: HashMap<String, String> = response_snapshot
155                        .headers()
156                        .iter()
157                        .filter_map(|(key, value)| {
158                            value
159                                .to_str()
160                                .ok()
161                                .map(|v| (key.as_str().to_string(), v.to_string()))
162                        })
163                        .collect();
164                    Ok(MiddlewareResult::Continue(headers))
165                }
166            }
167        }
168    }
169
170    pub async fn run_response_filter(
171        &self,
172        name: &str,
173        session: &mut Session,
174        req_ctx: &mut RequestContext,
175        response: &mut pingora::http::ResponseHeader,
176    ) {
177        let plugin = self
178            .plugins
179            .get(name)
180            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
181
182        if let Ok(plugin) = plugin {
183            match plugin.as_ref() {
184                PluginHandler::Builtin(builtin) => match builtin {
185                    PluginBuiltInType::Inbound(_) => {
186                        error!("The filter {name} is not a response filter");
187                    }
188                    PluginBuiltInType::Outbound(filter) => {
189                        filter
190                            .on_response(
191                                session,
192                                req_ctx,
193                                response,
194                                req_ctx.cardinal_context.clone(),
195                            )
196                            .await
197                    }
198                },
199                PluginHandler::Wasm(wasm) => {
200                    let runner =
201                        WasmRunner::new(wasm, ExecutionPhase::Outbound, self.host_imports());
202
203                    match runner.run(req_ctx.shared_context()) {
204                        Ok(exec) => {
205                            let snapshot = {
206                                let guard = exec.execution_context.read();
207                                guard.response().clone()
208                            };
209
210                            for (key, val) in snapshot.headers().iter() {
211                                let _ = response.insert_header(key.clone(), val.clone());
212                            }
213
214                            if let Some(status) = snapshot.status_override() {
215                                let _ = response.set_status(status);
216                            }
217                        }
218                        Err(e) => {
219                            error!("Failed to run plugin {}: {}", name, e);
220                        }
221                    }
222                }
223            }
224        }
225    }
226
227    fn build_response_header(response: &ResponseState) -> ResponseHeader {
228        let mut header = ResponseHeader::build(response.status(), None)
229            .expect("failed to build response header");
230
231        for (key, value) in response.headers().iter() {
232            let _ = header.insert_header(key.clone(), value.clone());
233        }
234
235        header
236    }
237}
238
239impl Default for PluginContainer {
240    fn default() -> Self {
241        Self::new()
242    }
243}
244
245#[async_trait::async_trait]
246impl Provider for PluginContainer {
247    async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
248        let preloaded_plugins = ctx.config.plugins.clone();
249        let mut plugin_container = PluginContainer::new();
250
251        for plugin in preloaded_plugins {
252            let plugin_name = plugin.name();
253            let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
254
255            if plugin_exists {
256                warn!("Plugin {} already exists, skipping", plugin_name);
257                continue;
258            }
259
260            match plugin {
261                Plugin::Builtin(_) => continue,
262                Plugin::Wasm(wasm_config) => {
263                    let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
264                        CardinalError::Other(format!(
265                            "Failed to load plugin {}: {}",
266                            wasm_config.name, e
267                        ))
268                    })?;
269                    plugin_container.plugins.insert(
270                        wasm_config.name.clone(),
271                        Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
272                    );
273                }
274            }
275        }
276
277        Ok(plugin_container)
278    }
279}