cardinal_plugins/
container.rs

1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
3use crate::utils::parse_query_string_multi;
4use cardinal_base::context::CardinalContext;
5use cardinal_base::destinations::container::DestinationWrapper;
6use cardinal_base::provider::Provider;
7use cardinal_config::Plugin;
8use cardinal_errors::CardinalError;
9use cardinal_wasm_plugins::plugin::WasmPlugin;
10use cardinal_wasm_plugins::runner::{
11    ExecutionType, HostFunctionBuilder, HostFunctionMap, WasmRunner,
12};
13use cardinal_wasm_plugins::{ExecutionContext, ExecutionRequest, ExecutionResponse};
14use pingora::prelude::Session;
15use std::collections::HashMap;
16use std::sync::Arc;
17use tracing::{error, warn};
18
19pub enum PluginBuiltInType {
20    Inbound(Arc<DynRequestMiddleware>),
21    Outbound(Arc<DynResponseMiddleware>),
22}
23
24pub enum PluginHandler {
25    Builtin(PluginBuiltInType),
26    Wasm(Arc<WasmPlugin>),
27}
28
29pub struct PluginContainer {
30    plugins: HashMap<String, Arc<PluginHandler>>,
31    host_imports: HostFunctionMap,
32}
33
34impl PluginContainer {
35    pub fn new() -> Self {
36        Self {
37            plugins: HashMap::from_iter(Self::builtin_plugins()),
38            host_imports: HashMap::new(),
39        }
40    }
41
42    pub fn new_empty() -> Self {
43        Self {
44            plugins: HashMap::new(),
45            host_imports: HashMap::new(),
46        }
47    }
48
49    pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
50        vec![(
51            "RestrictedRouteMiddleware".to_string(),
52            Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
53                Arc::new(RestrictedRouteMiddleware),
54            ))),
55        )]
56    }
57
58    pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
59        self.plugins.insert(name, Arc::new(plugin));
60    }
61
62    pub fn remove_plugin(&mut self, name: &str) {
63        self.plugins.remove(name);
64    }
65
66    pub fn add_host_function<F>(
67        &mut self,
68        namespace: impl Into<String>,
69        name: impl Into<String>,
70        builder: F,
71    ) where
72        F: Fn(&mut cardinal_wasm_plugins::wasmer::Store) -> cardinal_wasm_plugins::wasmer::Function
73            + Send
74            + Sync
75            + 'static,
76    {
77        let ns = namespace.into();
78        let host_entry = self.host_imports.entry(ns).or_default();
79        let builder: HostFunctionBuilder = Arc::new(builder);
80        host_entry.push((name.into(), builder));
81    }
82
83    pub fn extend_host_functions<I, S>(&mut self, namespace: S, functions: I)
84    where
85        I: IntoIterator<Item = (String, HostFunctionBuilder)>,
86        S: Into<String>,
87    {
88        let ns = namespace.into();
89        let host_entry = self.host_imports.entry(ns).or_default();
90        host_entry.extend(functions);
91    }
92
93    fn host_imports(&self) -> Option<&HostFunctionMap> {
94        if self.host_imports.is_empty() {
95            None
96        } else {
97            Some(&self.host_imports)
98        }
99    }
100
101    pub async fn run_request_filter(
102        &self,
103        name: &str,
104        session: &mut Session,
105        backend: Arc<DestinationWrapper>,
106        ctx: Arc<CardinalContext>,
107    ) -> Result<MiddlewareResult, CardinalError> {
108        let plugin = self
109            .plugins
110            .get(name)
111            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
112
113        match plugin.as_ref() {
114            PluginHandler::Builtin(builtin) => match builtin {
115                PluginBuiltInType::Inbound(filter) => {
116                    filter.on_request(session, backend, ctx).await
117                }
118                PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
119                    "The filter {name} is not a request filter"
120                ))),
121            },
122            PluginHandler::Wasm(wasm) => {
123                let get_req_headers = session
124                    .req_header()
125                    .headers
126                    .iter()
127                    .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
128                    .collect();
129
130                let query =
131                    parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
132
133                {
134                    let runner = WasmRunner::new(wasm, ExecutionType::Inbound, self.host_imports());
135
136                    let inbound_ctx = ExecutionRequest {
137                        query,
138                        memory: None,
139                        req_headers: get_req_headers,
140                        body: None,
141                    };
142
143                    let exec = runner.run(ExecutionContext::Inbound(inbound_ctx))?;
144
145                    if exec.should_continue {
146                        Ok(MiddlewareResult::Continue)
147                    } else {
148                        let _ = session.respond_error(403).await;
149                        Ok(MiddlewareResult::Responded)
150                    }
151                }
152            }
153        }
154    }
155
156    pub async fn run_response_filter(
157        &self,
158        name: &str,
159        session: &mut Session,
160        backend: Arc<DestinationWrapper>,
161        response: &mut pingora::http::ResponseHeader,
162        ctx: Arc<CardinalContext>,
163    ) {
164        let plugin = self
165            .plugins
166            .get(name)
167            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
168
169        if let Ok(plugin) = plugin {
170            match plugin.as_ref() {
171                PluginHandler::Builtin(builtin) => match builtin {
172                    PluginBuiltInType::Inbound(_) => {
173                        error!("The filter {name} is not a response filter");
174                    }
175                    PluginBuiltInType::Outbound(filter) => {
176                        filter.on_response(session, backend, response, ctx).await
177                    }
178                },
179                PluginHandler::Wasm(wasm) => {
180                    let get_req_headers = session
181                        .req_header()
182                        .headers
183                        .iter()
184                        .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or_default().to_string()))
185                        .collect();
186                    let query =
187                        parse_query_string_multi(session.req_header().uri.query().unwrap_or(""));
188
189                    {
190                        let runner =
191                            WasmRunner::new(wasm, ExecutionType::Outbound, self.host_imports());
192
193                        let outbound_ctx = ExecutionResponse {
194                            memory: None,
195                            req_headers: get_req_headers,
196                            query,
197                            resp_headers: Default::default(),
198                            status: 0,
199                            body: None,
200                        };
201
202                        let exec = runner.run(ExecutionContext::Outbound(outbound_ctx));
203
204                        match &exec {
205                            Ok(ex) => {
206                                let outbound_resp = ex.execution_context.as_outbound().unwrap();
207                                for (key, val) in &outbound_resp.resp_headers {
208                                    let _ =
209                                        response.insert_header(key.to_string(), val.to_string());
210                                }
211
212                                if outbound_resp.status != 0 {
213                                    response.set_status(outbound_resp.status as u16).unwrap();
214                                }
215                            }
216                            Err(e) => {
217                                error!("Failed to run plugin {}: {}", name, e);
218                            }
219                        }
220                    }
221                }
222            }
223        }
224    }
225}
226
227impl Default for PluginContainer {
228    fn default() -> Self {
229        Self::new()
230    }
231}
232
233#[async_trait::async_trait]
234impl Provider for PluginContainer {
235    async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
236        let preloaded_plugins = ctx.config.plugins.clone();
237        let mut plugin_container = PluginContainer::new();
238
239        for plugin in preloaded_plugins {
240            let plugin_name = plugin.name();
241            let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
242
243            if plugin_exists {
244                warn!("Plugin {} already exists, skipping", plugin_name);
245                continue;
246            }
247
248            match plugin {
249                Plugin::Builtin(_) => continue,
250                Plugin::Wasm(wasm_config) => {
251                    let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
252                        CardinalError::Other(format!(
253                            "Failed to load plugin {}: {}",
254                            wasm_config.name, e
255                        ))
256                    })?;
257                    plugin_container.plugins.insert(
258                        wasm_config.name.clone(),
259                        Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
260                    );
261                }
262            }
263        }
264
265        Ok(plugin_container)
266    }
267}