omnia_wasi_http/host/
default_impl.rs1use std::fmt::Display;
2
3use anyhow::{Context, Result};
4use base64ct::{Base64, Encoding};
5use bytes::Bytes;
6use fromenv::FromEnv;
7use futures::Future;
8use http::header::{
9 CONNECTION, HOST, HeaderName, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TRANSFER_ENCODING,
10 UPGRADE,
11};
12use http::{Request, Response};
13use http_body_util::BodyExt;
14use http_body_util::combinators::UnsyncBoxBody;
15use omnia::Backend;
16use tracing::instrument;
17use wasmtime_wasi::TrappableError;
18use wasmtime_wasi_http::p3::bindings::http::types::ErrorCode;
19use wasmtime_wasi_http::p3::{self, RequestOptions};
20
21pub type HttpResult<T> = Result<T, HttpError>;
22pub type HttpError = TrappableError<ErrorCode>;
23pub type FutureResult<T> = Box<dyn Future<Output = Result<T, ErrorCode>> + Send>;
24
25pub const FORBIDDEN_HEADERS: [HeaderName; 9] = [
27 CONNECTION,
28 HOST,
29 PROXY_AUTHENTICATE,
30 PROXY_AUTHORIZATION,
31 TRANSFER_ENCODING,
32 UPGRADE,
33 HeaderName::from_static("keep-alive"),
34 HeaderName::from_static("proxy-connection"),
35 HeaderName::from_static("http2-settings"),
36];
37
38#[derive(Debug, Clone, FromEnv)]
39pub struct ConnectOptions {
40 #[env(from = "HTTP_ADDR", default = "http://localhost:8080")]
41 pub addr: String,
42}
43
44impl omnia::FromEnv for ConnectOptions {
45 fn from_env() -> Result<Self> {
46 Self::from_env().finalize().context("issue loading connection options")
47 }
48}
49
50#[derive(Debug, Clone)]
52pub struct HttpDefault;
53
54impl Backend for HttpDefault {
55 type ConnectOptions = ConnectOptions;
56
57 #[instrument]
58 async fn connect_with(options: Self::ConnectOptions) -> Result<Self> {
59 Ok(Self)
60 }
61}
62
63impl p3::WasiHttpCtx for HttpDefault {
64 fn send_request(
65 &mut self, request: Request<UnsyncBoxBody<Bytes, ErrorCode>>,
66 _options: Option<RequestOptions>, fut: FutureResult<()>,
67 ) -> Box<
68 dyn Future<
69 Output = HttpResult<(Response<UnsyncBoxBody<Bytes, ErrorCode>>, FutureResult<()>)>,
70 > + Send,
71 > {
72 Box::new(async move {
73 let (mut parts, body) = request.into_parts();
74
75 let values = parts.headers.get_all(HOST).iter().cloned().collect::<Vec<_>>();
77 if values.len() > 1 {
78 parts.headers.remove(HOST);
79 for v in values.into_iter().skip(1) {
80 parts.headers.append(HOST, v);
81 }
82 }
83
84 let mut builder = reqwest::Client::builder();
86
87 if let Some(encoded_cert) = parts.headers.remove("Client-Cert") {
89 tracing::debug!("using client certificate");
90 let encoded = encoded_cert.to_str().map_err(internal_err)?;
91 let bytes = Base64::decode_vec(encoded).map_err(internal_err)?;
92 let identity = reqwest::Identity::from_pem(&bytes).map_err(internal_err)?;
93 builder = builder.identity(identity);
94 }
95
96 #[cfg(test)]
98 let builder = builder.no_proxy();
99 let client = builder.build().map_err(reqwest_err)?;
100
101 let collected = body.collect().await.map_err(internal_err)?;
102
103 let resp = client
105 .request(parts.method, parts.uri.to_string())
106 .headers(parts.headers)
107 .body(collected.to_bytes())
108 .send()
109 .await
110 .map_err(reqwest_err)?;
111
112 let converted: Response<reqwest::Body> = resp.into();
114 let (parts, body) = converted.into_parts();
115 let body = body.map_err(reqwest_err).boxed_unsync();
116 let mut response = Response::from_parts(parts, body);
117
118 let headers = response.headers_mut();
120 for header in &FORBIDDEN_HEADERS {
121 headers.remove(header);
122 }
123
124 Ok((response, fut))
125 })
126 }
127}
128
129fn internal_err(e: impl Display) -> ErrorCode {
130 ErrorCode::InternalError(Some(e.to_string()))
131}
132
133#[allow(clippy::needless_pass_by_value)]
134fn reqwest_err(e: reqwest::Error) -> ErrorCode {
135 if e.is_timeout() {
136 ErrorCode::ConnectionTimeout
137 } else if e.is_connect() {
138 ErrorCode::ConnectionRefused
139 } else if e.is_request() {
140 ErrorCode::HttpRequestUriInvalid
141 } else {
142 internal_err(e)
143 }
144}
145
146#[cfg(test)]
147mod tests {
148 use std::pin::Pin;
149
150 use http::header::{AUTHORIZATION, CONTENT_TYPE};
151 use http::{Method, StatusCode};
152 use http_body_util::{Empty, Full};
153 use p3::WasiHttpCtx;
154 use wiremock::matchers::{body_string, header, method};
155 use wiremock::{Mock, MockServer, ResponseTemplate};
156
157 use super::*;
158
159 #[tokio::test]
160 async fn multiple_host_headers() {
161 let server = MockServer::start().await;
162 Mock::given(method("GET"))
163 .respond_with(ResponseTemplate::new(200).set_body_string("Hello, World!"))
164 .mount(&server)
165 .await;
166
167 let request = Request::get(server.uri())
168 .header(HOST, "localhost-1")
169 .header(HOST, "localhost-2")
170 .body(Empty::new().map_err(internal_err).boxed_unsync())
171 .unwrap();
172
173 let result = HttpDefault.handle(request).await;
174 assert!(result.is_ok());
175
176 let (response, _) = result.unwrap();
178 assert_eq!(response.status(), StatusCode::OK);
179
180 let body = response.into_body().collect().await.unwrap().to_bytes();
182 assert_eq!(body, Bytes::from("Hello, World!"));
183
184 let requests = server.received_requests().await.expect("should have requests");
186 assert_eq!(requests.len(), 1);
187
188 assert_eq!(requests[0].headers.get_all(HOST).iter().count(), 1);
189 assert_eq!(requests[0].headers.get(HOST).unwrap().to_str().unwrap(), "localhost-2");
190 }
191
192 #[tokio::test]
193 async fn post_with_body() {
194 let server = MockServer::start().await;
195 Mock::given(method("POST"))
196 .and(body_string("test body"))
197 .respond_with(ResponseTemplate::new(201).set_body_string("Created"))
198 .mount(&server)
199 .await;
200
201 let request = Request::post(server.uri())
202 .body(Full::new(Bytes::from("test body")).map_err(internal_err).boxed_unsync())
203 .unwrap();
204
205 let result = HttpDefault.handle(request).await;
206 assert!(result.is_ok());
207
208 let (response, _) = result.unwrap();
209 assert_eq!(response.status(), StatusCode::CREATED);
210 }
211
212 #[tokio::test]
213 async fn custom_headers() {
214 let server = MockServer::start().await;
215 Mock::given(method("GET"))
216 .and(header("X-Custom-Header", "custom-value"))
217 .and(header(AUTHORIZATION, "Bearer token123"))
218 .respond_with(ResponseTemplate::new(200))
219 .mount(&server)
220 .await;
221
222 let request = Request::get(server.uri())
223 .header("X-Custom-Header", "custom-value")
224 .header(AUTHORIZATION, "Bearer token123")
225 .body(Empty::new().map_err(internal_err).boxed_unsync())
226 .unwrap();
227
228 let result = HttpDefault.handle(request).await;
229 assert!(result.is_ok());
230
231 let (response, _) = result.unwrap();
232 assert_eq!(response.status(), StatusCode::OK);
233 }
234
235 #[tokio::test]
236 async fn permitted_headers() {
237 let server = MockServer::start().await;
238 Mock::given(method("GET"))
239 .respond_with(
240 ResponseTemplate::new(200)
241 .insert_header(CONNECTION, "keep-alive")
242 .insert_header(TRANSFER_ENCODING, "chunked")
243 .insert_header(UPGRADE, "websocket")
244 .insert_header(CONTENT_TYPE, "application/json")
245 .insert_header("X-Safe-Header", "safe-value"),
246 )
247 .mount(&server)
248 .await;
249
250 let request = Request::get(server.uri())
251 .body(Empty::new().map_err(internal_err).boxed_unsync())
252 .unwrap();
253
254 let result = HttpDefault.handle(request).await;
255 assert!(result.is_ok());
256
257 let (response, _) = result.unwrap();
259 let headers = response.headers();
260
261 assert_eq!(headers.get(CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json");
263 assert_eq!(headers.get("X-Safe-Header").unwrap().to_str().unwrap(), "safe-value");
264
265 assert!(!headers.contains_key(CONNECTION));
267 assert!(!headers.contains_key(TRANSFER_ENCODING));
268 assert!(!headers.contains_key(UPGRADE));
269 }
270
271 #[tokio::test]
272 async fn invalid_uri() {
273 let body = Full::new(Bytes::from("")).map_err(internal_err).boxed_unsync();
274 let request =
275 Request::builder().method(Method::GET).uri("not-a-valid-uri").body(body).unwrap();
276
277 let result = HttpDefault.handle(request).await;
278 assert!(result.is_err());
279 }
280
281 #[tokio::test]
282 async fn connection_refused() {
283 let request = Request::get("http://localhost:59999/test")
284 .body(Empty::new().map_err(internal_err).boxed_unsync())
285 .unwrap();
286
287 let result = HttpDefault.handle(request).await;
288 assert!(result.is_err());
289 }
290
291 #[tokio::test]
292 async fn client_cert_base64() {
293 let server = MockServer::start().await;
294 Mock::given(method("GET")).respond_with(ResponseTemplate::new(200)).mount(&server).await;
295
296 let request = Request::get(server.uri())
297 .header("Client-Cert", "not-valid-base64!!!")
298 .body(Empty::new().map_err(internal_err).boxed_unsync())
299 .unwrap();
300
301 let result = HttpDefault.handle(request).await;
302 assert!(result.is_err());
303 }
304
305 #[tokio::test]
306 async fn client_cert_pem() {
307 let server = MockServer::start().await;
308 Mock::given(method("GET")).respond_with(ResponseTemplate::new(200)).mount(&server).await;
309
310 let invalid_pem = "invalid pem content";
311 let encoded = Base64::encode_string(invalid_pem.as_bytes());
312 let request = Request::get(server.uri())
313 .header("Client-Cert", encoded)
314 .body(Empty::new().map_err(internal_err).boxed_unsync())
315 .unwrap();
316
317 let result = HttpDefault.handle(request).await;
318 assert!(result.is_err());
319 }
320
321 impl HttpDefault {
323 async fn handle(
324 &mut self, request: Request<UnsyncBoxBody<Bytes, ErrorCode>>,
325 ) -> HttpResult<(Response<UnsyncBoxBody<Bytes, ErrorCode>>, FutureResult<()>)> {
326 let boxed = self.send_request(request, None, Box::new(async { Ok(()) }));
327 Pin::from(boxed).await
328 }
329 }
330}