bulwark_wasm_host/
plugin.rs

1#[doc(hidden)]
2mod bulwark_host {
3    wasmtime::component::bindgen!({
4        world: "bulwark:plugin/host-api",
5        async: true,
6    });
7}
8
9#[doc(hidden)]
10mod handlers {
11    wasmtime::component::bindgen!({
12        world: "bulwark:plugin/handlers",
13        async: true,
14    });
15}
16
17use {
18    crate::{
19        ContextInstantiationError, PluginExecutionError, PluginInstantiationError, PluginLoadError,
20    },
21    async_trait::async_trait,
22    bulwark_config::ConfigSerializationError,
23    bulwark_host::{DecisionInterface, OutcomeInterface},
24    bulwark_wasm_sdk::{Decision, Outcome},
25    chrono::Utc,
26    redis::Commands,
27    std::str::FromStr,
28    std::{
29        collections::BTreeSet,
30        convert::From,
31        net::IpAddr,
32        ops::DerefMut,
33        path::Path,
34        sync::{Arc, Mutex, MutexGuard},
35    },
36    url::Url,
37    validator::Validate,
38    wasmtime::component::{Component, Linker},
39    wasmtime::{AsContextMut, Config, Engine, Store},
40    wasmtime_wasi::preview2::{
41        pipe::MemoryOutputPipe, HostOutputStream, StdoutStream, Table, WasiCtx, WasiCtxBuilder,
42        WasiView,
43    },
44};
45
46extern crate redis;
47
48/// Wraps an [`IpAddr`] representing the remote IP for the incoming request.
49///
50/// In an architecture with proxies or load balancers in front of Bulwark, this IP will belong to the immediately
51/// exterior proxy or load balancer rather than the IP address of the client that originated the request.
52#[derive(Copy, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
53pub struct RemoteIP(pub IpAddr);
54/// Wraps an [`IpAddr`] representing the forwarded IP for the incoming request.
55///
56/// In an architecture with proxies or load balancers in front of Bulwark, this IP will belong to the IP address
57/// of the client that originated the request rather than the immediately exterior proxy or load balancer.
58#[derive(Copy, Clone, Eq, PartialEq, Hash, PartialOrd, Ord)]
59pub struct ForwardedIP(pub IpAddr);
60
61// TODO: from.rs
62
63impl From<Arc<bulwark_wasm_sdk::Request>> for bulwark_host::RequestInterface {
64    fn from(request: Arc<bulwark_wasm_sdk::Request>) -> Self {
65        bulwark_host::RequestInterface {
66            method: request.method().to_string(),
67            uri: request.uri().to_string(),
68            version: format!("{:?}", request.version()),
69            headers: request
70                .headers()
71                .iter()
72                .map(|(name, value)| (name.to_string(), value.as_bytes().to_vec()))
73                .collect(),
74            body_received: request.body().received,
75            chunk_start: request.body().start,
76            chunk_length: request.body().size,
77            end_of_stream: request.body().end_of_stream,
78            // TODO: figure out how to avoid the copy
79            chunk: request.body().content.clone(),
80        }
81    }
82}
83
84impl From<Arc<bulwark_wasm_sdk::Response>> for bulwark_host::ResponseInterface {
85    fn from(response: Arc<bulwark_wasm_sdk::Response>) -> Self {
86        bulwark_host::ResponseInterface {
87            // this unwrap should be okay since a non-zero u16 should always be coercible to u32
88            status: response.status().as_u16().try_into().unwrap(),
89            headers: response
90                .headers()
91                .iter()
92                .map(|(name, value)| (name.to_string(), value.as_bytes().to_vec()))
93                .collect(),
94            body_received: response.body().received,
95            chunk_start: response.body().start,
96            chunk_length: response.body().size,
97            end_of_stream: response.body().end_of_stream,
98            // TODO: figure out how to avoid the copy
99            chunk: response.body().content.clone(),
100        }
101    }
102}
103
104impl From<IpAddr> for bulwark_host::IpInterface {
105    fn from(ip: IpAddr) -> Self {
106        match ip {
107            IpAddr::V4(v4) => {
108                let octets = v4.octets();
109                bulwark_host::IpInterface::V4((octets[0], octets[1], octets[2], octets[3]))
110            }
111            IpAddr::V6(v6) => {
112                let segments = v6.segments();
113                bulwark_host::IpInterface::V6((
114                    segments[0],
115                    segments[1],
116                    segments[2],
117                    segments[3],
118                    segments[4],
119                    segments[5],
120                    segments[6],
121                    segments[7],
122                ))
123            }
124        }
125    }
126}
127
128impl From<DecisionInterface> for Decision {
129    fn from(decision: DecisionInterface) -> Self {
130        Decision {
131            accept: decision.accepted,
132            restrict: decision.restricted,
133            unknown: decision.unknown,
134        }
135    }
136}
137
138impl From<Decision> for DecisionInterface {
139    fn from(decision: Decision) -> Self {
140        DecisionInterface {
141            accepted: decision.accept,
142            restricted: decision.restrict,
143            unknown: decision.unknown,
144        }
145    }
146}
147
148impl From<Outcome> for OutcomeInterface {
149    fn from(outcome: Outcome) -> Self {
150        match outcome {
151            Outcome::Trusted => OutcomeInterface::Trusted,
152            Outcome::Accepted => OutcomeInterface::Accepted,
153            Outcome::Suspected => OutcomeInterface::Suspected,
154            Outcome::Restricted => OutcomeInterface::Restricted,
155        }
156    }
157}
158
159/// The primary output of a [`PluginInstance`]'s execution. Combines a [`Decision`] and a list of tags together.
160///
161/// Both the output of individual plugins as well as the combined decision output of a group of plugins may be
162/// represented by `DecisionComponents`. The latter is the result of applying Dempster-Shafer combination to each
163/// `decision` value in a [`DecisionComponents`] list and then taking the union set of all `tags` lists and forming
164/// a new [`DecisionComponents`] with both results.
165#[derive(Clone, Default)]
166pub struct DecisionComponents {
167    /// A `Decision` made by a plugin or a group of plugins
168    pub decision: Decision,
169    /// The tags applied by plugins to annotate a [`Decision`]
170    pub tags: Vec<String>,
171}
172
173/// Wraps a Redis connection pool and a registry of predefined Lua scripts.
174pub struct RedisInfo {
175    /// The connection pool
176    pub pool: r2d2::Pool<redis::Client>,
177    /// A Lua script registry
178    pub registry: ScriptRegistry,
179}
180
181/// A registry of predefined Lua scripts for execution within Redis.
182pub struct ScriptRegistry {
183    /// Increments a Redis key's counter value if it has not yet expired.
184    ///
185    /// Uses the service's clock rather than Redis'. Uses Redis' TTL on a best-effort basis.
186    increment_rate_limit: redis::Script,
187    /// Checks a Redis key's counter value if it has not yet expired.
188    ///
189    /// Uses the service's clock rather than Redis'. Uses Redis' TTL on a best-effort basis.
190    check_rate_limit: redis::Script,
191    /// Increments a Redis key's counter value, corresponding to either success or failure, if it has not yet expired.
192    ///
193    /// Uses the service's clock rather than Redis'. Uses Redis' TTL on a best-effort basis.
194    increment_breaker: redis::Script,
195    /// Checks a Redis key's counter value, corresponding to either success or failure, if it has not yet expired.
196    ///
197    /// Uses the service's clock rather than Redis'. Uses Redis' TTL on a best-effort basis.
198    check_breaker: redis::Script,
199}
200
201impl Default for ScriptRegistry {
202    fn default() -> ScriptRegistry {
203        ScriptRegistry {
204            // TODO: handle overflow errors by expiring everything on overflow and returning nil?
205            increment_rate_limit: redis::Script::new(
206                r#"
207                local counter_key = "bulwark:rl:" .. KEYS[1]
208                local increment_delta = tonumber(ARGV[1])
209                local expiration_window = tonumber(ARGV[2])
210                local timestamp = tonumber(ARGV[3])
211                local expiration_key = counter_key .. ":exp"
212                local expiration = tonumber(redis.call("get", expiration_key))
213                local next_expiration = timestamp + expiration_window
214                if not expiration or timestamp > expiration then
215                    redis.call("set", expiration_key, next_expiration)
216                    redis.call("set", counter_key, 0)
217                    redis.call("expireat", expiration_key, next_expiration + 1)
218                    redis.call("expireat", counter_key, next_expiration + 1)
219                    expiration = next_expiration
220                end
221                local attempts = redis.call("incrby", counter_key, increment_delta)
222                return { attempts, expiration }
223                "#,
224            ),
225            check_rate_limit: redis::Script::new(
226                r#"
227                local counter_key = "bulwark:rl:" .. KEYS[1]
228                local expiration_key = counter_key .. ":exp"
229                local timestamp = tonumber(ARGV[1])
230                local attempts = tonumber(redis.call("get", counter_key))
231                local expiration = nil
232                if attempts then
233                    expiration = tonumber(redis.call("get", expiration_key))
234                    if not expiration or timestamp > expiration then
235                        attempts = nil
236                        expiration = nil
237                    end
238                end
239                return { attempts, expiration }
240                "#,
241            ),
242            increment_breaker: redis::Script::new(
243                r#"
244                local generation_key = "bulwark:bk:g:" .. KEYS[1]
245                local success_key = "bulwark:bk:s:" .. KEYS[1]
246                local failure_key = "bulwark:bk:f:" .. KEYS[1]
247                local consec_success_key = "bulwark:bk:cs:" .. KEYS[1]
248                local consec_failure_key = "bulwark:bk:cf:" .. KEYS[1]
249                local success_delta = tonumber(ARGV[1])
250                local failure_delta = tonumber(ARGV[2])
251                local expiration_window = tonumber(ARGV[3])
252                local timestamp = tonumber(ARGV[4])
253                local expiration = timestamp + expiration_window
254                local generation = redis.call("incrby", generation_key, 1)
255                local successes = 0
256                local failures = 0
257                local consec_successes = 0
258                local consec_failures = 0
259                if success_delta > 0 then
260                    successes = redis.call("incrby", success_key, success_delta)
261                    failures = tonumber(redis.call("get", failure_key)) or 0
262                    consec_successes = redis.call("incrby", consec_success_key, success_delta)
263                    redis.call("set", consec_failure_key, 0)
264                    consec_failures = 0
265                else
266                    successes = tonumber(redis.call("get", success_key))
267                    failures = redis.call("incrby", failure_key, failure_delta) or 0
268                    redis.call("set", consec_success_key, 0)
269                    consec_successes = 0
270                    consec_failures = redis.call("incrby", consec_failure_key, failure_delta)
271                end
272                redis.call("expireat", generation_key, expiration + 1)
273                redis.call("expireat", success_key, expiration + 1)
274                redis.call("expireat", failure_key, expiration + 1)
275                redis.call("expireat", consec_success_key, expiration + 1)
276                redis.call("expireat", consec_failure_key, expiration + 1)
277                return { generation, successes, failures, consec_successes, consec_failures, expiration }
278                "#,
279            ),
280            check_breaker: redis::Script::new(
281                r#"
282                local generation_key = "bulwark:bk:g:" .. KEYS[1]
283                local success_key = "bulwark:bk:s:" .. KEYS[1]
284                local failure_key = "bulwark:bk:f:" .. KEYS[1]
285                local consec_success_key = "bulwark:bk:cs:" .. KEYS[1]
286                local consec_failure_key = "bulwark:bk:cf:" .. KEYS[1]
287                local generation = tonumber(redis.call("get", generation_key))
288                if not generation then
289                    return { nil, nil, nil, nil, nil, nil }
290                end
291                local successes = tonumber(redis.call("get", success_key)) or 0
292                local failures = tonumber(redis.call("get", failure_key)) or 0
293                local consec_successes = tonumber(redis.call("get", consec_success_key)) or 0
294                local consec_failures = tonumber(redis.call("get", consec_failure_key)) or 0
295                local expiration = tonumber(redis.call("expiretime", success_key)) - 1
296                return { generation, successes, failures, consec_successes, consec_failures, expiration }
297                "#,
298            ),
299        }
300    }
301}
302
303/// The RequestContext provides a store of information that needs to cross the plugin sandbox boundary.
304pub struct RequestContext {
305    /// The WASI context that determines how things like stdio map to our buffers.
306    wasi_ctx: WasiCtx,
307    /// The WASI table that maps handles to resources.
308    wasi_table: Table,
309    /// Context values that will not be modified.
310    read_only_ctx: ReadOnlyContext,
311    /// Context values that will be mutated by the guest environment.
312    guest_mut_ctx: GuestMutableContext,
313    /// Context values that will be mutated by the host environment.
314    host_mut_ctx: HostMutableContext,
315    /// The standard I/O buffers used by WASI and captured for logging.
316    stdio: PluginStdio,
317}
318
319impl RequestContext {
320    /// Creates a new `RequestContext`.
321    ///
322    /// # Arguments
323    ///
324    /// * `plugin` - The [`Plugin`] and its associated configuration.
325    /// * `redis_info` - The Redis connection pool.
326    /// * `params` - A key-value map that plugins use to pass values within the context of a request.
327    ///     Any parameters captured by the router will be added to this before plugin execution.
328    /// * `request` - The [`Request`](bulwark_wasm_sdk::Request) that plugins will be operating on.
329    pub fn new(
330        plugin: Arc<Plugin>,
331        redis_info: Option<Arc<RedisInfo>>,
332        http_client: Arc<reqwest::blocking::Client>,
333        params: Arc<Mutex<bulwark_wasm_sdk::Map<String, bulwark_wasm_sdk::Value>>>,
334        request: Arc<bulwark_wasm_sdk::Request>,
335    ) -> Result<RequestContext, ContextInstantiationError> {
336        let stdio = PluginStdio::default();
337        let wasi_ctx = WasiCtxBuilder::new()
338            .stdout(stdio.stdout.clone())
339            .stderr(stdio.stderr.clone())
340            .build();
341        let client_ip = request
342            .extensions()
343            .get::<ForwardedIP>()
344            .map(|forwarded_ip| bulwark_host::IpInterface::from(forwarded_ip.0));
345
346        Ok(RequestContext {
347            wasi_ctx,
348            wasi_table: Table::new(),
349            read_only_ctx: ReadOnlyContext {
350                config: Arc::new(plugin.guest_config()?),
351                permissions: plugin.permissions(),
352                client_ip,
353                redis_info,
354                http_client,
355            },
356            guest_mut_ctx: GuestMutableContext {
357                receive_request_body: Arc::new(Mutex::new(false)),
358                receive_response_body: Arc::new(Mutex::new(false)),
359                params,
360                decision_components: DecisionComponents::default(),
361            },
362            host_mut_ctx: HostMutableContext::new(bulwark_host::RequestInterface::from(request)),
363            stdio,
364        })
365    }
366}
367
368impl WasiView for RequestContext {
369    fn table(&self) -> &Table {
370        &self.wasi_table
371    }
372
373    fn table_mut(&mut self) -> &mut Table {
374        &mut self.wasi_table
375    }
376
377    fn ctx(&self) -> &WasiCtx {
378        &self.wasi_ctx
379    }
380
381    fn ctx_mut(&mut self) -> &mut WasiCtx {
382        &mut self.wasi_ctx
383    }
384}
385
386/// A singular detection plugin and provides the interface between WASM host and guest.
387///
388/// One `Plugin` may spawn many [`PluginInstance`]s, which will handle the incoming request data.
389#[derive(Clone)]
390pub struct Plugin {
391    reference: String,
392    config: Arc<bulwark_config::Plugin>,
393    engine: Engine,
394    component: Component,
395}
396
397impl Plugin {
398    /// Creates and compiles a new [`Plugin`] from a [`String`] of
399    /// [WAT](https://webassembly.github.io/spec/core/text/index.html)-formatted WASM.
400    pub fn from_wat(
401        name: String,
402        wat: &str,
403        config: &bulwark_config::Plugin,
404    ) -> Result<Self, PluginLoadError> {
405        Self::from_component(
406            name,
407            config,
408            |engine| -> Result<Component, PluginLoadError> {
409                Ok(Component::new(engine, wat.as_bytes())?)
410            },
411        )
412    }
413
414    /// Creates and compiles a new [`Plugin`] from a byte slice of WASM.
415    ///
416    /// The bytes it expects are what you'd get if you read in a `*.wasm` file.
417    /// See [`Component::from_binary`].
418    pub fn from_bytes(
419        name: String,
420        bytes: &[u8],
421        config: &bulwark_config::Plugin,
422    ) -> Result<Self, PluginLoadError> {
423        Self::from_component(
424            name,
425            config,
426            |engine| -> Result<Component, PluginLoadError> {
427                Ok(Component::from_binary(engine, bytes)?)
428            },
429        )
430    }
431
432    /// Creates and compiles a new [`Plugin`] by reading in a file in either `*.wasm` or `*.wat` format.
433    ///
434    /// See [`Component::from_file`].
435    pub fn from_file(
436        path: impl AsRef<Path>,
437        config: &bulwark_config::Plugin,
438    ) -> Result<Self, PluginLoadError> {
439        let name = config.reference.clone();
440        Self::from_component(
441            name,
442            config,
443            |engine| -> Result<Component, PluginLoadError> {
444                Ok(Component::from_file(engine, &path)?)
445            },
446        )
447    }
448
449    /// Helper method for the other `from_*` functions.
450    fn from_component<F>(
451        reference: String,
452        config: &bulwark_config::Plugin,
453        mut get_component: F,
454    ) -> Result<Self, PluginLoadError>
455    where
456        F: FnMut(&Engine) -> Result<Component, PluginLoadError>,
457    {
458        let mut wasm_config = Config::new();
459        wasm_config.wasm_backtrace_details(wasmtime::WasmBacktraceDetails::Enable);
460        wasm_config.wasm_multi_memory(true);
461        wasm_config.wasm_component_model(true);
462        wasm_config.async_support(true);
463
464        let engine = Engine::new(&wasm_config)?;
465        let component = get_component(&engine)?;
466
467        Ok(Plugin {
468            reference,
469            config: Arc::new(config.clone()),
470            engine,
471            component,
472        })
473    }
474
475    /// Makes the guest's configuration available as serialized JSON bytes.
476    fn guest_config(&self) -> Result<Vec<u8>, ConfigSerializationError> {
477        // TODO: should guest config be required or optional?
478        self.config.config_to_json()
479    }
480
481    /// Makes the permissions the plugin has been granted available to the guest environment.
482    fn permissions(&self) -> bulwark_config::Permissions {
483        self.config.permissions.clone()
484    }
485}
486
487/// A collection of values that will not change over the lifecycle of a request/response.
488struct ReadOnlyContext {
489    /// Plugin-specific configuration. Stored as bytes and deserialized as JSON values by the SDK.
490    ///
491    /// There may be multiple instances of the same plugin with different values for this configuration
492    /// causing the plugin behavior to be different. For instance, a plugin might define a pattern-matching
493    /// algorithm in its code while reading the specific patterns to match from this configuration.
494    config: Arc<Vec<u8>>,
495    /// The set of permissions granted to a plugin.
496    permissions: bulwark_config::Permissions,
497    /// The IP address of the client that originated the request, if available.
498    client_ip: Option<bulwark_host::IpInterface>,
499    /// The Redis connection pool and its associated Lua scripts.
500    redis_info: Option<Arc<RedisInfo>>,
501    /// The HTTP client used to send outbound requests from plugins.
502    http_client: Arc<reqwest::blocking::Client>,
503}
504
505/// A collection of values that the guest environment will mutate over the lifecycle of a request/response.
506#[derive(Clone, Default)]
507struct GuestMutableContext {
508    /// Whether this plugin instance expects to process a request body.
509    receive_request_body: Arc<Mutex<bool>>,
510    /// Whether this plugin instance expects to process a response body.
511    receive_response_body: Arc<Mutex<bool>>,
512    /// The `params` are a key-value map shared between all plugin instances for a single request.
513    params: Arc<Mutex<bulwark_wasm_sdk::Map<String, bulwark_wasm_sdk::Value>>>,
514    /// The plugin's decision and tags annotating it.
515    decision_components: DecisionComponents,
516}
517
518/// A collection of values that the host environment will mutate over the lifecycle of a request/response.
519#[derive(Clone)]
520struct HostMutableContext {
521    /// The HTTP request received from the exterior client.
522    request: Arc<Mutex<bulwark_host::RequestInterface>>,
523    /// The HTTP response received from the interior service.
524    response: Arc<Mutex<Option<bulwark_host::ResponseInterface>>>,
525    /// The combined decision of all plugins at the end of the request phase.
526    ///
527    /// Accessible to plugins in the response and feedback phases.
528    combined_decision: Arc<Mutex<Option<bulwark_host::DecisionInterface>>>,
529    /// The combined union set of all tags attached by plugins across all phases.
530    combined_tags: Arc<Mutex<Option<Vec<String>>>>,
531    /// The decision outcome after the decision has been checked against configured thresholds.
532    outcome: Arc<Mutex<Option<bulwark_host::OutcomeInterface>>>,
533}
534
535impl HostMutableContext {
536    fn new(request: bulwark_host::RequestInterface) -> Self {
537        HostMutableContext {
538            request: Arc::new(Mutex::new(request)),
539            response: Arc::new(Mutex::new(None)),
540            combined_decision: Arc::new(Mutex::new(None)),
541            combined_tags: Arc::new(Mutex::new(None)),
542            outcome: Arc::new(Mutex::new(None)),
543        }
544    }
545}
546
547/// Allows the host to capture plugin standard IO and record it to the log.
548#[derive(Clone)]
549struct BufStdoutStream(MemoryOutputPipe);
550
551impl BufStdoutStream {
552    pub fn contents(&self) -> bytes::Bytes {
553        self.0.contents()
554    }
555
556    pub(crate) fn writer(&self) -> impl HostOutputStream {
557        self.0.clone()
558    }
559}
560
561impl Default for BufStdoutStream {
562    fn default() -> Self {
563        Self(MemoryOutputPipe::new(usize::MAX))
564    }
565}
566
567impl StdoutStream for BufStdoutStream {
568    fn stream(&self) -> Box<dyn HostOutputStream> {
569        Box::new(self.writer())
570    }
571
572    fn isatty(&self) -> bool {
573        false
574    }
575}
576
577/// Wraps buffers to capture plugin stdio.
578#[derive(Clone, Default)]
579pub struct PluginStdio {
580    stdout: BufStdoutStream,
581    stderr: BufStdoutStream,
582}
583
584impl PluginStdio {
585    pub fn stdout_buffer(&self) -> Vec<u8> {
586        self.stdout.contents().to_vec()
587    }
588
589    pub fn stderr_buffer(&self) -> Vec<u8> {
590        self.stderr.contents().to_vec()
591    }
592}
593
594/// An instance of a [`Plugin`], associated with a [`RequestContext`].
595pub struct PluginInstance {
596    /// A reference to the parent `Plugin` and its configuration.
597    plugin: Arc<Plugin>,
598    /// The WASM store that holds state associated with the incoming request.
599    store: Store<RequestContext>,
600    handlers: handlers::Handlers,
601    receive_request_body: Arc<Mutex<bool>>,
602    receive_response_body: Arc<Mutex<bool>>,
603    /// All plugin-visible state that the host environment will mutate over the lifecycle of a request/response.
604    host_mut_ctx: HostMutableContext,
605    /// The buffers for `stdin`, `stdout`, and `stderr` used by the plugin for I/O.
606    stdio: PluginStdio,
607}
608
609impl PluginInstance {
610    /// Instantiates a [`Plugin`], creating a new `PluginInstance`.
611    ///
612    /// # Arguments
613    ///
614    /// * `plugin` - The plugin we are creating a `PluginInstance` for.
615    /// * `request_context` - The request context stores all of the state associated with an incoming request and its corresponding response.
616    pub async fn new(
617        plugin: Arc<Plugin>,
618        request_context: RequestContext,
619    ) -> Result<PluginInstance, PluginInstantiationError> {
620        // Clone the request/response body receive flags so we can provide them to the service layer.
621        let receive_request_body = request_context.guest_mut_ctx.receive_request_body.clone();
622        let receive_response_body = request_context.guest_mut_ctx.receive_response_body.clone();
623
624        // Clone the host mutable context so that we can make changes to the interior of our request context from the parent.
625        let host_mut_ctx = request_context.host_mut_ctx.clone();
626
627        // Clone the stdio so we can read the captured stdout and stderr buffers after execution has completed.
628        let stdio = request_context.stdio.clone();
629
630        // TODO: do we need to retain a reference to the linker value anywhere? explore how other wasm-based systems use it.
631        // convert from normal request struct to wasm request interface
632        let mut linker: Linker<RequestContext> = Linker::new(&plugin.engine);
633
634        wasmtime_wasi::preview2::command::add_to_linker(&mut linker)?;
635
636        let mut store = Store::new(&plugin.engine, request_context);
637        bulwark_host::HostApi::add_to_linker(&mut linker, |ctx: &mut RequestContext| ctx)?;
638
639        // We discard the instance for this because we only use the generated interface to make calls
640
641        let (handlers, _) =
642            handlers::Handlers::instantiate_async(&mut store, &plugin.component, &linker).await?;
643
644        Ok(PluginInstance {
645            plugin,
646            store,
647            handlers,
648            receive_request_body,
649            receive_response_body,
650            host_mut_ctx,
651            stdio,
652        })
653    }
654
655    /// Returns `stdout` and `stderr` captured during plugin execution.
656    pub fn stdio(&self) -> PluginStdio {
657        self.stdio.clone()
658    }
659
660    /// Returns whether this plugin instance expects to process a request body.
661    pub fn receive_request_body(&self) -> bool {
662        let receive_request_body = self.receive_request_body.lock().expect("poisoned mutex");
663        *receive_request_body
664    }
665
666    /// Returns whether this plugin instance expects to process a response body.
667    pub fn receive_response_body(&self) -> bool {
668        let receive_response_body = self.receive_response_body.lock().expect("poisoned mutex");
669        *receive_response_body
670    }
671
672    /// Returns the configured weight value for tuning [`Decision`] values.
673    pub fn weight(&self) -> f64 {
674        self.plugin.config.weight
675    }
676
677    /// Records a [`Request`](bulwark_wasm_sdk::Request) so that it will be accessible to the plugin guest
678    /// environment. Overwrites the existing `Request`.
679    pub fn record_request(&mut self, request: Arc<bulwark_wasm_sdk::Request>) {
680        let mut interior_request = self.host_mut_ctx.request.lock().expect("poisoned mutex");
681        *interior_request = bulwark_host::RequestInterface::from(request);
682    }
683
684    /// Records a [`Response`](bulwark_wasm_sdk::Response) so that it will be accessible to the plugin guest
685    /// environment.
686    pub fn record_response(&mut self, response: Arc<bulwark_wasm_sdk::Response>) {
687        let mut interior_response = self.host_mut_ctx.response.lock().expect("poisoned mutex");
688        *interior_response = Some(bulwark_host::ResponseInterface::from(response));
689    }
690
691    /// Records the combined [`Decision`], it's tags, and the associated [`Outcome`] so that they will be accessible
692    /// to the plugin guest environment.
693    pub fn record_combined_decision(
694        &mut self,
695        decision_components: &DecisionComponents,
696        outcome: Outcome,
697    ) {
698        let mut interior_decision = self
699            .host_mut_ctx
700            .combined_decision
701            .lock()
702            .expect("poisoned mutex");
703        *interior_decision = Some(decision_components.decision.into());
704        let mut interior_outcome = self.host_mut_ctx.outcome.lock().expect("poisoned mutex");
705        *interior_outcome = Some(outcome.into());
706    }
707
708    /// Returns the plugin's identifier.
709    pub fn plugin_reference(&self) -> String {
710        self.plugin.reference.clone()
711    }
712
713    /// Executes the guest's `init` function.
714    pub async fn handle_init(&mut self) -> Result<(), PluginExecutionError> {
715        let result = self
716            .handlers
717            .call_on_init(self.store.as_context_mut())
718            .await?;
719        match result {
720            Ok(_) => metrics::increment_counter!(
721                "plugin_on_init",
722                "ref" => self.plugin_reference(), "result" => "ok"
723            ),
724            Err(_) => metrics::increment_counter!(
725                "plugin_on_init",
726                "ref" => self.plugin_reference(), "result" => "error"
727            ),
728        }
729
730        Ok(())
731    }
732
733    /// Executes the guest's `on_request` function.
734    pub async fn handle_request(&mut self) -> Result<(), PluginExecutionError> {
735        let result = self
736            .handlers
737            .call_on_request(self.store.as_context_mut())
738            .await?;
739        match result {
740            Ok(_) => metrics::increment_counter!(
741                "plugin_on_request",
742                "ref" => self.plugin_reference(), "result" => "ok"
743            ),
744            Err(_) => metrics::increment_counter!(
745                "plugin_on_request",
746                "ref" => self.plugin_reference(), "result" => "error"
747            ),
748        }
749
750        Ok(())
751    }
752
753    /// Executes the guest's `on_request_decision` function.
754    pub async fn handle_request_decision(&mut self) -> Result<(), PluginExecutionError> {
755        let result = self
756            .handlers
757            .call_on_request_decision(self.store.as_context_mut())
758            .await?;
759        match result {
760            Ok(_) => metrics::increment_counter!(
761                "plugin_on_request_decision",
762                "ref" => self.plugin_reference(), "result" => "ok"
763            ),
764            Err(_) => metrics::increment_counter!(
765                "plugin_on_request_decision",
766                "ref" => self.plugin_reference(), "result" => "error"
767            ),
768        }
769
770        Ok(())
771    }
772
773    /// Executes the guest's `on_response_decision` function.
774    pub async fn handle_response_decision(&mut self) -> Result<(), PluginExecutionError> {
775        let result = self
776            .handlers
777            .call_on_response_decision(self.store.as_context_mut())
778            .await?;
779        match result {
780            Ok(_) => metrics::increment_counter!(
781                "plugin_on_request_body_decision",
782                "ref" => self.plugin_reference(), "result" => "ok"
783            ),
784            Err(_) => metrics::increment_counter!(
785                "plugin_on_request_body_decision",
786                "ref" => self.plugin_reference(), "result" => "error"
787            ),
788        }
789
790        Ok(())
791    }
792
793    /// Executes the guest's `on_request_body_decision` function.
794    pub async fn handle_request_body_decision(&mut self) -> Result<(), PluginExecutionError> {
795        let result = self
796            .handlers
797            .call_on_request_body_decision(self.store.as_context_mut())
798            .await?;
799        match result {
800            Ok(_) => metrics::increment_counter!(
801                "plugin_on_response_decision",
802                "ref" => self.plugin_reference(), "result" => "ok"
803            ),
804            Err(_) => metrics::increment_counter!(
805                "plugin_on_response_decision",
806                "ref" => self.plugin_reference(), "result" => "error"
807            ),
808        }
809
810        Ok(())
811    }
812
813    /// Executes the guest's `on_response_body_decision` function.
814    pub async fn handle_response_body_decision(&mut self) -> Result<(), PluginExecutionError> {
815        let result = self
816            .handlers
817            .call_on_response_body_decision(self.store.as_context_mut())
818            .await?;
819        match result {
820            Ok(_) => metrics::increment_counter!(
821                "plugin_on_response_body_decision",
822                "ref" => self.plugin_reference(), "result" => "ok"
823            ),
824            Err(_) => metrics::increment_counter!(
825                "plugin_on_response_body_decision",
826                "ref" => self.plugin_reference(), "result" => "error"
827            ),
828        }
829
830        Ok(())
831    }
832
833    /// Executes the guest's `on_decision_feedback` function.
834    pub async fn handle_decision_feedback(&mut self) -> Result<(), PluginExecutionError> {
835        let result = self
836            .handlers
837            .call_on_decision_feedback(self.store.as_context_mut())
838            .await?;
839        match result {
840            Ok(_) => metrics::increment_counter!(
841                "plugin_on_decision_feedback",
842                "ref" => self.plugin_reference(), "result" => "ok"
843            ),
844            Err(_) => metrics::increment_counter!(
845                "plugin_on_decision_feedback",
846                "ref" => self.plugin_reference(), "result" => "error"
847            ),
848        }
849
850        Ok(())
851    }
852
853    /// Returns the decision components from the [`RequestContext`].
854    pub fn decision(&mut self) -> DecisionComponents {
855        let ctx = self.store.data();
856
857        ctx.guest_mut_ctx.decision_components.clone()
858    }
859}
860
861#[async_trait]
862impl bulwark_host::HostApiImports for RequestContext {
863    /// Returns the guest environment's configuration value as serialized JSON.
864    async fn get_config(&mut self) -> Result<Vec<u8>, wasmtime::Error> {
865        Ok(self.read_only_ctx.config.to_vec())
866    }
867
868    /// Returns a named value from the request context's params.
869    ///
870    /// # Arguments
871    ///
872    /// * `key` - The key name corresponding to the param value.
873    async fn get_param_value(
874        &mut self,
875        key: String,
876    ) -> Result<Result<Vec<u8>, bulwark_host::ParamError>, wasmtime::Error> {
877        let params = self.guest_mut_ctx.params.lock().expect("poisoned mutex");
878        let value = params.get(&key).unwrap_or(&bulwark_wasm_sdk::Value::Null);
879        match serde_json::to_vec(value) {
880            Ok(bytes) => Ok(Ok(bytes)),
881            Err(err) => Ok(Err(bulwark_host::ParamError::Json(err.to_string()))),
882        }
883    }
884
885    /// Set a named value in the request context's params.
886    ///
887    /// # Arguments
888    ///
889    /// * `key` - The key name corresponding to the param value.
890    /// * `value` - The value to record. Values are serialized JSON.
891    async fn set_param_value(
892        &mut self,
893        key: String,
894        value: Vec<u8>,
895    ) -> Result<Result<(), bulwark_host::ParamError>, wasmtime::Error> {
896        let mut params = self.guest_mut_ctx.params.lock().expect("poisoned mutex");
897        match serde_json::from_slice(&value) {
898            Ok(value) => {
899                params.insert(key, value);
900                Ok(Ok(()))
901            }
902            Err(err) => Ok(Err(bulwark_host::ParamError::Json(err.to_string()))),
903        }
904    }
905
906    /// Returns a named environment variable value as bytes.
907    ///
908    /// # Arguments
909    ///
910    /// * `key` - The environment variable name. Case-sensitive.
911    async fn get_env_bytes(
912        &mut self,
913        key: String,
914    ) -> Result<Result<Vec<u8>, bulwark_host::EnvError>, wasmtime::Error> {
915        let allowed_env_vars = self
916            .read_only_ctx
917            .permissions
918            .env
919            .iter()
920            .cloned()
921            .collect::<BTreeSet<String>>();
922        if !allowed_env_vars.contains(&key) {
923            return Ok(Err(bulwark_host::EnvError::Permission(key)));
924        }
925        match std::env::var(&key) {
926            Ok(var) => Ok(Ok(var.as_bytes().to_vec())),
927            Err(err) => match err {
928                std::env::VarError::NotPresent => Ok(Err(bulwark_host::EnvError::Missing(key))),
929                std::env::VarError::NotUnicode(_) => {
930                    Ok(Err(bulwark_host::EnvError::NotUnicode(key)))
931                }
932            },
933        }
934    }
935
936    /// Returns the incoming request associated with the request context.
937    async fn get_request(&mut self) -> Result<bulwark_host::RequestInterface, wasmtime::Error> {
938        let request = self.host_mut_ctx.request.lock().expect("poisoned mutex");
939        Ok(request.clone())
940    }
941
942    /// Returns the response received from the interior service.
943    ///
944    /// If called from `on_request` or `on_request_decision`, it will return `None` since a response
945    /// is not yet available.
946    async fn get_response(
947        &mut self,
948    ) -> Result<Option<bulwark_host::ResponseInterface>, wasmtime::Error> {
949        let response: MutexGuard<Option<bulwark_host::ResponseInterface>> =
950            self.host_mut_ctx.response.lock().expect("poisoned mutex");
951        Ok(response.to_owned())
952    }
953
954    /// Determines whether the request body will be received by the plugin in the `on_request_body_decision` handler.
955    async fn receive_request_body(&mut self, body: bool) -> Result<(), wasmtime::Error> {
956        let mut receive_request_body = self
957            .guest_mut_ctx
958            .receive_request_body
959            .lock()
960            .expect("poisoned mutex");
961        *receive_request_body = body;
962        Ok(())
963    }
964
965    /// Determines whether the response body will be received by the plugin in the `on_response_body_decision` handler.
966    async fn receive_response_body(&mut self, body: bool) -> Result<(), wasmtime::Error> {
967        let mut receive_response_body = self
968            .guest_mut_ctx
969            .receive_response_body
970            .lock()
971            .expect("poisoned mutex");
972        *receive_response_body = body;
973        Ok(())
974    }
975
976    /// Returns the originating client's IP address, if available.
977    async fn get_client_ip(
978        &mut self,
979    ) -> Result<Option<bulwark_host::IpInterface>, wasmtime::Error> {
980        Ok(self.read_only_ctx.client_ip)
981    }
982
983    /// Begins an outbound request. Returns a request ID used by `add_request_header` and `set_request_body`.
984    ///
985    /// # Arguments
986    ///
987    /// * `method` - The HTTP method
988    /// * `uri` - The absolute URI of the resource to request
989    async fn send_request(
990        &mut self,
991        request: bulwark_host::RequestInterface,
992    ) -> Result<Result<bulwark_host::ResponseInterface, bulwark_host::HttpError>, wasmtime::Error>
993    {
994        Ok(
995            // Inner function to permit ? operator
996            || -> Result<bulwark_host::ResponseInterface, bulwark_host::HttpError> {
997                verify_http_domains(&self.read_only_ctx.permissions.http, &request.uri)?;
998
999                let method = reqwest::Method::from_str(&request.method)
1000                    .map_err(|_| bulwark_host::HttpError::InvalidMethod(request.method.clone()))?;
1001
1002                let mut builder = self.read_only_ctx.http_client.request(method, &request.uri);
1003                for (name, value) in request.headers {
1004                    builder = builder.header(name, value);
1005                }
1006
1007                if !request.end_of_stream {
1008                    return Err(bulwark_host::HttpError::UnavailableContent(
1009                        "the entire request body must be available".to_string(),
1010                    ));
1011                } else if request.chunk_start != 0 {
1012                    return Err(bulwark_host::HttpError::InvalidStart(
1013                        "chunk start must be 0".to_string(),
1014                    ));
1015                } else if request.chunk_length > 16384 {
1016                    return Err(bulwark_host::HttpError::ContentTooLarge(
1017                        "the entire request body must be 16384 bytes or less".to_string(),
1018                    ));
1019                }
1020
1021                builder = builder.body(request.chunk);
1022
1023                let response = builder
1024                    .send()
1025                    .map_err(|err| bulwark_host::HttpError::Transmit(err.to_string()))?;
1026                let status: u32 = response.status().as_u16() as u32;
1027                // need to read headers before body because retrieving body bytes will move the response
1028                let headers: Vec<(String, Vec<u8>)> = response
1029                    .headers()
1030                    .iter()
1031                    .map(|(name, value)| (name.to_string(), value.as_bytes().to_vec()))
1032                    .collect();
1033                let body = response.bytes().unwrap().to_vec();
1034                let content_length: u64 = body.len() as u64;
1035                Ok(bulwark_host::ResponseInterface {
1036                    status,
1037                    headers,
1038                    body_received: true,
1039                    chunk: body,
1040                    chunk_start: 0,
1041                    chunk_length: content_length,
1042                    end_of_stream: true,
1043                })
1044            }(),
1045        )
1046    }
1047
1048    /// Records the decision value the plugin wants to return.
1049    ///
1050    /// # Arguments
1051    ///
1052    /// * `decision` - The [`Decision`] output of the plugin.
1053    async fn set_decision(
1054        &mut self,
1055        decision: bulwark_host::DecisionInterface,
1056    ) -> Result<Result<(), bulwark_host::DecisionError>, wasmtime::Error> {
1057        let decision = Decision::from(decision);
1058        // Validate on both the guest and the host because we can't guarantee usage of the SDK.
1059        match decision.validate() {
1060            Ok(_) => {
1061                self.guest_mut_ctx.decision_components.decision = decision;
1062                Ok(Ok(()))
1063            }
1064            Err(err) => Ok(Err(bulwark_host::DecisionError::Invalid(err.to_string()))),
1065        }
1066    }
1067
1068    /// Records the tags the plugin wants to associate with its decision.
1069    ///
1070    /// # Arguments
1071    ///
1072    /// * `tags` - The list of tags to associate with a [`Decision`].
1073    async fn set_tags(&mut self, tags: Vec<String>) -> Result<(), wasmtime::Error> {
1074        self.guest_mut_ctx.decision_components.tags = tags;
1075        Ok(())
1076    }
1077
1078    /// Records additional tags the plugin wants to associate with its decision. Existing tags will be kept.
1079    ///
1080    /// # Arguments
1081    ///
1082    /// * `tags` - The list of tags to associate with a [`Decision`].
1083    async fn append_tags(&mut self, mut tags: Vec<String>) -> Result<Vec<String>, wasmtime::Error> {
1084        self.guest_mut_ctx
1085            .decision_components
1086            .tags
1087            .append(&mut tags);
1088        Ok(self.guest_mut_ctx.decision_components.tags.clone())
1089    }
1090
1091    /// Returns the combined decision, if available.
1092    ///
1093    /// Typically used in the feedback phase.
1094    async fn get_combined_decision(
1095        &mut self,
1096    ) -> Result<Option<bulwark_host::DecisionInterface>, wasmtime::Error> {
1097        let combined_decision: MutexGuard<Option<bulwark_host::DecisionInterface>> = self
1098            .host_mut_ctx
1099            .combined_decision
1100            .lock()
1101            .expect("poisoned mutex");
1102        Ok(combined_decision.to_owned())
1103    }
1104
1105    /// Returns the combined set of tags associated with a decision, if available.
1106    ///
1107    /// Typically used in the feedback phase.
1108    async fn get_combined_tags(&mut self) -> Result<Option<Vec<String>>, wasmtime::Error> {
1109        let combined_tags: MutexGuard<Option<Vec<String>>> = self
1110            .host_mut_ctx
1111            .combined_tags
1112            .lock()
1113            .expect("poisoned mutex");
1114        Ok(combined_tags.to_owned())
1115    }
1116
1117    /// Returns the outcome of the combined decision, if available.
1118    ///
1119    /// Typically used in the feedback phase.
1120    async fn get_outcome(
1121        &mut self,
1122    ) -> Result<Option<bulwark_host::OutcomeInterface>, wasmtime::Error> {
1123        let outcome: MutexGuard<Option<bulwark_host::OutcomeInterface>> =
1124            self.host_mut_ctx.outcome.lock().expect("poisoned mutex");
1125        Ok(outcome.to_owned())
1126    }
1127
1128    /// Returns the named state value retrieved from Redis.
1129    ///
1130    /// Also used to retrieve a counter value.
1131    ///
1132    /// # Arguments
1133    ///
1134    /// * `key` - The key name corresponding to the state value.
1135    async fn get_remote_state(
1136        &mut self,
1137        key: String,
1138    ) -> Result<Result<Vec<u8>, bulwark_host::StateError>, wasmtime::Error> {
1139        Ok(
1140            // Inner function to permit ? operator
1141            || -> Result<Vec<u8>, bulwark_host::StateError> {
1142                verify_remote_state_prefixes(&self.read_only_ctx.permissions.state, &key)?;
1143
1144                if let Some(redis_info) = self.read_only_ctx.redis_info.clone() {
1145                    let mut conn = redis_info
1146                        .pool
1147                        .get()
1148                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1149
1150                    Ok(conn
1151                        .get(key)
1152                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?)
1153                } else {
1154                    Err(bulwark_host::StateError::Remote(
1155                        "no remote state configured".to_string(),
1156                    ))
1157                }
1158            }(),
1159        )
1160    }
1161
1162    /// Set a named value in Redis.
1163    ///
1164    /// # Arguments
1165    ///
1166    /// * `key` - The key name corresponding to the state value.
1167    /// * `value` - The value to record. Values are byte strings, but may be interpreted differently by Redis depending on context.
1168    async fn set_remote_state(
1169        &mut self,
1170        key: String,
1171        value: Vec<u8>,
1172    ) -> Result<Result<(), bulwark_host::StateError>, wasmtime::Error> {
1173        Ok(
1174            // Inner function to permit ? operator
1175            || -> Result<(), bulwark_host::StateError> {
1176                verify_remote_state_prefixes(&self.read_only_ctx.permissions.state, &key)?;
1177
1178                if let Some(redis_info) = self.read_only_ctx.redis_info.clone() {
1179                    let mut conn = redis_info
1180                        .pool
1181                        .get()
1182                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1183
1184                    conn.set::<String, Vec<u8>, redis::Value>(key, value)
1185                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1186                    Ok(())
1187                } else {
1188                    Err(bulwark_host::StateError::Remote(
1189                        "no remote state configured".to_string(),
1190                    ))
1191                }
1192            }(),
1193        )
1194    }
1195
1196    /// Increments a named counter in Redis.
1197    ///
1198    /// # Arguments
1199    ///
1200    /// * `key` - The key name corresponding to the state counter.
1201    async fn increment_remote_state(
1202        &mut self,
1203        key: String,
1204    ) -> Result<Result<i64, bulwark_host::StateError>, wasmtime::Error> {
1205        self.increment_remote_state_by(key, 1).await
1206    }
1207
1208    /// Increments a named counter in Redis by a specified delta value.
1209    ///
1210    /// # Arguments
1211    ///
1212    /// * `key` - The key name corresponding to the state counter.
1213    /// * `delta` - The amount to increase the counter by.
1214    async fn increment_remote_state_by(
1215        &mut self,
1216        key: String,
1217        delta: i64,
1218    ) -> Result<Result<i64, bulwark_host::StateError>, wasmtime::Error> {
1219        Ok(
1220            // Inner function to permit ? operator
1221            || -> Result<i64, bulwark_host::StateError> {
1222                verify_remote_state_prefixes(&self.read_only_ctx.permissions.state, &key)?;
1223
1224                if let Some(redis_info) = self.read_only_ctx.redis_info.clone() {
1225                    let mut conn = redis_info
1226                        .pool
1227                        .get()
1228                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1229
1230                    Ok(conn
1231                        .incr(key, delta)
1232                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?)
1233                } else {
1234                    Err(bulwark_host::StateError::Remote(
1235                        "no remote state configured".to_string(),
1236                    ))
1237                }
1238            }(),
1239        )
1240    }
1241
1242    /// Sets an expiration on a named value in Redis.
1243    ///
1244    /// # Arguments
1245    ///
1246    /// * `key` - The key name corresponding to the state value.
1247    /// * `ttl` - The time-to-live for the value in seconds.
1248    async fn set_remote_ttl(
1249        &mut self,
1250        key: String,
1251        ttl: i64,
1252    ) -> Result<Result<(), bulwark_host::StateError>, wasmtime::Error> {
1253        Ok(
1254            // Inner function to permit ? operator
1255            || -> Result<(), bulwark_host::StateError> {
1256                verify_remote_state_prefixes(&self.read_only_ctx.permissions.state, &key)?;
1257
1258                if let Some(redis_info) = self.read_only_ctx.redis_info.clone() {
1259                    let mut conn = redis_info
1260                        .pool
1261                        .get()
1262                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1263
1264                    conn.expire::<String, redis::Value>(key, ttl as usize)
1265                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1266                    Ok(())
1267                } else {
1268                    Err(bulwark_host::StateError::Remote(
1269                        "no remote state configured".to_string(),
1270                    ))
1271                }
1272            }(),
1273        )
1274    }
1275
1276    /// Increments a rate limit, returning the number of attempts so far and the expiration time.
1277    ///
1278    /// The rate limiter is a counter over a period of time. At the end of the period, it will expire,
1279    /// beginning a new period. Window periods should be set to the longest amount of time that a client should
1280    /// be locked out for. The plugin is responsible for performing all rate-limiting logic with the counter
1281    /// value it receives.
1282    ///
1283    /// # Arguments
1284    ///
1285    /// * `key` - The key name corresponding to the state counter.
1286    /// * `delta` - The amount to increase the counter by.
1287    /// * `window` - How long each period should be in seconds.
1288    async fn increment_rate_limit(
1289        &mut self,
1290        key: String,
1291        delta: i64,
1292        window: i64,
1293    ) -> Result<Result<bulwark_host::RateInterface, bulwark_host::StateError>, wasmtime::Error>
1294    {
1295        Ok(
1296            // Inner function to permit ? operator
1297            || -> Result<bulwark_host::RateInterface, bulwark_host::StateError> {
1298                verify_remote_state_prefixes(&self.read_only_ctx.permissions.state, &key)?;
1299
1300                if let Some(redis_info) = self.read_only_ctx.redis_info.clone() {
1301                    let mut conn = redis_info
1302                        .pool
1303                        .get()
1304                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1305                    let dt = Utc::now();
1306                    let timestamp: i64 = dt.timestamp();
1307                    let script = redis_info.registry.increment_rate_limit.clone();
1308                    // Invoke the script and map to our rate type
1309                    let (attempts, expiration) = script
1310                        .key(key)
1311                        .arg(delta)
1312                        .arg(window)
1313                        .arg(timestamp)
1314                        .invoke::<(i64, i64)>(conn.deref_mut())
1315                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1316                    Ok(bulwark_host::RateInterface {
1317                        attempts,
1318                        expiration,
1319                    })
1320                } else {
1321                    Err(bulwark_host::StateError::Remote(
1322                        "no remote state configured".to_string(),
1323                    ))
1324                }
1325            }(),
1326        )
1327    }
1328
1329    /// Checks a rate limit, returning the number of attempts so far and the expiration time.
1330    ///
1331    /// See `increment_rate_limit`.
1332    ///
1333    /// # Arguments
1334    ///
1335    /// * `key` - The key name corresponding to the state counter.
1336    async fn check_rate_limit(
1337        &mut self,
1338        key: String,
1339    ) -> Result<Result<bulwark_host::RateInterface, bulwark_host::StateError>, wasmtime::Error>
1340    {
1341        Ok(
1342            // Inner function to permit ? operator
1343            || -> Result<bulwark_host::RateInterface, bulwark_host::StateError> {
1344                verify_remote_state_prefixes(&self.read_only_ctx.permissions.state, &key)?;
1345
1346                if let Some(redis_info) = self.read_only_ctx.redis_info.clone() {
1347                    let mut conn = redis_info
1348                        .pool
1349                        .get()
1350                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1351                    let dt = Utc::now();
1352                    let timestamp: i64 = dt.timestamp();
1353                    let script = redis_info.registry.check_rate_limit.clone();
1354                    // Invoke the script and map to our rate type
1355                    let (attempts, expiration) = script
1356                        .key(key)
1357                        .arg(timestamp)
1358                        .invoke::<(i64, i64)>(conn.deref_mut())
1359                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1360                    Ok(bulwark_host::RateInterface {
1361                        attempts,
1362                        expiration,
1363                    })
1364                } else {
1365                    Err(bulwark_host::StateError::Remote(
1366                        "no remote state configured".to_string(),
1367                    ))
1368                }
1369            }(),
1370        )
1371    }
1372
1373    /// Increments a circuit breaker, returning the generation count, success count, failure count,
1374    /// consecutive success count, consecutive failure count, and expiration time.
1375    ///
1376    /// The plugin is responsible for performing all circuit-breaking logic with the counter
1377    /// values it receives. The host environment does as little as possible to maximize how much
1378    /// control the plugin has over the behavior of the breaker.
1379    ///
1380    /// # Arguments
1381    ///
1382    /// * `key` - The key name corresponding to the state counter.
1383    /// * `success_delta` - The amount to increase the success counter by. Generally zero on failure.
1384    /// * `failure_delta` - The amount to increase the failure counter by. Generally zero on success.
1385    /// * `window` - How long each period should be in seconds.
1386    async fn increment_breaker(
1387        &mut self,
1388        key: String,
1389        success_delta: i64,
1390        failure_delta: i64,
1391        window: i64,
1392    ) -> Result<Result<bulwark_host::BreakerInterface, bulwark_host::StateError>, wasmtime::Error>
1393    {
1394        Ok(
1395            // Inner function to permit ? operator
1396            || -> Result<bulwark_host::BreakerInterface, bulwark_host::StateError> {
1397                verify_remote_state_prefixes(&self.read_only_ctx.permissions.state, &key)?;
1398
1399                if let Some(redis_info) = self.read_only_ctx.redis_info.clone() {
1400                    let mut conn = redis_info
1401                        .pool
1402                        .get()
1403                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1404                    let dt = Utc::now();
1405                    let timestamp: i64 = dt.timestamp();
1406                    let script = redis_info.registry.increment_breaker.clone();
1407                    // Invoke the script and map to our breaker type
1408                    let (
1409                        generation,
1410                        successes,
1411                        failures,
1412                        consecutive_successes,
1413                        consecutive_failures,
1414                        expiration,
1415                    ) = script
1416                        .key(key)
1417                        .arg(success_delta)
1418                        .arg(failure_delta)
1419                        .arg(window)
1420                        .arg(timestamp)
1421                        .invoke::<(i64, i64, i64, i64, i64, i64)>(conn.deref_mut())
1422                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1423                    Ok(bulwark_host::BreakerInterface {
1424                        generation,
1425                        successes,
1426                        failures,
1427                        consecutive_successes,
1428                        consecutive_failures,
1429                        expiration,
1430                    })
1431                } else {
1432                    Err(bulwark_host::StateError::Remote(
1433                        "no remote state configured".to_string(),
1434                    ))
1435                }
1436            }(),
1437        )
1438    }
1439
1440    /// Checks a circuit breaker, returning the generation count, success count, failure count,
1441    /// consecutive success count, consecutive failure count, and expiration time.
1442    ///
1443    /// See `increment_breaker`.
1444    ///
1445    /// # Arguments
1446    ///
1447    /// * `key` - The key name corresponding to the state counter.
1448    async fn check_breaker(
1449        &mut self,
1450        key: String,
1451    ) -> Result<Result<bulwark_host::BreakerInterface, bulwark_host::StateError>, wasmtime::Error>
1452    {
1453        Ok(
1454            // Inner function to permit ? operator
1455            || -> Result<bulwark_host::BreakerInterface, bulwark_host::StateError> {
1456                verify_remote_state_prefixes(&self.read_only_ctx.permissions.state, &key)?;
1457
1458                if let Some(redis_info) = self.read_only_ctx.redis_info.clone() {
1459                    let mut conn = redis_info
1460                        .pool
1461                        .get()
1462                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1463                    let dt = Utc::now();
1464                    let timestamp: i64 = dt.timestamp();
1465                    let script = redis_info.registry.check_breaker.clone();
1466                    // Invoke the script and map to our breaker type
1467                    let (
1468                        generation,
1469                        successes,
1470                        failures,
1471                        consecutive_successes,
1472                        consecutive_failures,
1473                        expiration,
1474                    ) = script
1475                        .key(key)
1476                        .arg(timestamp)
1477                        .invoke::<(i64, i64, i64, i64, i64, i64)>(conn.deref_mut())
1478                        .map_err(|err| bulwark_host::StateError::Remote(err.to_string()))?;
1479                    Ok(bulwark_host::BreakerInterface {
1480                        generation,
1481                        successes,
1482                        failures,
1483                        consecutive_successes,
1484                        consecutive_failures,
1485                        expiration,
1486                    })
1487                } else {
1488                    Err(bulwark_host::StateError::Remote(
1489                        "no remote state configured".to_string(),
1490                    ))
1491                }
1492            }(),
1493        )
1494    }
1495}
1496
1497/// Ensures that access to any HTTP host has the appropriate permissions set.
1498fn verify_http_domains(
1499    // TODO: BTreeSet<String> instead, all the way up
1500    allowed_http_domains: &[String],
1501    uri: &str,
1502) -> Result<(), bulwark_host::HttpError> {
1503    let parsed_uri =
1504        Url::parse(uri).map_err(|_| bulwark_host::HttpError::InvalidUri(uri.to_string()))?;
1505    let requested_domain = parsed_uri
1506        .domain()
1507        .ok_or_else(|| bulwark_host::HttpError::InvalidUri(uri.to_string()))?;
1508    if !allowed_http_domains.contains(&requested_domain.to_string()) {
1509        return Err(bulwark_host::HttpError::Permission(uri.to_string()));
1510    }
1511    Ok(())
1512}
1513
1514/// Ensures that access to any remote state key has the appropriate permissions set.
1515fn verify_remote_state_prefixes(
1516    // TODO: BTreeSet<String> instead, all the way up
1517    allowed_key_prefixes: &[String],
1518    key: &str,
1519) -> Result<(), bulwark_host::StateError> {
1520    let key = key.to_string();
1521    if !allowed_key_prefixes
1522        .iter()
1523        .any(|prefix| key.starts_with(prefix))
1524    {
1525        return Err(bulwark_host::StateError::Permission(key));
1526    }
1527    Ok(())
1528}
1529
1530#[cfg(test)]
1531mod tests {
1532    use super::*;
1533
1534    fn adapt_wasm_output(
1535        wasm_bytes: Vec<u8>,
1536        adapter_bytes: Vec<u8>,
1537    ) -> Result<Vec<u8>, Box<dyn std::error::Error>> {
1538        let component = wit_component::ComponentEncoder::default()
1539            .module(&wasm_bytes)?
1540            .validate(true)
1541            .adapter("wasi_snapshot_preview1", &adapter_bytes)?
1542            .encode()?;
1543
1544        Ok(component.to_vec())
1545    }
1546
1547    #[test]
1548    fn test_wasm_execution() -> Result<(), Box<dyn std::error::Error>> {
1549        let wasm_bytes = include_bytes!("../tests/bulwark_blank_slate.wasm");
1550        let adapter_bytes = include_bytes!("../tests/wasi_snapshot_preview1.reactor.wasm");
1551        let adapted_component = adapt_wasm_output(wasm_bytes.to_vec(), adapter_bytes.to_vec())?;
1552        let plugin = Arc::new(Plugin::from_bytes(
1553            "bulwark-blank-slate.wasm".to_string(),
1554            &adapted_component,
1555            &bulwark_config::Plugin::default(),
1556        )?);
1557        let request = Arc::new(
1558            http::Request::builder()
1559                .method("GET")
1560                .uri("/")
1561                .version(http::Version::HTTP_11)
1562                .body(bulwark_wasm_sdk::NO_BODY)?,
1563        );
1564        let params = Arc::new(Mutex::new(bulwark_wasm_sdk::Map::new()));
1565        let request_context = RequestContext::new(
1566            plugin.clone(),
1567            None,
1568            Arc::new(reqwest::blocking::Client::new()),
1569            params,
1570            request,
1571        )?;
1572        let mut plugin_instance =
1573            tokio_test::block_on(PluginInstance::new(plugin, request_context))?;
1574        let decision_components = plugin_instance.decision();
1575        assert_eq!(decision_components.decision.accept, 0.0);
1576        assert_eq!(decision_components.decision.restrict, 0.0);
1577        assert_eq!(decision_components.decision.unknown, 1.0);
1578        assert_eq!(decision_components.tags, vec![""; 0]);
1579
1580        Ok(())
1581    }
1582
1583    #[test]
1584    fn test_wasm_logic() -> Result<(), Box<dyn std::error::Error>> {
1585        let wasm_bytes = include_bytes!("../tests/bulwark_evil_bit.wasm");
1586        let adapter_bytes = include_bytes!("../tests/wasi_snapshot_preview1.reactor.wasm");
1587        let adapted_component = adapt_wasm_output(wasm_bytes.to_vec(), adapter_bytes.to_vec())?;
1588        let plugin = Arc::new(Plugin::from_bytes(
1589            "bulwark-evil-bit.wasm".to_string(),
1590            &adapted_component,
1591            &bulwark_config::Plugin::default(),
1592        )?);
1593
1594        let request = Arc::new(
1595            http::Request::builder()
1596                .method("POST")
1597                .uri("/example")
1598                .version(http::Version::HTTP_11)
1599                .header("Content-Type", "application/json")
1600                .body(bulwark_wasm_sdk::UNAVAILABLE_BODY)?,
1601        );
1602        let params = Arc::new(Mutex::new(bulwark_wasm_sdk::Map::new()));
1603        let request_context = RequestContext::new(
1604            plugin.clone(),
1605            None,
1606            Arc::new(reqwest::blocking::Client::new()),
1607            params,
1608            request,
1609        )?;
1610        let mut typical_plugin_instance =
1611            tokio_test::block_on(PluginInstance::new(plugin.clone(), request_context))?;
1612        tokio_test::block_on(typical_plugin_instance.handle_request_decision())?;
1613        let typical_decision = typical_plugin_instance.decision();
1614        assert_eq!(typical_decision.decision.accept, 0.0);
1615        assert_eq!(typical_decision.decision.restrict, 0.0);
1616        assert_eq!(typical_decision.decision.unknown, 1.0);
1617        assert_eq!(typical_decision.tags, vec![""; 0]);
1618
1619        let request = Arc::new(
1620            http::Request::builder()
1621                .method("POST")
1622                .uri("/example")
1623                .version(http::Version::HTTP_11)
1624                .header("Content-Type", "application/json")
1625                .header("Evil", "true")
1626                .body(bulwark_wasm_sdk::UNAVAILABLE_BODY)?,
1627        );
1628        let params = Arc::new(Mutex::new(bulwark_wasm_sdk::Map::new()));
1629        let request_context = RequestContext::new(
1630            plugin.clone(),
1631            None,
1632            Arc::new(reqwest::blocking::Client::new()),
1633            params,
1634            request,
1635        )?;
1636        let mut evil_plugin_instance =
1637            tokio_test::block_on(PluginInstance::new(plugin, request_context))?;
1638        tokio_test::block_on(evil_plugin_instance.handle_request_decision())?;
1639        let evil_decision = evil_plugin_instance.decision();
1640        assert_eq!(evil_decision.decision.accept, 0.0);
1641        assert_eq!(evil_decision.decision.restrict, 1.0);
1642        assert_eq!(evil_decision.decision.unknown, 0.0);
1643        assert_eq!(evil_decision.tags, vec!["evil"]);
1644
1645        Ok(())
1646    }
1647}