Skip to main content

omnia_wasi_http/host/
default_impl.rs

1use std::fmt::Display;
2
3use anyhow::{Context, Result};
4use base64ct::{Base64, Encoding};
5use bytes::Bytes;
6use fromenv::FromEnv;
7use futures::Future;
8use http::header::{
9    CONNECTION, HOST, HeaderName, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TRANSFER_ENCODING,
10    UPGRADE,
11};
12use http::{Request, Response};
13use http_body_util::BodyExt;
14use http_body_util::combinators::UnsyncBoxBody;
15use omnia::Backend;
16use tracing::instrument;
17use wasmtime_wasi::TrappableError;
18use wasmtime_wasi_http::p3::bindings::http::types::ErrorCode;
19use wasmtime_wasi_http::p3::{self, RequestOptions};
20
21pub type HttpResult<T> = Result<T, HttpError>;
22pub type HttpError = TrappableError<ErrorCode>;
23pub type FutureResult<T> = Box<dyn Future<Output = Result<T, ErrorCode>> + Send>;
24
25/// Set of headers that are forbidden by by `wasmtime-wasi-http`.
26pub const FORBIDDEN_HEADERS: [HeaderName; 9] = [
27    CONNECTION,
28    HOST,
29    PROXY_AUTHENTICATE,
30    PROXY_AUTHORIZATION,
31    TRANSFER_ENCODING,
32    UPGRADE,
33    HeaderName::from_static("keep-alive"),
34    HeaderName::from_static("proxy-connection"),
35    HeaderName::from_static("http2-settings"),
36];
37
38#[derive(Debug, Clone, FromEnv)]
39pub struct ConnectOptions {
40    #[env(from = "HTTP_ADDR", default = "http://localhost:8080")]
41    pub addr: String,
42}
43
44impl omnia::FromEnv for ConnectOptions {
45    fn from_env() -> Result<Self> {
46        Self::from_env().finalize().context("issue loading connection options")
47    }
48}
49
50/// Default implementation for `wasi:http`.
51#[derive(Debug, Clone)]
52pub struct HttpDefault;
53
54impl Backend for HttpDefault {
55    type ConnectOptions = ConnectOptions;
56
57    #[instrument]
58    async fn connect_with(options: Self::ConnectOptions) -> Result<Self> {
59        Ok(Self)
60    }
61}
62
63impl p3::WasiHttpCtx for HttpDefault {
64    fn send_request(
65        &mut self, request: Request<UnsyncBoxBody<Bytes, ErrorCode>>,
66        _options: Option<RequestOptions>, fut: FutureResult<()>,
67    ) -> Box<
68        dyn Future<
69                Output = HttpResult<(Response<UnsyncBoxBody<Bytes, ErrorCode>>, FutureResult<()>)>,
70            > + Send,
71    > {
72        Box::new(async move {
73            let (mut parts, body) = request.into_parts();
74
75            // dedupe "Host" headers (keep the one added by wasmtime/wasip3?)
76            let values = parts.headers.get_all(HOST).iter().cloned().collect::<Vec<_>>();
77            if values.len() > 1 {
78                parts.headers.remove(HOST);
79                for v in values.into_iter().skip(1) {
80                    parts.headers.append(HOST, v);
81                }
82            }
83
84            // build client
85            let mut builder = reqwest::Client::builder();
86
87            // check for "Client-Cert" header
88            if let Some(encoded_cert) = parts.headers.remove("Client-Cert") {
89                tracing::debug!("using client certificate");
90                let encoded = encoded_cert.to_str().map_err(internal_err)?;
91                let bytes = Base64::decode_vec(encoded).map_err(internal_err)?;
92                let identity = reqwest::Identity::from_pem(&bytes).map_err(internal_err)?;
93                builder = builder.identity(identity);
94            }
95
96            // disable system proxy in tests to avoid macOS issues
97            #[cfg(test)]
98            let builder = builder.no_proxy();
99            let client = builder.build().map_err(reqwest_err)?;
100
101            let collected = body.collect().await.map_err(internal_err)?;
102
103            // make request
104            let resp = client
105                .request(parts.method, parts.uri.to_string())
106                .headers(parts.headers)
107                .body(collected.to_bytes())
108                .send()
109                .await
110                .map_err(reqwest_err)?;
111
112            // process response
113            let converted: Response<reqwest::Body> = resp.into();
114            let (parts, body) = converted.into_parts();
115            let body = body.map_err(reqwest_err).boxed_unsync();
116            let mut response = Response::from_parts(parts, body);
117
118            // remove forbidden headers (disallowed by `wasmtime-wasi-http`)
119            let headers = response.headers_mut();
120            for header in &FORBIDDEN_HEADERS {
121                headers.remove(header);
122            }
123
124            Ok((response, fut))
125        })
126    }
127}
128
129fn internal_err(e: impl Display) -> ErrorCode {
130    ErrorCode::InternalError(Some(e.to_string()))
131}
132
133#[allow(clippy::needless_pass_by_value)]
134fn reqwest_err(e: reqwest::Error) -> ErrorCode {
135    if e.is_timeout() {
136        ErrorCode::ConnectionTimeout
137    } else if e.is_connect() {
138        ErrorCode::ConnectionRefused
139    } else if e.is_request() {
140        ErrorCode::HttpRequestUriInvalid
141    } else {
142        internal_err(e)
143    }
144}
145
146#[cfg(test)]
147mod tests {
148    use std::pin::Pin;
149
150    use http::header::{AUTHORIZATION, CONTENT_TYPE};
151    use http::{Method, StatusCode};
152    use http_body_util::{Empty, Full};
153    use p3::WasiHttpCtx;
154    use wiremock::matchers::{body_string, header, method};
155    use wiremock::{Mock, MockServer, ResponseTemplate};
156
157    use super::*;
158
159    #[tokio::test]
160    async fn multiple_host_headers() {
161        let server = MockServer::start().await;
162        Mock::given(method("GET"))
163            .respond_with(ResponseTemplate::new(200).set_body_string("Hello, World!"))
164            .mount(&server)
165            .await;
166
167        let request = Request::get(server.uri())
168            .header(HOST, "localhost-1")
169            .header(HOST, "localhost-2")
170            .body(Empty::new().map_err(internal_err).boxed_unsync())
171            .unwrap();
172
173        let result = HttpDefault.handle(request).await;
174        assert!(result.is_ok());
175
176        // check response
177        let (response, _) = result.unwrap();
178        assert_eq!(response.status(), StatusCode::OK);
179
180        // check body
181        let body = response.into_body().collect().await.unwrap().to_bytes();
182        assert_eq!(body, Bytes::from("Hello, World!"));
183
184        // check received request
185        let requests = server.received_requests().await.expect("should have requests");
186        assert_eq!(requests.len(), 1);
187
188        assert_eq!(requests[0].headers.get_all(HOST).iter().count(), 1);
189        assert_eq!(requests[0].headers.get(HOST).unwrap().to_str().unwrap(), "localhost-2");
190    }
191
192    #[tokio::test]
193    async fn post_with_body() {
194        let server = MockServer::start().await;
195        Mock::given(method("POST"))
196            .and(body_string("test body"))
197            .respond_with(ResponseTemplate::new(201).set_body_string("Created"))
198            .mount(&server)
199            .await;
200
201        let request = Request::post(server.uri())
202            .body(Full::new(Bytes::from("test body")).map_err(internal_err).boxed_unsync())
203            .unwrap();
204
205        let result = HttpDefault.handle(request).await;
206        assert!(result.is_ok());
207
208        let (response, _) = result.unwrap();
209        assert_eq!(response.status(), StatusCode::CREATED);
210    }
211
212    #[tokio::test]
213    async fn custom_headers() {
214        let server = MockServer::start().await;
215        Mock::given(method("GET"))
216            .and(header("X-Custom-Header", "custom-value"))
217            .and(header(AUTHORIZATION, "Bearer token123"))
218            .respond_with(ResponseTemplate::new(200))
219            .mount(&server)
220            .await;
221
222        let request = Request::get(server.uri())
223            .header("X-Custom-Header", "custom-value")
224            .header(AUTHORIZATION, "Bearer token123")
225            .body(Empty::new().map_err(internal_err).boxed_unsync())
226            .unwrap();
227
228        let result = HttpDefault.handle(request).await;
229        assert!(result.is_ok());
230
231        let (response, _) = result.unwrap();
232        assert_eq!(response.status(), StatusCode::OK);
233    }
234
235    #[tokio::test]
236    async fn permitted_headers() {
237        let server = MockServer::start().await;
238        Mock::given(method("GET"))
239            .respond_with(
240                ResponseTemplate::new(200)
241                    .insert_header(CONNECTION, "keep-alive")
242                    .insert_header(TRANSFER_ENCODING, "chunked")
243                    .insert_header(UPGRADE, "websocket")
244                    .insert_header(CONTENT_TYPE, "application/json")
245                    .insert_header("X-Safe-Header", "safe-value"),
246            )
247            .mount(&server)
248            .await;
249
250        let request = Request::get(server.uri())
251            .body(Empty::new().map_err(internal_err).boxed_unsync())
252            .unwrap();
253
254        let result = HttpDefault.handle(request).await;
255        assert!(result.is_ok());
256
257        // check response
258        let (response, _) = result.unwrap();
259        let headers = response.headers();
260
261        // permitted headers are preserved
262        assert_eq!(headers.get(CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json");
263        assert_eq!(headers.get("X-Safe-Header").unwrap().to_str().unwrap(), "safe-value");
264
265        // verify forbidden headers are removed
266        assert!(!headers.contains_key(CONNECTION));
267        assert!(!headers.contains_key(TRANSFER_ENCODING));
268        assert!(!headers.contains_key(UPGRADE));
269    }
270
271    #[tokio::test]
272    async fn invalid_uri() {
273        let body = Full::new(Bytes::from("")).map_err(internal_err).boxed_unsync();
274        let request =
275            Request::builder().method(Method::GET).uri("not-a-valid-uri").body(body).unwrap();
276
277        let result = HttpDefault.handle(request).await;
278        assert!(result.is_err());
279    }
280
281    #[tokio::test]
282    async fn connection_refused() {
283        let request = Request::get("http://localhost:59999/test")
284            .body(Empty::new().map_err(internal_err).boxed_unsync())
285            .unwrap();
286
287        let result = HttpDefault.handle(request).await;
288        assert!(result.is_err());
289    }
290
291    #[tokio::test]
292    async fn client_cert_base64() {
293        let server = MockServer::start().await;
294        Mock::given(method("GET")).respond_with(ResponseTemplate::new(200)).mount(&server).await;
295
296        let request = Request::get(server.uri())
297            .header("Client-Cert", "not-valid-base64!!!")
298            .body(Empty::new().map_err(internal_err).boxed_unsync())
299            .unwrap();
300
301        let result = HttpDefault.handle(request).await;
302        assert!(result.is_err());
303    }
304
305    #[tokio::test]
306    async fn client_cert_pem() {
307        let server = MockServer::start().await;
308        Mock::given(method("GET")).respond_with(ResponseTemplate::new(200)).mount(&server).await;
309
310        let invalid_pem = "invalid pem content";
311        let encoded = Base64::encode_string(invalid_pem.as_bytes());
312        let request = Request::get(server.uri())
313            .header("Client-Cert", encoded)
314            .body(Empty::new().map_err(internal_err).boxed_unsync())
315            .unwrap();
316
317        let result = HttpDefault.handle(request).await;
318        assert!(result.is_err());
319    }
320
321    // Mock `wasip3::proxy::wasi::http::handler::handle` method
322    impl HttpDefault {
323        async fn handle(
324            &mut self, request: Request<UnsyncBoxBody<Bytes, ErrorCode>>,
325        ) -> HttpResult<(Response<UnsyncBoxBody<Bytes, ErrorCode>>, FutureResult<()>)> {
326            let boxed = self.send_request(request, None, Box::new(async { Ok(()) }));
327            Pin::from(boxed).await
328        }
329    }
330}