1use std::future::Future;
2use std::net::{IpAddr, SocketAddr};
3
4use anyhow::{Context, Result, bail};
5use reqwest::header::{CONTENT_TYPE, COOKIE, HeaderMap, HeaderName, HeaderValue, LOCATION};
6use url::Url;
7
8use earl_core::allowlist::ensure_url_allowed;
9use earl_core::{ExecutionContext, PreparedBody, PreparedMultipartPart, RawExecutionResult};
10
11use crate::PreparedHttpData;
12
13pub async fn execute_http_once_with_host_validator<F, Fut>(
15 http_data: &PreparedHttpData,
16 ctx: &ExecutionContext,
17 host_validator: &mut F,
18) -> Result<RawExecutionResult>
19where
20 F: FnMut(Url) -> Fut,
21 Fut: Future<Output = Result<Vec<IpAddr>>>,
22{
23 let mut method = http_data.method.clone();
24 let mut body = http_data.body.clone();
25 let mut url = http_data.url.clone();
26
27 for hop in 0..=ctx.transport.max_redirect_hops {
28 ensure_url_allowed(&url, &ctx.allow_rules)?;
29 let resolved_ips = host_validator(url.clone()).await?;
30 let client = build_http_client(ctx, &url, &resolved_ips)?;
31
32 let request = build_request(
33 &client,
34 &method,
35 &url,
36 &http_data.headers,
37 &http_data.cookies,
38 &http_data.query,
39 &body,
40 )?;
41 let response = request
42 .send()
43 .await
44 .with_context(|| format!("request execution failed for `{}`", url.as_str()))?;
45
46 if response.status().is_redirection() && ctx.transport.follow_redirects {
47 if hop >= ctx.transport.max_redirect_hops {
48 bail!(
49 "maximum redirect hops reached ({})",
50 ctx.transport.max_redirect_hops
51 );
52 }
53
54 let location = response
55 .headers()
56 .get(LOCATION)
57 .ok_or_else(|| anyhow::anyhow!("redirect response missing Location header"))?
58 .to_str()
59 .context("redirect Location header is not valid UTF-8")?
60 .to_string();
61
62 let new_url = url
63 .join(&location)
64 .with_context(|| format!("invalid redirect Location `{location}`"))?;
65
66 let status = response.status().as_u16();
67 if status == 303
68 || ((status == 301 || status == 302) && method == reqwest::Method::POST)
69 {
70 method = reqwest::Method::GET;
71 body = PreparedBody::Empty;
72 }
73 url = new_url;
74 continue;
75 }
76
77 let status = response.status().as_u16();
78 let content_type = response
79 .headers()
80 .get(CONTENT_TYPE)
81 .and_then(|v| v.to_str().ok())
82 .map(|v| v.to_string());
83
84 let body_bytes =
85 read_response_body_limited(response, ctx.transport.max_response_bytes).await?;
86
87 return Ok(RawExecutionResult {
88 status,
89 url: url.to_string(),
90 body: body_bytes,
91 content_type,
92 });
93 }
94
95 bail!("redirect handling failed unexpectedly")
96}
97
98fn build_request(
99 client: &reqwest::Client,
100 method: &reqwest::Method,
101 url: &Url,
102 headers: &[(String, String)],
103 cookies: &[(String, String)],
104 query: &[(String, String)],
105 body: &PreparedBody,
106) -> Result<reqwest::RequestBuilder> {
107 let mut builder = client.request(method.clone(), url.clone());
108
109 if !query.is_empty() {
110 builder = builder.query(query);
111 }
112
113 let mut header_map = HeaderMap::new();
114 for (name, value) in headers {
115 let header_name = HeaderName::from_bytes(name.as_bytes())
116 .with_context(|| format!("invalid header name `{name}`"))?;
117 let header_value = HeaderValue::from_str(value)
118 .with_context(|| format!("invalid header value for `{name}`"))?;
119 header_map.append(header_name, header_value);
120 }
121
122 if !cookies.is_empty() {
123 let cookie_value = cookies
124 .iter()
125 .map(|(k, v)| format!("{k}={v}"))
126 .collect::<Vec<_>>()
127 .join("; ");
128 header_map.insert(
129 COOKIE,
130 HeaderValue::from_str(&cookie_value).context("invalid cookie header value")?,
131 );
132 }
133
134 builder = builder.headers(header_map);
135
136 match body {
137 PreparedBody::Empty => {}
138 PreparedBody::Json(value) => {
139 builder = builder.json(value);
140 }
141 PreparedBody::Form(fields) => {
142 builder = builder.form(fields);
143 }
144 PreparedBody::Multipart(parts) => {
145 builder = builder.multipart(build_multipart(parts)?);
146 }
147 PreparedBody::RawBytes {
148 bytes,
149 content_type,
150 } => {
151 if let Some(content_type) = content_type {
152 builder = builder.header(CONTENT_TYPE, content_type);
153 }
154 builder = builder.body(bytes.clone());
155 }
156 }
157
158 Ok(builder)
159}
160
161fn build_http_client(
162 ctx: &ExecutionContext,
163 url: &Url,
164 resolved_ips: &[IpAddr],
165) -> Result<reqwest::Client> {
166 if resolved_ips.is_empty() {
167 bail!("host validation returned no resolved IP addresses");
168 }
169
170 let mut builder = reqwest::Client::builder()
171 .timeout(ctx.transport.timeout)
172 .redirect(reqwest::redirect::Policy::none())
173 .gzip(ctx.transport.compression)
174 .brotli(ctx.transport.compression)
175 .zstd(ctx.transport.compression)
176 .deflate(ctx.transport.compression);
177
178 if let Some(version) = ctx.transport.tls_min_version {
179 builder = builder.min_tls_version(version);
180 }
181
182 if let Some(proxy_url) = &ctx.transport.proxy_url {
183 let proxy = reqwest::Proxy::all(proxy_url)
184 .with_context(|| format!("invalid proxy URL `{proxy_url}`"))?;
185 builder = builder.proxy(proxy);
186 }
187
188 let host = url
189 .host_str()
190 .ok_or_else(|| anyhow::anyhow!("request URL missing host"))?;
191 let port = url
192 .port_or_known_default()
193 .ok_or_else(|| anyhow::anyhow!("request URL missing port"))?;
194
195 if !resolved_ips.is_empty() {
196 let addrs: Vec<SocketAddr> = resolved_ips
197 .iter()
198 .map(|ip| SocketAddr::new(*ip, port))
199 .collect();
200 builder = builder.resolve_to_addrs(host, &addrs);
201 }
202
203 builder
204 .build()
205 .context("failed constructing reqwest client")
206}
207
208async fn read_response_body_limited(
209 mut response: reqwest::Response,
210 limit: usize,
211) -> Result<Vec<u8>> {
212 let mut out = Vec::new();
213 while let Some(chunk) = response.chunk().await? {
214 if out.len().saturating_add(chunk.len()) > limit {
215 bail!("response body exceeded configured max_response_bytes ({limit} bytes)");
216 }
217 out.extend_from_slice(&chunk);
218 }
219 Ok(out)
220}
221
222use earl_core::{ProtocolExecutor, StreamChunk, StreamMeta, StreamingProtocolExecutor};
223use tokio::sync::mpsc;
224
225pub struct HttpExecutor<F> {
229 pub host_validator: F,
230}
231
232impl<F, Fut> ProtocolExecutor for HttpExecutor<F>
233where
234 F: FnMut(Url) -> Fut + Send,
235 Fut: Future<Output = Result<Vec<IpAddr>>> + Send,
236{
237 type PreparedData = PreparedHttpData;
238
239 async fn execute(
240 &mut self,
241 data: &PreparedHttpData,
242 ctx: &ExecutionContext,
243 ) -> Result<RawExecutionResult> {
244 execute_http_once_with_host_validator(data, ctx, &mut self.host_validator).await
245 }
246}
247
248pub struct HttpStreamExecutor<F> {
254 pub host_validator: F,
255}
256
257impl<F, Fut> StreamingProtocolExecutor for HttpStreamExecutor<F>
258where
259 F: FnMut(Url) -> Fut + Send,
260 Fut: Future<Output = Result<Vec<IpAddr>>> + Send,
261{
262 type PreparedData = PreparedHttpData;
263
264 async fn execute_stream(
265 &mut self,
266 data: &PreparedHttpData,
267 ctx: &ExecutionContext,
268 sender: mpsc::Sender<StreamChunk>,
269 ) -> anyhow::Result<StreamMeta> {
270 let mut method = data.method.clone();
271 let mut body = data.body.clone();
272 let mut url = data.url.clone();
273
274 for hop in 0..=ctx.transport.max_redirect_hops {
275 ensure_url_allowed(&url, &ctx.allow_rules)?;
276 let resolved_ips = (self.host_validator)(url.clone()).await?;
277 let client = build_http_client(ctx, &url, &resolved_ips)?;
278
279 let request = build_request(
280 &client,
281 &method,
282 &url,
283 &data.headers,
284 &data.cookies,
285 &data.query,
286 &body,
287 )?;
288 let response = request
289 .send()
290 .await
291 .with_context(|| format!("request execution failed for `{}`", url.as_str()))?;
292
293 if response.status().is_redirection() && ctx.transport.follow_redirects {
294 if hop >= ctx.transport.max_redirect_hops {
295 bail!(
296 "maximum redirect hops reached ({})",
297 ctx.transport.max_redirect_hops
298 );
299 }
300
301 let location = response
302 .headers()
303 .get(LOCATION)
304 .ok_or_else(|| anyhow::anyhow!("redirect response missing Location header"))?
305 .to_str()
306 .context("redirect Location header is not valid UTF-8")?
307 .to_string();
308
309 let new_url = url
310 .join(&location)
311 .with_context(|| format!("invalid redirect Location `{location}`"))?;
312
313 let status = response.status().as_u16();
314 if status == 303
315 || ((status == 301 || status == 302) && method == reqwest::Method::POST)
316 {
317 method = reqwest::Method::GET;
318 body = PreparedBody::Empty;
319 }
320 url = new_url;
321 continue;
322 }
323
324 let status = response.status().as_u16();
325 let content_type = response
326 .headers()
327 .get(CONTENT_TYPE)
328 .and_then(|v| v.to_str().ok())
329 .map(|v| v.to_string());
330
331 let is_sse = content_type
333 .as_deref()
334 .map(|ct| ct.starts_with("text/event-stream"))
335 .unwrap_or(false);
336
337 let mut response = response;
339 let mut total_bytes = 0usize;
340 let mut sse_parser = if is_sse {
341 Some(crate::sse::SseParser::new())
342 } else {
343 None
344 };
345 let mut utf8_buffer: Vec<u8> = Vec::new();
347 let max = ctx.transport.max_response_bytes;
348
349 while let Some(chunk) = response.chunk().await? {
350 if let Some(parser) = &mut sse_parser {
351 utf8_buffer.extend_from_slice(&chunk);
352 if utf8_buffer.len() > max {
357 bail!(
358 "streaming response exceeded configured max_response_bytes ({max} bytes)"
359 );
360 }
361 let valid_up_to = match std::str::from_utf8(&utf8_buffer) {
364 Ok(_) => utf8_buffer.len(),
365 Err(e) => e.valid_up_to(),
366 };
367 if valid_up_to == 0 {
368 continue;
370 }
371 let text = std::str::from_utf8(&utf8_buffer[..valid_up_to])
372 .expect("validated UTF-8 boundary");
373 let events = parser.feed(text);
374 utf8_buffer.drain(..valid_up_to);
376 for event in events {
377 total_bytes = total_bytes.saturating_add(event.data.len());
378 if total_bytes > max {
379 bail!(
380 "streaming response exceeded configured max_response_bytes ({max} bytes)"
381 );
382 }
383 if sender
384 .send(StreamChunk {
385 data: event.data.into_bytes(),
386 content_type: None,
390 })
391 .await
392 .is_err()
393 {
394 return Ok(StreamMeta {
395 status,
396 url: url.to_string(),
397 });
398 }
399 }
400 } else {
401 total_bytes = total_bytes.saturating_add(chunk.len());
402 if total_bytes > ctx.transport.max_response_bytes {
403 bail!(
404 "streaming response exceeded configured max_response_bytes ({} bytes)",
405 ctx.transport.max_response_bytes
406 );
407 }
408 if sender
409 .send(StreamChunk {
410 data: chunk.to_vec(),
411 content_type: content_type.clone(),
412 })
413 .await
414 .is_err()
415 {
416 break;
418 }
419 }
420 }
421
422 if let Some(mut parser) = sse_parser {
424 if let Ok(text) = std::str::from_utf8(&utf8_buffer)
425 && !text.is_empty()
426 {
427 for event in parser.feed(text) {
428 total_bytes = total_bytes.saturating_add(event.data.len());
429 if total_bytes > max {
430 bail!(
431 "streaming response exceeded configured max_response_bytes ({max} bytes)"
432 );
433 }
434 let _ = sender
435 .send(StreamChunk {
436 data: event.data.into_bytes(),
437 content_type: None,
438 })
439 .await;
440 }
441 }
442
443 if let Some(event) = parser.flush() {
444 total_bytes = total_bytes.saturating_add(event.data.len());
445 if total_bytes > max {
446 bail!(
447 "streaming response exceeded configured max_response_bytes ({max} bytes)"
448 );
449 }
450 let _ = sender
451 .send(StreamChunk {
452 data: event.data.into_bytes(),
453 content_type: None,
454 })
455 .await;
456 }
457 }
458
459 return Ok(StreamMeta {
460 status,
461 url: url.to_string(),
462 });
463 }
464
465 bail!("redirect handling failed unexpectedly")
466 }
467}
468
469fn build_multipart(parts: &[PreparedMultipartPart]) -> Result<reqwest::multipart::Form> {
470 let mut form = reqwest::multipart::Form::new();
471 for part in parts {
472 let mut req_part = reqwest::multipart::Part::bytes(part.bytes.clone());
473 if let Some(content_type) = &part.content_type {
474 req_part = req_part.mime_str(content_type)?;
475 }
476 if let Some(filename) = &part.filename {
477 req_part = req_part.file_name(filename.clone());
478 }
479 form = form.part(part.name.clone(), req_part);
480 }
481 Ok(form)
482}