cardinal_plugins/
container.rs

1use crate::builtin::restricted_route_middleware::RestrictedRouteMiddleware;
2use crate::request_context::RequestContext;
3use crate::runner::{DynRequestMiddleware, DynResponseMiddleware, MiddlewareResult};
4use cardinal_base::context::CardinalContext;
5use cardinal_base::provider::Provider;
6use cardinal_config::Plugin;
7use cardinal_errors::CardinalError;
8use cardinal_wasm_plugins::plugin::WasmPlugin;
9use cardinal_wasm_plugins::runner::{HostFunctionBuilder, HostFunctionMap, WasmRunner};
10use cardinal_wasm_plugins::wasmer::{Function, FunctionEnv, Store};
11use cardinal_wasm_plugins::{ResponseState, SharedExecutionContext};
12use pingora::http::ResponseHeader;
13use pingora::prelude::Session;
14use std::collections::HashMap;
15use std::sync::Arc;
16use tracing::{error, warn};
17
18pub enum PluginBuiltInType {
19    Inbound(Arc<DynRequestMiddleware>),
20    Outbound(Arc<DynResponseMiddleware>),
21}
22
23pub enum PluginHandler {
24    Builtin(PluginBuiltInType),
25    Wasm(Arc<WasmPlugin>),
26}
27
28pub struct PluginContainer {
29    plugins: HashMap<String, Arc<PluginHandler>>,
30    host_imports: HostFunctionMap,
31}
32
33impl PluginContainer {
34    pub fn new() -> Self {
35        Self {
36            plugins: HashMap::from_iter(Self::builtin_plugins()),
37            host_imports: HashMap::new(),
38        }
39    }
40
41    pub fn new_empty() -> Self {
42        Self {
43            plugins: HashMap::new(),
44            host_imports: HashMap::new(),
45        }
46    }
47
48    pub fn builtin_plugins() -> Vec<(String, Arc<PluginHandler>)> {
49        vec![(
50            "RestrictedRouteMiddleware".to_string(),
51            Arc::new(PluginHandler::Builtin(PluginBuiltInType::Inbound(
52                Arc::new(RestrictedRouteMiddleware),
53            ))),
54        )]
55    }
56
57    pub fn add_plugin(&mut self, name: String, plugin: PluginHandler) {
58        self.plugins.insert(name, Arc::new(plugin));
59    }
60
61    pub fn remove_plugin(&mut self, name: &str) {
62        self.plugins.remove(name);
63    }
64
65    pub fn add_host_function<F>(
66        &mut self,
67        namespace: impl Into<String>,
68        name: impl Into<String>,
69        builder: F,
70    ) where
71        F: Fn(&mut Store, &FunctionEnv<SharedExecutionContext>) -> Function + Send + Sync + 'static,
72    {
73        let ns = namespace.into();
74        let host_entry = self.host_imports.entry(ns).or_default();
75        let builder: HostFunctionBuilder = Arc::new(builder);
76        host_entry.push((name.into(), builder));
77    }
78
79    pub fn extend_host_functions<I, S>(&mut self, namespace: S, functions: I)
80    where
81        I: IntoIterator<Item = (String, HostFunctionBuilder)>,
82        S: Into<String>,
83    {
84        let ns = namespace.into();
85        let host_entry = self.host_imports.entry(ns).or_default();
86        host_entry.extend(functions);
87    }
88
89    fn host_imports(&self) -> Option<&HostFunctionMap> {
90        if self.host_imports.is_empty() {
91            None
92        } else {
93            Some(&self.host_imports)
94        }
95    }
96
97    pub async fn run_request_filter(
98        &self,
99        name: &str,
100        session: &mut Session,
101        req_ctx: &mut RequestContext,
102        ctx: Arc<CardinalContext>,
103    ) -> Result<MiddlewareResult, CardinalError> {
104        let plugin = self
105            .plugins
106            .get(name)
107            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")))?;
108
109        match plugin.as_ref() {
110            PluginHandler::Builtin(builtin) => match builtin {
111                PluginBuiltInType::Inbound(filter) => {
112                    filter.on_request(session, req_ctx, ctx).await
113                }
114                PluginBuiltInType::Outbound(_) => Err(CardinalError::Other(format!(
115                    "The filter {name} is not a request filter"
116                ))),
117            },
118            PluginHandler::Wasm(wasm) => {
119                let runner = WasmRunner::new(wasm, self.host_imports());
120
121                let exec = runner.run(req_ctx.shared_context())?;
122
123                if !exec.execution_context.req_headers().is_empty() {
124                    for (key, val) in exec.execution_context.req_headers() {
125                        let _ = session.req_header_mut().insert_header(key.to_string(), val);
126                    }
127                }
128
129                let response_state = exec.execution_context.response();
130                if !exec.should_continue || response_state.status_override().is_some() {
131                    let header_response = Self::build_response_header(response_state);
132                    let _ = session
133                        .write_response_header(Box::new(header_response), false)
134                        .await;
135                    let _ = session.respond_error(response_state.status()).await;
136                    Ok(MiddlewareResult::Responded)
137                } else {
138                    Ok(MiddlewareResult::Continue(response_state.headers().clone()))
139                }
140            }
141        }
142    }
143
144    pub async fn run_response_filter(
145        &self,
146        name: &str,
147        session: &mut Session,
148        req_ctx: &mut RequestContext,
149        response: &mut pingora::http::ResponseHeader,
150        ctx: Arc<CardinalContext>,
151    ) {
152        let plugin = self
153            .plugins
154            .get(name)
155            .ok_or_else(|| CardinalError::Other(format!("Plugin {name} does not exist")));
156
157        if let Ok(plugin) = plugin {
158            match plugin.as_ref() {
159                PluginHandler::Builtin(builtin) => match builtin {
160                    PluginBuiltInType::Inbound(_) => {
161                        error!("The filter {name} is not a response filter");
162                    }
163                    PluginBuiltInType::Outbound(filter) => {
164                        filter.on_response(session, req_ctx, response, ctx).await
165                    }
166                },
167                PluginHandler::Wasm(wasm) => {
168                    let runner = WasmRunner::new(wasm, self.host_imports());
169
170                    let exec = runner.run(req_ctx.shared_context());
171
172                    match &exec {
173                        Ok(ex) => {
174                            let response_state = ex.execution_context.response();
175
176                            for (key, val) in response_state.headers() {
177                                let _ = response.insert_header(key.to_string(), val.to_string());
178                            }
179
180                            if let Some(status) = response_state.status_override() {
181                                let _ = response.set_status(status);
182                            }
183                        }
184                        Err(e) => {
185                            error!("Failed to run plugin {}: {}", name, e);
186                        }
187                    }
188                }
189            }
190        }
191    }
192
193    fn build_response_header(response: &ResponseState) -> ResponseHeader {
194        let mut header = ResponseHeader::build(response.status(), None)
195            .expect("failed to build response header");
196
197        for (key, value) in response.headers() {
198            let _ = header.insert_header(key.to_string(), value.to_string());
199        }
200
201        header
202    }
203}
204
205impl Default for PluginContainer {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211#[async_trait::async_trait]
212impl Provider for PluginContainer {
213    async fn provide(ctx: &CardinalContext) -> Result<Self, CardinalError> {
214        let preloaded_plugins = ctx.config.plugins.clone();
215        let mut plugin_container = PluginContainer::new();
216
217        for plugin in preloaded_plugins {
218            let plugin_name = plugin.name();
219            let plugin_exists = plugin_container.plugins.contains_key(plugin_name);
220
221            if plugin_exists {
222                warn!("Plugin {} already exists, skipping", plugin_name);
223                continue;
224            }
225
226            match plugin {
227                Plugin::Builtin(_) => continue,
228                Plugin::Wasm(wasm_config) => {
229                    let wasm_plugin = WasmPlugin::from_path(&wasm_config.path).map_err(|e| {
230                        CardinalError::Other(format!(
231                            "Failed to load plugin {}: {}",
232                            wasm_config.name, e
233                        ))
234                    })?;
235                    plugin_container.plugins.insert(
236                        wasm_config.name.clone(),
237                        Arc::new(PluginHandler::Wasm(Arc::new(wasm_plugin))),
238                    );
239                }
240            }
241        }
242
243        Ok(plugin_container)
244    }
245}