Skip to main content

agent_fetch/
client.rs

1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::Duration;
5
6use bytes::Bytes;
7use reqwest::dns::{Addrs, Name, Resolve, Resolving};
8
9use crate::dns::SafeDnsResolver;
10use crate::error::FetchError;
11use crate::policy::FetchPolicy;
12use crate::rate_limit::RateLimiter;
13use crate::url_check::{validate_url, ValidatedUrl};
14
15/// A request to be executed by the safe client.
16#[derive(Debug, Clone)]
17pub struct FetchRequest {
18    pub url: String,
19    pub method: String,
20    pub headers: HashMap<String, String>,
21    pub body: Option<Vec<u8>>,
22}
23
24/// The response returned by the safe client.
25#[derive(Debug, Clone)]
26pub struct FetchResponse {
27    pub status: u16,
28    pub headers: HashMap<String, String>,
29    pub body: Vec<u8>,
30}
31
32/// Custom DNS resolver that pins connections to pre-validated IP addresses.
33/// This defeats DNS rebinding attacks by resolving once through our safe resolver
34/// and then feeding those addresses to reqwest.
35struct PinnedResolver {
36    addrs: Vec<SocketAddr>,
37}
38
39impl Resolve for PinnedResolver {
40    fn resolve(&self, _name: Name) -> Resolving {
41        let addrs: Vec<SocketAddr> = self.addrs.clone();
42        Box::pin(async move {
43            let iter: Addrs = Box::new(addrs.into_iter());
44            Ok(iter)
45        })
46    }
47}
48
49/// The safe HTTP client that enforces all policies.
50pub struct SafeClient {
51    policy: FetchPolicy,
52    dns_resolver: SafeDnsResolver,
53    rate_limiter: RateLimiter,
54}
55
56impl SafeClient {
57    pub fn new(policy: FetchPolicy) -> Self {
58        let dns_resolver = SafeDnsResolver::new(policy.deny_private_ips);
59        let rate_limiter = RateLimiter::new(
60            policy.max_requests_per_minute,
61            policy.max_concurrent_requests,
62        );
63
64        Self {
65            policy,
66            dns_resolver,
67            rate_limiter,
68        }
69    }
70
71    /// Execute a fetch request through the full validation pipeline.
72    pub async fn fetch(&self, request: FetchRequest) -> Result<FetchResponse, FetchError> {
73        let validated = validate_url(&request.url)?;
74        self.policy.check_scheme(&validated.scheme)?;
75        self.policy.check_domain(&validated.host)?;
76        self.policy.check_method(&request.method)?;
77
78        if let Some(ref body) = request.body {
79            if body.len() > self.policy.max_request_body_bytes {
80                return Err(FetchError::RequestBodyTooLarge {
81                    size: body.len(),
82                    limit: self.policy.max_request_body_bytes,
83                });
84            }
85        }
86
87        let _permit = self.rate_limiter.acquire(&validated.host).await?;
88
89        let port = validated.url.port_or_known_default().unwrap_or(443);
90        let addrs = self.dns_resolver.resolve(&validated.host, port).await?;
91
92        self.execute_request(&request, &validated, addrs).await
93    }
94
95    fn build_client(&self, addrs: Vec<SocketAddr>) -> Result<reqwest::Client, FetchError> {
96        reqwest::Client::builder()
97            .dns_resolver(Arc::new(PinnedResolver { addrs }))
98            .connect_timeout(Duration::from_millis(self.policy.connect_timeout_ms))
99            .timeout(Duration::from_millis(self.policy.request_timeout_ms))
100            .redirect(reqwest::redirect::Policy::none())
101            .build()
102            .map_err(|e: reqwest::Error| FetchError::HttpError(e.to_string()))
103    }
104
105    async fn execute_request(
106        &self,
107        request: &FetchRequest,
108        validated: &ValidatedUrl,
109        addrs: Vec<SocketAddr>,
110    ) -> Result<FetchResponse, FetchError> {
111        let client = self.build_client(addrs)?;
112
113        let method: http::Method = request
114            .method
115            .parse()
116            .map_err(|_| FetchError::MethodNotAllowed(request.method.clone()))?;
117
118        let mut req_builder = client.request(method, validated.url.as_str());
119
120        for (key, value) in &request.headers {
121            req_builder = req_builder.header(key.as_str(), value.as_str());
122        }
123
124        if let Some(ref body) = request.body {
125            req_builder = req_builder.body(Bytes::from(body.clone()));
126        }
127
128        let mut current_url = validated.url.clone();
129        let mut redirects_followed: u8 = 0;
130        let mut response: reqwest::Response =
131            req_builder.send().await.map_err(classify_reqwest_error)?;
132
133        while response.status().is_redirection() {
134            redirects_followed += 1;
135            if redirects_followed > self.policy.max_redirects {
136                return Err(FetchError::TooManyRedirects {
137                    limit: self.policy.max_redirects,
138                });
139            }
140
141            let location = response
142                .headers()
143                .get(http::header::LOCATION)
144                .and_then(|v: &http::HeaderValue| v.to_str().ok())
145                .ok_or_else(|| FetchError::HttpError("redirect without Location header".into()))?
146                .to_string();
147
148            let redirect_url = current_url
149                .join(&location)
150                .map_err(|e| FetchError::InvalidUrl(e.to_string()))?;
151
152            let redirect_validated = validate_url(redirect_url.as_str())?;
153            self.policy.check_scheme(&redirect_validated.scheme)?;
154            self.policy.check_domain(&redirect_validated.host)?;
155
156            let redirect_port = redirect_validated
157                .url
158                .port_or_known_default()
159                .unwrap_or(443);
160            let redirect_addrs = self
161                .dns_resolver
162                .resolve(&redirect_validated.host, redirect_port)
163                .await
164                .map_err(|e| match e {
165                    FetchError::PrivateIpBlocked { resolved_ip, .. } => {
166                        FetchError::RedirectToPrivateIp {
167                            url: redirect_url.to_string(),
168                            resolved_ip,
169                        }
170                    }
171                    other => other,
172                })?;
173
174            let redirect_client = self.build_client(redirect_addrs)?;
175
176            current_url = redirect_validated.url.clone();
177            response = redirect_client
178                .get(redirect_validated.url.as_str())
179                .send()
180                .await
181                .map_err(classify_reqwest_error)?;
182        }
183
184        self.read_body_limited(response).await
185    }
186
187    async fn read_body_limited(
188        &self,
189        response: reqwest::Response,
190    ) -> Result<FetchResponse, FetchError> {
191        let status = response.status().as_u16();
192
193        let headers: HashMap<String, String> = response
194            .headers()
195            .iter()
196            .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
197            .collect();
198
199        if let Some(cl) = response.content_length() {
200            if cl as usize > self.policy.max_response_body_bytes {
201                return Err(FetchError::ResponseBodyTooLarge {
202                    size: cl as usize,
203                    limit: self.policy.max_response_body_bytes,
204                });
205            }
206        }
207
208        let body = response
209            .bytes()
210            .await
211            .map_err(|e| FetchError::HttpError(e.to_string()))?;
212
213        if body.len() > self.policy.max_response_body_bytes {
214            return Err(FetchError::ResponseBodyTooLarge {
215                size: body.len(),
216                limit: self.policy.max_response_body_bytes,
217            });
218        }
219
220        Ok(FetchResponse {
221            status,
222            headers,
223            body: body.to_vec(),
224        })
225    }
226}
227
228fn classify_reqwest_error(e: reqwest::Error) -> FetchError {
229    if e.is_connect() {
230        FetchError::ConnectionTimeout
231    } else if e.is_timeout() {
232        FetchError::RequestTimeout
233    } else {
234        FetchError::HttpError(e.to_string())
235    }
236}