1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
3use crate::utils::parse_query_string_multi;
4use cardinal_base::context::CardinalContext;
5use cardinal_base::destinations::container::DestinationWrapper;
6use cardinal_base::provider::Provider;
7use cardinal_config::Plugin;
8use cardinal_errors::CardinalError;
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{
11 ExecutionType, HostFunctionBuilder, HostFunctionMap, WasmRunner,
12};
13use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
14use cardinal_wasm_plugins::{ExecutionContext, ExecutionRequest, ExecutionResponse};
15use pingora::prelude::Session;
16use std::collections::HashMap;
17use std::sync::Arc;
18use tracing::{error, warn};
19
20pub enum PluginBuiltInType {
21 Inbound(Arc<DynRequestMiddleware>),
22 Outbound(Arc<DynResponseMiddleware>),
23}
24
25pub enum PluginHandler {
26 Builtin(PluginBuiltInType),
27 Wasm(Arc<WasmPlugin>),
28}
29
30pub struct PluginContainer {
31 plugins: HashMap<String, Arc<PluginHandler>>,
32 host_imports: HostFunctionMap,
33}
34
35impl PluginContainer {
36 pub fn new() -> Self {
37 Self {
38 plugins: HashMap::from_iter(Self::builtin_plugins()),
39 host_imports: HashMap::new(),
40 }
41 }
42
43 pub fn new_empty() -> Self {
44 Self {
45 plugins: HashMap::new(),
46 host_imports: HashMap::new(),
47 }
48 }
49
50 pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
51 vec![(
52 "RestrictedRouteMiddleware".to_string(),
53 Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
54 Arc::new(RestrictedRouteMiddleware),
55 ))),
56 )]
57 }
58
59 pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
60 self.plugins.insert(name, Arc::new(plugin));
61 }
62
63 pub fn remove_plugin(&mut self, name: &str) {
64 self.plugins.remove(name);
65 }
66
67 pub fn add_host_function<F>(
68 &mut self,
69 namespace: impl Into<String>,
70 name: impl Into<String>,
71 builder: F,
72 ) where
73 F: Fn(&mut Store, &FunctionEnv<ExecutionContext>) -> Function + Send + Sync + 'static,
74 {
75 let ns = namespace.into();
76 let host_entry = self.host_imports.entry(ns).or_default();
77 let builder: HostFunctionBuilder = Arc::new(builder);
78 host_entry.push((name.into(), builder));
79 }
80
81 pub fn extend_host_functions<I, S>(&mut self, namespace: S, functions: I)
82 where
83 I: IntoIterator<Item = (String, HostFunctionBuilder)>,
84 S: Into<String>,
85 {
86 let ns = namespace.into();
87 let host_entry = self.host_imports.entry(ns).or_default();
88 host_entry.extend(functions);
89 }
90
91 fn host_imports(&self) -> Option<&HostFunctionMap> {
92 if self.host_imports.is_empty() {
93 None
94 } else {
95 Some(&self.host_imports)
96 }
97 }
98
99 pub async fn run_request_filter(
100 &self,
101 name: &str,
102 session: &mut Session,
103 backend: Arc<DestinationWrapper>,
104 ctx: Arc<CardinalContext>,
105 ) -> Result<MiddlewareResult, CardinalError> {
106 let plugin = self
107 .plugins
108 .get(name)
109 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
110
111 match plugin.as_ref() {
112 PluginHandler::Builtin(builtin) => match builtin {
113 PluginBuiltInType::Inbound(filter) => {
114 filter.on_request(session, backend, ctx).await
115 }
116 PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
117 "The filter {name} is not a request filter"
118 ))),
119 },
120 PluginHandler::Wasm(wasm) => {
121 let get_req_headers = session
122 .req_header()
123 .headers
124 .iter()
125 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
126 .collect();
127
128 let query =
129 parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
130
131 {
132 let runner = WasmRunner::new(wasm, ExecutionType::Inbound, self.host_imports());
133
134 let inbound_ctx = ExecutionRequest {
135 query,
136 memory: None,
137 req_headers: get_req_headers,
138 body: None,
139 };
140
141 let exec = runner.run(ExecutionContext::Inbound(inbound_ctx))?;
142
143 if exec.should_continue {
144 Ok(MiddlewareResult::Continue)
145 } else {
146 let _ = session.respond_error(403).await;
147 Ok(MiddlewareResult::Responded)
148 }
149 }
150 }
151 }
152 }
153
154 pub async fn run_response_filter(
155 &self,
156 name: &str,
157 session: &mut Session,
158 backend: Arc<DestinationWrapper>,
159 response: &mut pingora::http::ResponseHeader,
160 ctx: Arc<CardinalContext>,
161 ) {
162 let plugin = self
163 .plugins
164 .get(name)
165 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
166
167 if let Ok(plugin) = plugin {
168 match plugin.as_ref() {
169 PluginHandler::Builtin(builtin) => match builtin {
170 PluginBuiltInType::Inbound(_) => {
171 error!("The filter {name} is not a response filter");
172 }
173 PluginBuiltInType::Outbound(filter) => {
174 filter.on_response(session, backend, response, ctx).await
175 }
176 },
177 PluginHandler::Wasm(wasm) => {
178 let get_req_headers = session
179 .req_header()
180 .headers
181 .iter()
182 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
183 .collect();
184 let query =
185 parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
186
187 {
188 let runner =
189 WasmRunner::new(wasm, ExecutionType::Outbound, self.host_imports());
190
191 let outbound_ctx = ExecutionResponse {
192 memory: None,
193 req_headers: get_req_headers,
194 query,
195 resp_headers: Default::default(),
196 status: 0,
197 body: None,
198 };
199
200 let exec = runner.run(ExecutionContext::Outbound(outbound_ctx));
201
202 match &exec {
203 Ok(ex) => {
204 let outbound_resp = ex.execution_context.as_outbound().unwrap();
205 for (key, val) in &outbound_resp.resp_headers {
206 let _ =
207 response.insert_header(key.to_string(), val.to_string());
208 }
209
210 if outbound_resp.status != 0 {
211 response.set_status(outbound_resp.status as u16).unwrap();
212 }
213 }
214 Err(e) => {
215 error!("Failed to run plugin {}: {}", name, e);
216 }
217 }
218 }
219 }
220 }
221 }
222 }
223}
224
225impl Default for PluginContainer {
226 fn default() -> Self {
227 Self::new()
228 }
229}
230
231#[async_trait::async_trait]
232impl Provider for PluginContainer {
233 async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
234 let preloaded_plugins = ctx.config.plugins.clone();
235 let mut plugin_container = PluginContainer::new();
236
237 for plugin in preloaded_plugins {
238 let plugin_name = plugin.name();
239 let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
240
241 if plugin_exists {
242 warn!("Plugin {} already exists, skipping", plugin_name);
243 continue;
244 }
245
246 match plugin {
247 Plugin::Builtin(_) => continue,
248 Plugin::Wasm(wasm_config) => {
249 let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
250 CardinalError::Other(format!(
251 "Failed to load plugin {}: {}",
252 wasm_config.name, e
253 ))
254 })?;
255 plugin_container.plugins.insert(
256 wasm_config.name.clone(),
257 Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
258 );
259 }
260 }
261 }
262
263 Ok(plugin_container)
264 }
265}