Skip to main content

earl_protocol_http/
executor.rs

1use std::future::Future;
2use std::net::{IpAddr, SocketAddr};
3
4use anyhow::{Context, Result, bail};
5use reqwest::header::{CONTENT_TYPE, COOKIE, HeaderMap, HeaderName, HeaderValue, LOCATION};
6use url::Url;
7
8use earl_core::allowlist::ensure_url_allowed;
9use earl_core::{ExecutionContext, PreparedBody, PreparedMultipartPart, RawExecutionResult};
10
11use crate::PreparedHttpData;
12
13/// Execute a single HTTP request (with redirect following) and return the result.
14pub async fn execute_http_once_with_host_validator<F, Fut>(
15    http_data: &PreparedHttpData,
16    ctx: &ExecutionContext,
17    host_validator: &mut F,
18) -> Result<RawExecutionResult>
19where
20    F: FnMut(Url) -> Fut,
21    Fut: Future<Output = Result<Vec<IpAddr>>>,
22{
23    let mut method = http_data.method.clone();
24    let mut body = http_data.body.clone();
25    let mut url = http_data.url.clone();
26
27    for hop in 0..=ctx.transport.max_redirect_hops {
28        ensure_url_allowed(&url, &ctx.allow_rules)?;
29        let resolved_ips = host_validator(url.clone()).await?;
30        let client = build_http_client(ctx, &url, &resolved_ips)?;
31
32        let request = build_request(
33            &client,
34            &method,
35            &url,
36            &http_data.headers,
37            &http_data.cookies,
38            &http_data.query,
39            &body,
40        )?;
41        let response = request
42            .send()
43            .await
44            .with_context(|| format!("request execution failed for `{}`", url.as_str()))?;
45
46        if response.status().is_redirection() && ctx.transport.follow_redirects {
47            if hop >= ctx.transport.max_redirect_hops {
48                bail!(
49                    "maximum redirect hops reached ({})",
50                    ctx.transport.max_redirect_hops
51                );
52            }
53
54            let location = response
55                .headers()
56                .get(LOCATION)
57                .ok_or_else(|| anyhow::anyhow!("redirect response missing Location header"))?
58                .to_str()
59                .context("redirect Location header is not valid UTF-8")?
60                .to_string();
61
62            let new_url = url
63                .join(&location)
64                .with_context(|| format!("invalid redirect Location `{location}`"))?;
65
66            let status = response.status().as_u16();
67            if status == 303
68                || ((status == 301 || status == 302) && method == reqwest::Method::POST)
69            {
70                method = reqwest::Method::GET;
71                body = PreparedBody::Empty;
72            }
73            url = new_url;
74            continue;
75        }
76
77        let status = response.status().as_u16();
78        let content_type = response
79            .headers()
80            .get(CONTENT_TYPE)
81            .and_then(|v| v.to_str().ok())
82            .map(|v| v.to_string());
83
84        let body_bytes =
85            read_response_body_limited(response, ctx.transport.max_response_bytes).await?;
86
87        return Ok(RawExecutionResult {
88            status,
89            url: url.to_string(),
90            body: body_bytes,
91            content_type,
92        });
93    }
94
95    bail!("redirect handling failed unexpectedly")
96}
97
98fn build_request(
99    client: &reqwest::Client,
100    method: &reqwest::Method,
101    url: &Url,
102    headers: &[(String, String)],
103    cookies: &[(String, String)],
104    query: &[(String, String)],
105    body: &PreparedBody,
106) -> Result<reqwest::RequestBuilder> {
107    let mut builder = client.request(method.clone(), url.clone());
108
109    if !query.is_empty() {
110        builder = builder.query(query);
111    }
112
113    let mut header_map = HeaderMap::new();
114    for (name, value) in headers {
115        let header_name = HeaderName::from_bytes(name.as_bytes())
116            .with_context(|| format!("invalid header name `{name}`"))?;
117        let header_value = HeaderValue::from_str(value)
118            .with_context(|| format!("invalid header value for `{name}`"))?;
119        header_map.append(header_name, header_value);
120    }
121
122    if !cookies.is_empty() {
123        let cookie_value = cookies
124            .iter()
125            .map(|(k, v)| format!("{k}={v}"))
126            .collect::<Vec<_>>()
127            .join("; ");
128        header_map.insert(
129            COOKIE,
130            HeaderValue::from_str(&cookie_value).context("invalid cookie header value")?,
131        );
132    }
133
134    builder = builder.headers(header_map);
135
136    match body {
137        PreparedBody::Empty => {}
138        PreparedBody::Json(value) => {
139            builder = builder.json(value);
140        }
141        PreparedBody::Form(fields) => {
142            builder = builder.form(fields);
143        }
144        PreparedBody::Multipart(parts) => {
145            builder = builder.multipart(build_multipart(parts)?);
146        }
147        PreparedBody::RawBytes {
148            bytes,
149            content_type,
150        } => {
151            if let Some(content_type) = content_type {
152                builder = builder.header(CONTENT_TYPE, content_type);
153            }
154            builder = builder.body(bytes.clone());
155        }
156    }
157
158    Ok(builder)
159}
160
161fn build_http_client(
162    ctx: &ExecutionContext,
163    url: &Url,
164    resolved_ips: &[IpAddr],
165) -> Result<reqwest::Client> {
166    if resolved_ips.is_empty() {
167        bail!("host validation returned no resolved IP addresses");
168    }
169
170    let mut builder = reqwest::Client::builder()
171        .timeout(ctx.transport.timeout)
172        .redirect(reqwest::redirect::Policy::none())
173        .gzip(ctx.transport.compression)
174        .brotli(ctx.transport.compression)
175        .zstd(ctx.transport.compression)
176        .deflate(ctx.transport.compression);
177
178    if let Some(version) = ctx.transport.tls_min_version {
179        builder = builder.min_tls_version(version);
180    }
181
182    if let Some(proxy_url) = &ctx.transport.proxy_url {
183        let proxy = reqwest::Proxy::all(proxy_url)
184            .with_context(|| format!("invalid proxy URL `{proxy_url}`"))?;
185        builder = builder.proxy(proxy);
186    }
187
188    let host = url
189        .host_str()
190        .ok_or_else(|| anyhow::anyhow!("request URL missing host"))?;
191    let port = url
192        .port_or_known_default()
193        .ok_or_else(|| anyhow::anyhow!("request URL missing port"))?;
194
195    if !resolved_ips.is_empty() {
196        let addrs: Vec<SocketAddr> = resolved_ips
197            .iter()
198            .map(|ip| SocketAddr::new(*ip, port))
199            .collect();
200        builder = builder.resolve_to_addrs(host, &addrs);
201    }
202
203    builder
204        .build()
205        .context("failed constructing reqwest client")
206}
207
208async fn read_response_body_limited(
209    mut response: reqwest::Response,
210    limit: usize,
211) -> Result<Vec<u8>> {
212    let mut out = Vec::new();
213    while let Some(chunk) = response.chunk().await? {
214        if out.len().saturating_add(chunk.len()) > limit {
215            bail!("response body exceeded configured max_response_bytes ({limit} bytes)");
216        }
217        out.extend_from_slice(&chunk);
218    }
219    Ok(out)
220}
221
222use earl_core::ProtocolExecutor;
223
224/// HTTP/GraphQL protocol executor.
225///
226/// Holds a host validator closure used for DNS resolution and SSRF protection.
227pub struct HttpExecutor<F> {
228    pub host_validator: F,
229}
230
231impl<F, Fut> ProtocolExecutor for HttpExecutor<F>
232where
233    F: FnMut(Url) -> Fut + Send,
234    Fut: Future<Output = Result<Vec<IpAddr>>> + Send,
235{
236    type PreparedData = PreparedHttpData;
237
238    async fn execute(
239        &mut self,
240        data: &PreparedHttpData,
241        ctx: &ExecutionContext,
242    ) -> Result<RawExecutionResult> {
243        execute_http_once_with_host_validator(data, ctx, &mut self.host_validator).await
244    }
245}
246
247fn build_multipart(parts: &[PreparedMultipartPart]) -> Result<reqwest::multipart::Form> {
248    let mut form = reqwest::multipart::Form::new();
249    for part in parts {
250        let mut req_part = reqwest::multipart::Part::bytes(part.bytes.clone());
251        if let Some(content_type) = &part.content_type {
252            req_part = req_part.mime_str(content_type)?;
253        }
254        if let Some(filename) = &part.filename {
255            req_part = req_part.file_name(filename.clone());
256        }
257        form = form.part(part.name.clone(), req_part);
258    }
259    Ok(form)
260}