cardinal_plugins/
container.rs

1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
3use crate::utils::parse_query_string_multi;
4use cardinal_base::context::CardinalContext;
5use cardinal_base::destinations::container::DestinationWrapper;
6use cardinal_base::provider::Provider;
7use cardinal_config::Plugin;
8use cardinal_errors::CardinalError;
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{
11    ExecutionType, HostFunctionBuilder, HostFunctionMap, WasmRunner,
12};
13use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
14use cardinal_wasm_plugins::{ExecutionContext, ExecutionRequest, ExecutionResponse};
15use pingora::prelude::Session;
16use std::collections::HashMap;
17use std::sync::Arc;
18use tracing::{error, warn};
19
20pub enum PluginBuiltInType {
21    Inbound(Arc<DynRequestMiddleware>),
22    Outbound(Arc<DynResponseMiddleware>),
23}
24
25pub enum PluginHandler {
26    Builtin(PluginBuiltInType),
27    Wasm(Arc<WasmPlugin>),
28}
29
30pub struct PluginContainer {
31    plugins: HashMap<String, Arc<PluginHandler>>,
32    host_imports: HostFunctionMap,
33}
34
35impl PluginContainer {
36    pub fn new() -> Self {
37        Self {
38            plugins: HashMap::from_iter(Self::builtin_plugins()),
39            host_imports: HashMap::new(),
40        }
41    }
42
43    pub fn new_empty() -> Self {
44        Self {
45            plugins: HashMap::new(),
46            host_imports: HashMap::new(),
47        }
48    }
49
50    pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
51        vec![(
52            "RestrictedRouteMiddleware".to_string(),
53            Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
54                Arc::new(RestrictedRouteMiddleware),
55            ))),
56        )]
57    }
58
59    pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
60        self.plugins.insert(name, Arc::new(plugin));
61    }
62
63    pub fn remove_plugin(&mut self, name: &str) {
64        self.plugins.remove(name);
65    }
66
67    pub fn add_host_function<F>(
68        &mut self,
69        namespace: impl Into<String>,
70        name: impl Into<String>,
71        builder: F,
72    ) where
73        F: Fn(&mut Store, &FunctionEnv<ExecutionContext>) -> Function + Send + Sync + 'static,
74    {
75        let ns = namespace.into();
76        let host_entry = self.host_imports.entry(ns).or_default();
77        let builder: HostFunctionBuilder = Arc::new(builder);
78        host_entry.push((name.into(), builder));
79    }
80
81    pub fn extend_host_functions<I, S>(&mut self, namespace: S, functions: I)
82    where
83        I: IntoIterator<Item = (String, HostFunctionBuilder)>,
84        S: Into<String>,
85    {
86        let ns = namespace.into();
87        let host_entry = self.host_imports.entry(ns).or_default();
88        host_entry.extend(functions);
89    }
90
91    fn host_imports(&self) -> Option<&HostFunctionMap> {
92        if self.host_imports.is_empty() {
93            None
94        } else {
95            Some(&self.host_imports)
96        }
97    }
98
99    pub async fn run_request_filter(
100        &self,
101        name: &str,
102        session: &mut Session,
103        backend: Arc<DestinationWrapper>,
104        ctx: Arc<CardinalContext>,
105    ) -> Result<MiddlewareResult, CardinalError> {
106        let plugin = self
107            .plugins
108            .get(name)
109            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
110
111        match plugin.as_ref() {
112            PluginHandler::Builtin(builtin) => match builtin {
113                PluginBuiltInType::Inbound(filter) => {
114                    filter.on_request(session, backend, ctx).await
115                }
116                PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
117                    "The filter {name} is not a request filter"
118                ))),
119            },
120            PluginHandler::Wasm(wasm) => {
121                let get_req_headers = session
122                    .req_header()
123                    .headers
124                    .iter()
125                    .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
126                    .collect();
127
128                let query =
129                    parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
130
131                {
132                    let runner = WasmRunner::new(wasm, ExecutionType::Inbound, self.host_imports());
133
134                    let inbound_ctx = ExecutionRequest {
135                        query,
136                        memory: None,
137                        req_headers: get_req_headers,
138                        body: None,
139                    };
140
141                    let exec = runner.run(ExecutionContext::Inbound(inbound_ctx))?;
142
143                    if exec.should_continue {
144                        Ok(MiddlewareResult::Continue)
145                    } else {
146                        let _ = session.respond_error(403).await;
147                        Ok(MiddlewareResult::Responded)
148                    }
149                }
150            }
151        }
152    }
153
154    pub async fn run_response_filter(
155        &self,
156        name: &str,
157        session: &mut Session,
158        backend: Arc<DestinationWrapper>,
159        response: &mut pingora::http::ResponseHeader,
160        ctx: Arc<CardinalContext>,
161    ) {
162        let plugin = self
163            .plugins
164            .get(name)
165            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
166
167        if let Ok(plugin) = plugin {
168            match plugin.as_ref() {
169                PluginHandler::Builtin(builtin) => match builtin {
170                    PluginBuiltInType::Inbound(_) => {
171                        error!("The filter {name} is not a response filter");
172                    }
173                    PluginBuiltInType::Outbound(filter) => {
174                        filter.on_response(session, backend, response, ctx).await
175                    }
176                },
177                PluginHandler::Wasm(wasm) => {
178                    let get_req_headers = session
179                        .req_header()
180                        .headers
181                        .iter()
182                        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
183                        .collect();
184                    let query =
185                        parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
186
187                    {
188                        let runner =
189                            WasmRunner::new(wasm, ExecutionType::Outbound, self.host_imports());
190
191                        let outbound_ctx = ExecutionResponse {
192                            memory: None,
193                            req_headers: get_req_headers,
194                            query,
195                            resp_headers: Default::default(),
196                            status: 0,
197                            body: None,
198                        };
199
200                        let exec = runner.run(ExecutionContext::Outbound(outbound_ctx));
201
202                        match &exec {
203                            Ok(ex) => {
204                                let outbound_resp = ex.execution_context.as_outbound().unwrap();
205                                for (key, val) in &outbound_resp.resp_headers {
206                                    let _ =
207                                        response.insert_header(key.to_string(), val.to_string());
208                                }
209
210                                if outbound_resp.status != 0 {
211                                    response.set_status(outbound_resp.status as u16).unwrap();
212                                }
213                            }
214                            Err(e) => {
215                                error!("Failed to run plugin {}: {}", name, e);
216                            }
217                        }
218                    }
219                }
220            }
221        }
222    }
223}
224
225impl Default for PluginContainer {
226    fn default() -> Self {
227        Self::new()
228    }
229}
230
231#[async_trait::async_trait]
232impl Provider for PluginContainer {
233    async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
234        let preloaded_plugins = ctx.config.plugins.clone();
235        let mut plugin_container = PluginContainer::new();
236
237        for plugin in preloaded_plugins {
238            let plugin_name = plugin.name();
239            let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
240
241            if plugin_exists {
242                warn!("Plugin {} already exists, skipping", plugin_name);
243                continue;
244            }
245
246            match plugin {
247                Plugin::Builtin(_) => continue,
248                Plugin::Wasm(wasm_config) => {
249                    let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
250                        CardinalError::Other(format!(
251                            "Failed to load plugin {}: {}",
252                            wasm_config.name, e
253                        ))
254                    })?;
255                    plugin_container.plugins.insert(
256                        wasm_config.name.clone(),
257                        Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
258                    );
259                }
260            }
261        }
262
263        Ok(plugin_container)
264    }
265}