1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
3use crate::utils::parse_query_string_multi;
4use cardinal_base::context::CardinalContext;
5use cardinal_base::destinations::container::DestinationWrapper;
6use cardinal_base::provider::Provider;
7use cardinal_config::Plugin;
8use cardinal_errors::CardinalError;
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{HostFunctionBuilder, HostFunctionMap, WasmRunner};
11use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
12use cardinal_wasm_plugins::{ExecutionContext, ResponseState};
13use pingora::http::ResponseHeader;
14use pingora::prelude::Session;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{error, warn};
18
19pub enum PluginBuiltInType {
20 Inbound(Arc<DynRequestMiddleware>),
21 Outbound(Arc<DynResponseMiddleware>),
22}
23
24pub enum PluginHandler {
25 Builtin(PluginBuiltInType),
26 Wasm(Arc<WasmPlugin>),
27}
28
29pub struct PluginContainer {
30 plugins: HashMap<String, Arc<PluginHandler>>,
31 host_imports: HostFunctionMap,
32}
33
34impl PluginContainer {
35 pub fn new() -> Self {
36 Self {
37 plugins: HashMap::from_iter(Self::builtin_plugins()),
38 host_imports: HashMap::new(),
39 }
40 }
41
42 pub fn new_empty() -> Self {
43 Self {
44 plugins: HashMap::new(),
45 host_imports: HashMap::new(),
46 }
47 }
48
49 pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
50 vec![(
51 "RestrictedRouteMiddleware".to_string(),
52 Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
53 Arc::new(RestrictedRouteMiddleware),
54 ))),
55 )]
56 }
57
58 pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
59 self.plugins.insert(name, Arc::new(plugin));
60 }
61
62 pub fn remove_plugin(&mut self, name: &str) {
63 self.plugins.remove(name);
64 }
65
66 pub fn add_host_function<F>(
67 &mut self,
68 namespace: impl Into<String>,
69 name: impl Into<String>,
70 builder: F,
71 ) where
72 F: Fn(&mut Store, &FunctionEnv<ExecutionContext>) -> Function + Send + Sync + 'static,
73 {
74 let ns = namespace.into();
75 let host_entry = self.host_imports.entry(ns).or_default();
76 let builder: HostFunctionBuilder = Arc::new(builder);
77 host_entry.push((name.into(), builder));
78 }
79
80 pub fn extend_host_functions<I, S>(&mut self, namespace: S, functions: I)
81 where
82 I: IntoIterator<Item = (String, HostFunctionBuilder)>,
83 S: Into<String>,
84 {
85 let ns = namespace.into();
86 let host_entry = self.host_imports.entry(ns).or_default();
87 host_entry.extend(functions);
88 }
89
90 fn host_imports(&self) -> Option<&HostFunctionMap> {
91 if self.host_imports.is_empty() {
92 None
93 } else {
94 Some(&self.host_imports)
95 }
96 }
97
98 pub async fn run_request_filter(
99 &self,
100 name: &str,
101 session: &mut Session,
102 backend: Arc<DestinationWrapper>,
103 ctx: Arc<CardinalContext>,
104 ) -> Result<MiddlewareResult, CardinalError> {
105 let plugin = self
106 .plugins
107 .get(name)
108 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
109
110 match plugin.as_ref() {
111 PluginHandler::Builtin(builtin) => match builtin {
112 PluginBuiltInType::Inbound(filter) => {
113 filter.on_request(session, backend, ctx).await
114 }
115 PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
116 "The filter {name} is not a request filter"
117 ))),
118 },
119 PluginHandler::Wasm(wasm) => {
120 let get_req_headers = session
121 .req_header()
122 .headers
123 .iter()
124 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
125 .collect();
126
127 let query =
128 parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
129
130 {
131 let runner = WasmRunner::new(wasm, self.host_imports());
132
133 let inbound_ctx = ExecutionContext::from_parts(
134 get_req_headers,
135 query,
136 None,
137 ResponseState::with_default_status(200),
138 );
139
140 let mut exec = runner.run(inbound_ctx)?;
141
142 if !exec.execution_context.req_headers().is_empty() {
143 for (key, val) in exec.execution_context.req_headers() {
144 let _ = session.req_header_mut().insert_header(key.to_string(), val);
145 }
146 }
147
148 if exec
149 .execution_context
150 .response()
151 .status_override()
152 .is_some()
153 {
154 exec.should_continue = false;
155 }
156
157 if exec.should_continue {
158 Ok(MiddlewareResult::Continue(
159 exec.execution_context.response().headers().clone(),
160 ))
161 } else {
162 let response_state = exec.execution_context.response();
163 let header_response = Self::build_response_header(response_state);
164 let _ = session
165 .write_response_header(Box::new(header_response), false)
166 .await;
167 let _ = session.respond_error(response_state.status()).await;
168 Ok(MiddlewareResult::Responded)
169 }
170 }
171 }
172 }
173 }
174
175 pub async fn run_response_filter(
176 &self,
177 name: &str,
178 session: &mut Session,
179 backend: Arc<DestinationWrapper>,
180 response: &mut pingora::http::ResponseHeader,
181 ctx: Arc<CardinalContext>,
182 ) {
183 let plugin = self
184 .plugins
185 .get(name)
186 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
187
188 if let Ok(plugin) = plugin {
189 match plugin.as_ref() {
190 PluginHandler::Builtin(builtin) => match builtin {
191 PluginBuiltInType::Inbound(_) => {
192 error!("The filter {name} is not a response filter");
193 }
194 PluginBuiltInType::Outbound(filter) => {
195 filter.on_response(session, backend, response, ctx).await
196 }
197 },
198 PluginHandler::Wasm(wasm) => {
199 let get_req_headers = session
200 .req_header()
201 .headers
202 .iter()
203 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
204 .collect();
205 let query =
206 parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
207
208 {
209 let runner = WasmRunner::new(wasm, self.host_imports());
210
211 let outbound_ctx = ExecutionContext::from_parts(
212 get_req_headers,
213 query,
214 None,
215 ResponseState::default(),
216 );
217
218 let exec = runner.run(outbound_ctx);
219
220 match &exec {
221 Ok(ex) => {
222 let response_state = ex.execution_context.response();
223
224 for (key, val) in response_state.headers() {
225 let _ =
226 response.insert_header(key.to_string(), val.to_string());
227 }
228
229 if let Some(status) = response_state.status_override() {
230 let _ = response.set_status(status);
231 }
232 }
233 Err(e) => {
234 error!("Failed to run plugin {}: {}", name, e);
235 }
236 }
237 }
238 }
239 }
240 }
241 }
242
243 fn build_response_header(response: &ResponseState) -> ResponseHeader {
244 let mut header = ResponseHeader::build(response.status(), None)
245 .expect("failed to build response header");
246
247 for (key, value) in response.headers() {
248 let _ = header.insert_header(key.to_string(), value.to_string());
249 }
250
251 header
252 }
253}
254
255impl Default for PluginContainer {
256 fn default() -> Self {
257 Self::new()
258 }
259}
260
261#[async_trait::async_trait]
262impl Provider for PluginContainer {
263 async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
264 let preloaded_plugins = ctx.config.plugins.clone();
265 let mut plugin_container = PluginContainer::new();
266
267 for plugin in preloaded_plugins {
268 let plugin_name = plugin.name();
269 let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
270
271 if plugin_exists {
272 warn!("Plugin {} already exists, skipping", plugin_name);
273 continue;
274 }
275
276 match plugin {
277 Plugin::Builtin(_) => continue,
278 Plugin::Wasm(wasm_config) => {
279 let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
280 CardinalError::Other(format!(
281 "Failed to load plugin {}: {}",
282 wasm_config.name, e
283 ))
284 })?;
285 plugin_container.plugins.insert(
286 wasm_config.name.clone(),
287 Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
288 );
289 }
290 }
291 }
292
293 Ok(plugin_container)
294 }
295}