1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::request_context::RequestContext;
3use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
4use cardinal_base::context::CardinalContext;
5use cardinal_base::provider::Provider;
6use cardinal_config::Plugin;
7use cardinal_errors::CardinalError;
8use cardinal_wasm_plugins::host::{HostFunctionBuilder, HostImportHandle};
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{host_import_from_builder, ExecutionPhase, WasmRunner};
11use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
12use cardinal_wasm_plugins::{ResponseState, SharedExecutionContext};
13use pingora::http::ResponseHeader;
14use pingora::prelude::Session;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{error, warn};
18
19pub enum PluginBuiltInType {
20 Inbound(Arc<DynRequestMiddleware>),
21 Outbound(Arc<DynResponseMiddleware>),
22}
23
24pub enum PluginHandler {
25 Builtin(PluginBuiltInType),
26 Wasm(Arc<WasmPlugin>),
27}
28
29pub struct PluginContainer {
30 plugins: HashMap<String, Arc<PluginHandler>>,
31 host_imports: Vec<HostImportHandle>,
32}
33
34impl PluginContainer {
35 pub fn new() -> Self {
36 Self {
37 plugins: HashMap::from_iter(Self::builtin_plugins()),
38 host_imports: Vec::new(),
39 }
40 }
41
42 pub fn new_empty() -> Self {
43 Self {
44 plugins: HashMap::new(),
45 host_imports: Vec::new(),
46 }
47 }
48
49 pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
50 vec![(
51 "RestrictedRouteMiddleware".to_string(),
52 Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
53 Arc::new(RestrictedRouteMiddleware),
54 ))),
55 )]
56 }
57
58 pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
59 self.plugins.insert(name, Arc::new(plugin));
60 }
61
62 pub fn remove_plugin(&mut self, name: &str) {
63 self.plugins.remove(name);
64 }
65
66 pub fn add_host_function<F>(
67 &mut self,
68 namespace: impl Into<String>,
69 name: impl Into<String>,
70 builder: F,
71 ) where
72 F: Fn(&mut Store, &FunctionEnv<SharedExecutionContext>) -> Function + Send + Sync + 'static,
73 {
74 let builder: HostFunctionBuilder = Arc::new(builder);
75 let import = host_import_from_builder(namespace, name, builder);
76 self.host_imports.push(import);
77 }
78
79 pub fn extend_host_functions<I>(&mut self, functions: I)
80 where
81 I: IntoIterator<Item = HostImportHandle>,
82 {
83 self.host_imports.extend(functions);
84 }
85
86 fn host_imports(&self) -> Option<&[HostImportHandle]> {
87 if self.host_imports.is_empty() {
88 None
89 } else {
90 Some(&self.host_imports)
91 }
92 }
93
94 pub async fn run_request_filter(
95 &self,
96 name: &str,
97 session: &mut Session,
98 req_ctx: &mut RequestContext,
99 ) -> Result<MiddlewareResult, CardinalError> {
100 let plugin = self
101 .plugins
102 .get(name)
103 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
104
105 let (result, shared_ctx, force_respond) = match plugin.as_ref() {
106 PluginHandler::Builtin(builtin) => match builtin {
107 PluginBuiltInType::Inbound(filter) => {
108 let result = filter
109 .on_request(session, req_ctx, req_ctx.cardinal_context.clone())
110 .await?;
111
112 Ok((result, req_ctx.shared_context(), false))
113 }
114 PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
115 "The filter {name} is not a request filter"
116 ))),
117 },
118 PluginHandler::Wasm(wasm) => {
119 let runner = WasmRunner::new(wasm, ExecutionPhase::Inbound, self.host_imports());
120
121 let exec = runner.run(req_ctx.shared_context())?;
122 let result = if exec.should_continue {
123 MiddlewareResult::Continue(HashMap::new())
124 } else {
125 MiddlewareResult::Responded
126 };
127
128 Ok((
129 result,
130 exec.execution_context.clone(),
131 !exec.should_continue,
132 ))
133 }
134 }?;
135
136 let (header_updates, response_snapshot) = Self::snapshot_execution_context(&shared_ctx);
137
138 Self::apply_request_header_updates(session, header_updates);
139
140 Ok(Self::finalize_request_result(result, response_snapshot, session, force_respond).await)
141 }
142
143 fn snapshot_execution_context(
144 shared_ctx: &SharedExecutionContext,
145 ) -> (Vec<(String, String)>, ResponseState) {
146 let guard = shared_ctx.read();
147
148 let request_headers: Vec<(String, String)> = guard
149 .request()
150 .headers()
151 .iter()
152 .filter_map(|(key, value)| {
153 value
154 .to_str()
155 .ok()
156 .map(|v| (key.as_str().to_string(), v.to_string()))
157 })
158 .collect();
159
160 let response_state = guard.response().clone();
161 (request_headers, response_state)
162 }
163
164 fn apply_request_header_updates(session: &mut Session, header_updates: Vec<(String, String)>) {
165 if header_updates.is_empty() {
166 return;
167 }
168
169 for (key, val) in header_updates {
170 let _ = session.req_header_mut().insert_header(key, val);
171 }
172 }
173
174 async fn finalize_request_result(
175 result: MiddlewareResult,
176 response_snapshot: ResponseState,
177 session: &mut Session,
178 force_respond: bool,
179 ) -> MiddlewareResult {
180 if force_respond || response_snapshot.status_override().is_some() {
181 let state = Self::build_response_header(&response_snapshot);
182 return Self::respond_from_response_state(state, response_snapshot.status(), session)
183 .await;
184 }
185
186 match result {
187 MiddlewareResult::Responded => MiddlewareResult::Responded,
188 MiddlewareResult::Continue(mut headers) => {
189 let context_headers: HashMap<String, String> = response_snapshot
190 .headers()
191 .iter()
192 .filter_map(|(key, value)| {
193 value
194 .to_str()
195 .ok()
196 .map(|v| (key.as_str().to_string(), v.to_string()))
197 })
198 .collect();
199
200 headers.extend(context_headers);
201 MiddlewareResult::Continue(headers)
202 }
203 }
204 }
205
206 pub async fn run_response_filter(
207 &self,
208 name: &str,
209 session: &mut Session,
210 req_ctx: &mut RequestContext,
211 response: &mut pingora::http::ResponseHeader,
212 ) {
213 let plugin = self
214 .plugins
215 .get(name)
216 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
217
218 if let Ok(plugin) = plugin {
219 match plugin.as_ref() {
220 PluginHandler::Builtin(builtin) => match builtin {
221 PluginBuiltInType::Inbound(_) => {
222 error!("The filter {name} is not a response filter");
223 }
224 PluginBuiltInType::Outbound(filter) => {
225 filter
226 .on_response(
227 session,
228 req_ctx,
229 response,
230 req_ctx.cardinal_context.clone(),
231 )
232 .await
233 }
234 },
235 PluginHandler::Wasm(wasm) => {
236 let runner =
237 WasmRunner::new(wasm, ExecutionPhase::Outbound, self.host_imports());
238
239 match runner.run(req_ctx.shared_context()) {
240 Ok(exec) => {
241 let snapshot = {
242 let guard = exec.execution_context.read();
243 guard.response().clone()
244 };
245
246 for (key, val) in snapshot.headers().iter() {
247 let _ = response.insert_header(key.clone(), val.clone());
248 }
249
250 if let Some(status) = snapshot.status_override() {
251 let _ = response.set_status(status);
252 }
253 }
254 Err(e) => {
255 error!("Failed to run plugin {}: {}", name, e);
256 }
257 }
258 }
259 }
260 }
261 }
262
263 pub fn build_response_header(response: &ResponseState) -> ResponseHeader {
264 let mut header = ResponseHeader::build(response.status(), None)
265 .expect("failed to build response header");
266
267 for (key, value) in response.headers().iter() {
268 let _ = header.insert_header(key.clone(), value.clone());
269 }
270
271 header
272 }
273
274 pub async fn respond_from_response_state(
275 response_header: ResponseHeader,
276 status: u16,
277 session: &mut Session,
278 ) -> MiddlewareResult {
279 let _ = session
280 .write_response_header(Box::new(response_header), false)
281 .await;
282 let _ = session.respond_error(status).await;
283
284 MiddlewareResult::Responded
285 }
286}
287
288impl Default for PluginContainer {
289 fn default() -> Self {
290 Self::new()
291 }
292}
293
294#[async_trait::async_trait]
295impl Provider for PluginContainer {
296 async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
297 let preloaded_plugins = ctx.config.plugins.clone();
298 let mut plugin_container = PluginContainer::new();
299
300 for plugin in preloaded_plugins {
301 let plugin_name = plugin.name();
302 let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
303
304 if plugin_exists {
305 warn!("Plugin {} already exists, skipping", plugin_name);
306 continue;
307 }
308
309 match plugin {
310 Plugin::Builtin(_) => continue,
311 Plugin::Wasm(wasm_config) => {
312 let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
313 CardinalError::Other(format!(
314 "Failed to load plugin {}: {}",
315 wasm_config.name, e
316 ))
317 })?;
318 plugin_container.plugins.insert(
319 wasm_config.name.clone(),
320 Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
321 );
322 }
323 }
324 }
325
326 Ok(plugin_container)
327 }
328}