cardinal_plugins/
container.rs

1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::request_context::RequestContext;
3use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
4use cardinal_base::context::CardinalContext;
5use cardinal_base::provider::Provider;
6use cardinal_config::Plugin;
7use cardinal_errors::CardinalError;
8use cardinal_wasm_plugins::host::{HostFunctionBuilder, HostImportHandle};
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{host_import_from_builder, ExecutionPhase, WasmRunner};
11use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
12use cardinal_wasm_plugins::{ResponseState, SharedExecutionContext};
13use pingora::http::ResponseHeader;
14use pingora::prelude::Session;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{error, warn};
18
19pub enum PluginBuiltInType {
20    Inbound(Arc<DynRequestMiddleware>),
21    Outbound(Arc<DynResponseMiddleware>),
22}
23
24pub enum PluginHandler {
25    Builtin(PluginBuiltInType),
26    Wasm(Arc<WasmPlugin>),
27}
28
29pub struct PluginContainer {
30    plugins: HashMap<String, Arc<PluginHandler>>,
31    host_imports: Vec<HostImportHandle>,
32}
33
34impl PluginContainer {
35    pub fn new() -> Self {
36        Self {
37            plugins: HashMap::from_iter(Self::builtin_plugins()),
38            host_imports: Vec::new(),
39        }
40    }
41
42    pub fn new_empty() -> Self {
43        Self {
44            plugins: HashMap::new(),
45            host_imports: Vec::new(),
46        }
47    }
48
49    pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
50        vec![(
51            "RestrictedRouteMiddleware".to_string(),
52            Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
53                Arc::new(RestrictedRouteMiddleware),
54            ))),
55        )]
56    }
57
58    pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
59        self.plugins.insert(name, Arc::new(plugin));
60    }
61
62    pub fn remove_plugin(&mut self, name: &str) {
63        self.plugins.remove(name);
64    }
65
66    pub fn add_host_function<F>(
67        &mut self,
68        namespace: impl Into<String>,
69        name: impl Into<String>,
70        builder: F,
71    ) where
72        F: Fn(&mut Store, &FunctionEnv<SharedExecutionContext>) -> Function + Send + Sync + 'static,
73    {
74        let builder: HostFunctionBuilder = Arc::new(builder);
75        let import = host_import_from_builder(namespace, name, builder);
76        self.host_imports.push(import);
77    }
78
79    pub fn extend_host_functions<I>(&mut self, functions: I)
80    where
81        I: IntoIterator<Item = HostImportHandle>,
82    {
83        self.host_imports.extend(functions);
84    }
85
86    fn host_imports(&self) -> Option<&[HostImportHandle]> {
87        if self.host_imports.is_empty() {
88            None
89        } else {
90            Some(&self.host_imports)
91        }
92    }
93
94    pub async fn run_request_filter(
95        &self,
96        name: &str,
97        session: &mut Session,
98        req_ctx: &mut RequestContext,
99    ) -> Result<MiddlewareResult, CardinalError> {
100        let plugin = self
101            .plugins
102            .get(name)
103            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
104
105        let (result, shared_ctx, force_respond) = match plugin.as_ref() {
106            PluginHandler::Builtin(builtin) => match builtin {
107                PluginBuiltInType::Inbound(filter) => {
108                    let result = filter
109                        .on_request(session, req_ctx, req_ctx.cardinal_context.clone())
110                        .await?;
111
112                    Ok((result, req_ctx.shared_context(), false))
113                }
114                PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
115                    "The filter {name} is not a request filter"
116                ))),
117            },
118            PluginHandler::Wasm(wasm) => {
119                let runner = WasmRunner::new(wasm, ExecutionPhase::Inbound, self.host_imports());
120
121                let exec = runner.run(req_ctx.shared_context())?;
122                let result = if exec.should_continue {
123                    MiddlewareResult::Continue(HashMap::new())
124                } else {
125                    MiddlewareResult::Responded
126                };
127
128                Ok((
129                    result,
130                    exec.execution_context.clone(),
131                    !exec.should_continue,
132                ))
133            }
134        }?;
135
136        let (header_updates, response_snapshot) = Self::snapshot_execution_context(&shared_ctx);
137
138        Self::apply_request_header_updates(session, header_updates);
139
140        Ok(Self::finalize_request_result(result, response_snapshot, session, force_respond).await)
141    }
142
143    fn snapshot_execution_context(
144        shared_ctx: &SharedExecutionContext,
145    ) -> (Vec<(String, String)>, ResponseState) {
146        let guard = shared_ctx.read();
147
148        let request_headers: Vec<(String, String)> = guard
149            .request()
150            .headers()
151            .iter()
152            .filter_map(|(key, value)| {
153                value
154                    .to_str()
155                    .ok()
156                    .map(|v| (key.as_str().to_string(), v.to_string()))
157            })
158            .collect();
159
160        let response_state = guard.response().clone();
161        (request_headers, response_state)
162    }
163
164    fn apply_request_header_updates(session: &mut Session, header_updates: Vec<(String, String)>) {
165        if header_updates.is_empty() {
166            return;
167        }
168
169        for (key, val) in header_updates {
170            let _ = session.req_header_mut().insert_header(key, val);
171        }
172    }
173
174    async fn finalize_request_result(
175        result: MiddlewareResult,
176        response_snapshot: ResponseState,
177        session: &mut Session,
178        force_respond: bool,
179    ) -> MiddlewareResult {
180        if force_respond || response_snapshot.status_override().is_some() {
181            let state = Self::build_response_header(&response_snapshot);
182            return Self::respond_from_response_state(state, response_snapshot.status(), session)
183                .await;
184        }
185
186        match result {
187            MiddlewareResult::Responded => MiddlewareResult::Responded,
188            MiddlewareResult::Continue(mut headers) => {
189                let context_headers: HashMap<String, String> = response_snapshot
190                    .headers()
191                    .iter()
192                    .filter_map(|(key, value)| {
193                        value
194                            .to_str()
195                            .ok()
196                            .map(|v| (key.as_str().to_string(), v.to_string()))
197                    })
198                    .collect();
199
200                headers.extend(context_headers);
201                MiddlewareResult::Continue(headers)
202            }
203        }
204    }
205
206    pub async fn run_response_filter(
207        &self,
208        name: &str,
209        session: &mut Session,
210        req_ctx: &mut RequestContext,
211        response: &mut pingora::http::ResponseHeader,
212    ) {
213        let plugin = self
214            .plugins
215            .get(name)
216            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
217
218        if let Ok(plugin) = plugin {
219            match plugin.as_ref() {
220                PluginHandler::Builtin(builtin) => match builtin {
221                    PluginBuiltInType::Inbound(_) => {
222                        error!("The filter {name} is not a response filter");
223                    }
224                    PluginBuiltInType::Outbound(filter) => {
225                        filter
226                            .on_response(
227                                session,
228                                req_ctx,
229                                response,
230                                req_ctx.cardinal_context.clone(),
231                            )
232                            .await
233                    }
234                },
235                PluginHandler::Wasm(wasm) => {
236                    let runner =
237                        WasmRunner::new(wasm, ExecutionPhase::Outbound, self.host_imports());
238
239                    match runner.run(req_ctx.shared_context()) {
240                        Ok(exec) => {
241                            let snapshot = {
242                                let guard = exec.execution_context.read();
243                                guard.response().clone()
244                            };
245
246                            for (key, val) in snapshot.headers().iter() {
247                                let _ = response.insert_header(key.clone(), val.clone());
248                            }
249
250                            if let Some(status) = snapshot.status_override() {
251                                let _ = response.set_status(status);
252                            }
253                        }
254                        Err(e) => {
255                            error!("Failed to run plugin {}: {}", name, e);
256                        }
257                    }
258                }
259            }
260        }
261    }
262
263    pub fn build_response_header(response: &ResponseState) -> ResponseHeader {
264        let mut header = ResponseHeader::build(response.status(), None)
265            .expect("failed to build response header");
266
267        for (key, value) in response.headers().iter() {
268            let _ = header.insert_header(key.clone(), value.clone());
269        }
270
271        header
272    }
273
274    pub async fn respond_from_response_state(
275        response_header: ResponseHeader,
276        status: u16,
277        session: &mut Session,
278    ) -> MiddlewareResult {
279        let _ = session
280            .write_response_header(Box::new(response_header), false)
281            .await;
282        let _ = session.respond_error(status).await;
283
284        MiddlewareResult::Responded
285    }
286}
287
288impl Default for PluginContainer {
289    fn default() -> Self {
290        Self::new()
291    }
292}
293
294#[async_trait::async_trait]
295impl Provider for PluginContainer {
296    async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
297        let preloaded_plugins = ctx.config.plugins.clone();
298        let mut plugin_container = PluginContainer::new();
299
300        for plugin in preloaded_plugins {
301            let plugin_name = plugin.name();
302            let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
303
304            if plugin_exists {
305                warn!("Plugin {} already exists, skipping", plugin_name);
306                continue;
307            }
308
309            match plugin {
310                Plugin::Builtin(_) => continue,
311                Plugin::Wasm(wasm_config) => {
312                    let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
313                        CardinalError::Other(format!(
314                            "Failed to load plugin {}: {}",
315                            wasm_config.name, e
316                        ))
317                    })?;
318                    plugin_container.plugins.insert(
319                        wasm_config.name.clone(),
320                        Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
321                    );
322                }
323            }
324        }
325
326        Ok(plugin_container)
327    }
328}