Skip to main content

omnia_wasi_http/host/
default_impl.rs

1use std::fmt::Display;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use base64ct::{Base64, Encoding};
6use bytes::Bytes;
7use fromenv::FromEnv;
8use futures::Future;
9use http::header::{
10    CONNECTION, HOST, HeaderName, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TRANSFER_ENCODING,
11    UPGRADE,
12};
13use http::{Request, Response};
14use http_body_util::BodyExt;
15use http_body_util::combinators::UnsyncBoxBody;
16use omnia::Backend;
17use tracing::instrument;
18use wasmtime::component::ResourceTable;
19use wasmtime_wasi::TrappableError;
20use wasmtime_wasi_http::WasiHttpCtx;
21use wasmtime_wasi_http::p3::bindings::http::types::ErrorCode;
22use wasmtime_wasi_http::p3::{self, RequestOptions, WasiHttpCtxView};
23
24pub type HttpResult<T> = Result<T, HttpError>;
25pub type HttpError = TrappableError<ErrorCode>;
26pub type FutureResult<T> = Box<dyn Future<Output = Result<T, ErrorCode>> + Send>;
27
28/// Set of headers that are forbidden by `wasmtime-wasi-http`.
29pub const FORBIDDEN_HEADERS: [HeaderName; 9] = [
30    CONNECTION,
31    HOST,
32    PROXY_AUTHENTICATE,
33    PROXY_AUTHORIZATION,
34    TRANSFER_ENCODING,
35    UPGRADE,
36    HeaderName::from_static("keep-alive"),
37    HeaderName::from_static("proxy-connection"),
38    HeaderName::from_static("http2-settings"),
39];
40
41#[derive(Debug, Clone, FromEnv)]
42pub struct ConnectOptions {
43    #[env(from = "HTTP_ADDR", default = "http://localhost:8080")]
44    pub addr: String,
45    #[env(from = "HTTP_CONNECT_TIMEOUT_SECS", default = "10")]
46    pub connect_timeout_secs: u64,
47}
48
49impl omnia::FromEnv for ConnectOptions {
50    fn from_env() -> Result<Self> {
51        Self::from_env().finalize().context("issue loading connection options")
52    }
53}
54
55/// Reqwest-based HTTP hooks for outbound `wasi:http` requests.
56#[derive(Debug, Clone)]
57struct HttpHooks {
58    client: reqwest::Client,
59    connect_timeout: Duration,
60}
61
62/// Default implementation for `wasi:http`.
63#[derive(Debug, Clone)]
64pub struct HttpDefault {
65    hooks: HttpHooks,
66    ctx: WasiHttpCtx,
67}
68
69impl HttpDefault {
70    /// Produce a [`WasiHttpCtxView`] by splitting borrows on inner fields.
71    pub fn as_view<'a>(&'a mut self, table: &'a mut ResourceTable) -> WasiHttpCtxView<'a> {
72        WasiHttpCtxView {
73            hooks: &mut self.hooks,
74            ctx: &mut self.ctx,
75            table,
76        }
77    }
78}
79
80impl Backend for HttpDefault {
81    type ConnectOptions = ConnectOptions;
82
83    #[instrument]
84    async fn connect_with(options: Self::ConnectOptions) -> Result<Self> {
85        let connect_timeout = Duration::from_secs(options.connect_timeout_secs);
86
87        let builder = reqwest::Client::builder().connect_timeout(connect_timeout);
88
89        #[cfg(test)]
90        let builder = builder.no_proxy();
91
92        let client = builder.build().context("building HTTP client")?;
93        Ok(Self {
94            hooks: HttpHooks {
95                client,
96                connect_timeout,
97            },
98            ctx: WasiHttpCtx::default(),
99        })
100    }
101}
102
103impl p3::WasiHttpHooks for HttpHooks {
104    fn send_request(
105        &mut self, request: Request<UnsyncBoxBody<Bytes, ErrorCode>>,
106        _options: Option<RequestOptions>, fut: FutureResult<()>,
107    ) -> Box<
108        dyn Future<
109                Output = HttpResult<(Response<UnsyncBoxBody<Bytes, ErrorCode>>, FutureResult<()>)>,
110            > + Send,
111    > {
112        let shared_client = self.client.clone();
113        let connect_timeout = self.connect_timeout;
114
115        Box::new(async move {
116            let (mut parts, body) = request.into_parts();
117
118            // remove "Host" headers (`reqwest` adds its own)
119            parts.headers.remove(HOST);
120
121            // use a one-off client when a client certificate is required, otherwise
122            // reuse the shared client for connection pooling
123            let client = if let Some(encoded_cert) = parts.headers.remove("Client-Cert") {
124                tracing::debug!("using client certificate");
125                let encoded = encoded_cert.to_str().map_err(internal_err)?;
126                let bytes = Base64::decode_vec(encoded).map_err(internal_err)?;
127                let identity = reqwest::Identity::from_pem(&bytes).map_err(internal_err)?;
128                let builder =
129                    reqwest::Client::builder().connect_timeout(connect_timeout).identity(identity);
130
131                #[cfg(test)]
132                let builder = builder.no_proxy();
133
134                builder.build().map_err(reqwest_err)?
135            } else {
136                shared_client
137            };
138
139            let collected = body.collect().await.map_err(internal_err)?;
140
141            // make request
142            let url = parts.uri.to_string();
143            let resp = client
144                .request(parts.method, &url)
145                .headers(parts.headers)
146                .body(collected.to_bytes())
147                .send()
148                .await
149                .map_err(reqwest_err)?;
150
151            // process response
152            let converted: Response<reqwest::Body> = resp.into();
153            let (parts, body) = converted.into_parts();
154            let body = body.map_err(reqwest_err).boxed_unsync();
155            let mut response = Response::from_parts(parts, body);
156
157            // remove forbidden headers (disallowed by `wasmtime-wasi-http`)
158            let headers = response.headers_mut();
159            for header in &FORBIDDEN_HEADERS {
160                headers.remove(header);
161            }
162
163            Ok((response, fut))
164        })
165    }
166}
167
168fn internal_err(e: impl Display) -> ErrorCode {
169    ErrorCode::InternalError(Some(e.to_string()))
170}
171
172#[allow(clippy::needless_pass_by_value)]
173fn reqwest_err(e: reqwest::Error) -> ErrorCode {
174    if e.is_timeout() {
175        ErrorCode::ConnectionTimeout
176    } else if e.is_connect() {
177        ErrorCode::ConnectionRefused
178    } else if e.is_request() {
179        ErrorCode::HttpRequestUriInvalid
180    } else {
181        internal_err(e)
182    }
183}
184
185#[cfg(test)]
186mod tests {
187    use std::pin::Pin;
188
189    use http::header::{AUTHORIZATION, CONTENT_TYPE};
190    use http::{Method, StatusCode};
191    use http_body_util::{Empty, Full};
192    use p3::WasiHttpHooks;
193    use wiremock::matchers::{body_string, header, method};
194    use wiremock::{Mock, MockServer, ResponseTemplate};
195
196    use super::*;
197
198    async fn test_client() -> HttpDefault {
199        let options = ConnectOptions {
200            addr: String::new(),
201            connect_timeout_secs: 10,
202        };
203        HttpDefault::connect_with(options).await.unwrap()
204    }
205
206    #[tokio::test]
207    async fn multiple_host_headers() {
208        let server = MockServer::start().await;
209        Mock::given(method("GET"))
210            .respond_with(ResponseTemplate::new(200).set_body_string("Hello, World!"))
211            .mount(&server)
212            .await;
213
214        let request = Request::get(server.uri())
215            .header(HOST, "localhost-1")
216            .header(HOST, "localhost-2")
217            .body(Empty::new().map_err(internal_err).boxed_unsync())
218            .unwrap();
219
220        let result = test_client().await.handle(request).await;
221        assert!(result.is_ok());
222
223        // check response
224        let (response, _) = result.unwrap();
225        assert_eq!(response.status(), StatusCode::OK);
226
227        // check body
228        let body = response.into_body().collect().await.unwrap().to_bytes();
229        assert_eq!(body, Bytes::from("Hello, World!"));
230
231        // check received request
232        let requests = server.received_requests().await.expect("should have requests");
233        assert_eq!(requests.len(), 1);
234
235        assert_eq!(requests[0].headers.get_all(HOST).iter().count(), 1);
236        assert!(requests[0].headers.get(HOST).unwrap().to_str().unwrap().starts_with("127.0.0.1:"));
237    }
238
239    #[tokio::test]
240    async fn post_with_body() {
241        let server = MockServer::start().await;
242        Mock::given(method("POST"))
243            .and(body_string("test body"))
244            .respond_with(ResponseTemplate::new(201).set_body_string("Created"))
245            .mount(&server)
246            .await;
247
248        let request = Request::post(server.uri())
249            .body(Full::new(Bytes::from("test body")).map_err(internal_err).boxed_unsync())
250            .unwrap();
251
252        let result = test_client().await.handle(request).await;
253        assert!(result.is_ok());
254
255        let (response, _) = result.unwrap();
256        assert_eq!(response.status(), StatusCode::CREATED);
257    }
258
259    #[tokio::test]
260    async fn custom_headers() {
261        let server = MockServer::start().await;
262        Mock::given(method("GET"))
263            .and(header("X-Custom-Header", "custom-value"))
264            .and(header(AUTHORIZATION, "Bearer token123"))
265            .respond_with(ResponseTemplate::new(200))
266            .mount(&server)
267            .await;
268
269        let request = Request::get(server.uri())
270            .header("X-Custom-Header", "custom-value")
271            .header(AUTHORIZATION, "Bearer token123")
272            .body(Empty::new().map_err(internal_err).boxed_unsync())
273            .unwrap();
274
275        let result = test_client().await.handle(request).await;
276        assert!(result.is_ok());
277
278        let (response, _) = result.unwrap();
279        assert_eq!(response.status(), StatusCode::OK);
280    }
281
282    #[tokio::test]
283    async fn permitted_headers() {
284        let server = MockServer::start().await;
285        Mock::given(method("GET"))
286            .respond_with(
287                ResponseTemplate::new(200)
288                    .insert_header(CONNECTION, "keep-alive")
289                    .insert_header(TRANSFER_ENCODING, "chunked")
290                    .insert_header(UPGRADE, "websocket")
291                    .insert_header(CONTENT_TYPE, "application/json")
292                    .insert_header("X-Safe-Header", "safe-value"),
293            )
294            .mount(&server)
295            .await;
296
297        let request = Request::get(server.uri())
298            .body(Empty::new().map_err(internal_err).boxed_unsync())
299            .unwrap();
300
301        let result = test_client().await.handle(request).await;
302        assert!(result.is_ok());
303
304        // check response
305        let (response, _) = result.unwrap();
306        let headers = response.headers();
307
308        // permitted headers are preserved
309        assert_eq!(headers.get(CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json");
310        assert_eq!(headers.get("X-Safe-Header").unwrap().to_str().unwrap(), "safe-value");
311
312        // verify forbidden headers are removed
313        assert!(!headers.contains_key(CONNECTION));
314        assert!(!headers.contains_key(TRANSFER_ENCODING));
315        assert!(!headers.contains_key(UPGRADE));
316    }
317
318    #[tokio::test]
319    async fn invalid_uri() {
320        let body = Full::new(Bytes::from("")).map_err(internal_err).boxed_unsync();
321        let request =
322            Request::builder().method(Method::GET).uri("not-a-valid-uri").body(body).unwrap();
323
324        let result = test_client().await.handle(request).await;
325        assert!(result.is_err());
326    }
327
328    #[tokio::test]
329    async fn connection_refused() {
330        let request = Request::get("http://localhost:59999/test")
331            .body(Empty::new().map_err(internal_err).boxed_unsync())
332            .unwrap();
333
334        let result = test_client().await.handle(request).await;
335        assert!(result.is_err());
336    }
337
338    #[tokio::test]
339    async fn client_cert_base64() {
340        let server = MockServer::start().await;
341        Mock::given(method("GET")).respond_with(ResponseTemplate::new(200)).mount(&server).await;
342
343        let request = Request::get(server.uri())
344            .header("Client-Cert", "not-valid-base64!!!")
345            .body(Empty::new().map_err(internal_err).boxed_unsync())
346            .unwrap();
347
348        let result = test_client().await.handle(request).await;
349        assert!(result.is_err());
350    }
351
352    #[tokio::test]
353    async fn client_cert_pem() {
354        let server = MockServer::start().await;
355        Mock::given(method("GET")).respond_with(ResponseTemplate::new(200)).mount(&server).await;
356
357        let invalid_pem = "invalid pem content";
358        let encoded = Base64::encode_string(invalid_pem.as_bytes());
359        let request = Request::get(server.uri())
360            .header("Client-Cert", encoded)
361            .body(Empty::new().map_err(internal_err).boxed_unsync())
362            .unwrap();
363
364        let result = test_client().await.handle(request).await;
365        assert!(result.is_err());
366    }
367
368    impl HttpDefault {
369        async fn handle(
370            &mut self, request: Request<UnsyncBoxBody<Bytes, ErrorCode>>,
371        ) -> HttpResult<(Response<UnsyncBoxBody<Bytes, ErrorCode>>, FutureResult<()>)> {
372            let boxed = self.hooks.send_request(request, None, Box::new(async { Ok(()) }));
373            Pin::from(boxed).await
374        }
375    }
376}