1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::request_context::RequestContext;
3use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
4use cardinal_base::context::CardinalContext;
5use cardinal_base::provider::Provider;
6use cardinal_config::Plugin;
7use cardinal_errors::CardinalError;
8use cardinal_wasm_plugins::host::{HostFunctionBuilder, HostImportHandle};
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{host_import_from_builder, ExecutionPhase, WasmRunner};
11use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
12use cardinal_wasm_plugins::{ResponseState, SharedExecutionContext};
13use pingora::http::ResponseHeader;
14use pingora::prelude::Session;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{error, warn};
18
19pub enum PluginBuiltInType {
20 Inbound(Arc<DynRequestMiddleware>),
21 Outbound(Arc<DynResponseMiddleware>),
22}
23
24pub enum PluginHandler {
25 Builtin(PluginBuiltInType),
26 Wasm(Arc<WasmPlugin>),
27}
28
29pub struct PluginContainer {
30 plugins: HashMap<String, Arc<PluginHandler>>,
31 host_imports: Vec<HostImportHandle>,
32}
33
34impl PluginContainer {
35 pub fn new() -> Self {
36 Self {
37 plugins: HashMap::from_iter(Self::builtin_plugins()),
38 host_imports: Vec::new(),
39 }
40 }
41
42 pub fn new_empty() -> Self {
43 Self {
44 plugins: HashMap::new(),
45 host_imports: Vec::new(),
46 }
47 }
48
49 pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
50 vec![(
51 "RestrictedRouteMiddleware".to_string(),
52 Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
53 Arc::new(RestrictedRouteMiddleware),
54 ))),
55 )]
56 }
57
58 pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
59 self.plugins.insert(name, Arc::new(plugin));
60 }
61
62 pub fn remove_plugin(&mut self, name: &str) {
63 self.plugins.remove(name);
64 }
65
66 pub fn add_host_function<F>(
67 &mut self,
68 namespace: impl Into<String>,
69 name: impl Into<String>,
70 builder: F,
71 ) where
72 F: Fn(&mut Store, &FunctionEnv<SharedExecutionContext>) -> Function + Send + Sync + 'static,
73 {
74 let builder: HostFunctionBuilder = Arc::new(builder);
75 let import = host_import_from_builder(namespace, name, builder);
76 self.host_imports.push(import);
77 }
78
79 pub fn extend_host_functions<I>(&mut self, functions: I)
80 where
81 I: IntoIterator<Item = HostImportHandle>,
82 {
83 self.host_imports.extend(functions);
84 }
85
86 fn host_imports(&self) -> Option<&[HostImportHandle]> {
87 if self.host_imports.is_empty() {
88 None
89 } else {
90 Some(&self.host_imports)
91 }
92 }
93
94 pub async fn run_request_filter(
95 &self,
96 name: &str,
97 session: &mut Session,
98 req_ctx: &mut RequestContext,
99 ) -> Result<MiddlewareResult, CardinalError> {
100 let plugin = self
101 .plugins
102 .get(name)
103 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
104
105 match plugin.as_ref() {
106 PluginHandler::Builtin(builtin) => match builtin {
107 PluginBuiltInType::Inbound(filter) => {
108 filter
109 .on_request(session, req_ctx, req_ctx.cardinal_context.clone())
110 .await
111 }
112 PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
113 "The filter {name} is not a request filter"
114 ))),
115 },
116 PluginHandler::Wasm(wasm) => {
117 let runner = WasmRunner::new(wasm, ExecutionPhase::Inbound, self.host_imports());
118
119 let exec = runner.run(req_ctx.shared_context())?;
120 let should_continue = exec.should_continue;
121
122 let (header_updates, response_snapshot) = {
123 let guard = exec.execution_context.read();
124 let request_headers: Vec<(String, String)> = guard
125 .request()
126 .headers()
127 .iter()
128 .filter_map(|(key, value)| {
129 value
130 .to_str()
131 .ok()
132 .map(|v| (key.as_str().to_string(), v.to_string()))
133 })
134 .collect();
135
136 let response_state = guard.response().clone();
137 (request_headers, response_state)
138 };
139
140 if !header_updates.is_empty() {
141 for (key, val) in header_updates {
142 let _ = session.req_header_mut().insert_header(key, val);
143 }
144 }
145
146 if !should_continue || response_snapshot.status_override().is_some() {
147 let header_response = Self::build_response_header(&response_snapshot);
148 let _ = session
149 .write_response_header(Box::new(header_response), false)
150 .await;
151 let _ = session.respond_error(response_snapshot.status()).await;
152 Ok(MiddlewareResult::Responded)
153 } else {
154 let headers: HashMap<String, String> = response_snapshot
155 .headers()
156 .iter()
157 .filter_map(|(key, value)| {
158 value
159 .to_str()
160 .ok()
161 .map(|v| (key.as_str().to_string(), v.to_string()))
162 })
163 .collect();
164 Ok(MiddlewareResult::Continue(headers))
165 }
166 }
167 }
168 }
169
170 pub async fn run_response_filter(
171 &self,
172 name: &str,
173 session: &mut Session,
174 req_ctx: &mut RequestContext,
175 response: &mut pingora::http::ResponseHeader,
176 ) {
177 let plugin = self
178 .plugins
179 .get(name)
180 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
181
182 if let Ok(plugin) = plugin {
183 match plugin.as_ref() {
184 PluginHandler::Builtin(builtin) => match builtin {
185 PluginBuiltInType::Inbound(_) => {
186 error!("The filter {name} is not a response filter");
187 }
188 PluginBuiltInType::Outbound(filter) => {
189 filter
190 .on_response(
191 session,
192 req_ctx,
193 response,
194 req_ctx.cardinal_context.clone(),
195 )
196 .await
197 }
198 },
199 PluginHandler::Wasm(wasm) => {
200 let runner =
201 WasmRunner::new(wasm, ExecutionPhase::Outbound, self.host_imports());
202
203 match runner.run(req_ctx.shared_context()) {
204 Ok(exec) => {
205 let snapshot = {
206 let guard = exec.execution_context.read();
207 guard.response().clone()
208 };
209
210 for (key, val) in snapshot.headers().iter() {
211 let _ = response.insert_header(key.clone(), val.clone());
212 }
213
214 if let Some(status) = snapshot.status_override() {
215 let _ = response.set_status(status);
216 }
217 }
218 Err(e) => {
219 error!("Failed to run plugin {}: {}", name, e);
220 }
221 }
222 }
223 }
224 }
225 }
226
227 fn build_response_header(response: &ResponseState) -> ResponseHeader {
228 let mut header = ResponseHeader::build(response.status(), None)
229 .expect("failed to build response header");
230
231 for (key, value) in response.headers().iter() {
232 let _ = header.insert_header(key.clone(), value.clone());
233 }
234
235 header
236 }
237}
238
239impl Default for PluginContainer {
240 fn default() -> Self {
241 Self::new()
242 }
243}
244
245#[async_trait::async_trait]
246impl Provider for PluginContainer {
247 async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
248 let preloaded_plugins = ctx.config.plugins.clone();
249 let mut plugin_container = PluginContainer::new();
250
251 for plugin in preloaded_plugins {
252 let plugin_name = plugin.name();
253 let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
254
255 if plugin_exists {
256 warn!("Plugin {} already exists, skipping", plugin_name);
257 continue;
258 }
259
260 match plugin {
261 Plugin::Builtin(_) => continue,
262 Plugin::Wasm(wasm_config) => {
263 let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
264 CardinalError::Other(format!(
265 "Failed to load plugin {}: {}",
266 wasm_config.name, e
267 ))
268 })?;
269 plugin_container.plugins.insert(
270 wasm_config.name.clone(),
271 Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
272 );
273 }
274 }
275 }
276
277 Ok(plugin_container)
278 }
279}