Skip to main content

obeli_sk_wasm_workers/
http_hooks.rs

1use crate::component_logger::ComponentLogger;
2use crate::http_request_policy::{HttpRequestPolicy, PolicyError};
3use concepts::storage::LogLevel;
4use concepts::storage::http_client_trace::{RequestTrace, ResponseTrace};
5use concepts::time::ClockFn;
6use tokio::sync::oneshot;
7use tracing::Instrument;
8use wasmtime_wasi_http::p2::body::HyperOutgoingBody;
9use wasmtime_wasi_http::p2::types::{HostFutureIncomingResponse, OutgoingRequestConfig};
10use wasmtime_wasi_http::p2::{HttpResult, WasiHttpHooks, default_send_request_handler};
11
12pub type HttpClientTracesContainer = Vec<(RequestTrace, oneshot::Receiver<ResponseTrace>)>;
13
14/// The TOML config section type for error messages.
15#[derive(Clone, Copy, Debug, derive_more::Display)]
16pub enum ConfigSectionHint {
17    #[display("activity_js")]
18    ActivityJs,
19    #[display("activity_wasm")]
20    ActivityWasm,
21    #[display("webhook_endpoint_js")]
22    WebhookEndpointJs,
23    #[display("webhook_endpoint_wasm")]
24    WebhookEndpointWasm,
25}
26
27pub(crate) struct HttpHooks {
28    pub(crate) clock_fn: Box<dyn ClockFn>,
29    pub(crate) http_client_traces: HttpClientTracesContainer,
30    pub(crate) http_policy: HttpRequestPolicy,
31    pub(crate) component_logger: ComponentLogger,
32    /// The TOML config section type for error messages
33    pub(crate) config_section_hint: ConfigSectionHint,
34}
35
36/// Generate a simplified host pattern for the TOML snippet.
37/// - <https://foo:443> -> foo (HTTPS is default, 443 is default for HTTPS)
38/// - <https://foo:8080> -> foo:8080 (non-default port)
39/// - <http://bar:80> -> <http://bar> (HTTP is not default, but 80 is default for HTTP)
40/// - <http://bar:8080> -> <http://bar:8080> (non-default port)
41fn format_host_pattern(scheme: &str, host: &str, port: u16) -> String {
42    match scheme {
43        "https" if port == 443 => host.to_string(),
44        "https" => format!("{host}:{port}"),
45        "http" if port == 80 => format!("http://{host}"),
46        "http" => format!("http://{host}:{port}"),
47        _ => format!("{scheme}://{host}:{port}"),
48    }
49}
50
51/// Generate a TOML config snippet to help users fix denied HTTP requests.
52fn generate_toml_snippet(
53    err: &PolicyError,
54    config_section_hint: ConfigSectionHint,
55) -> Option<String> {
56    if let PolicyError::RequestDenied {
57        method,
58        scheme,
59        host,
60        port,
61    } = err
62    {
63        let pattern = format_host_pattern(scheme, host, *port);
64        Some(format!(
65            "{err}\n\
66             To allow this request, add the following to your configuration:\n\n\
67             [[{section}.allowed_host]]\n\
68             pattern = \"{pattern}\"\n\
69             methods = [\"{method}\"]",
70            section = config_section_hint,
71            method = method.as_str()
72        ))
73    } else {
74        None
75    }
76}
77
78impl WasiHttpHooks for HttpHooks {
79    fn send_request(
80        &mut self,
81        mut request: hyper::Request<HyperOutgoingBody>,
82        config: OutgoingRequestConfig,
83    ) -> HttpResult<HostFutureIncomingResponse> {
84        // Prepare request trace & channel
85        let req = RequestTrace {
86            sent_at: self.clock_fn.now(),
87            uri: request.uri().to_string(),
88            method: request.method().to_string(),
89        };
90        let (resp_trace_tx, resp_trace_rx) = oneshot::channel();
91        self.http_client_traces.push((req, resp_trace_rx));
92
93        // Apply HTTP policy (allowlist + placeholder replacement in headers and query params)
94        let http_policy_res = self.http_policy.apply(&mut request);
95        if let Err(err) = http_policy_res {
96            // Generate a helpful TOML snippet for the user
97            let log_msg = generate_toml_snippet(&err, self.config_section_hint)
98                .unwrap_or_else(|| err.to_string());
99            self.component_logger.log(LogLevel::Warn, log_msg); // Append to execution's logs table
100            let _ = resp_trace_tx.send(ResponseTrace {
101                finished_at: self.clock_fn.now(),
102                status: Err(err.to_string()),
103            });
104            let err = wasmtime_wasi_http::p2::bindings::http::types::ErrorCode::from(err);
105            return Err(err.into());
106        }
107
108        let span = tracing::info_span!(parent: &self.component_logger.span, "send_request",
109            otel.name = format!("send_request {} {}", request.method(), request.uri()),
110            method = %request.method(),
111            uri = %request.uri(),
112        );
113        let clock_fn = self.clock_fn.clone_box();
114        let http_policy = self.http_policy.clone();
115        span.in_scope(|| tracing::debug!("Sending {request:?}"));
116        let handle = wasmtime_wasi::runtime::spawn(
117            async move {
118                http_policy.apply_body_replacement(&mut request).await;
119                let resp_result: Result<
120                    wasmtime_wasi_http::p2::types::IncomingResponse,
121                    wasmtime_wasi_http::p2::bindings::http::types::ErrorCode,
122                > = default_send_request_handler(request, config).await;
123                tracing::debug!("Got response {resp_result:?}");
124                let _ = resp_trace_tx.send(ResponseTrace {
125                    finished_at: clock_fn.now(),
126                    status: resp_result
127                        .as_ref()
128                        .map(|resp| resp.resp.status().as_u16())
129                        .map_err(std::string::ToString::to_string),
130                });
131                Ok(resp_result)
132            }
133            .instrument(span),
134        );
135        Ok(HostFutureIncomingResponse::pending(handle))
136    }
137}