1use std::fmt::Display;
2use std::time::Duration;
3
4use anyhow::{Context, Result};
5use base64ct::{Base64, Encoding};
6use bytes::Bytes;
7use fromenv::FromEnv;
8use futures::Future;
9use http::header::{
10 CONNECTION, HOST, HeaderName, PROXY_AUTHENTICATE, PROXY_AUTHORIZATION, TRANSFER_ENCODING,
11 UPGRADE,
12};
13use http::{Request, Response};
14use http_body_util::BodyExt;
15use http_body_util::combinators::UnsyncBoxBody;
16use omnia::Backend;
17use tracing::instrument;
18use wasmtime::component::ResourceTable;
19use wasmtime_wasi::TrappableError;
20use wasmtime_wasi_http::WasiHttpCtx;
21use wasmtime_wasi_http::p3::bindings::http::types::ErrorCode;
22use wasmtime_wasi_http::p3::{self, RequestOptions, WasiHttpCtxView};
23
24pub type HttpResult<T> = Result<T, HttpError>;
25pub type HttpError = TrappableError<ErrorCode>;
26pub type FutureResult<T> = Box<dyn Future<Output = Result<T, ErrorCode>> + Send>;
27
28pub const FORBIDDEN_HEADERS: [HeaderName; 9] = [
30 CONNECTION,
31 HOST,
32 PROXY_AUTHENTICATE,
33 PROXY_AUTHORIZATION,
34 TRANSFER_ENCODING,
35 UPGRADE,
36 HeaderName::from_static("keep-alive"),
37 HeaderName::from_static("proxy-connection"),
38 HeaderName::from_static("http2-settings"),
39];
40
41#[derive(Debug, Clone, FromEnv)]
42pub struct ConnectOptions {
43 #[env(from = "HTTP_ADDR", default = "http://localhost:8080")]
44 pub addr: String,
45 #[env(from = "HTTP_CONNECT_TIMEOUT_SECS", default = "10")]
46 pub connect_timeout_secs: u64,
47}
48
49impl omnia::FromEnv for ConnectOptions {
50 fn from_env() -> Result<Self> {
51 Self::from_env().finalize().context("issue loading connection options")
52 }
53}
54
55#[derive(Debug, Clone)]
57struct HttpHooks {
58 client: reqwest::Client,
59 connect_timeout: Duration,
60}
61
62#[derive(Debug, Clone)]
64pub struct HttpDefault {
65 hooks: HttpHooks,
66 ctx: WasiHttpCtx,
67}
68
69impl HttpDefault {
70 pub fn as_view<'a>(&'a mut self, table: &'a mut ResourceTable) -> WasiHttpCtxView<'a> {
72 WasiHttpCtxView {
73 hooks: &mut self.hooks,
74 ctx: &mut self.ctx,
75 table,
76 }
77 }
78}
79
80impl Backend for HttpDefault {
81 type ConnectOptions = ConnectOptions;
82
83 #[instrument]
84 async fn connect_with(options: Self::ConnectOptions) -> Result<Self> {
85 let connect_timeout = Duration::from_secs(options.connect_timeout_secs);
86
87 let builder = reqwest::Client::builder().connect_timeout(connect_timeout);
88
89 #[cfg(test)]
90 let builder = builder.no_proxy();
91
92 let client = builder.build().context("building HTTP client")?;
93 Ok(Self {
94 hooks: HttpHooks {
95 client,
96 connect_timeout,
97 },
98 ctx: WasiHttpCtx::default(),
99 })
100 }
101}
102
103impl p3::WasiHttpHooks for HttpHooks {
104 fn send_request(
105 &mut self, request: Request<UnsyncBoxBody<Bytes, ErrorCode>>,
106 _options: Option<RequestOptions>, fut: FutureResult<()>,
107 ) -> Box<
108 dyn Future<
109 Output = HttpResult<(Response<UnsyncBoxBody<Bytes, ErrorCode>>, FutureResult<()>)>,
110 > + Send,
111 > {
112 let shared_client = self.client.clone();
113 let connect_timeout = self.connect_timeout;
114
115 Box::new(async move {
116 let (mut parts, body) = request.into_parts();
117
118 parts.headers.remove(HOST);
120
121 let client = if let Some(encoded_cert) = parts.headers.remove("Client-Cert") {
124 tracing::debug!("using client certificate");
125 let encoded = encoded_cert.to_str().map_err(internal_err)?;
126 let bytes = Base64::decode_vec(encoded).map_err(internal_err)?;
127 let identity = reqwest::Identity::from_pem(&bytes).map_err(internal_err)?;
128 let builder =
129 reqwest::Client::builder().connect_timeout(connect_timeout).identity(identity);
130
131 #[cfg(test)]
132 let builder = builder.no_proxy();
133
134 builder.build().map_err(reqwest_err)?
135 } else {
136 shared_client
137 };
138
139 let collected = body.collect().await.map_err(internal_err)?;
140
141 let url = parts.uri.to_string();
143 let resp = client
144 .request(parts.method, &url)
145 .headers(parts.headers)
146 .body(collected.to_bytes())
147 .send()
148 .await
149 .map_err(reqwest_err)?;
150
151 let converted: Response<reqwest::Body> = resp.into();
153 let (parts, body) = converted.into_parts();
154 let body = body.map_err(reqwest_err).boxed_unsync();
155 let mut response = Response::from_parts(parts, body);
156
157 let headers = response.headers_mut();
159 for header in &FORBIDDEN_HEADERS {
160 headers.remove(header);
161 }
162
163 Ok((response, fut))
164 })
165 }
166}
167
168fn internal_err(e: impl Display) -> ErrorCode {
169 ErrorCode::InternalError(Some(e.to_string()))
170}
171
172#[allow(clippy::needless_pass_by_value)]
173fn reqwest_err(e: reqwest::Error) -> ErrorCode {
174 if e.is_timeout() {
175 ErrorCode::ConnectionTimeout
176 } else if e.is_connect() {
177 ErrorCode::ConnectionRefused
178 } else if e.is_request() {
179 ErrorCode::HttpRequestUriInvalid
180 } else {
181 internal_err(e)
182 }
183}
184
185#[cfg(test)]
186mod tests {
187 use std::pin::Pin;
188
189 use http::header::{AUTHORIZATION, CONTENT_TYPE};
190 use http::{Method, StatusCode};
191 use http_body_util::{Empty, Full};
192 use p3::WasiHttpHooks;
193 use wiremock::matchers::{body_string, header, method};
194 use wiremock::{Mock, MockServer, ResponseTemplate};
195
196 use super::*;
197
198 async fn test_client() -> HttpDefault {
199 let options = ConnectOptions {
200 addr: String::new(),
201 connect_timeout_secs: 10,
202 };
203 HttpDefault::connect_with(options).await.unwrap()
204 }
205
206 #[tokio::test]
207 async fn multiple_host_headers() {
208 let server = MockServer::start().await;
209 Mock::given(method("GET"))
210 .respond_with(ResponseTemplate::new(200).set_body_string("Hello, World!"))
211 .mount(&server)
212 .await;
213
214 let request = Request::get(server.uri())
215 .header(HOST, "localhost-1")
216 .header(HOST, "localhost-2")
217 .body(Empty::new().map_err(internal_err).boxed_unsync())
218 .unwrap();
219
220 let result = test_client().await.handle(request).await;
221 assert!(result.is_ok());
222
223 let (response, _) = result.unwrap();
225 assert_eq!(response.status(), StatusCode::OK);
226
227 let body = response.into_body().collect().await.unwrap().to_bytes();
229 assert_eq!(body, Bytes::from("Hello, World!"));
230
231 let requests = server.received_requests().await.expect("should have requests");
233 assert_eq!(requests.len(), 1);
234
235 assert_eq!(requests[0].headers.get_all(HOST).iter().count(), 1);
236 assert!(requests[0].headers.get(HOST).unwrap().to_str().unwrap().starts_with("127.0.0.1:"));
237 }
238
239 #[tokio::test]
240 async fn post_with_body() {
241 let server = MockServer::start().await;
242 Mock::given(method("POST"))
243 .and(body_string("test body"))
244 .respond_with(ResponseTemplate::new(201).set_body_string("Created"))
245 .mount(&server)
246 .await;
247
248 let request = Request::post(server.uri())
249 .body(Full::new(Bytes::from("test body")).map_err(internal_err).boxed_unsync())
250 .unwrap();
251
252 let result = test_client().await.handle(request).await;
253 assert!(result.is_ok());
254
255 let (response, _) = result.unwrap();
256 assert_eq!(response.status(), StatusCode::CREATED);
257 }
258
259 #[tokio::test]
260 async fn custom_headers() {
261 let server = MockServer::start().await;
262 Mock::given(method("GET"))
263 .and(header("X-Custom-Header", "custom-value"))
264 .and(header(AUTHORIZATION, "Bearer token123"))
265 .respond_with(ResponseTemplate::new(200))
266 .mount(&server)
267 .await;
268
269 let request = Request::get(server.uri())
270 .header("X-Custom-Header", "custom-value")
271 .header(AUTHORIZATION, "Bearer token123")
272 .body(Empty::new().map_err(internal_err).boxed_unsync())
273 .unwrap();
274
275 let result = test_client().await.handle(request).await;
276 assert!(result.is_ok());
277
278 let (response, _) = result.unwrap();
279 assert_eq!(response.status(), StatusCode::OK);
280 }
281
282 #[tokio::test]
283 async fn permitted_headers() {
284 let server = MockServer::start().await;
285 Mock::given(method("GET"))
286 .respond_with(
287 ResponseTemplate::new(200)
288 .insert_header(CONNECTION, "keep-alive")
289 .insert_header(TRANSFER_ENCODING, "chunked")
290 .insert_header(UPGRADE, "websocket")
291 .insert_header(CONTENT_TYPE, "application/json")
292 .insert_header("X-Safe-Header", "safe-value"),
293 )
294 .mount(&server)
295 .await;
296
297 let request = Request::get(server.uri())
298 .body(Empty::new().map_err(internal_err).boxed_unsync())
299 .unwrap();
300
301 let result = test_client().await.handle(request).await;
302 assert!(result.is_ok());
303
304 let (response, _) = result.unwrap();
306 let headers = response.headers();
307
308 assert_eq!(headers.get(CONTENT_TYPE).unwrap().to_str().unwrap(), "application/json");
310 assert_eq!(headers.get("X-Safe-Header").unwrap().to_str().unwrap(), "safe-value");
311
312 assert!(!headers.contains_key(CONNECTION));
314 assert!(!headers.contains_key(TRANSFER_ENCODING));
315 assert!(!headers.contains_key(UPGRADE));
316 }
317
318 #[tokio::test]
319 async fn invalid_uri() {
320 let body = Full::new(Bytes::from("")).map_err(internal_err).boxed_unsync();
321 let request =
322 Request::builder().method(Method::GET).uri("not-a-valid-uri").body(body).unwrap();
323
324 let result = test_client().await.handle(request).await;
325 assert!(result.is_err());
326 }
327
328 #[tokio::test]
329 async fn connection_refused() {
330 let request = Request::get("http://localhost:59999/test")
331 .body(Empty::new().map_err(internal_err).boxed_unsync())
332 .unwrap();
333
334 let result = test_client().await.handle(request).await;
335 assert!(result.is_err());
336 }
337
338 #[tokio::test]
339 async fn client_cert_base64() {
340 let server = MockServer::start().await;
341 Mock::given(method("GET")).respond_with(ResponseTemplate::new(200)).mount(&server).await;
342
343 let request = Request::get(server.uri())
344 .header("Client-Cert", "not-valid-base64!!!")
345 .body(Empty::new().map_err(internal_err).boxed_unsync())
346 .unwrap();
347
348 let result = test_client().await.handle(request).await;
349 assert!(result.is_err());
350 }
351
352 #[tokio::test]
353 async fn client_cert_pem() {
354 let server = MockServer::start().await;
355 Mock::given(method("GET")).respond_with(ResponseTemplate::new(200)).mount(&server).await;
356
357 let invalid_pem = "invalid pem content";
358 let encoded = Base64::encode_string(invalid_pem.as_bytes());
359 let request = Request::get(server.uri())
360 .header("Client-Cert", encoded)
361 .body(Empty::new().map_err(internal_err).boxed_unsync())
362 .unwrap();
363
364 let result = test_client().await.handle(request).await;
365 assert!(result.is_err());
366 }
367
368 impl HttpDefault {
369 async fn handle(
370 &mut self, request: Request<UnsyncBoxBody<Bytes, ErrorCode>>,
371 ) -> HttpResult<(Response<UnsyncBoxBody<Bytes, ErrorCode>>, FutureResult<()>)> {
372 let boxed = self.hooks.send_request(request, None, Box::new(async { Ok(()) }));
373 Pin::from(boxed).await
374 }
375 }
376}