1use std::collections::HashMap;
2use std::net::SocketAddr;
3use std::sync::Arc;
4use std::time::Duration;
5
6use bytes::Bytes;
7use reqwest::dns::{Addrs, Name, Resolve, Resolving};
8
9use crate::dns::SafeDnsResolver;
10use crate::error::FetchError;
11use crate::policy::FetchPolicy;
12use crate::rate_limit::RateLimiter;
13use crate::url_check::{validate_url, ValidatedUrl};
14
15#[derive(Debug, Clone)]
17pub struct FetchRequest {
18 pub url: String,
19 pub method: String,
20 pub headers: HashMap<String, String>,
21 pub body: Option<Vec<u8>>,
22}
23
24#[derive(Debug, Clone)]
26pub struct FetchResponse {
27 pub status: u16,
28 pub headers: HashMap<String, String>,
29 pub body: Vec<u8>,
30}
31
32struct PinnedResolver {
36 addrs: Vec<SocketAddr>,
37}
38
39impl Resolve for PinnedResolver {
40 fn resolve(&self, _name: Name) -> Resolving {
41 let addrs: Vec<SocketAddr> = self.addrs.clone();
42 Box::pin(async move {
43 let iter: Addrs = Box::new(addrs.into_iter());
44 Ok(iter)
45 })
46 }
47}
48
49pub struct SafeClient {
51 policy: FetchPolicy,
52 dns_resolver: SafeDnsResolver,
53 rate_limiter: RateLimiter,
54}
55
56impl SafeClient {
57 pub fn new(policy: FetchPolicy) -> Self {
58 let dns_resolver = SafeDnsResolver::new(policy.deny_private_ips);
59 let rate_limiter = RateLimiter::new(
60 policy.max_requests_per_minute,
61 policy.max_concurrent_requests,
62 );
63
64 Self {
65 policy,
66 dns_resolver,
67 rate_limiter,
68 }
69 }
70
71 pub async fn fetch(&self, request: FetchRequest) -> Result<FetchResponse, FetchError> {
73 let validated = validate_url(&request.url)?;
74 self.policy.check_scheme(&validated.scheme)?;
75 self.policy.check_domain(&validated.host)?;
76 self.policy.check_method(&request.method)?;
77
78 if let Some(ref body) = request.body {
79 if body.len() > self.policy.max_request_body_bytes {
80 return Err(FetchError::RequestBodyTooLarge {
81 size: body.len(),
82 limit: self.policy.max_request_body_bytes,
83 });
84 }
85 }
86
87 let _permit = self.rate_limiter.acquire(&validated.host).await?;
88
89 let port = validated.url.port_or_known_default().unwrap_or(443);
90 let addrs = self.dns_resolver.resolve(&validated.host, port).await?;
91
92 self.execute_request(&request, &validated, addrs).await
93 }
94
95 fn build_client(&self, addrs: Vec<SocketAddr>) -> Result<reqwest::Client, FetchError> {
96 reqwest::Client::builder()
97 .dns_resolver(Arc::new(PinnedResolver { addrs }))
98 .connect_timeout(Duration::from_millis(self.policy.connect_timeout_ms))
99 .timeout(Duration::from_millis(self.policy.request_timeout_ms))
100 .redirect(reqwest::redirect::Policy::none())
101 .build()
102 .map_err(|e: reqwest::Error| FetchError::HttpError(e.to_string()))
103 }
104
105 async fn execute_request(
106 &self,
107 request: &FetchRequest,
108 validated: &ValidatedUrl,
109 addrs: Vec<SocketAddr>,
110 ) -> Result<FetchResponse, FetchError> {
111 let client = self.build_client(addrs)?;
112
113 let method: http::Method = request
114 .method
115 .parse()
116 .map_err(|_| FetchError::MethodNotAllowed(request.method.clone()))?;
117
118 let mut req_builder = client.request(method, validated.url.as_str());
119
120 for (key, value) in &request.headers {
121 req_builder = req_builder.header(key.as_str(), value.as_str());
122 }
123
124 if let Some(ref body) = request.body {
125 req_builder = req_builder.body(Bytes::from(body.clone()));
126 }
127
128 let mut current_url = validated.url.clone();
129 let mut redirects_followed: u8 = 0;
130 let mut response: reqwest::Response =
131 req_builder.send().await.map_err(classify_reqwest_error)?;
132
133 while response.status().is_redirection() {
134 redirects_followed += 1;
135 if redirects_followed > self.policy.max_redirects {
136 return Err(FetchError::TooManyRedirects {
137 limit: self.policy.max_redirects,
138 });
139 }
140
141 let location = response
142 .headers()
143 .get(http::header::LOCATION)
144 .and_then(|v: &http::HeaderValue| v.to_str().ok())
145 .ok_or_else(|| FetchError::HttpError("redirect without Location header".into()))?
146 .to_string();
147
148 let redirect_url = current_url
149 .join(&location)
150 .map_err(|e| FetchError::InvalidUrl(e.to_string()))?;
151
152 let redirect_validated = validate_url(redirect_url.as_str())?;
153 self.policy.check_scheme(&redirect_validated.scheme)?;
154 self.policy.check_domain(&redirect_validated.host)?;
155
156 let redirect_port = redirect_validated
157 .url
158 .port_or_known_default()
159 .unwrap_or(443);
160 let redirect_addrs = self
161 .dns_resolver
162 .resolve(&redirect_validated.host, redirect_port)
163 .await
164 .map_err(|e| match e {
165 FetchError::PrivateIpBlocked { resolved_ip, .. } => {
166 FetchError::RedirectToPrivateIp {
167 url: redirect_url.to_string(),
168 resolved_ip,
169 }
170 }
171 other => other,
172 })?;
173
174 let redirect_client = self.build_client(redirect_addrs)?;
175
176 current_url = redirect_validated.url.clone();
177 response = redirect_client
178 .get(redirect_validated.url.as_str())
179 .send()
180 .await
181 .map_err(classify_reqwest_error)?;
182 }
183
184 self.read_body_limited(response).await
185 }
186
187 async fn read_body_limited(
188 &self,
189 response: reqwest::Response,
190 ) -> Result<FetchResponse, FetchError> {
191 let status = response.status().as_u16();
192
193 let headers: HashMap<String, String> = response
194 .headers()
195 .iter()
196 .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
197 .collect();
198
199 if let Some(cl) = response.content_length() {
200 if cl as usize > self.policy.max_response_body_bytes {
201 return Err(FetchError::ResponseBodyTooLarge {
202 size: cl as usize,
203 limit: self.policy.max_response_body_bytes,
204 });
205 }
206 }
207
208 let body = response
209 .bytes()
210 .await
211 .map_err(|e| FetchError::HttpError(e.to_string()))?;
212
213 if body.len() > self.policy.max_response_body_bytes {
214 return Err(FetchError::ResponseBodyTooLarge {
215 size: body.len(),
216 limit: self.policy.max_response_body_bytes,
217 });
218 }
219
220 Ok(FetchResponse {
221 status,
222 headers,
223 body: body.to_vec(),
224 })
225 }
226}
227
228fn classify_reqwest_error(e: reqwest::Error) -> FetchError {
229 if e.is_connect() {
230 FetchError::ConnectionTimeout
231 } else if e.is_timeout() {
232 FetchError::RequestTimeout
233 } else {
234 FetchError::HttpError(e.to_string())
235 }
236}