Skip to main content

earl_protocol_http/
executor.rs

1use std::future::Future;
2use std::net::{IpAddr, SocketAddr};
3
4use anyhow::{Context, Result, bail};
5use reqwest::header::{CONTENT_TYPE, COOKIE, HeaderMap, HeaderName, HeaderValue, LOCATION};
6use url::Url;
7
8use earl_core::allowlist::ensure_url_allowed;
9use earl_core::{ExecutionContext, PreparedBody, PreparedMultipartPart, RawExecutionResult};
10
11use crate::PreparedHttpData;
12
13/// Execute a single HTTP request (with redirect following) and return the result.
14pub async fn execute_http_once_with_host_validator<F, Fut>(
15    http_data: &PreparedHttpData,
16    ctx: &ExecutionContext,
17    host_validator: &mut F,
18) -> Result<RawExecutionResult>
19where
20    F: FnMut(Url) -> Fut,
21    Fut: Future<Output = Result<Vec<IpAddr>>>,
22{
23    let mut method = http_data.method.clone();
24    let mut body = http_data.body.clone();
25    let mut url = http_data.url.clone();
26
27    for hop in 0..=ctx.transport.max_redirect_hops {
28        ensure_url_allowed(&url, &ctx.allow_rules)?;
29        let resolved_ips = host_validator(url.clone()).await?;
30        let client = build_http_client(ctx, &url, &resolved_ips)?;
31
32        let request = build_request(
33            &client,
34            &method,
35            &url,
36            &http_data.headers,
37            &http_data.cookies,
38            &http_data.query,
39            &body,
40        )?;
41        let response = request
42            .send()
43            .await
44            .with_context(|| format!("request execution failed for `{}`", url.as_str()))?;
45
46        if response.status().is_redirection() && ctx.transport.follow_redirects {
47            if hop >= ctx.transport.max_redirect_hops {
48                bail!(
49                    "maximum redirect hops reached ({})",
50                    ctx.transport.max_redirect_hops
51                );
52            }
53
54            let location = response
55                .headers()
56                .get(LOCATION)
57                .ok_or_else(|| anyhow::anyhow!("redirect response missing Location header"))?
58                .to_str()
59                .context("redirect Location header is not valid UTF-8")?
60                .to_string();
61
62            let new_url = url
63                .join(&location)
64                .with_context(|| format!("invalid redirect Location `{location}`"))?;
65
66            let status = response.status().as_u16();
67            if status == 303
68                || ((status == 301 || status == 302) && method == reqwest::Method::POST)
69            {
70                method = reqwest::Method::GET;
71                body = PreparedBody::Empty;
72            }
73            url = new_url;
74            continue;
75        }
76
77        let status = response.status().as_u16();
78        let content_type = response
79            .headers()
80            .get(CONTENT_TYPE)
81            .and_then(|v| v.to_str().ok())
82            .map(|v| v.to_string());
83
84        let body_bytes =
85            read_response_body_limited(response, ctx.transport.max_response_bytes).await?;
86
87        return Ok(RawExecutionResult {
88            status,
89            url: url.to_string(),
90            body: body_bytes,
91            content_type,
92        });
93    }
94
95    bail!("redirect handling failed unexpectedly")
96}
97
98fn build_request(
99    client: &reqwest::Client,
100    method: &reqwest::Method,
101    url: &Url,
102    headers: &[(String, String)],
103    cookies: &[(String, String)],
104    query: &[(String, String)],
105    body: &PreparedBody,
106) -> Result<reqwest::RequestBuilder> {
107    let mut builder = client.request(method.clone(), url.clone());
108
109    if !query.is_empty() {
110        builder = builder.query(query);
111    }
112
113    let mut header_map = HeaderMap::new();
114    for (name, value) in headers {
115        let header_name = HeaderName::from_bytes(name.as_bytes())
116            .with_context(|| format!("invalid header name `{name}`"))?;
117        let header_value = HeaderValue::from_str(value)
118            .with_context(|| format!("invalid header value for `{name}`"))?;
119        header_map.append(header_name, header_value);
120    }
121
122    if !cookies.is_empty() {
123        let cookie_value = cookies
124            .iter()
125            .map(|(k, v)| format!("{k}={v}"))
126            .collect::<Vec<_>>()
127            .join("; ");
128        header_map.insert(
129            COOKIE,
130            HeaderValue::from_str(&cookie_value).context("invalid cookie header value")?,
131        );
132    }
133
134    builder = builder.headers(header_map);
135
136    match body {
137        PreparedBody::Empty => {}
138        PreparedBody::Json(value) => {
139            builder = builder.json(value);
140        }
141        PreparedBody::Form(fields) => {
142            builder = builder.form(fields);
143        }
144        PreparedBody::Multipart(parts) => {
145            builder = builder.multipart(build_multipart(parts)?);
146        }
147        PreparedBody::RawBytes {
148            bytes,
149            content_type,
150        } => {
151            if let Some(content_type) = content_type {
152                builder = builder.header(CONTENT_TYPE, content_type);
153            }
154            builder = builder.body(bytes.clone());
155        }
156    }
157
158    Ok(builder)
159}
160
161fn build_http_client(
162    ctx: &ExecutionContext,
163    url: &Url,
164    resolved_ips: &[IpAddr],
165) -> Result<reqwest::Client> {
166    if resolved_ips.is_empty() {
167        bail!("host validation returned no resolved IP addresses");
168    }
169
170    let mut builder = reqwest::Client::builder()
171        .timeout(ctx.transport.timeout)
172        .redirect(reqwest::redirect::Policy::none())
173        .gzip(ctx.transport.compression)
174        .brotli(ctx.transport.compression)
175        .zstd(ctx.transport.compression)
176        .deflate(ctx.transport.compression);
177
178    if let Some(version) = ctx.transport.tls_min_version {
179        builder = builder.min_tls_version(version);
180    }
181
182    if let Some(proxy_url) = &ctx.transport.proxy_url {
183        let proxy = reqwest::Proxy::all(proxy_url)
184            .with_context(|| format!("invalid proxy URL `{proxy_url}`"))?;
185        builder = builder.proxy(proxy);
186    }
187
188    let host = url
189        .host_str()
190        .ok_or_else(|| anyhow::anyhow!("request URL missing host"))?;
191    let port = url
192        .port_or_known_default()
193        .ok_or_else(|| anyhow::anyhow!("request URL missing port"))?;
194
195    if !resolved_ips.is_empty() {
196        let addrs: Vec<SocketAddr> = resolved_ips
197            .iter()
198            .map(|ip| SocketAddr::new(*ip, port))
199            .collect();
200        builder = builder.resolve_to_addrs(host, &addrs);
201    }
202
203    builder
204        .build()
205        .context("failed constructing reqwest client")
206}
207
208async fn read_response_body_limited(
209    mut response: reqwest::Response,
210    limit: usize,
211) -> Result<Vec<u8>> {
212    let mut out = Vec::new();
213    while let Some(chunk) = response.chunk().await? {
214        if out.len().saturating_add(chunk.len()) > limit {
215            bail!("response body exceeded configured max_response_bytes ({limit} bytes)");
216        }
217        out.extend_from_slice(&chunk);
218    }
219    Ok(out)
220}
221
222use earl_core::{ProtocolExecutor, StreamChunk, StreamMeta, StreamingProtocolExecutor};
223use tokio::sync::mpsc;
224
225/// HTTP/GraphQL protocol executor.
226///
227/// Holds a host validator closure used for DNS resolution and SSRF protection.
228pub struct HttpExecutor<F> {
229    pub host_validator: F,
230}
231
232impl<F, Fut> ProtocolExecutor for HttpExecutor<F>
233where
234    F: FnMut(Url) -> Fut + Send,
235    Fut: Future<Output = Result<Vec<IpAddr>>> + Send,
236{
237    type PreparedData = PreparedHttpData;
238
239    async fn execute(
240        &mut self,
241        data: &PreparedHttpData,
242        ctx: &ExecutionContext,
243    ) -> Result<RawExecutionResult> {
244        execute_http_once_with_host_validator(data, ctx, &mut self.host_validator).await
245    }
246}
247
248/// Streaming HTTP executor — sends response chunks as they arrive.
249///
250/// Reuses the same connection setup (redirect following, SSRF validation,
251/// client building) as [`HttpExecutor`] but streams chunks through an
252/// `mpsc::Sender` instead of buffering the entire response body.
253pub struct HttpStreamExecutor<F> {
254    pub host_validator: F,
255}
256
257impl<F, Fut> StreamingProtocolExecutor for HttpStreamExecutor<F>
258where
259    F: FnMut(Url) -> Fut + Send,
260    Fut: Future<Output = Result<Vec<IpAddr>>> + Send,
261{
262    type PreparedData = PreparedHttpData;
263
264    async fn execute_stream(
265        &mut self,
266        data: &PreparedHttpData,
267        ctx: &ExecutionContext,
268        sender: mpsc::Sender<StreamChunk>,
269    ) -> anyhow::Result<StreamMeta> {
270        let mut method = data.method.clone();
271        let mut body = data.body.clone();
272        let mut url = data.url.clone();
273
274        for hop in 0..=ctx.transport.max_redirect_hops {
275            ensure_url_allowed(&url, &ctx.allow_rules)?;
276            let resolved_ips = (self.host_validator)(url.clone()).await?;
277            let client = build_http_client(ctx, &url, &resolved_ips)?;
278
279            let request = build_request(
280                &client,
281                &method,
282                &url,
283                &data.headers,
284                &data.cookies,
285                &data.query,
286                &body,
287            )?;
288            let response = request
289                .send()
290                .await
291                .with_context(|| format!("request execution failed for `{}`", url.as_str()))?;
292
293            if response.status().is_redirection() && ctx.transport.follow_redirects {
294                if hop >= ctx.transport.max_redirect_hops {
295                    bail!(
296                        "maximum redirect hops reached ({})",
297                        ctx.transport.max_redirect_hops
298                    );
299                }
300
301                let location = response
302                    .headers()
303                    .get(LOCATION)
304                    .ok_or_else(|| anyhow::anyhow!("redirect response missing Location header"))?
305                    .to_str()
306                    .context("redirect Location header is not valid UTF-8")?
307                    .to_string();
308
309                let new_url = url
310                    .join(&location)
311                    .with_context(|| format!("invalid redirect Location `{location}`"))?;
312
313                let status = response.status().as_u16();
314                if status == 303
315                    || ((status == 301 || status == 302) && method == reqwest::Method::POST)
316                {
317                    method = reqwest::Method::GET;
318                    body = PreparedBody::Empty;
319                }
320                url = new_url;
321                continue;
322            }
323
324            let status = response.status().as_u16();
325            let content_type = response
326                .headers()
327                .get(CONTENT_TYPE)
328                .and_then(|v| v.to_str().ok())
329                .map(|v| v.to_string());
330
331            // Detect SSE responses so we can parse individual events.
332            let is_sse = content_type
333                .as_deref()
334                .map(|ct| ct.starts_with("text/event-stream"))
335                .unwrap_or(false);
336
337            // Stream chunks instead of buffering the entire response body.
338            let mut response = response;
339            let mut total_bytes = 0usize;
340            let mut sse_parser = if is_sse {
341                Some(crate::sse::SseParser::new())
342            } else {
343                None
344            };
345            // Buffer for incomplete UTF-8 sequences at chunk boundaries (SSE only).
346            let mut utf8_buffer: Vec<u8> = Vec::new();
347            let max = ctx.transport.max_response_bytes;
348
349            while let Some(chunk) = response.chunk().await? {
350                if let Some(parser) = &mut sse_parser {
351                    utf8_buffer.extend_from_slice(&chunk);
352                    // Guard against unbounded buffer growth from invalid
353                    // UTF-8 or a server that never sends event boundaries.
354                    // An incomplete multi-byte sequence is at most 3 trailing
355                    // bytes; anything larger indicates bad data.
356                    if utf8_buffer.len() > max {
357                        bail!(
358                            "streaming response exceeded configured max_response_bytes ({max} bytes)"
359                        );
360                    }
361                    // Find the last valid UTF-8 boundary — bytes beyond it
362                    // are an incomplete multi-byte character.
363                    let valid_up_to = match std::str::from_utf8(&utf8_buffer) {
364                        Ok(_) => utf8_buffer.len(),
365                        Err(e) => e.valid_up_to(),
366                    };
367                    if valid_up_to == 0 {
368                        // No complete UTF-8 characters yet — wait for more data.
369                        continue;
370                    }
371                    let text = std::str::from_utf8(&utf8_buffer[..valid_up_to])
372                        .expect("validated UTF-8 boundary");
373                    let events = parser.feed(text);
374                    // Keep any incomplete trailing bytes for the next chunk.
375                    utf8_buffer.drain(..valid_up_to);
376                    for event in events {
377                        total_bytes = total_bytes.saturating_add(event.data.len());
378                        if total_bytes > max {
379                            bail!(
380                                "streaming response exceeded configured max_response_bytes ({max} bytes)"
381                            );
382                        }
383                        if sender
384                            .send(StreamChunk {
385                                data: event.data.into_bytes(),
386                                // SSE event data is extracted content — not
387                                // text/event-stream.  Leave content_type as
388                                // None so decode="auto" can probe the data.
389                                content_type: None,
390                            })
391                            .await
392                            .is_err()
393                        {
394                            return Ok(StreamMeta {
395                                status,
396                                url: url.to_string(),
397                            });
398                        }
399                    }
400                } else {
401                    total_bytes = total_bytes.saturating_add(chunk.len());
402                    if total_bytes > ctx.transport.max_response_bytes {
403                        bail!(
404                            "streaming response exceeded configured max_response_bytes ({} bytes)",
405                            ctx.transport.max_response_bytes
406                        );
407                    }
408                    if sender
409                        .send(StreamChunk {
410                            data: chunk.to_vec(),
411                            content_type: content_type.clone(),
412                        })
413                        .await
414                        .is_err()
415                    {
416                        // Receiver dropped — stop streaming gracefully.
417                        break;
418                    }
419                }
420            }
421
422            // Feed remaining UTF-8 bytes and flush trailing SSE event.
423            if let Some(mut parser) = sse_parser {
424                if let Ok(text) = std::str::from_utf8(&utf8_buffer)
425                    && !text.is_empty()
426                {
427                    for event in parser.feed(text) {
428                        total_bytes = total_bytes.saturating_add(event.data.len());
429                        if total_bytes > max {
430                            bail!(
431                                "streaming response exceeded configured max_response_bytes ({max} bytes)"
432                            );
433                        }
434                        let _ = sender
435                            .send(StreamChunk {
436                                data: event.data.into_bytes(),
437                                content_type: None,
438                            })
439                            .await;
440                    }
441                }
442
443                if let Some(event) = parser.flush() {
444                    total_bytes = total_bytes.saturating_add(event.data.len());
445                    if total_bytes > max {
446                        bail!(
447                            "streaming response exceeded configured max_response_bytes ({max} bytes)"
448                        );
449                    }
450                    let _ = sender
451                        .send(StreamChunk {
452                            data: event.data.into_bytes(),
453                            content_type: None,
454                        })
455                        .await;
456                }
457            }
458
459            return Ok(StreamMeta {
460                status,
461                url: url.to_string(),
462            });
463        }
464
465        bail!("redirect handling failed unexpectedly")
466    }
467}
468
469fn build_multipart(parts: &[PreparedMultipartPart]) -> Result<reqwest::multipart::Form> {
470    let mut form = reqwest::multipart::Form::new();
471    for part in parts {
472        let mut req_part = reqwest::multipart::Part::bytes(part.bytes.clone());
473        if let Some(content_type) = &part.content_type {
474            req_part = req_part.mime_str(content_type)?;
475        }
476        if let Some(filename) = &part.filename {
477            req_part = req_part.file_name(filename.clone());
478        }
479        form = form.part(part.name.clone(), req_part);
480    }
481    Ok(form)
482}