cardinal_plugins/
container.rs1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::request_context::RequestContext;
3use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
4use cardinal_base::context::CardinalContext;
5use cardinal_base::provider::Provider;
6use cardinal_config::Plugin;
7use cardinal_errors::CardinalError;
8use cardinal_wasm_plugins::plugin::WasmPlugin;
9use cardinal_wasm_plugins::runner::{HostFunctionBuilder, HostFunctionMap, WasmRunner};
10use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
11use cardinal_wasm_plugins::{ResponseState, SharedExecutionContext};
12use pingora::http::ResponseHeader;
13use pingora::prelude::Session;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tracing::{error, warn};
17
18pub enum PluginBuiltInType {
19 Inbound(Arc<DynRequestMiddleware>),
20 Outbound(Arc<DynResponseMiddleware>),
21}
22
23pub enum PluginHandler {
24 Builtin(PluginBuiltInType),
25 Wasm(Arc<WasmPlugin>),
26}
27
28pub struct PluginContainer {
29 plugins: HashMap<String, Arc<PluginHandler>>,
30 host_imports: HostFunctionMap,
31}
32
33impl PluginContainer {
34 pub fn new() -> Self {
35 Self {
36 plugins: HashMap::from_iter(Self::builtin_plugins()),
37 host_imports: HashMap::new(),
38 }
39 }
40
41 pub fn new_empty() -> Self {
42 Self {
43 plugins: HashMap::new(),
44 host_imports: HashMap::new(),
45 }
46 }
47
48 pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
49 vec![(
50 "RestrictedRouteMiddleware".to_string(),
51 Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
52 Arc::new(RestrictedRouteMiddleware),
53 ))),
54 )]
55 }
56
57 pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
58 self.plugins.insert(name, Arc::new(plugin));
59 }
60
61 pub fn remove_plugin(&mut self, name: &str) {
62 self.plugins.remove(name);
63 }
64
65 pub fn add_host_function<F>(
66 &mut self,
67 namespace: impl Into<String>,
68 name: impl Into<String>,
69 builder: F,
70 ) where
71 F: Fn(&mut Store, &FunctionEnv<SharedExecutionContext>) -> Function + Send + Sync + 'static,
72 {
73 let ns = namespace.into();
74 let host_entry = self.host_imports.entry(ns).or_default();
75 let builder: HostFunctionBuilder = Arc::new(builder);
76 host_entry.push((name.into(), builder));
77 }
78
79 pub fn extend_host_functions<I, S>(&mut self, namespace: S, functions: I)
80 where
81 I: IntoIterator<Item = (String, HostFunctionBuilder)>,
82 S: Into<String>,
83 {
84 let ns = namespace.into();
85 let host_entry = self.host_imports.entry(ns).or_default();
86 host_entry.extend(functions);
87 }
88
89 fn host_imports(&self) -> Option<&HostFunctionMap> {
90 if self.host_imports.is_empty() {
91 None
92 } else {
93 Some(&self.host_imports)
94 }
95 }
96
97 pub async fn run_request_filter(
98 &self,
99 name: &str,
100 session: &mut Session,
101 req_ctx: &mut RequestContext,
102 ctx: Arc<CardinalContext>,
103 ) -> Result<MiddlewareResult, CardinalError> {
104 let plugin = self
105 .plugins
106 .get(name)
107 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
108
109 match plugin.as_ref() {
110 PluginHandler::Builtin(builtin) => match builtin {
111 PluginBuiltInType::Inbound(filter) => {
112 filter.on_request(session, req_ctx, ctx).await
113 }
114 PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
115 "The filter {name} is not a request filter"
116 ))),
117 },
118 PluginHandler::Wasm(wasm) => {
119 let runner = WasmRunner::new(wasm, self.host_imports());
120
121 let exec = runner.run(req_ctx.shared_context())?;
122
123 if !exec.execution_context.req_headers().is_empty() {
124 for (key, val) in exec.execution_context.req_headers() {
125 let _ = session.req_header_mut().insert_header(key.to_string(), val);
126 }
127 }
128
129 let response_state = exec.execution_context.response();
130 if !exec.should_continue || response_state.status_override().is_some() {
131 let header_response = Self::build_response_header(response_state);
132 let _ = session
133 .write_response_header(Box::new(header_response), false)
134 .await;
135 let _ = session.respond_error(response_state.status()).await;
136 Ok(MiddlewareResult::Responded)
137 } else {
138 Ok(MiddlewareResult::Continue(response_state.headers().clone()))
139 }
140 }
141 }
142 }
143
144 pub async fn run_response_filter(
145 &self,
146 name: &str,
147 session: &mut Session,
148 req_ctx: &mut RequestContext,
149 response: &mut pingora::http::ResponseHeader,
150 ctx: Arc<CardinalContext>,
151 ) {
152 let plugin = self
153 .plugins
154 .get(name)
155 .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
156
157 if let Ok(plugin) = plugin {
158 match plugin.as_ref() {
159 PluginHandler::Builtin(builtin) => match builtin {
160 PluginBuiltInType::Inbound(_) => {
161 error!("The filter {name} is not a response filter");
162 }
163 PluginBuiltInType::Outbound(filter) => {
164 filter.on_response(session, req_ctx, response, ctx).await
165 }
166 },
167 PluginHandler::Wasm(wasm) => {
168 let runner = WasmRunner::new(wasm, self.host_imports());
169
170 let exec = runner.run(req_ctx.shared_context());
171
172 match &exec {
173 Ok(ex) => {
174 let response_state = ex.execution_context.response();
175
176 for (key, val) in response_state.headers() {
177 let _ = response.insert_header(key.to_string(), val.to_string());
178 }
179
180 if let Some(status) = response_state.status_override() {
181 let _ = response.set_status(status);
182 }
183 }
184 Err(e) => {
185 error!("Failed to run plugin {}: {}", name, e);
186 }
187 }
188 }
189 }
190 }
191 }
192
193 fn build_response_header(response: &ResponseState) -> ResponseHeader {
194 let mut header = ResponseHeader::build(response.status(), None)
195 .expect("failed to build response header");
196
197 for (key, value) in response.headers() {
198 let _ = header.insert_header(key.to_string(), value.to_string());
199 }
200
201 header
202 }
203}
204
205impl Default for PluginContainer {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211#[async_trait::async_trait]
212impl Provider for PluginContainer {
213 async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
214 let preloaded_plugins = ctx.config.plugins.clone();
215 let mut plugin_container = PluginContainer::new();
216
217 for plugin in preloaded_plugins {
218 let plugin_name = plugin.name();
219 let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
220
221 if plugin_exists {
222 warn!("Plugin {} already exists, skipping", plugin_name);
223 continue;
224 }
225
226 match plugin {
227 Plugin::Builtin(_) => continue,
228 Plugin::Wasm(wasm_config) => {
229 let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
230 CardinalError::Other(format!(
231 "Failed to load plugin {}: {}",
232 wasm_config.name, e
233 ))
234 })?;
235 plugin_container.plugins.insert(
236 wasm_config.name.clone(),
237 Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
238 );
239 }
240 }
241 }
242
243 Ok(plugin_container)
244 }
245}